Skip to content

Commit 7c40682

Browse files
committed
Feat: Ensure audits run even if adding them is a metadata change
1 parent 3271ae1 commit 7c40682

4 files changed

Lines changed: 211 additions & 64 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,34 @@ def _data_hash_values(self) -> t.List[str]:
10691069

10701070
return data # type: ignore
10711071

1072+
def _audit_metadata(self) -> t.List[str]:
1073+
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1074+
1075+
metadata = []
1076+
1077+
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1078+
metadata.append(audit_name)
1079+
if audit_name in BUILT_IN_AUDITS:
1080+
for arg_name, arg_value in audit_args.items():
1081+
metadata.append(arg_name)
1082+
metadata.append(gen(arg_value))
1083+
else:
1084+
audit = self.audit_definitions[audit_name]
1085+
query = (
1086+
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1087+
or audit.query
1088+
)
1089+
metadata.extend(
1090+
[
1091+
gen(query),
1092+
audit.dialect,
1093+
str(audit.skip),
1094+
str(audit.blocking),
1095+
]
1096+
)
1097+
1098+
return metadata
1099+
10721100
@property
10731101
def metadata_hash(self) -> str:
10741102
"""
@@ -1078,8 +1106,6 @@ def metadata_hash(self) -> str:
10781106
The metadata hash for the node.
10791107
"""
10801108
if self._metadata_hash is None:
1081-
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1082-
10831109
metadata = [
10841110
self.dialect,
10851111
self.owner,
@@ -1100,29 +1126,9 @@ def metadata_hash(self) -> str:
11001126
str(self.allow_partials),
11011127
gen(self.session_properties_) if self.session_properties_ else None,
11021128
*[gen(g) for g in self.grains],
1129+
*self._audit_metadata(),
11031130
]
11041131

1105-
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1106-
metadata.append(audit_name)
1107-
if audit_name in BUILT_IN_AUDITS:
1108-
for arg_name, arg_value in audit_args.items():
1109-
metadata.append(arg_name)
1110-
metadata.append(gen(arg_value))
1111-
else:
1112-
audit = self.audit_definitions[audit_name]
1113-
query = (
1114-
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1115-
or audit.query
1116-
)
1117-
metadata.extend(
1118-
[
1119-
gen(query),
1120-
audit.dialect,
1121-
str(audit.skip),
1122-
str(audit.blocking),
1123-
]
1124-
)
1125-
11261132
for key, value in (self.virtual_properties or {}).items():
11271133
metadata.append(key)
11281134
metadata.append(gen(value))

sqlmesh/core/plan/evaluator.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
SnapshotInfoLike,
3535
SnapshotTableInfo,
3636
)
37+
from sqlmesh.core.snapshot.definition import SnapshotChangeCategory, parent_snapshots_by_name
3738
from sqlmesh.utils import CompletionStatus
3839
from sqlmesh.core.state_sync import StateSync
3940
from sqlmesh.core.state_sync.base import PromotionResult
4041
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4142
from sqlmesh.utils.errors import PlanError
4243
from sqlmesh.utils.dag import DAG
4344
from sqlmesh.utils.date import now
45+
from sqlmesh.utils.hashing import hash_data
4446

4547
logger = logging.getLogger(__name__)
4648

