Skip to content

Commit 1262d1a

Browse files
committed
refactor scheduler
1 parent f4326da commit 1262d1a

4 files changed

Lines changed: 292 additions & 161 deletions

File tree

sqlmesh/core/scheduler.py

Lines changed: 143 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from dataclasses import dataclass
23
import logging
34
import typing as t
45
import time
@@ -39,7 +40,6 @@
3940
from sqlmesh.utils.date import (
4041
TimeLike,
4142
now_timestamp,
42-
to_timestamp,
4343
validate_date_range,
4444
)
4545
from sqlmesh.utils.errors import (
@@ -55,9 +55,46 @@
5555

5656
logger = logging.getLogger(__name__)
5757
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
58-
# we store snapshot name instead of snapshots/snapshotids because pydantic
59-
# is extremely slow to hash. snapshot names should be unique within a dag run
60-
SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]]
58+
59+
60+
class BaseNode:
61+
snapshot_name: str
62+
63+
def __lt__(self, other: BaseNode) -> bool:
64+
return (self.__class__.__name__, self.snapshot_name) < (
65+
other.__class__.__name__,
66+
other.snapshot_name,
67+
)
68+
69+
70+
@dataclass(frozen=True)
71+
class EvaluateNode(BaseNode):
72+
snapshot_name: str
73+
interval: Interval
74+
batch_index: int
75+
76+
def __lt__(self, other: BaseNode) -> bool:
77+
if not isinstance(other, EvaluateNode):
78+
return super().__lt__(other)
79+
return (self.__class__.__name__, self.snapshot_name, self.interval, self.batch_index) < (
80+
other.__class__.__name__,
81+
other.snapshot_name,
82+
other.interval,
83+
other.batch_index,
84+
)
85+
86+
87+
@dataclass(frozen=True)
88+
class CreateNode(BaseNode):
89+
snapshot_name: str
90+
91+
92+
@dataclass(frozen=True)
93+
class DummyNode(BaseNode):
94+
snapshot_name: str
95+
96+
97+
SchedulingUnit = t.Union[EvaluateNode, CreateNode, DummyNode]
6198

6299

