Skip to content

Commit 7761dd6

Browse files
committed
Add directly modified and restatement triggers
1 parent c4c56b3 commit 7761dd6

7 files changed

Lines changed: 136 additions & 25 deletions

File tree

sqlmesh/core/console.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress(
38303830
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
38313831
if snapshot_evaluation_triggers.select_snapshot_triggers:
38323832
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833-
3834-
if snapshot_evaluation_triggers:
3835-
if snapshot_evaluation_triggers.auto_restatement_triggers:
3836-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3837-
if snapshot_evaluation_triggers.select_snapshot_triggers:
3838-
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833+
if snapshot_evaluation_triggers.directly_modified_triggers:
3834+
message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}"
3835+
if snapshot_evaluation_triggers.restatement_triggers:
3836+
message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}"
38393837

38403838
if audit_only:
38413839
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2306,7 +2306,7 @@ def check_intervals(
23062306
if select_models:
23072307
selected, _ = self._select_models_for_run(select_models, True, snapshots.values())
23082308
else:
2309-
selected = t.cast(t.Set[str], snapshots.keys())
2309+
selected = set(snapshots.keys())
23102310

23112311
results = {}
23122312
execution_context = self.execution_context(snapshots=snapshots)

sqlmesh/core/plan/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def build(self) -> Plan:
293293
else DeployabilityIndex.all_deployable()
294294
)
295295

296-
restatements = self._build_restatements(
296+
restatements, restatement_triggers = self._build_restatements(
297297
dag,
298298
earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time),
299299
)
@@ -330,6 +330,7 @@ def build(self) -> Plan:
330330
indirectly_modified=indirectly_modified,
331331
deployability_index=deployability_index,
332332
restatements=restatements,
333+
restatement_triggers=restatement_triggers,
333334
start_override_per_model=self._start_override_per_model,
334335
end_override_per_model=end_override_per_model,
335336
selected_models_to_backfill=self._backfill_models,
@@ -352,14 +353,14 @@ def _build_dag(self) -> DAG[SnapshotId]:
352353

353354
def _build_restatements(
354355
self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike
355-
) -> t.Dict[SnapshotId, Interval]:
356+
) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]:
356357
restate_models = self._restate_models
357358
if restate_models == set():
358359
# This is a warning but we print this as error since the Console is lacking API for warnings.
359360
self._console.log_error(
360361
"Provided restated models do not match any models. No models will be included in plan."
361362
)
362-
return {}
363+
return {}, {}
363364

