Skip to content

Commit b5570ec

Browse files
authored
Feat!: Ensure metadata snapshots with modified audits are still audited (#4341)
1 parent 1837bb6 commit b5570ec

7 files changed

Lines changed: 433 additions & 170 deletions

File tree

sqlmesh/core/console.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,14 @@ def start_evaluation_progress(
350350
batched_intervals: t.Dict[Snapshot, Intervals],
351351
environment_naming_info: EnvironmentNamingInfo,
352352
default_catalog: t.Optional[str],
353+
audit_only: bool = False,
353354
) -> None:
354-
"""Indicates that a new snapshot evaluation progress has begun."""
355+
"""Indicates that a new snapshot evaluation/auditing progress has begun."""
355356

356357
@abc.abstractmethod
357-
def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None:
358+
def start_snapshot_evaluation_progress(
359+
self, snapshot: Snapshot, audit_only: bool = False
360+
) -> None:
358361
"""Starts the snapshot evaluation progress."""
359362

360363
@abc.abstractmethod
@@ -366,6 +369,7 @@ def update_snapshot_evaluation_progress(
366369
duration_ms: t.Optional[int],
367370
num_audits_passed: int,
368371
num_audits_failed: int,
372+
audit_only: bool = False,
369373
) -> None:
370374
"""Updates the snapshot evaluation progress."""
371375

@@ -507,10 +511,13 @@ def start_evaluation_progress(
507511
batched_intervals: t.Dict[Snapshot, Intervals],
508512
environment_naming_info: EnvironmentNamingInfo,
509513
default_catalog: t.Optional[str],
514+
audit_only: bool = False,
510515
) -> None:
511516
pass
512517

513-
def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None:
518+
def start_snapshot_evaluation_progress(
519+
self, snapshot: Snapshot, audit_only: bool = False
520+
) -> None:
514521
pass
515522

516523
def update_snapshot_evaluation_progress(
@@ -521,6 +528,7 @@ def update_snapshot_evaluation_progress(
521528
duration_ms: t.Optional[int],
522529
num_audits_passed: int,
523530
num_audits_failed: int,
531+
audit_only: bool = False,
524532
) -> None:
525533
pass
526534

@@ -891,11 +899,12 @@ def start_evaluation_progress(
891899
batched_intervals: t.Dict[Snapshot, Intervals],
892900
environment_naming_info: EnvironmentNamingInfo,
893901
default_catalog: t.Optional[str],
902+
audit_only: bool = False,
894903
) -> None:
895-
"""Indicates that a new snapshot evaluation progress has begun."""
904+
"""Indicates that a new snapshot evaluation/auditing progress has begun."""
896905
if not self.evaluation_progress_live:
897906
self.evaluation_total_progress = make_progress_bar(
898-
"Executing model batches", self.console
907+
"Executing model batches" if not audit_only else "Auditing models", self.console
899908
)
900909

901910
self.evaluation_model_progress = Progress(
@@ -916,8 +925,9 @@ def start_evaluation_progress(
916925
batch_sizes = {
917926
snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()
918927
}
928+
message = "Executing" if not audit_only else "Auditing"
919929
self.evaluation_total_task = self.evaluation_total_progress.add_task(
920-
"Executing models...", total=sum(batch_sizes.values())
930+
f"{message} models...", total=sum(batch_sizes.values())
921931
)
922932

923933
# determine column widths
@@ -943,15 +953,17 @@ def start_evaluation_progress(
943953
self.environment_naming_info = environment_naming_info
944954
self.default_catalog = default_catalog
945955

946-
def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None:
956+
def start_snapshot_evaluation_progress(
957+
self, snapshot: Snapshot, audit_only: bool = False
958+
) -> None:
947959
if self.evaluation_model_progress and snapshot.name not in self.evaluation_model_tasks:
948960
display_name = snapshot.display_name(
949961
self.environment_naming_info,
950962
self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
951963
dialect=self.dialect,
952964
)
953965
self.evaluation_model_tasks[snapshot.name] = self.evaluation_model_progress.add_task(
954-
f"Evaluating {display_name}...",
966+
f"{'Evaluating' if not audit_only else 'Auditing'} {display_name}...",
955967
view_name=display_name,
956968
total=self.evaluation_model_batch_sizes[snapshot],
957969
)
@@ -964,6 +976,7 @@ def update_snapshot_evaluation_progress(
964976
duration_ms: t.Optional[int],
965977
num_audits_passed: int,
966978
num_audits_failed: int,
979+
audit_only: bool = False,
967980
) -> None:
968981
"""Update the snapshot evaluation progress."""
969982
if (
@@ -1003,7 +1016,7 @@ def update_snapshot_evaluation_progress(
10031016
self.evaluation_column_widths["duration"]
10041017
)
10051018

1006-
msg = f"{batch} {display_name} {annotation} {duration}".replace(
1019+
msg = f"{f'{batch} ' if not audit_only else ''}{display_name} {annotation} {duration}".replace(
10071020
self.AUDIT_PASS_MARK, self.GREEN_AUDIT_PASS_MARK
10081021
)
10091022

@@ -1015,7 +1028,10 @@ def update_snapshot_evaluation_progress(
10151028

10161029
model_task_id = self.evaluation_model_tasks[snapshot.name]
10171030
self.evaluation_model_progress.update(model_task_id, refresh=True, advance=1)
1018-
if self.evaluation_model_progress._tasks[model_task_id].completed >= total_batches:
1031+
if (
1032+
self.evaluation_model_progress._tasks[model_task_id].completed >= total_batches
1033+
or audit_only
1034+
):
10191035
self.evaluation_model_progress.remove_task(model_task_id)
10201036

10211037
def stop_evaluation_progress(self, success: bool = True) -> None:
@@ -3208,14 +3224,17 @@ def start_evaluation_progress(
32083224
batched_intervals: t.Dict[Snapshot, Intervals],
32093225
environment_naming_info: EnvironmentNamingInfo,
32103226
default_catalog: t.Optional[str],
3227+
audit_only: bool = False,
32113228
) -> None:
32123229
self.evaluation_model_batch_sizes = {
32133230
snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()
32143231
}
32153232
self.evaluation_environment_naming_info = environment_naming_info
32163233
self.default_catalog = default_catalog
32173234

3218-
def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None:
3235+
def start_snapshot_evaluation_progress(
3236+
self, snapshot: Snapshot, audit_only: bool = False
3237+
) -> None:
32193238
if not self.evaluation_batch_progress.get(snapshot.snapshot_id):
32203239
display_name = snapshot.display_name(
32213240
self.evaluation_environment_naming_info,
@@ -3235,8 +3254,14 @@ def update_snapshot_evaluation_progress(
32353254
duration_ms: t.Optional[int],
32363255
num_audits_passed: int,
32373256
num_audits_failed: int,
3257+
audit_only: bool = False,
32383258
) -> None:
32393259
view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id]
3260+
3261+
if audit_only:
3262+
print(f"Completed Auditing {view_name}")
3263+
return
3264+
32403265
total_batches = self.evaluation_model_batch_sizes[snapshot]
32413266

32423267
loaded_batches += 1
@@ -3378,13 +3403,17 @@ def start_evaluation_progress(
33783403
batched_intervals: t.Dict[Snapshot, Intervals],
33793404
environment_naming_info: EnvironmentNamingInfo,
33803405
default_catalog: t.Optional[str],
3406+
audit_only: bool = False,
33813407
) -> None:
3408+
message = "evaluation" if not audit_only else "auditing"
33823409
self._write(
3383-
f"Starting evaluation for {sum(len(intervals) for intervals in batched_intervals.values())} snapshots"
3410+
f"Starting {message} for {sum(len(intervals) for intervals in batched_intervals.values())} snapshots"
33843411
)
33853412

3386-
def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None:
3387-
self._write(f"Evaluating {snapshot.name}")
3413+
def start_snapshot_evaluation_progress(
3414+
self, snapshot: Snapshot, audit_only: bool = False
3415+
) -> None:
3416+
self._write(f"{'Evaluating' if not audit_only else 'Auditing'} {snapshot.name}")
33883417

33893418
def update_snapshot_evaluation_progress(
33903419
self,
@@ -3394,10 +3423,14 @@ def update_snapshot_evaluation_progress(
33943423
duration_ms: t.Optional[int],
33953424
num_audits_passed: int,
33963425
num_audits_failed: int,
3426+
audit_only: bool = False,
33973427
) -> None:
3398-
self._write(
3399-
f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3400-
)
3428+
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3429+
3430+
if audit_only:
3431+
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3432+
3433+
self._write(message)
34013434

34023435
def stop_evaluation_progress(self, success: bool = True) -> None:
34033436
self._write(f"Stopping evaluation with success={success}")

sqlmesh/core/model/definition.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,37 @@ def _data_hash_values(self) -> t.List[str]:
10691069

10701070
return data # type: ignore
10711071

1072+
def _audit_metadata_hash_values(self) -> t.List[str]:
1073+
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1074+
1075+
metadata = []
1076+
1077+
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1078+
metadata.append(audit_name)
1079+
if audit_name in BUILT_IN_AUDITS:
1080+
for arg_name, arg_value in audit_args.items():
1081+
metadata.append(arg_name)
1082+
metadata.append(gen(arg_value))
1083+
else:
1084+
audit = self.audit_definitions[audit_name]
1085+
query = (
1086+
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1087+
or audit.query
1088+
)
1089+
metadata.extend(
1090+
[
1091+
gen(query),
1092+
audit.dialect,
1093+
str(audit.skip),
1094+
str(audit.blocking),
1095+
]
1096+
)
1097+
1098+
return metadata
1099+
1100+
def audit_metadata_hash(self) -> str:
1101+
return hash_data(self._audit_metadata_hash_values())
1102+
10721103
@property
10731104
def metadata_hash(self) -> str:
10741105
"""
@@ -1078,8 +1109,6 @@ def metadata_hash(self) -> str:
10781109
The metadata hash for the node.
10791110
"""
10801111
if self._metadata_hash is None:
1081-
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1082-
10831112
metadata = [
10841113
self.dialect,
10851114
self.owner,
@@ -1100,29 +1129,9 @@ def metadata_hash(self) -> str:
11001129
str(self.allow_partials),
11011130
gen(self.session_properties_) if self.session_properties_ else None,
11021131
*[gen(g) for g in self.grains],
1132+
*self._audit_metadata_hash_values(),
11031133
]
11041134

1105-
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1106-
metadata.append(audit_name)
1107-
if audit_name in BUILT_IN_AUDITS:
1108-
for arg_name, arg_value in audit_args.items():
1109-
metadata.append(arg_name)
1110-
metadata.append(gen(arg_value))
1111-
else:
1112-
audit = self.audit_definitions[audit_name]
1113-
query = (
1114-
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1115-
or audit.query
1116-
)
1117-
metadata.extend(
1118-
[
1119-
gen(query),
1120-
audit.dialect,
1121-
str(audit.skip),
1122-
str(audit.blocking),
1123-
]
1124-
)
1125-
11261135
for key, value in (self.virtual_properties or {}).items():
11271136
metadata.append(key)
11281137
metadata.append(gen(value))

sqlmesh/core/plan/evaluator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def evaluate(
128128
execution_time=plan.execution_time,
129129
)
130130

131+
self._run_audits_for_metadata_snapshots(plan, new_snapshots)
132+
131133
push_completion_status = self._push(plan, snapshots, deployability_index_for_creation)
132134
if push_completion_status.is_nothing_to_do:
133135
self.console.log_status_update(
@@ -545,6 +547,54 @@ def _restatement_intervals_across_all_environments(
545547

546548
return set(snapshots_to_restate.values())
547549

550+
def _run_audits_for_metadata_snapshots(
551+
self,
552+
plan: EvaluatablePlan,
553+
new_snapshots: t.Dict[SnapshotId, Snapshot],
554+
) -> None:
555+
# Filter out snapshots that are not categorized as metadata changes on models
556+
metadata_snapshots = []
557+
for snapshot in new_snapshots.values():
558+
if not snapshot.is_metadata or not snapshot.is_model or not snapshot.evaluatable:
559+
continue
560+
561+
metadata_snapshots.append(snapshot)
562+
563+
# Bulk load all the previous snapshots
564+
previous_snapshots = self.state_sync.get_snapshots(
565+
[
566+
s.previous_version.snapshot_id(s.name)
567+
for s in metadata_snapshots
568+
if s.previous_version
569+
]
570+
).values()
571+
572+
# Check if any of the snapshots have modifications to the audits field by comparing the hashes
573+
audit_snapshots = {}
574+
for snapshot, previous_snapshot in zip(metadata_snapshots, previous_snapshots):
575+
new_audits_hash = snapshot.model.audit_metadata_hash()
576+
previous_audit_hash = previous_snapshot.model.audit_metadata_hash()
577+
578+
if snapshot.model.audits and previous_audit_hash != new_audits_hash:
579+
audit_snapshots[snapshot.snapshot_id] = snapshot
580+
581+
if not audit_snapshots:
582+
return
583+
584+
# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
585+
scheduler = self.create_scheduler(audit_snapshots.values())
586+
completion_status = scheduler.audit(
587+
plan.environment,
588+
plan.start,
589+
plan.end,
590+
execution_time=plan.execution_time,
591+
end_bounded=plan.end_bounded,
592+
interval_end_per_model=plan.interval_end_per_model,
593+
)
594+
595+
if completion_status.is_failure:
596+
raise PlanError("Plan application failed.")
597+
548598

549599
def update_intervals_for_new_snapshots(
550600
snapshots: t.Collection[Snapshot], state_sync: StateSync

0 commit comments

Comments
 (0)