63100
class Scheduler:
@@ -162,6 +199,7 @@ def evaluate(
162199
batch_index: int,
163200
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
164201
allow_destructive_snapshots: t.Optional[t.Set[str]] = None,
202+
target_table_exists: t.Optional[bool] = None,
165203
**kwargs: t.Any,
166204
) -> t.List[AuditResult]:
167205
"""Evaluate a snapshot and add the processed interval to the state sync.
@@ -175,6 +213,7 @@ def evaluate(
175213
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
176214
batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it
177215
auto_restatement_enabled: Whether to enable auto restatements.
216+
target_table_exists: Whether the target table exists. If None, the table will be checked for existence.
178217
kwargs: Additional kwargs to pass to the renderer.
179218
180219
Returns:
@@ -195,6 +234,7 @@ def evaluate(
195234
allow_destructive_snapshots=allow_destructive_snapshots,
196235
deployability_index=deployability_index,
197236
batch_index=batch_index,
237+
target_table_exists=target_table_exists,
198238
**kwargs,
199239
)
200240
audit_results = self._audit_snapshot(
@@ -404,7 +444,14 @@ def run_merged_intervals(
404444
audit_only=audit_only,
405445
)
406446

407-
dag = self._dag(batched_intervals)
447+
snapshots_to_create = {
448+
s.snapshot_id
449+
for s in self.snapshot_evaluator.get_snapshots_to_create(
450+
merged_intervals.keys(), deployability_index
451+
)
452+
}
453+
454+
dag = self._dag(batched_intervals, snapshots_to_create=snapshots_to_create)
408455

409456
if run_environment_statements:
410457
environment_statements = self.state_sync.get_environment_statements(
@@ -425,55 +472,63 @@ def run_merged_intervals(
425472
def evaluate_node(node: SchedulingUnit) -> None:
426473
if circuit_breaker and circuit_breaker():
427474
raise CircuitBreakerError()
428-
429-
snapshot_name, ((start, end), batch_idx) = node
430-
if batch_idx == -1:
475+
if isinstance(node, DummyNode):
431476
return
432-
snapshot = self.snapshots_by_name[snapshot_name]
433-
434-
self.console.start_snapshot_evaluation_progress(snapshot)
435-
436-
execution_start_ts = now_timestamp()
437-
evaluation_duration_ms: t.Optional[int] = None
438477

439-
audit_results: t.List[AuditResult] = []
440-
try:
441-
assert execution_time # mypy
442-
assert deployability_index # mypy
443-
444-
if audit_only:
445-
audit_results = self._audit_snapshot(
446-
snapshot=snapshot,
447-
environment_naming_info=environment_naming_info,
448-
deployability_index=deployability_index,
449-
snapshots=self.snapshots_by_name,
450-
start=start,
451-
end=end,
452-
execution_time=execution_time,
453-
)
454-
else:
455-
audit_results = self.evaluate(
456-
snapshot=snapshot,
457-
environment_naming_info=environment_naming_info,
458-
start=start,
459-
end=end,
460-
execution_time=execution_time,
461-
deployability_index=deployability_index,
462-
batch_index=batch_idx,
463-
allow_destructive_snapshots=allow_destructive_snapshots,
478+
snapshot = self.snapshots_by_name[node.snapshot_name]
479+
480+
if isinstance(node, EvaluateNode):
481+
self.console.start_snapshot_evaluation_progress(snapshot)
482+
execution_start_ts = now_timestamp()
483+
evaluation_duration_ms: t.Optional[int] = None
484+
start, end = node.interval
485+
486+
audit_results: t.List[AuditResult] = []
487+
try:
488+
assert execution_time # mypy
489+
assert deployability_index # mypy
490+
491+
if audit_only:
492+
audit_results = self._audit_snapshot(
493+
snapshot=snapshot,
494+
environment_naming_info=environment_naming_info,
495+
deployability_index=deployability_index,
496+
snapshots=self.snapshots_by_name,
497+
start=start,
498+
end=end,
499+
execution_time=execution_time,
500+
)
501+
else:
502+
audit_results = self.evaluate(
503+
snapshot=snapshot,
504+
environment_naming_info=environment_naming_info,
505+
start=start,
506+
end=end,
507+
execution_time=execution_time,
508+
deployability_index=deployability_index,
509+
batch_index=node.batch_index,
510+
allow_destructive_snapshots=allow_destructive_snapshots,
511+
target_table_exists=snapshot.snapshot_id not in snapshots_to_create,
512+
)
513+
514+
evaluation_duration_ms = now_timestamp() - execution_start_ts
515+
finally:
516+
num_audits = len(audit_results)
517+
num_audits_failed = sum(1 for result in audit_results if result.count)
518+
self.console.update_snapshot_evaluation_progress(
519+
snapshot,
520+
batched_intervals[snapshot][node.batch_index],
521+
node.batch_index,
522+
evaluation_duration_ms,
523+
num_audits - num_audits_failed,
524+
num_audits_failed,
464525
)
465-
466-
evaluation_duration_ms = now_timestamp() - execution_start_ts
467-
finally:
468-
num_audits = len(audit_results)
469-
num_audits_failed = sum(1 for result in audit_results if result.count)
470-
self.console.update_snapshot_evaluation_progress(
471-
snapshot,
472-
batched_intervals[snapshot][batch_idx],
473-
batch_idx,
474-
evaluation_duration_ms,
475-
num_audits - num_audits_failed,
476-
num_audits_failed,
526+
elif isinstance(node, CreateNode):
527+
self.snapshot_evaluator.create_snapshot(
528+
snapshot=snapshot,
529+
snapshots=self.snapshots_by_name,
530+
deployability_index=deployability_index,
531+
allow_destructive_snapshots=allow_destructive_snapshots or set(),
477532
)
478533

479534
try:
@@ -486,7 +541,9 @@ def evaluate_node(node: SchedulingUnit) -> None:
486541
)
487542
self.console.stop_evaluation_progress(success=not errors)
488543

489-
skipped_snapshots = {i[0] for i in skipped_intervals}
544+
skipped_snapshots = {
545+
i.snapshot_name for i in skipped_intervals if isinstance(i, EvaluateNode)
546+
}
490547
self.console.log_skipped_models(skipped_snapshots)
491548
for skipped in skipped_snapshots:
492549
logger.info(f"SKIPPED snapshot {skipped}\n")
@@ -515,11 +572,16 @@ def evaluate_node(node: SchedulingUnit) -> None:
515572

516573
self.state_sync.recycle()
517574

518-
def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
575+
def _dag(
576+
self,
577+
batches: SnapshotToIntervals,
578+
snapshots_to_create: t.Optional[t.Set[SnapshotId]] = None,
579+
) -> DAG[SchedulingUnit]:
519580
"""Builds a DAG of snapshot intervals to be evaluated.
520581
521582
Args:
522583
batches: The batches of snapshots and intervals to evaluate.
584+
snapshots_to_create: The snapshots with missing physical tables.
523585
524586
Returns:
525587
A DAG of snapshot intervals to be evaluated.
@@ -528,46 +590,64 @@ def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
528590
intervals_per_snapshot = {
529591
snapshot.name: intervals for snapshot, intervals in batches.items()
530592
}
593+
snapshots_to_create = snapshots_to_create or set()
531594

