Skip to content

Commit 3a5953d

Browse files
committed
move more of creation into the scheduler
1 parent d3717b3 commit 3a5953d

8 files changed

Lines changed: 75 additions & 55 deletions

File tree

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
254254
start=plan.start,
255255
end=plan.end,
256256
allow_destructive_snapshots=plan.allow_destructive_models,
257+
selected_snapshot_ids=stage.selected_snapshot_ids,
257258
)
258259
if errors:
259260
raise PlanError("Plan application failed.")

sqlmesh/core/plan/stages.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,14 @@ class BackfillStage:
116116
Args:
117117
snapshot_to_intervals: Intervals to backfill. This collection can be empty in which case no backfill is needed.
118118
This can be useful to report the lack of backfills back to the user.
119+
selected_snapshot_ids: The snapshots to include in the run DAG.
119120
all_snapshots: All snapshots in the plan by name.
120121
deployability_index: Deployability index for this stage.
121122
before_promote: Whether this stage is before the promotion stage.
122123
"""
123124

124125
snapshot_to_intervals: SnapshotToIntervals
126+
selected_snapshot_ids: t.Set[SnapshotId]
125127
all_snapshots: t.Dict[str, Snapshot]
126128
deployability_index: DeployabilityIndex
127129
before_promote: bool = True
@@ -298,26 +300,13 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
298300
stages.append(CreateSnapshotRecordsStage(snapshots=plan.new_snapshots))
299301

300302
snapshots_to_create = self._get_snapshots_to_create(plan, snapshots)
301-
stages.append(
302-
PhysicalLayerSchemaCreationStage(
303-
snapshots=snapshots_to_create, deployability_index=deployability_index
304-
)
305-
)
306-
if not plan.skip_backfill and not plan.empty_backfill:
307-
# If the snapshot is selected for backfill and is not representative, then we assume
308-
# this is a paused forward-only snapshot and we need to make sure a clone has been
309-
# created for it in dev.
310-
filtered_snapshots_to_create = []
311-
for snapshot in snapshots_to_create:
312-
if (
313-
plan.is_selected_for_backfill(snapshot.name)
314-
and snapshot not in snapshots_to_intervals
315-
and snapshot.is_materialized
316-
and not deployability_index.is_representative(snapshot)
317-
):
318-
filtered_snapshots_to_create.append(snapshot)
319-
snapshots_to_create = filtered_snapshots_to_create
320303
if snapshots_to_create:
304+
stages.append(
305+
PhysicalLayerSchemaCreationStage(
306+
snapshots=snapshots_to_create, deployability_index=deployability_index
307+
)
308+
)
309+
if not needs_backfill:
321310
stages.append(
322311
self._get_physical_layer_update_stage(
323312
plan,
@@ -340,6 +329,11 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
340329
stages.append(
341330
BackfillStage(
342331
snapshot_to_intervals=missing_intervals_before_promote,
332+
selected_snapshot_ids={
333+
s_id
334+
for s_id in before_promote_snapshots
335+
if plan.is_selected_for_backfill(s_id.name)
336+
},
343337
all_snapshots=snapshots_by_name,
344338
deployability_index=deployability_index,
345339
)
@@ -349,6 +343,7 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
349343
stages.append(
350344
BackfillStage(
351345
snapshot_to_intervals={},
346+
selected_snapshot_ids=set(),
352347
all_snapshots=snapshots_by_name,
353348
deployability_index=deployability_index,
354349
)
@@ -379,6 +374,11 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
379374
stages.append(
380375
BackfillStage(
381376
snapshot_to_intervals=missing_intervals_after_promote,
377+
selected_snapshot_ids={
378+
s_id
379+
for s_id in after_promote_snapshots
380+
if plan.is_selected_for_backfill(s_id.name)
381+
},
382382
all_snapshots=snapshots_by_name,
383383
deployability_index=deployability_index,
384384
)

sqlmesh/core/scheduler.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,9 @@ def batch_intervals(
332332
merged_intervals: SnapshotToIntervals,
333333
deployability_index: t.Optional[DeployabilityIndex],
334334
environment_naming_info: EnvironmentNamingInfo,
335+
dag: t.Optional[DAG[SnapshotId]] = None,
335336
) -> t.Dict[Snapshot, Intervals]:
336-
dag = snapshots_to_dag(merged_intervals)
337+
dag = dag or snapshots_to_dag(merged_intervals)
337338

338339
snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
339340
snapshot.snapshot_id: (
@@ -413,6 +414,7 @@ def run_merged_intervals(
413414
start: t.Optional[TimeLike] = None,
414415
end: t.Optional[TimeLike] = None,
415416
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
417+
selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None,
416418
run_environment_statements: bool = False,
417419
audit_only: bool = False,
418420
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
@@ -427,14 +429,21 @@ def run_merged_intervals(
427429
start: The start of the run.
428430
end: The end of the run.
429431
allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed.
432+
selected_snapshot_ids: The snapshots to include in the run DAG. If None, all snapshots with missing intervals will be included.
430433
431434
Returns:
432435
A tuple of errors and skipped intervals.
433436
"""
434437
execution_time = execution_time or now_timestamp()
435438

