Skip to content

Commit 5f69e08

Browse files
committed
Move steps into a separate module
1 parent 2f09175 commit 5f69e08

2 files changed

Lines changed: 330 additions & 251 deletions

File tree

sqlmesh/core/plan/explainer.py

Lines changed: 27 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,30 @@
11
import abc
22
import typing as t
3+
import logging
34

4-
from dataclasses import dataclass
55
from rich.console import Console as RichConsole
66
from rich.tree import Tree
77
from sqlglot.dialects.dialect import DialectType
88
from sqlmesh.core import constants as c
99
from sqlmesh.core.console import Console, TerminalConsole, get_console
1010
from sqlmesh.core.environment import EnvironmentNamingInfo
1111
from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals
12+
from sqlmesh.core.plan import steps
1213
from sqlmesh.core.plan.evaluator import (
1314
PlanEvaluator,
14-
get_audit_only_snapshots,
15-
get_snapshots_to_create,
1615
)
1716
from sqlmesh.core.state_sync import StateReader
18-
from sqlmesh.core.scheduler import merged_missing_intervals, SnapshotToIntervals
1917
from sqlmesh.core.snapshot.definition import (
20-
DeployabilityIndex,
21-
Snapshot,
2218
SnapshotInfoMixin,
23-
SnapshotTableInfo,
24-
Interval,
2519
)
2620
from sqlmesh.utils import Verbosity, rich as srich
2721
from sqlmesh.utils.date import to_ts
2822
from sqlmesh.utils.errors import SQLMeshError
2923

3024

