Skip to content

Commit c4952c6

Browse files
committed
Collect selected snapshot triggers
1 parent bbed5ac commit c4952c6

8 files changed

Lines changed: 140 additions & 51 deletions

File tree

sqlmesh/core/console.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3819,7 +3819,17 @@ def update_snapshot_evaluation_progress(
38193819
audit_only: bool = False,
38203820
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
38213821
) -> None:
3822-
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3822+
message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3823+
3824+
if snapshot_evaluation_triggers:
3825+
if snapshot_evaluation_triggers.ignore_cron_flag is not None:
3826+
message += f" | ignore_cron_flag={snapshot_evaluation_triggers.ignore_cron_flag}"
3827+
if snapshot_evaluation_triggers.cron_ready is not None:
3828+
message += f" | cron_ready={snapshot_evaluation_triggers.cron_ready}"
3829+
if snapshot_evaluation_triggers.auto_restatement_triggers:
3830+
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3831+
if snapshot_evaluation_triggers.select_snapshot_triggers:
3832+
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
38233833

38243834
if snapshot_evaluation_triggers:
38253835
if snapshot_evaluation_triggers.auto_restatement_triggers:
@@ -3828,7 +3838,7 @@ def update_snapshot_evaluation_progress(
38283838
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
38293839

38303840
if audit_only:
3831-
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3841+
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
38323842

38333843
self._write(message)
38343844

sqlmesh/core/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,9 +2898,10 @@ def _select_models_for_run(
28982898
dag.add(fqn, model.depends_on)
28992899
model_selector = self._new_selector(models=models, dag=dag)
29002900
result = set(model_selector.expand_model_selections(select_models))
2901-
if not no_auto_upstream:
2902-
result_with_upstream = set(dag.subdag(*result))
2903-
return result, result_with_upstream - result
2901+
if no_auto_upstream:
2902+
return result, set()
2903+
result_with_upstream = set(dag.subdag(*result))
2904+
return result_with_upstream, result_with_upstream - result
29042905

29052906
@cached_property
29062907
def _project_type(self) -> str:

sqlmesh/core/plan/stages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def _missing_intervals(
553553
snapshots_by_name: t.Dict[str, Snapshot],
554554
deployability_index: DeployabilityIndex,
555555
) -> SnapshotToIntervals:
556-
return merged_missing_intervals(
556+
missing_intervals, _ = merged_missing_intervals(
557557
snapshots=snapshots_by_name.values(),
558558
start=plan.start,
559559
end=plan.end,
@@ -568,6 +568,7 @@ def _missing_intervals(
568568
start_override_per_model=plan.start_override_per_model,
569569
end_override_per_model=plan.end_override_per_model,
570570
)
571+
return missing_intervals
571572

572573
def _get_audit_only_snapshots(
573574
self, new_snapshots: t.Dict[SnapshotId, Snapshot]

sqlmesh/core/scheduler.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def merged_missing_intervals(
147147
ignore_cron: bool = False,
148148
end_bounded: bool = False,
149149
selected_snapshots: t.Optional[t.Set[str]] = None,
150-
) -> SnapshotToIntervals:
150+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
151151
"""Find the largest contiguous date interval parameters based only on what is missing.
152152
153153
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -167,8 +167,11 @@ def merged_missing_intervals(
167167
end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback,
168168
allow_partials, and other attributes that could cause the intervals to exceed the target end date.
169169
selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run.
170+
171+
Returns:
172+
A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes).
170173
"""
171-
snapshots_to_intervals = merged_missing_intervals(
174+
snapshots_to_intervals, snapshots_naive_cron_ready = merged_missing_intervals(
172175
snapshots=self.snapshot_per_version.values(),
173176
start=start,
174177
end=end,
@@ -186,7 +189,7 @@ def merged_missing_intervals(
186189
snapshots_to_intervals = {
187190
s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots
188191
}
189-
return snapshots_to_intervals
192+
return snapshots_to_intervals, snapshots_naive_cron_ready
190193

191194
def evaluate(
192195
self,
@@ -755,7 +758,7 @@ def _run_or_audit(
755758
{s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()}
756759
)
757760

758-
merged_intervals = self.merged_missing_intervals(
761+
merged_intervals, snapshots_naive_cron_ready = self.merged_missing_intervals(
759762
start,
760763
end,
761764
execution_time,
@@ -770,9 +773,7 @@ def _run_or_audit(
770773
if not merged_intervals:
771774
return CompletionStatus.NOTHING_TO_DO
772775

773-
merged_intervals_snapshots = {
774-
snapshot.snapshot_id: snapshot for snapshot in merged_intervals.keys()
775-
}
776+
merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals}
776777
select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
777778
if selected_snapshots and selected_snapshots_auto_upstream:
778779
# actually selected snapshots are their own triggers
@@ -788,24 +789,25 @@ def _run_or_audit(
788789
]
789790
}
790791

791-
# trace upstream by reversing dag of all snapshots to evaluate
792-
reversed_intervals_dag = snapshots_to_dag(merged_intervals_snapshots.values()).reversed
793-
for s_id in reversed_intervals_dag:
794-
if s_id not in select_snapshot_triggers:
795-
triggers = []
796-
for parent_s_id in merged_intervals_snapshots[s_id].parents:
797-
triggers.extend(select_snapshot_triggers[parent_s_id])
792+
# trace upstream by walking downstream on reversed dag
793+
reversed_dag = snapshots_to_dag(self.snapshots.values()).reversed
794+
for s_id in reversed_dag:
795+
if s_id in merged_intervals_snapshots:
796+
triggers = select_snapshot_triggers.get(s_id, [])
797+
for parent_s_id in reversed_dag.graph.get(s_id, set()):
798+
triggers.extend(select_snapshot_triggers.get(parent_s_id, []))
798799
select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers))
799800

800801
all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = {
801802
s_id: SnapshotEvaluationTriggers(
802-
ignore_cron=ignore_cron,
803+
ignore_cron_flag=ignore_cron,
804+
cron_ready=s_id in snapshots_naive_cron_ready,
803805
auto_restatement_triggers=auto_restatement_triggers.get(s_id, []),
804806
select_snapshot_triggers=select_snapshot_triggers.get(s_id, []),
805807
)
806808
for s_id in merged_intervals_snapshots
807-
if ignore_cron or s_id in auto_restatement_triggers or s_id in select_snapshot_triggers
808809
}
810+
809811
errors, _ = self.run_merged_intervals(
810812
merged_intervals=merged_intervals,
811813
deployability_index=deployability_index,
@@ -967,7 +969,7 @@ def merged_missing_intervals(
967969
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
968970
ignore_cron: bool = False,
969971
end_bounded: bool = False,
970-
) -> SnapshotToIntervals:
972+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
971973
"""Find the largest contiguous date interval parameters based only on what is missing.
972974
973975
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -1017,7 +1019,7 @@ def compute_interval_params(
10171019
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
10181020
ignore_cron: bool = False,
10191021
end_bounded: bool = False,
1020-
) -> SnapshotToIntervals:
1022+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
10211023
"""Find the largest contiguous date interval parameters based only on what is missing.
10221024
10231025
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -1039,7 +1041,7 @@ def compute_interval_params(
10391041
allow_partials, and other attributes that could cause the intervals to exceed the target end date.
10401042
10411043
Returns:
1042-
A dict containing all snapshots needing to be run with their associated interval params.
1044+
A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes).
10431045
"""
10441046
snapshot_merged_intervals = {}
10451047

@@ -1067,7 +1069,11 @@ def compute_interval_params(
10671069
contiguous_batch.append((next_batch[0][0], next_batch[-1][-1]))
10681070
snapshot_merged_intervals[snapshot] = contiguous_batch
10691071

1070-
return snapshot_merged_intervals
1072+
snapshots_naive_cron_ready = [
1073+
snap.snapshot_id for snap in missing_intervals(snapshots, execution_time=execution_time)
1074+
]
1075+
1076+
return snapshot_merged_intervals, snapshots_naive_cron_ready
10711077

10721078

10731079
def interval_diff(

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,10 @@ def table_name_for_environment(
328328

329329

330330
class SnapshotEvaluationTriggers(PydanticModel):
331-
ignore_cron: bool
331+
ignore_cron_flag: t.Optional[bool] = None
332+
cron_ready: t.Optional[bool] = None
332333
auto_restatement_triggers: t.List[SnapshotId] = []
333334
select_snapshot_triggers: t.List[SnapshotId] = []
334-
directly_modified_triggers: t.List[SnapshotId] = []
335-
manual_restatement_triggers: t.List[SnapshotId] = []
336335

337336

338337
class SnapshotInfoMixin(ModelKindMixin):

0 commit comments

Comments
 (0)