532595
dag = DAG[SchedulingUnit]()
533-
terminal_node = ((to_timestamp(0), to_timestamp(0)), -1)
534596

535597
for snapshot, intervals in batches.items():
536598
if not intervals:
537599
continue
538600

539-
upstream_dependencies = []
601+
upstream_dependencies: t.List[SchedulingUnit] = []
540602

541603
for p_sid in snapshot.parents:
542604
if p_sid in self.snapshots:
543605
p_intervals = intervals_per_snapshot.get(p_sid.name, [])
544606

545607
if len(p_intervals) > 1:
546-
upstream_dependencies.append((p_sid.name, terminal_node))
608+
upstream_dependencies.append(DummyNode(snapshot_name=p_sid.name))
547609
else:
548610
for i, interval in enumerate(p_intervals):
549-
upstream_dependencies.append((p_sid.name, (interval, i)))
611+
upstream_dependencies.append(
612+
EvaluateNode(
613+
snapshot_name=p_sid.name, interval=interval, batch_index=i
614+
)
615+
)
550616

551617
batch_concurrency = snapshot.node.batch_concurrency
552618
if snapshot.depends_on_past:
553619
batch_concurrency = 1
554620

621+
create_node: t.Optional[CreateNode] = None
622+
if (
623+
batch_concurrency
624+
and batch_concurrency > 1
625+
and snapshot.snapshot_id in snapshots_to_create
626+
):
627+
# Add a separate node for table creation in case when there multiple concurrent
628+
# evaluation nodes.
629+
create_node = CreateNode(snapshot_name=snapshot.name)
630+
555631
for i, interval in enumerate(intervals):
556-
node = (snapshot.name, (interval, i))
632+
node = EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=i)
557633
dag.add(node, upstream_dependencies)
558634

559635
if len(intervals) > 1:
560-
dag.add((snapshot.name, terminal_node), [node])
636+
dag.add(DummyNode(snapshot_name=snapshot.name), [node])
637+
638+
if create_node:
639+
dag.add(node, [create_node])
561640

562641
if batch_concurrency and i >= batch_concurrency:
563642
batch_idx_to_wait_for = i - batch_concurrency
564643
dag.add(
565644
node,
566645
[
567-
(
568-
snapshot.name,
569-
(intervals[batch_idx_to_wait_for], batch_idx_to_wait_for),
570-
)
646+
EvaluateNode(
647+
snapshot_name=snapshot.name,
648+
interval=intervals[batch_idx_to_wait_for],
649+
batch_index=batch_idx_to_wait_for,
650+
),
571651
],
572652
)
573653
return dag

sqlmesh/core/snapshot/evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,11 +2213,7 @@ def insert(
22132213
render_kwargs: t.Dict[str, t.Any],
22142214
**kwargs: t.Any,
22152215
) -> None:
2216-
deployability_index = (
2217-
kwargs.get("deployability_index") or DeployabilityIndex.all_deployable()
2218-
)
22192216
snapshot = kwargs["snapshot"]
2220-
snapshots = kwargs["snapshots"]
22212217

22222218
if (
22232219
not snapshot.is_materialized_view

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def push_plan(context: Context, plan: Plan) -> None:
271271
context.default_catalog,
272272
)
273273
deployability_index = DeployabilityIndex.create(context.snapshots.values())
274-
evaluatable_plan = plan.to_evaluatable()
274+
evaluatable_plan = plan.to_evaluatable().copy(update={"skip_backfill": True})
275275
stages = plan_stages.build_plan_stages(
276276
evaluatable_plan, context.state_sync, context.default_catalog
277277
)

0 commit comments

Comments
 (0)