439+
selected_snapshots = [self.snapshots[sid] for sid in (selected_snapshot_ids or set())]
440+
if not selected_snapshots:
441+
selected_snapshots = list(merged_intervals)
442+
443+
snapshot_dag = snapshots_to_dag(selected_snapshots)
444+
436445
batched_intervals = self.batch_intervals(
437-
merged_intervals, deployability_index, environment_naming_info
446+
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
438447
)
439448

440449
self.console.start_evaluation_progress(
@@ -447,11 +456,13 @@ def run_merged_intervals(
447456
snapshots_to_create = {
448457
s.snapshot_id
449458
for s in self.snapshot_evaluator.get_snapshots_to_create(
450-
merged_intervals.keys(), deployability_index
459+
selected_snapshots, deployability_index
451460
)
452461
}
453462

454-
dag = self._dag(batched_intervals, snapshots_to_create=snapshots_to_create)
463+
dag = self._dag(
464+
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
465+
)
455466

456467
if run_environment_statements:
457468
environment_statements = self.state_sync.get_environment_statements(
@@ -575,12 +586,14 @@ def evaluate_node(node: SchedulingUnit) -> None:
575586
def _dag(
576587
self,
577588
batches: SnapshotToIntervals,
589+
snapshot_dag: t.Optional[DAG[SnapshotId]] = None,
578590
snapshots_to_create: t.Optional[t.Set[SnapshotId]] = None,
579591
) -> DAG[SchedulingUnit]:
580592
"""Builds a DAG of snapshot intervals to be evaluated.
581593
582594
Args:
583595
batches: The batches of snapshots and intervals to evaluate.
596+
snapshot_dag: The DAG of all snapshots.
584597
snapshots_to_create: The snapshots with missing physical tables.
585598
586599
Returns:
@@ -591,20 +604,24 @@ def _dag(
591604
snapshot.name: intervals for snapshot, intervals in batches.items()
592605
}
593606
snapshots_to_create = snapshots_to_create or set()
607+
original_snapshots_to_create = snapshots_to_create.copy()
594608

609+
snapshot_dag = snapshot_dag or snapshots_to_dag(batches)
595610
dag = DAG[SchedulingUnit]()
596611

597-
for snapshot, intervals in batches.items():
598-
if not intervals:
599-
continue
612+
for snapshot_id in snapshot_dag:
613+
snapshot = self.snapshots_by_name[snapshot_id.name]
614+
intervals = intervals_per_snapshot.get(snapshot.name, [])
600615

601616
upstream_dependencies: t.List[SchedulingUnit] = []
602617

603618
for p_sid in snapshot.parents:
604619
if p_sid in self.snapshots:
605620
p_intervals = intervals_per_snapshot.get(p_sid.name, [])
606621

607-
if len(p_intervals) > 1:
622+
if not p_intervals and p_sid in original_snapshots_to_create:
623+
upstream_dependencies.append(CreateNode(snapshot_name=p_sid.name))
624+
elif len(p_intervals) > 1:
608625
upstream_dependencies.append(DummyNode(snapshot_name=p_sid.name))
609626
else:
610627
for i, interval in enumerate(p_intervals):
@@ -620,14 +637,16 @@ def _dag(
620637
batch_concurrency = 1
621638

622639
create_node: t.Optional[CreateNode] = None
623-
if snapshot.snapshot_id in snapshots_to_create and (
640+
if snapshot.snapshot_id in original_snapshots_to_create and (
624641
snapshot.is_incremental_by_time_range
625642
or ((not batch_concurrency or batch_concurrency > 1) and batch_size)
643+
or not intervals
626644
):
627645
# Add a separate node for table creation in case when there multiple concurrent
628-
# evaluation nodes.
646+
# evaluation nodes or when there are no intervals to evaluate.
629647
create_node = CreateNode(snapshot_name=snapshot.name)
630648
dag.add(create_node, upstream_dependencies)
649+
snapshots_to_create.remove(snapshot.snapshot_id)
631650

632651
for i, interval in enumerate(intervals):
633652
node = EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=i)

sqlmesh/core/snapshot/definition.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,15 @@ def categorize_as(self, category: SnapshotChangeCategory, forward_only: bool = F
10431043
# If the model has a pinned version then use that.
10441044
self.version = self.model.physical_version
10451045
elif is_no_rebuild and self.previous_version:
1046+
self.version = self.previous_version.data_version.version
1047+
elif self.is_model and self.model.forward_only and not self.previous_version:
1048+
# If this is a new model then use a deterministic version, independent of the fingerprint.
1049+
self.version = hash_data([self.name, *self.model.kind.data_hash_values])
1050+
else:
1051+
self.version = self.fingerprint.to_version()
1052+
1053+
if is_no_rebuild and self.previous_version:
10461054
previous_version = self.previous_version
1047-
self.version = previous_version.data_version.version
10481055
self.physical_schema_ = previous_version.physical_schema
10491056
self.table_naming_convention = previous_version.table_naming_convention
10501057
if self.is_materialized and (category.is_indirect_non_breaking or category.is_metadata):
@@ -1054,11 +1061,6 @@ def categorize_as(self, category: SnapshotChangeCategory, forward_only: bool = F
10541061
or previous_version.fingerprint.to_version()
10551062
)
10561063
self.dev_table_suffix = previous_version.data_version.dev_table_suffix
1057-
elif self.is_model and self.model.forward_only and not self.previous_version:
1058-
# If this is a new model then use a deterministic version, independent of the fingerprint.
1059-
self.version = hash_data([self.name, *self.model.kind.data_hash_values])
1060-
else:
1061-
self.version = self.fingerprint.to_version()
10621064

10631065
self.change_category = category
10641066
self.forward_only = forward_only
@@ -1603,9 +1605,7 @@ def create(
16031605
)
16041606
else:
16051607
children_deployable = False
1606-
if not snapshots[node].is_paused or (
1607-
snapshot.is_indirect_non_breaking and snapshot.intervals
1608-
):
1608+
if not snapshots[node].is_paused:
16091609
representative_shared_version_ids.add(node)
16101610

16111611
deployability_mapping[node] = this_deployable

sqlmesh/core/snapshot/evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -695,9 +695,11 @@ def _evaluate_snapshot(
695695
target_table_name = snapshot.table_name(is_deployable=is_snapshot_deployable)
696696
# https://github.com/TobikoData/sqlmesh/issues/2609
697697
# If there are no existing intervals yet; only consider this a first insert for the first snapshot in the batch
698-
is_first_insert = not _intervals(snapshot, deployability_index) and batch_index == 0
699698
if target_table_exists is None:
700699
target_table_exists = adapter.table_exists(target_table_name)
700+
is_first_insert = (
701+
not _intervals(snapshot, deployability_index) or not target_table_exists
702+
) and batch_index == 0
701703

702704
common_render_kwargs = dict(
703705
start=start,
@@ -749,7 +751,6 @@ def _evaluate_snapshot(
749751
allow_destructive_snapshots=allow_destructive_snapshots,
750752
)
751753
else:
752-
is_first_insert = True
753754
if model.annotated or model.is_seed or model.kind.is_scd_type_2:
754755
self._execute_create(
755756
snapshot=snapshot,

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,11 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory):
19141914
],
19151915
personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()],
19161916
)
1917+
config.before_all = [
1918+
f"CREATE SCHEMA IF NOT EXISTS {raw_test_schema}",
1919+
f"DROP VIEW IF EXISTS {raw_test_schema}.demographics",
1920+
f"CREATE VIEW {raw_test_schema}.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)",
1921+
]
19171922

19181923
# To enable parallelism in integration tests
19191924
config.gateways = {ctx.gateway: config.gateways[ctx.gateway]}
@@ -2132,6 +2137,8 @@ def validate_comments(
21322137
}
21332138

21342139
for model_name, comment in comments.items():
2140+
if not model_name in layer_models:
2141+
continue
21352142
layer_table_name = layer_models[model_name]["table_name"]
21362143
table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE"
21372144

tests/core/test_integration.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5890,7 +5890,7 @@ def get_default_catalog_and_non_tables(
58905890
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
58915891
assert len(prod_views) == 16
58925892
assert len(dev_views) == 16
5893-
assert len(user_default_tables) == 15
5893+
assert len(user_default_tables) == 16
58945894
assert len(non_default_tables) == 0
58955895
assert state_metadata.schemas == ["sqlmesh"]
58965896
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
@@ -5910,7 +5910,7 @@ def get_default_catalog_and_non_tables(
59105910
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
59115911
assert len(prod_views) == 16
59125912
assert len(dev_views) == 32
5913-
assert len(user_default_tables) == 15
5913+
assert len(user_default_tables) == 16
59145914
assert len(non_default_tables) == 0
59155915
assert state_metadata.schemas == ["sqlmesh"]
59165916
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
@@ -5931,7 +5931,7 @@ def get_default_catalog_and_non_tables(
59315931
) = get_default_catalog_and_non_tables(metadata, context.default_catalog)
59325932
assert len(prod_views) == 16
59335933
assert len(dev_views) == 16
5934-
assert len(user_default_tables) == 15
5934+
assert len(user_default_tables) == 16
59355935
assert len(non_default_tables) == 0
59365936
assert state_metadata.schemas == ["sqlmesh"]
59375937
assert {x.sql() for x in state_metadata.qualified_tables}.issuperset(
@@ -6902,17 +6902,7 @@ def plan_with_output(ctx: Context, environment: str):
69026902
assert "New environment `dev` will be created from `prod`" in output.stdout
69036903
assert "Differences from the `prod` environment" in output.stdout
69046904

6905-
assert (
6906-
"""MODEL (
6907-
name test.a,
6908-
+ owner test,
6909-
kind FULL
6910-
)
6911-
SELECT
6912-
- 5 AS col
6913-
+ 10 AS col"""
6914-
in output.stdout
6915-
)
6905+
assert "Directly Modified: test__dev.a" in output.stdout
69166906

69176907
# Case 6: Ensure that target environment and create_from environment are not the same
69186908
output = plan_with_output(ctx, "prod")

tests/utils/test_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def use_terminal_console(func):
8282
def test_wrapper(*args, **kwargs):
8383
orig_console = get_console()
8484
try:
85-
set_console(TerminalConsole())
85+
new_console = TerminalConsole()
86+
new_console.console.no_color = True
87+
set_console(new_console)
8688
func(*args, **kwargs)
8789
finally:
8890
set_console(orig_console)

0 commit comments

Comments
 (0)