364365
restatements: t.Dict[SnapshotId, Interval] = {}
365366
forward_only_preview_needed = self._forward_only_preview_needed
@@ -383,7 +384,7 @@ def _build_restatements(
383384
is_preview = True
384385

385386
if not restate_models:
386-
return {}
387+
return {}, {}
387388

388389
start = self._start or earliest_interval_start
389390
end = self._end or now()
@@ -393,6 +394,7 @@ def _build_restatements(
393394
if model_fqn not in self._model_fqn_to_snapshot:
394395
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")
395396

397+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
396398
# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
397399
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
398400
for s_id in dag:
@@ -428,6 +430,13 @@ def _build_restatements(
428430
logger.info("Skipping restatement for model '%s'", snapshot.name)
429431
continue
430432

433+
if snapshot.name in restate_models:
434+
restatement_triggers[s_id] = [s_id]
435+
if restating_parents:
436+
restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [
437+
s.snapshot_id for s in restating_parents
438+
]
439+
431440
possible_intervals = {
432441
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
433442
}
@@ -456,7 +465,7 @@ def _build_restatements(
456465

457466
restatements[s_id] = (snapshot_start, snapshot_end)
458467

459-
return restatements
468+
return restatements, restatement_triggers
460469

461470
def _build_directly_and_indirectly_modified(
462471
self, dag: DAG[SnapshotId]

sqlmesh/core/plan/definition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Plan(PydanticModel, frozen=True):
5858

5959
deployability_index: DeployabilityIndex
6060
restatements: t.Dict[SnapshotId, Interval]
61+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
6162
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6263
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6364

@@ -256,6 +257,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
256257
skip_backfill=self.skip_backfill,
257258
empty_backfill=self.empty_backfill,
258259
restatements={s.name: i for s, i in self.restatements.items()},
260+
restatement_triggers=self.restatement_triggers,
259261
is_dev=self.is_dev,
260262
allow_destructive_models=self.allow_destructive_models,
261263
forward_only=self.forward_only,
@@ -298,6 +300,7 @@ class EvaluatablePlan(PydanticModel):
298300
skip_backfill: bool
299301
empty_backfill: bool
300302
restatements: t.Dict[str, Interval]
303+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
301304
is_dev: bool
302305
allow_destructive_models: t.Set[str]
303306
forward_only: bool

sqlmesh/core/plan/evaluator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
SnapshotCreationFailedError,
3838
SnapshotNameVersion,
3939
)
40+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
4041
from sqlmesh.utils import to_snake_case
4142
from sqlmesh.core.state_sync import StateSync
4243
from sqlmesh.utils import CorrelationId
@@ -234,6 +235,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234235
self.console.log_success("SKIP: No model batches to execute")
235236
return
236237

238+
directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
239+
for parent, children in plan.indirectly_modified_snapshots.items():
240+
parent_id = stage.all_snapshots[parent].snapshot_id
241+
directly_modified_triggers[parent_id] = directly_modified_triggers.get(
242+
parent_id, []
243+
) + [parent_id]
244+
for child in children:
245+
directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [
246+
parent_id
247+
]
248+
directly_modified_triggers = {
249+
k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items()
250+
}
251+
snapshot_evaluation_triggers = {
252+
s_id: SnapshotEvaluationTriggers(
253+
directly_modified_triggers=directly_modified_triggers.get(s_id, []),
254+
restatement_triggers=plan.restatement_triggers.get(s_id, []),
255+
)
256+
for s_id in [s.snapshot_id for s in stage.all_snapshots.values()]
257+
}
258+
237259
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
238260
errors, _ = scheduler.run_merged_intervals(
239261
merged_intervals=stage.snapshot_to_intervals,

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ class SnapshotEvaluationTriggers(PydanticModel):
332332
cron_ready: t.Optional[bool] = None
333333
auto_restatement_triggers: t.List[SnapshotId] = []
334334
select_snapshot_triggers: t.List[SnapshotId] = []
335+
directly_modified_triggers: t.List[SnapshotId] = []
336+
restatement_triggers: t.List[SnapshotId] = []
335337

336338

337339
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
from sqlmesh import CustomMaterialization
29+
import sqlmesh
2930
from sqlmesh.cli.project_init import init_example_project
3031
from sqlmesh.core import constants as c
3132
from sqlmesh.core import dialect as d
@@ -1859,26 +1860,97 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18591860
context, plan = init_and_plan_context("examples/sushi")
18601861
context.apply(plan)
18611862

1863+
# modify 3 models
1864+
# - 2 breaking changes for testing plan directly modified triggers
1865+
# - 1 adding an auto-restatement for subsequent `run` test
1866+
marketing = context.get_model("sushi.marketing")
1867+
marketing_kwargs = {
1868+
**marketing.dict(),
1869+
"query": d.parse_one(
1870+
f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1871+
),
1872+
}
1873+
context.upsert_model(SqlModel.parse_obj(marketing_kwargs))
1874+
1875+
customers = context.get_model("sushi.customers")
1876+
customers_kwargs = {
1877+
**customers.dict(),
1878+
"query": d.parse_one(
1879+
f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1880+
),
1881+
}
1882+
context.upsert_model(SqlModel.parse_obj(customers_kwargs))
1883+
18621884
# add auto restatement to orders
1863-
model = context.get_model("sushi.orders")
1864-
kind = {
1865-
**model.kind.dict(),
1885+
orders = context.get_model("sushi.orders")
1886+
orders_kind = {
1887+
**orders.kind.dict(),
18661888
"auto_restatement_cron": "@hourly",
18671889
}
1868-
kwargs = {
1869-
**model.dict(),
1870-
"kind": kind,
1890+
orders_kwargs = {
1891+
**orders.dict(),
1892+
"kind": orders_kind,
18711893
}
1872-
context.upsert_model(PythonModel.parse_obj(kwargs))
1873-
plan = context.plan_builder(skip_tests=True).build()
1874-
context.apply(plan)
1894+
context.upsert_model(PythonModel.parse_obj(orders_kwargs))
18751895

1876-
# Mock run_merged_intervals to capture triggers arg
1877-
scheduler = context.scheduler()
1878-
run_merged_intervals_mock = mocker.patch.object(
1879-
scheduler, "run_merged_intervals", return_value=([], [])
1896+
spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals")
1897+
1898+
context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full())
1899+
1900+
# PLAN: directly modified triggers
1901+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1902+
actual_triggers_name = {
1903+
k.name: sorted([s.name for s in v.directly_modified_triggers])
1904+
for k, v in actual_triggers.items()
1905+
if v.directly_modified_triggers
1906+
}
1907+
marketing_name = '"memory"."sushi"."marketing"'
1908+
customers_name = '"memory"."sushi"."customers"'
1909+
marketing_customers_names = sorted([marketing_name, customers_name])
1910+
children_names = [
1911+
f'"memory"."sushi"."{model}"'
1912+
for model in {
1913+
"waiter_as_customer_by_day",
1914+
"active_customers",
1915+
"count_customers_active",
1916+
"count_customers_inactive",
1917+
}
1918+
]
1919+
assert actual_triggers_name == {
1920+
marketing_name: [marketing_name],
1921+
customers_name: [customers_name],
1922+
**{k: marketing_customers_names for k in children_names},
1923+
}
1924+
1925+
# PLAN: restatement triggers
1926+
spy.reset_mock()
1927+
context.plan(
1928+
restate_models=[
1929+
'"memory"."sushi"."marketing"',
1930+
'"memory"."sushi"."order_items"',
1931+
'"memory"."sushi"."waiter_revenue_by_day"',
1932+
],
1933+
auto_apply=True,
1934+
no_prompts=True,
18801935
)
18811936

1937+
order_items_name = '"memory"."sushi"."order_items"'
1938+
waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"'
1939+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1940+
actual_triggers_name = {
1941+
k.name: sorted([s.name for s in v.restatement_triggers])
1942+
for k, v in actual_triggers.items()
1943+
if v.restatement_triggers
1944+
}
1945+
assert actual_triggers_name == {
1946+
waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name],
1947+
order_items_name: [order_items_name],
1948+
'"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name],
1949+
'"memory"."sushi"."customer_revenue_by_day"': [order_items_name],
1950+
'"memory"."sushi"."customer_revenue_lifetime"': [order_items_name],
1951+
}
1952+
1953+
# RUN: select and auto-restatement triggers
18821954
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
18831955
selected_models = {"top_waiters", "waiter_revenue_by_day"}
18841956
selected_models_auto_upstream = {"order_items", "orders", "items"}
@@ -1889,6 +1961,11 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18891961
f'"memory"."sushi"."{model}"' for model in selected_models
18901962
}
18911963

1964+
scheduler = context.scheduler()
1965+
run_merged_intervals_mock = mocker.patch.object(
1966+
scheduler, "run_merged_intervals", return_value=([], [])
1967+
)
1968+
18921969
with time_machine.travel("2023-01-09 00:00:01 UTC"):
18931970
scheduler.run(
18941971
environment=c.PROD,

0 commit comments

Comments
 (0)