Skip to content

Commit 46dbc3e

Browse files
committed
Add directly modified and restatement triggers
1 parent c4952c6 commit 46dbc3e

7 files changed

Lines changed: 136 additions & 25 deletions

File tree

sqlmesh/core/console.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress(
38303830
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
38313831
if snapshot_evaluation_triggers.select_snapshot_triggers:
38323832
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833-
3834-
if snapshot_evaluation_triggers:
3835-
if snapshot_evaluation_triggers.auto_restatement_triggers:
3836-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3837-
if snapshot_evaluation_triggers.select_snapshot_triggers:
3838-
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833+
if snapshot_evaluation_triggers.directly_modified_triggers:
3834+
message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}"
3835+
if snapshot_evaluation_triggers.restatement_triggers:
3836+
message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}"
38393837

38403838
if audit_only:
38413839
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2309,7 +2309,7 @@ def check_intervals(
23092309
if select_models:
23102310
selected, _ = self._select_models_for_run(select_models, True, snapshots.values())
23112311
else:
2312-
selected = t.cast(t.Set[str], snapshots.keys())
2312+
selected = set(snapshots.keys())
23132313

23142314
results = {}
23152315
execution_context = self.execution_context(snapshots=snapshots)

sqlmesh/core/plan/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def build(self) -> Plan:
293293
else DeployabilityIndex.all_deployable()
294294
)
295295

296-
restatements = self._build_restatements(
296+
restatements, restatement_triggers = self._build_restatements(
297297
dag,
298298
earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time),
299299
)
@@ -330,6 +330,7 @@ def build(self) -> Plan:
330330
indirectly_modified=indirectly_modified,
331331
deployability_index=deployability_index,
332332
restatements=restatements,
333+
restatement_triggers=restatement_triggers,
333334
start_override_per_model=self._start_override_per_model,
334335
end_override_per_model=end_override_per_model,
335336
selected_models_to_backfill=self._backfill_models,
@@ -352,14 +353,14 @@ def _build_dag(self) -> DAG[SnapshotId]:
352353

353354
def _build_restatements(
354355
self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike
355-
) -> t.Dict[SnapshotId, Interval]:
356+
) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]:
356357
restate_models = self._restate_models
357358
if restate_models == set():
358359
# This is a warning but we print this as error since the Console is lacking API for warnings.
359360
self._console.log_error(
360361
"Provided restated models do not match any models. No models will be included in plan."
361362
)
362-
return {}
363+
return {}, {}
363364

