Skip to content

Commit 1b7fad4

Browse files
committed
Switch to using restatements
1 parent 93a905c commit 1b7fad4

3 files changed

Lines changed: 219 additions & 152 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,10 @@ def _audit_metadata_hash_values(self) -> t.List[str]:
10971097

10981098
return metadata
10991099

1100+
def audit_metadata_hash(self) -> t.Tuple[t.List[str], str]:
1101+
hash_values = self._audit_metadata_hash_values()
1102+
return hash_values, hash_data(hash_values)
1103+
11001104
@property
11011105
def metadata_hash(self) -> str:
11021106
"""

sqlmesh/core/plan/evaluator.py

Lines changed: 38 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,13 @@
3535
SnapshotTableInfo,
3636
SnapshotCreationFailedError,
3737
)
38-
from sqlmesh.core.snapshot.definition import SnapshotChangeCategory, parent_snapshots_by_name
3938
from sqlmesh.utils import CompletionStatus
4039
from sqlmesh.core.state_sync import StateSync
4140
from sqlmesh.core.state_sync.base import PromotionResult
4241
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4342
from sqlmesh.utils.errors import PlanError
4443
from sqlmesh.utils.dag import DAG
4544
from sqlmesh.utils.date import now
46-
from sqlmesh.utils.hashing import hash_data
4745

4846
logger = logging.getLogger(__name__)
4947

@@ -118,9 +116,7 @@ def evaluate(
118116
after_promote_snapshots = all_names - before_promote_snapshots
119117
deployability_index_for_evaluation = DeployabilityIndex.all_deployable()
120118

121-
self._run_audits_for_metadata_snapshots(
122-
new_snapshots, plan, deployability_index_for_evaluation
123-
)
119+
self._run_audits_for_metadata_snapshots(plan, snapshots, new_snapshots)
124120

125121
execute_environment_statements(
126122
adapter=self.snapshot_evaluator.adapter,
@@ -553,69 +549,56 @@ def _restatement_intervals_across_all_environments(
553549

554550
def _run_audits_for_metadata_snapshots(
555551
self,
556-
new_snapshots: t.Dict[SnapshotId, Snapshot],
557552
plan: EvaluatablePlan,
558-
deployability_index: DeployabilityIndex,
553+
snapshots: t.Dict[SnapshotId, Snapshot],
554+
new_snapshots: t.Dict[SnapshotId, Snapshot],
559555
) -> None:
560-
to_be_audited_snapshots = []
561-
556+
# Step 1: Filter out snapshots that are not categorized as metadata changes on models
557+
metadata_snapshots = []
562558
for snapshot in new_snapshots.values():
563-
if (
564-
snapshot.change_category != SnapshotChangeCategory.METADATA
565-
or not snapshot.previous_version
566-
or not snapshot.is_model
567-
):
559+
if not snapshot.is_metadata or not snapshot.is_model or not snapshot.evaluatable:
568560
continue
569561

570-
previous_snapshot_id = snapshot.previous_version.snapshot_id(snapshot.name)
571-
previous_snapshot = self.state_sync.get_snapshots([previous_snapshot_id])[
572-
previous_snapshot_id
573-
]
562+
metadata_snapshots.append(snapshot)
574563

575-
new_audits = snapshot.model._audit_metadata_hash_values()
564+
# Step 2: Bulk load their previous snapshots from state
565+
previous_snapshots = self.state_sync.get_snapshots(
566+
[
567+
s.previous_version.snapshot_id(s.name)
568+
for s in metadata_snapshots
569+
if s.previous_version
570+
]
571+
).values()
576572

577-
# Compare the audit metadata hashes to determine if there was a change
578-
previous_audit_hash = hash_data(previous_snapshot.model._audit_metadata_hash_values())
579-
current_audit_hash = hash_data(new_audits)
573+
# Step 3: Compare the audit metadata hashes to determine if there was a change in the audits field
574+
to_be_audited_snapshots = {}
575+
for snapshot, previous_snapshot in zip(metadata_snapshots, previous_snapshots):
576+
new_audits, new_audits_hash = snapshot.model.audit_metadata_hash()
577+
_, previous_audit_hash = previous_snapshot.model.audit_metadata_hash()
580578

581-
if previous_audit_hash != current_audit_hash and new_audits:
582-
to_be_audited_snapshots.append((snapshot, previous_snapshot))
579+
if previous_audit_hash != new_audits_hash and new_audits:
580+
snapshot_start = min(i[0] for i in snapshot.intervals)
581+
snapshot_end = max(i[1] for i in snapshot.intervals)
582+
to_be_audited_snapshots[snapshot.snapshot_id] = (snapshot_start, snapshot_end)
583583

584584
if not to_be_audited_snapshots:
585585
return
586586

587-
scheduler = self.create_scheduler(new_snapshots.values())
588-
raise_plan_error = False
589-
for to_be_audited_snapshot, previous_snapshot in to_be_audited_snapshots:
590-
parent_snapshots = parent_snapshots_by_name(to_be_audited_snapshot, new_snapshots)
591-
592-
# The previous snapshot is the snapshot before the metadata change
593-
# and contains the latest intervals that we should use for the new audit
594-
for interval in previous_snapshot.intervals:
595-
start, end = interval
596-
597-
try:
598-
scheduler._audit_snapshot(
599-
to_be_audited_snapshot,
600-
environment_naming_info=plan.environment.naming_info,
601-
snapshots=parent_snapshots,
602-
start=start,
603-
end=end,
604-
execution_time=plan.execution_time,
605-
deployability_index=deployability_index,
606-
)
607-
except Exception as e:
608-
# Simulate a node execution failure with the audit error passed as the
609-
# cause in order to reuse log_failed_models
610-
error = NodeExecutionFailedError(
611-
(to_be_audited_snapshot.name, ((start, end), -1))
612-
)
613-
error.__cause__ = e
614-
self.console.log_failed_models([error])
615-
raise_plan_error = True
587+
# Step 4: If there are any snapshots to be audited, we'll reuse the scheduler's
588+
# internals to audit them by utilizing the restatement logic
589+
scheduler = self.create_scheduler(snapshots.values())
590+
completion_status = scheduler.audit(
591+
plan.environment,
592+
plan.start,
593+
plan.end,
594+
execution_time=plan.execution_time,
595+
restatements=to_be_audited_snapshots,
596+
end_bounded=plan.end_bounded,
597+
interval_end_per_model=plan.interval_end_per_model,
598+
)
616599

617-
if raise_plan_error:
618-
raise PlanError("Plan application failed.")
600+
if completion_status.is_failure:
601+
raise PlanError("Plan application failed.")
619602

620603

621604
def update_intervals_for_new_snapshots(

0 commit comments

Comments
 (0)