25+
logger = logging.getLogger(__name__)
26+
27+
3128
class PlanExplainer(PlanEvaluator):
3229
def __init__(
3330
self,
@@ -42,244 +39,16 @@ def __init__(
4239
def evaluate(
4340
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
4441
) -> None:
45-
new_snapshots = {s.snapshot_id: s for s in plan.new_snapshots}
46-
stored_snapshots = self.state_reader.get_snapshots(plan.environment.snapshots)
47-
snapshots = {**new_snapshots, **stored_snapshots}
48-
snapshots_by_name = {s.name: s for s in snapshots.values()}
49-
50-
all_selected_for_backfill_snapshots = {
51-
s.snapshot_id for s in snapshots.values() if plan.is_selected_for_backfill(s.name)
52-
}
53-
54-
deployability_index = DeployabilityIndex.create(snapshots, start=plan.start)
55-
deployability_index_for_creation = deployability_index
56-
if plan.is_dev:
57-
before_promote_snapshots = all_selected_for_backfill_snapshots
58-
after_promote_snapshots = set()
59-
snapshots_with_schema_migration = []
60-
else:
61-
before_promote_snapshots = {
62-
s.snapshot_id
63-
for s in snapshots.values()
64-
if deployability_index.is_representative(s)
65-
and plan.is_selected_for_backfill(s.name)
66-
}
67-
after_promote_snapshots = all_selected_for_backfill_snapshots - before_promote_snapshots
68-
deployability_index = DeployabilityIndex.all_deployable()
69-
70-
snapshots_with_schema_migration = [
71-
s
72-
for s in snapshots.values()
73-
if s.is_paused
74-
and s.is_materialized
75-
and not deployability_index_for_creation.is_representative(s)
76-
]
77-
78-
snapshots_to_intervals = self._missing_intervals(
79-
plan, snapshots_by_name, deployability_index
80-
)
81-
82-
steps: t.List[PlanStep] = []
83-
84-
before_all = [
85-
statement
86-
for environment_statements in plan.environment_statements or []
87-
for statement in environment_statements.before_all
88-
]
89-
if before_all:
90-
steps.append(BeforeAllStep(statements=before_all))
91-
92-
snapshots_to_create = [
93-
s
94-
for s in get_snapshots_to_create(plan, snapshots)
95-
if s in snapshots_to_intervals and s.is_model and not s.is_symbolic
96-
]
97-
if snapshots_to_create:
98-
steps.append(
99-
PhysicalLayerUpdateStep(
100-
snapshots=snapshots_to_create,
101-
deployability_index=deployability_index_for_creation,
102-
)
103-
)
104-
105-
audit_only_snapshots = get_audit_only_snapshots(new_snapshots, self.state_reader)
106-
if audit_only_snapshots:
107-
steps.append(AuditOnlyRunStep(snapshots=list(audit_only_snapshots.values())))
108-
109-
if plan.restatements and not plan.is_dev:
110-
snapshot_intervals_to_restate = {}
111-
for name, interval in plan.restatements.items():
112-
restated_snapshot = snapshots_by_name[name]
113-
restated_snapshot.remove_interval(interval)
114-
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
115-
steps.append(RestatementStep(snapshot_intervals=snapshot_intervals_to_restate))
116-
117-
if before_promote_snapshots and not plan.empty_backfill and not plan.skip_backfill:
118-
missing_intervals_before_promote = {
119-
s: i
120-
for s, i in snapshots_to_intervals.items()
121-
if s.snapshot_id in before_promote_snapshots
122-
}
123-
if missing_intervals_before_promote:
124-
steps.append(
125-
BackfillStep(
126-
snapshot_to_intervals=missing_intervals_before_promote,
127-
deployability_index=deployability_index,
128-
)
129-
)
130-
131-
steps.append(UpdateEnvironmentRecordStep())
132-
133-
if snapshots_with_schema_migration:
134-
steps.append(MigrateSchemasStep(snapshots=snapshots_with_schema_migration))
135-
136-
if after_promote_snapshots and not plan.empty_backfill and not plan.skip_backfill:
137-
missing_intervals_after_promote = {
138-
s: i
139-
for s, i in snapshots_to_intervals.items()
140-
if s.snapshot_id in after_promote_snapshots
141-
}
142-
if missing_intervals_after_promote:
143-
steps.append(
144-
BackfillStep(
145-
snapshot_to_intervals=missing_intervals_after_promote,
146-
deployability_index=deployability_index,
147-
)
148-
)
149-
150-
promoted_snapshots, demoted_snapshots = self._get_promoted_demoted_snapshots(plan)
151-
if promoted_snapshots or demoted_snapshots:
152-
steps.append(
153-
UpdateVirtualLayerStep(
154-
promoted_snapshots=promoted_snapshots,
155-
demoted_snapshots=demoted_snapshots,
156-
deployability_index=deployability_index,
157-
)
158-
)
159-
160-
after_all = [
161-
statement
162-
for environment_statements in plan.environment_statements or []
163-
for statement in environment_statements.after_all
164-
]
165-
if after_all:
166-
steps.append(AfterAllStep(statements=after_all))
167-
42+
plan_steps = steps.build_plan_steps(plan, self.state_reader, self.default_catalog)
16843
explainer_console = _get_explainer_console(
16944
self.console, plan.environment, self.default_catalog
17045
)
171-
explainer_console.explain(steps)
172-
173-
def _get_promoted_demoted_snapshots(
174-
self, plan: EvaluatablePlan
175-
) -> t.Tuple[t.Set[SnapshotTableInfo], t.Set[SnapshotTableInfo]]:
176-
existing_environment = self.state_reader.get_environment(plan.environment.name)
177-
if existing_environment:
178-
snapshots_by_name = {s.name: s for s in existing_environment.snapshots}
179-
demoted_snapshot_names = {s.name for s in existing_environment.promoted_snapshots} - {
180-
s.name for s in plan.environment.promoted_snapshots
181-
}
182-
demoted_snapshots = {snapshots_by_name[name] for name in demoted_snapshot_names}
183-
else:
184-
demoted_snapshots = set()
185-
promoted_snapshots = set(plan.environment.promoted_snapshots)
186-
if existing_environment and plan.environment.can_partially_promote(existing_environment):
187-
promoted_snapshots -= set(existing_environment.promoted_snapshots)
188-
189-
def _snapshot_filter(snapshot: SnapshotTableInfo) -> bool:
190-
return snapshot.is_model and not snapshot.is_symbolic
191-
192-
return {s for s in promoted_snapshots if _snapshot_filter(s)}, {
193-
s for s in demoted_snapshots if _snapshot_filter(s)
194-
}
195-
196-
def _missing_intervals(
197-
self,
198-
plan: EvaluatablePlan,
199-
snapshots_by_name: t.Dict[str, Snapshot],
200-
deployability_index: DeployabilityIndex,
201-
) -> SnapshotToIntervals:
202-
return merged_missing_intervals(
203-
snapshots=snapshots_by_name.values(),
204-
start=plan.start,
205-
end=plan.end,
206-
execution_time=plan.execution_time,
207-
restatements={
208-
snapshots_by_name[name].snapshot_id: interval
209-
for name, interval in plan.restatements.items()
210-
},
211-
deployability_index=deployability_index,
212-
end_bounded=plan.end_bounded,
213-
interval_end_per_model=plan.interval_end_per_model,
214-
)
215-
216-
217-
@dataclass
218-
class BeforeAllStep:
219-
statements: t.List[str]
220-
221-
222-
@dataclass
223-
class AfterAllStep:
224-
statements: t.List[str]
225-
226-
227-
@dataclass
228-
class PhysicalLayerUpdateStep:
229-
snapshots: t.List[Snapshot]
230-
deployability_index: DeployabilityIndex
231-
232-
233-
@dataclass
234-
class AuditOnlyRunStep:
235-
snapshots: t.List[Snapshot]
236-
237-
238-
@dataclass
239-
class RestatementStep:
240-
snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
241-
242-
243-
@dataclass
244-
class BackfillStep:
245-
snapshot_to_intervals: SnapshotToIntervals
246-
deployability_index: DeployabilityIndex
247-
before_promote: bool = True
248-
249-
250-
@dataclass
251-
class MigrateSchemasStep:
252-
snapshots: t.List[Snapshot]
253-
254-
255-
@dataclass
256-
class UpdateVirtualLayerStep:
257-
promoted_snapshots: t.Set[SnapshotTableInfo]
258-
demoted_snapshots: t.Set[SnapshotTableInfo]
259-
deployability_index: DeployabilityIndex
260-
261-
262-
@dataclass
263-
class UpdateEnvironmentRecordStep:
264-
pass
265-
266-
267-
PlanStep = t.Union[
268-
BeforeAllStep,
269-
AfterAllStep,
270-
PhysicalLayerUpdateStep,
271-
AuditOnlyRunStep,
272-
RestatementStep,
273-
BackfillStep,
274-
MigrateSchemasStep,
275-
UpdateVirtualLayerStep,
276-
UpdateEnvironmentRecordStep,
277-
]
46+
explainer_console.explain(plan_steps)
27847

27948

28049
class ExplainerConsole(abc.ABC):
28150
@abc.abstractmethod
282-
def explain(self, steps: t.List[PlanStep]) -> None:
51+
def explain(self, steps: t.List[steps.PlanStep]) -> None:
28352
pass
28453

28554

@@ -301,31 +70,35 @@ def __init__(
30170
self.verbosity = verbosity
30271
self.console: RichConsole = console or srich.console
30372

304-
def explain(self, steps: t.List[PlanStep]) -> None:
73+
def explain(self, steps: t.List[steps.PlanStep]) -> None:
30574
tree = Tree("[bold]Explained plan[/bold]")
30675
for step in steps:
30776
handler_name = f"visit_{_to_snake_case(step.__class__.__name__)}"
30877
if not hasattr(self, handler_name):
309-
raise SQLMeshError(f"Unexpected step: {step.__class__.__name__}")
78+
logger.error("Unexpected step: %s", step.__class__.__name__)
79+
continue
31080
handler = getattr(self, handler_name)
31181
result = handler(step)
31282
if result:
31383
tree.add(self._limit_tree(result))
31484
self.console.print(tree)
31585

316-
def visit_before_all_step(self, step: BeforeAllStep) -> Tree:
86+
def visit_before_all_step(self, step: steps.BeforeAllStep) -> Tree:
31787
tree = Tree("[bold]Execute before all statements[/bold]")
31888
for statement in step.statements:
31989
tree.add(statement)
32090
return tree
32191

322-
def visit_after_all_step(self, step: AfterAllStep) -> Tree:
92+
def visit_after_all_step(self, step: steps.AfterAllStep) -> Tree:
32393
tree = Tree("[bold]Execute after all statements[/bold]")
32494
for statement in step.statements:
32595
tree.add(statement)
32696
return tree
32797

328-
def visit_physical_layer_update_step(self, step: PhysicalLayerUpdateStep) -> Tree:
98+
def visit_physical_layer_update_step(self, step: steps.PhysicalLayerUpdateStep) -> Tree:
99+
if not step.snapshots:
100+
return Tree("[bold]SKIP: No physical layer updates to perform[/bold]")
101+
329102
tree = Tree("[bold]Validate SQL and create physical tables if they do not exist[/bold]")
330103
for snapshot in step.snapshots:
331104
is_deployable = (
@@ -363,21 +136,24 @@ def visit_physical_layer_update_step(self, step: PhysicalLayerUpdateStep) -> Tre
363136
tree.add(model_tree)
364137
return tree
365138

366-
def visit_audit_only_run_step(self, step: AuditOnlyRunStep) -> Tree:
139+
def visit_audit_only_run_step(self, step: steps.AuditOnlyRunStep) -> Tree:
367140
tree = Tree("[bold]Audit-only execution[/bold]")
368141
for snapshot in step.snapshots:
369142
display_name = self._display_name(snapshot)
370143
tree.add(display_name)
371144
return tree
372145

373-
def visit_restatement_step(self, step: RestatementStep) -> Tree:
146+
def visit_restatement_step(self, step: steps.RestatementStep) -> Tree:
374147
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
375148
for snapshot_table_info, interval in step.snapshot_intervals.items():
376149
display_name = self._display_name(snapshot_table_info)
377150
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
378151
return tree
379152

380-
def visit_backfill_step(self, step: BackfillStep) -> Tree:
153+
def visit_backfill_step(self, step: steps.BackfillStep) -> Tree:
154+
if not step.snapshot_to_intervals:
155+
return Tree("[bold]SKIP: No model batches to execute[/bold]")
156+
381157
tree = Tree(
382158
"[bold]Backfill models by running their queries and run standalone audits[/bold]"
383159
)
@@ -426,7 +202,7 @@ def visit_backfill_step(self, step: BackfillStep) -> Tree:
426202
tree.add(f"{display_name} \[standalone audit]")
427203
return tree
428204

429-
def visit_migrate_schemas_step(self, step: MigrateSchemasStep) -> Tree:
205+
def visit_migrate_schemas_step(self, step: steps.MigrateSchemasStep) -> Tree:
430206
tree = Tree(
431207
"[bold]Update schemas (add, drop, alter columns) of production physical tables to reflect forward-only changes[/bold]"
432208
)
@@ -436,7 +212,7 @@ def visit_migrate_schemas_step(self, step: MigrateSchemasStep) -> Tree:
436212
tree.add(f"{display_name} -> {table_name}")
437213
return tree
438214

439-
def visit_update_virtual_layer_step(self, step: UpdateVirtualLayerStep) -> Tree:
215+
def visit_virtual_layer_update_step(self, step: steps.VirtualLayerUpdateStep) -> Tree:
440216
tree = Tree(
441217
f"[bold]Update the virtual layer for environment '{self.environment_naming_info.name}'[/bold]"
442218
)
@@ -461,8 +237,8 @@ def visit_update_virtual_layer_step(self, step: UpdateVirtualLayerStep) -> Tree:
461237
tree.add(self._limit_tree(demote_tree))
462238
return tree
463239

464-
def visit_update_environment_record_step(
465-
self, step: UpdateEnvironmentRecordStep
240+
def visit_environment_record_update_step(
241+
self, step: steps.EnvironmentRecordUpdateStep
466242
) -> t.Optional[Tree]:
467243
return None
468244

0 commit comments

Comments
 (0)