@@ -115,6 +117,10 @@ def evaluate(
115117
after_promote_snapshots = all_names - before_promote_snapshots
116118
deployability_index_for_evaluation = DeployabilityIndex.all_deployable()
117119

120+
self._run_audits_for_metadata_snapshots(
121+
new_snapshots, plan, deployability_index_for_evaluation
122+
)
123+
118124
execute_environment_statements(
119125
adapter=self.snapshot_evaluator.adapter,
120126
environment_statements=plan.environment_statements or [],
@@ -541,6 +547,72 @@ def _restatement_intervals_across_all_environments(
541547

542548
return set(snapshots_to_restate.values())
543549

550+
def _run_audits_for_metadata_snapshots(
551+
self,
552+
new_snapshots: t.Dict[SnapshotId, Snapshot],
553+
plan: EvaluatablePlan,
554+
deployability_index: DeployabilityIndex,
555+
) -> None:
556+
to_be_audited_snapshots = []
557+
558+
for snapshot in new_snapshots.values():
559+
if (
560+
snapshot.change_category != SnapshotChangeCategory.METADATA
561+
or not snapshot.previous_version
562+
or not snapshot.is_model
563+
):
564+
continue
565+
566+
previous_snapshot_id = snapshot.previous_version.snapshot_id(snapshot.name)
567+
previous_snapshot = self.state_sync.get_snapshots([previous_snapshot_id])[
568+
previous_snapshot_id
569+
]
570+
571+
new_audits = snapshot.model._audit_metadata()
572+
573+
# Compare the audit metadata hashes there was a change in the audits field
574+
previous_audit_hash = hash_data(previous_snapshot.model._audit_metadata())
575+
current_audit_hash = hash_data(new_audits)
576+
577+
if previous_audit_hash != current_audit_hash and new_audits:
578+
to_be_audited_snapshots.append((snapshot, previous_snapshot))
579+
580+
if not to_be_audited_snapshots:
581+
return
582+
583+
scheduler = self.create_scheduler(new_snapshots.values())
584+
raise_plan_error = False
585+
for to_be_audited_snapshot, previous_snapshot in to_be_audited_snapshots:
586+
parent_snapshots = parent_snapshots_by_name(to_be_audited_snapshot, new_snapshots)
587+
588+
# The previous snapshot is the snapshot before the metadata change
589+
# and contains the latest intervals that we should use for the new audit
590+
for interval in previous_snapshot.intervals:
591+
start, end = interval
592+
593+
try:
594+
scheduler._audit_snapshot(
595+
to_be_audited_snapshot,
596+
environment_naming_info=plan.environment.naming_info,
597+
snapshots=parent_snapshots,
598+
start=start,
599+
end=end,
600+
execution_time=plan.execution_time,
601+
deployability_index=deployability_index,
602+
)
603+
except Exception as e:
604+
# Simulate a node execution failure with the audit error passed as the
605+
# cause in order to reuse log_failed_models
606+
error = NodeExecutionFailedError(
607+
(to_be_audited_snapshot.name, ((start, end), -1))
608+
)
609+
error.__cause__ = e
610+
self.console.log_failed_models([error])
611+
raise_plan_error = True
612+
613+
if raise_plan_error:
614+
raise PlanError("Plan application failed.")
615+
544616

545617
def update_intervals_for_new_snapshots(
546618
snapshots: t.Collection[Snapshot], state_sync: StateSync

sqlmesh/core/scheduler.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ def evaluate(
149149
execution_time: TimeLike,
150150
deployability_index: DeployabilityIndex,
151151
batch_index: int,
152+
environment_naming_info: EnvironmentNamingInfo,
152153
**kwargs: t.Any,
153-
) -> t.Tuple[t.List[AuditResult], t.List[AuditError]]:
154+
) -> t.List[AuditResult]:
154155
"""Evaluate a snapshot and add the processed interval to the state sync.
155156
156157
Args:
@@ -182,8 +183,9 @@ def evaluate(
182183
batch_index=batch_index,
183184
**kwargs,
184185
)
185-
audit_results = self.snapshot_evaluator.audit(
186+
audit_results = self._audit_snapshot(
186187
snapshot=snapshot,
188+
environment_naming_info=environment_naming_info,
187189
start=start,
188190
end=end,
189191
execution_time=execution_time,
@@ -193,32 +195,8 @@ def evaluate(
193195
**kwargs,
194196
)
195197

196-
audit_errors_to_raise: t.List[AuditError] = []
197-
audit_errors_to_warn: t.List[AuditError] = []
198-
for audit_result in (result for result in audit_results if result.count):
199-
error = AuditError(
200-
audit_name=audit_result.audit.name,
201-
audit_args=audit_result.audit_args,
202-
model=snapshot.model_or_none,
203-
count=t.cast(int, audit_result.count),
204-
query=t.cast(exp.Query, audit_result.query),
205-
adapter_dialect=self.snapshot_evaluator.adapter.dialect,
206-
)
207-
self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, error)
208-
if is_deployable and snapshot.node.owner:
209-
self.notification_target_manager.notify_user(
210-
NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error
211-
)
212-
if audit_result.blocking:
213-
audit_errors_to_raise.append(error)
214-
else:
215-
audit_errors_to_warn.append(error)
216-
217-
if audit_errors_to_raise:
218-
raise NodeAuditsErrors(audit_errors_to_raise)
219-
220198
self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
221-
return audit_results, audit_errors_to_warn
199+
return audit_results
222200

223201
def run(
224202
self,
@@ -465,30 +443,19 @@ def evaluate_node(node: SchedulingUnit) -> None:
465443
evaluation_duration_ms: t.Optional[int] = None
466444

467445
audit_results: t.List[AuditResult] = []
468-
audit_errors_to_warn: t.List[AuditError] = []
469446
try:
470447
assert execution_time # mypy
471448
assert deployability_index # mypy
472-
audit_results, audit_errors_to_warn = self.evaluate(
449+
audit_results = self.evaluate(
473450
snapshot=snapshot,
451+
environment_naming_info=environment_naming_info,
474452
start=start,
475453
end=end,
476454
execution_time=execution_time,
477455
deployability_index=deployability_index,
478456
batch_index=batch_idx,
479457
)
480458

481-
for audit_error in audit_errors_to_warn:
482-
display_name = snapshot.display_name(
483-
environment_naming_info,
484-
self.default_catalog,
485-
self.snapshot_evaluator.adapter.dialect,
486-
)
487-
self.console.log_warning(
488-
f"\n{display_name}: {audit_error}.",
489-
f"{audit_error}. Audit query:\n{audit_error.query.sql(audit_error.adapter_dialect)}",
490-
)
491-
492459
evaluation_duration_ms = now_timestamp() - execution_start_ts
493460
finally:
494461
num_audits = len(audit_results)
@@ -583,6 +550,68 @@ def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
583550
)
584551
return dag
585552

553+
def _audit_snapshot(
554+
self,
555+
snapshot: Snapshot,
556+
environment_naming_info: EnvironmentNamingInfo,
557+
deployability_index: DeployabilityIndex,
558+
snapshots: t.Dict[str, Snapshot],
559+
start: t.Optional[TimeLike] = None,
560+
end: t.Optional[TimeLike] = None,
561+
execution_time: t.Optional[TimeLike] = None,
562+
wap_id: t.Optional[str] = None,
563+
**kwargs: t.Any,
564+
) -> t.List[AuditResult]:
565+
is_deployable = deployability_index.is_deployable(snapshot)
566+
567+
audit_results = self.snapshot_evaluator.audit(
568+
snapshot=snapshot,
569+
start=start,
570+
end=end,
571+
execution_time=execution_time,
572+
snapshots=snapshots,
573+
deployability_index=deployability_index,
574+
wap_id=wap_id,
575+
**kwargs,
576+
)
577+
578+
audit_errors_to_raise: t.List[AuditError] = []
579+
audit_errors_to_warn: t.List[AuditError] = []
580+
for audit_result in (result for result in audit_results if result.count):
581+
error = AuditError(
582+
audit_name=audit_result.audit.name,
583+
audit_args=audit_result.audit_args,
584+
model=snapshot.model_or_none,
585+
count=t.cast(int, audit_result.count),
586+
query=t.cast(exp.Query, audit_result.query),
587+
adapter_dialect=self.snapshot_evaluator.adapter.dialect,
588+
)
589+
self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, error)
590+
if is_deployable and snapshot.node.owner:
591+
self.notification_target_manager.notify_user(
592+
NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error
593+
)
594+
if audit_result.blocking:
595+
audit_errors_to_raise.append(error)
596+
else:
597+
audit_errors_to_warn.append(error)
598+
599+
if audit_errors_to_raise:
600+
raise NodeAuditsErrors(audit_errors_to_raise)
601+
602+
for audit_error in audit_errors_to_warn:
603+
display_name = snapshot.display_name(
604+
environment_naming_info,
605+
self.default_catalog,
606+
self.snapshot_evaluator.adapter.dialect,
607+
)
608+
self.console.log_warning(
609+
f"\n{display_name}: {audit_error}.",
610+
f"{audit_error}. Audit query:\n{audit_error.query.sql(audit_error.adapter_dialect)}",
611+
)
612+
613+
return audit_results
614+
586615

587616
def compute_interval_params(
588617
snapshots: t.Collection[Snapshot],

tests/core/test_context.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import date, timedelta
66
from tempfile import TemporaryDirectory
77
from unittest.mock import PropertyMock, call, patch
8+
from IPython.utils.capture import capture_output
89

910
import time_machine
1011
import pytest
@@ -1906,7 +1907,7 @@ def create_log_view(evaluator, view_name):
19061907
assert log_schema["my_schema"][0] == "db__dev"
19071908

19081909

1909-
def test_plan_audit_intervals(tmp_path: pathlib.Path, capsys, caplog):
1910+
def test_plan_audit_intervals(tmp_path: pathlib.Path, caplog):
19101911
ctx = Context(
19111912
paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))
19121913
)
@@ -2048,3 +2049,42 @@ def test_audit():
20482049
context.plan(no_prompts=True, auto_apply=True)
20492050

20502051
assert context.audit(models=["dummy"], start="2020-01-01", end="2020-01-01") is True
2052+
2053+
2054+
@use_terminal_console
2055+
def test_audits_running_on_metadata_changes(tmp_path: pathlib.Path):
2056+
def setup_senario(model_before: str, model_after: str):
2057+
models_dir = pathlib.Path("models")
2058+
create_temp_file(tmp_path, pathlib.Path(models_dir, "test.sql"), model_before)
2059+
2060+
# Create first snapshot
2061+
context = Context(paths=tmp_path, config=Config())
2062+
context.plan("prod", no_prompts=True, auto_apply=True)
2063+
2064+
# Create second (metadata) snapshot
2065+
create_temp_file(tmp_path, pathlib.Path(models_dir, "test.sql"), model_after)
2066+
context.load()
2067+
2068+
with capture_output() as output:
2069+
with pytest.raises(PlanError):
2070+
context.plan("prod", no_prompts=True, auto_apply=True)
2071+
2072+
assert 'Failed models\n\n "model"' in output.stdout
2073+
2074+
return output
2075+
2076+
# Ensure incorrect audits (bad data, incorrect definition etc) are evaluated immediately
2077+
output = setup_senario(
2078+
"MODEL (name model); SELECT NULL AS col",
2079+
"MODEL (name model, audits (not_null(columns=[col]))); SELECT NULL AS col",
2080+
)
2081+
assert "'not_null' audit error: 1 row failed" in output.stdout
2082+
2083+
output = setup_senario(
2084+
"MODEL (name model); SELECT NULL AS col",
2085+
"MODEL (name model, audits (not_null(columns=[this_col_does_not_exist]))); SELECT NULL AS col",
2086+
)
2087+
assert (
2088+
'Binder Error: Referenced column "this_col_does_not_exist" not found in \nFROM clause!'
2089+
in output.stdout
2090+
)

0 commit comments

Comments
 (0)