364365
restatements: t.Dict[SnapshotId, Interval] = {}
365366
forward_only_preview_needed = self._forward_only_preview_needed
@@ -383,7 +384,7 @@ def _build_restatements(
383384
is_preview = True
384385

385386
if not restate_models:
386-
return {}
387+
return {}, {}
387388

388389
start = self._start or earliest_interval_start
389390
end = self._end or now()
@@ -393,6 +394,7 @@ def _build_restatements(
393394
if model_fqn not in self._model_fqn_to_snapshot:
394395
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")
395396

397+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
396398
# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
397399
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
398400
for s_id in dag:
@@ -428,6 +430,13 @@ def _build_restatements(
428430
logger.info("Skipping restatement for model '%s'", snapshot.name)
429431
continue
430432

433+
if snapshot.name in restate_models:
434+
restatement_triggers[s_id] = [s_id]
435+
if restating_parents:
436+
restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [
437+
s.snapshot_id for s in restating_parents
438+
]
439+
431440
possible_intervals = {
432441
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
433442
}
@@ -456,7 +465,7 @@ def _build_restatements(
456465

457466
restatements[s_id] = (snapshot_start, snapshot_end)
458467

459-
return restatements
468+
return restatements, restatement_triggers
460469

461470
def _build_directly_and_indirectly_modified(
462471
self, dag: DAG[SnapshotId]

sqlmesh/core/plan/definition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Plan(PydanticModel, frozen=True):
5858

5959
deployability_index: DeployabilityIndex
6060
restatements: t.Dict[SnapshotId, Interval]
61+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
6162
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6263
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6364

@@ -256,6 +257,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
256257
skip_backfill=self.skip_backfill,
257258
empty_backfill=self.empty_backfill,
258259
restatements={s.name: i for s, i in self.restatements.items()},
260+
restatement_triggers=self.restatement_triggers,
259261
is_dev=self.is_dev,
260262
allow_destructive_models=self.allow_destructive_models,
261263
forward_only=self.forward_only,
@@ -298,6 +300,7 @@ class EvaluatablePlan(PydanticModel):
298300
skip_backfill: bool
299301
empty_backfill: bool
300302
restatements: t.Dict[str, Interval]
303+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
301304
is_dev: bool
302305
allow_destructive_models: t.Set[str]
303306
forward_only: bool

sqlmesh/core/plan/evaluator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
SnapshotCreationFailedError,
3838
SnapshotNameVersion,
3939
)
40+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
4041
from sqlmesh.utils import to_snake_case
4142
from sqlmesh.core.state_sync import StateSync
4243
from sqlmesh.utils import CorrelationId
@@ -244,6 +245,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
244245
self.console.log_success("SKIP: No model batches to execute")
245246
return
246247

248+
directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
249+
for parent, children in plan.indirectly_modified_snapshots.items():
250+
parent_id = stage.all_snapshots[parent].snapshot_id
251+
directly_modified_triggers[parent_id] = directly_modified_triggers.get(
252+
parent_id, []
253+
) + [parent_id]
254+
for child in children:
255+
directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [
256+
parent_id
257+
]
258+
directly_modified_triggers = {
259+
k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items()
260+
}
261+
snapshot_evaluation_triggers = {
262+
s_id: SnapshotEvaluationTriggers(
263+
directly_modified_triggers=directly_modified_triggers.get(s_id, []),
264+
restatement_triggers=plan.restatement_triggers.get(s_id, []),
265+
)
266+
for s_id in [s.snapshot_id for s in stage.all_snapshots.values()]
267+
}
268+
247269
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
248270
errors, _ = scheduler.run_merged_intervals(
249271
merged_intervals=stage.snapshot_to_intervals,

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ class SnapshotEvaluationTriggers(PydanticModel):
332332
cron_ready: t.Optional[bool] = None
333333
auto_restatement_triggers: t.List[SnapshotId] = []
334334
select_snapshot_triggers: t.List[SnapshotId] = []
335+
directly_modified_triggers: t.List[SnapshotId] = []
336+
restatement_triggers: t.List[SnapshotId] = []
335337

336338

337339
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
from sqlmesh import CustomMaterialization
29+
import sqlmesh
2930
from sqlmesh.cli.project_init import init_example_project
3031
from sqlmesh.core import constants as c
3132
from sqlmesh.core import dialect as d
@@ -1867,26 +1868,97 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18671868
context, plan = init_and_plan_context("examples/sushi")
18681869
context.apply(plan)
18691870

1871+
# modify 3 models
1872+
# - 2 breaking changes for testing plan directly modified triggers
1873+
# - 1 adding an auto-restatement for subsequent `run` test
1874+
marketing = context.get_model("sushi.marketing")
1875+
marketing_kwargs = {
1876+
**marketing.dict(),
1877+
"query": d.parse_one(
1878+
f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1879+
),
1880+
}
1881+
context.upsert_model(SqlModel.parse_obj(marketing_kwargs))
1882+
1883+
customers = context.get_model("sushi.customers")
1884+
customers_kwargs = {
1885+
**customers.dict(),
1886+
"query": d.parse_one(
1887+
f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1888+
),
1889+
}
1890+
context.upsert_model(SqlModel.parse_obj(customers_kwargs))
1891+
18701892
# add auto restatement to orders
1871-
model = context.get_model("sushi.orders")
1872-
kind = {
1873-
**model.kind.dict(),
1893+
orders = context.get_model("sushi.orders")
1894+
orders_kind = {
1895+
**orders.kind.dict(),
18741896
"auto_restatement_cron": "@hourly",
18751897
}
1876-
kwargs = {
1877-
**model.dict(),
1878-
"kind": kind,
1898+
orders_kwargs = {
1899+
**orders.dict(),
1900+
"kind": orders_kind,
18791901
}
1880-
context.upsert_model(PythonModel.parse_obj(kwargs))
1881-
plan = context.plan_builder(skip_tests=True).build()
1882-
context.apply(plan)
1902+
context.upsert_model(PythonModel.parse_obj(orders_kwargs))
18831903

1884-
# Mock run_merged_intervals to capture triggers arg
1885-
scheduler = context.scheduler()
1886-
run_merged_intervals_mock = mocker.patch.object(
1887-
scheduler, "run_merged_intervals", return_value=([], [])
1904+
spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals")
1905+
1906+
context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full())
1907+
1908+
# PLAN: directly modified triggers
1909+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1910+
actual_triggers_name = {
1911+
k.name: sorted([s.name for s in v.directly_modified_triggers])
1912+
for k, v in actual_triggers.items()
1913+
if v.directly_modified_triggers
1914+
}
1915+
marketing_name = '"memory"."sushi"."marketing"'
1916+
customers_name = '"memory"."sushi"."customers"'
1917+
marketing_customers_names = sorted([marketing_name, customers_name])
1918+
children_names = [
1919+
f'"memory"."sushi"."{model}"'
1920+
for model in {
1921+
"waiter_as_customer_by_day",
1922+
"active_customers",
1923+
"count_customers_active",
1924+
"count_customers_inactive",
1925+
}
1926+
]
1927+
assert actual_triggers_name == {
1928+
marketing_name: [marketing_name],
1929+
customers_name: [customers_name],
1930+
**{k: marketing_customers_names for k in children_names},
1931+
}
1932+
1933+
# PLAN: restatement triggers
1934+
spy.reset_mock()
1935+
context.plan(
1936+
restate_models=[
1937+
'"memory"."sushi"."marketing"',
1938+
'"memory"."sushi"."order_items"',
1939+
'"memory"."sushi"."waiter_revenue_by_day"',
1940+
],
1941+
auto_apply=True,
1942+
no_prompts=True,
18881943
)
18891944

1945+
order_items_name = '"memory"."sushi"."order_items"'
1946+
waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"'
1947+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1948+
actual_triggers_name = {
1949+
k.name: sorted([s.name for s in v.restatement_triggers])
1950+
for k, v in actual_triggers.items()
1951+
if v.restatement_triggers
1952+
}
1953+
assert actual_triggers_name == {
1954+
waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name],
1955+
order_items_name: [order_items_name],
1956+
'"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name],
1957+
'"memory"."sushi"."customer_revenue_by_day"': [order_items_name],
1958+
'"memory"."sushi"."customer_revenue_lifetime"': [order_items_name],
1959+
}
1960+
1961+
# RUN: select and auto-restatement triggers
18901962
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
18911963
selected_models = {"top_waiters", "waiter_revenue_by_day"}
18921964
selected_models_auto_upstream = {"order_items", "orders", "items"}
@@ -1897,6 +1969,11 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18971969
f'"memory"."sushi"."{model}"' for model in selected_models
18981970
}
18991971

1972+
scheduler = context.scheduler()
1973+
run_merged_intervals_mock = mocker.patch.object(
1974+
scheduler, "run_merged_intervals", return_value=([], [])
1975+
)
1976+
19001977
with time_machine.travel("2023-01-09 00:00:01 UTC"):
19011978
scheduler.run(
19021979
environment=c.PROD,

0 commit comments

Comments
 (0)