Skip to content

Commit b48005a

Browse files
committed
Feat: Ensure audits run even if adding them is a metadata change
1 parent 0531201 commit b48005a

4 files changed

Lines changed: 212 additions & 64 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,34 @@ def _data_hash_values(self) -> t.List[str]:
10691069

10701070
return data # type: ignore
10711071

1072+
def _audit_metadata(self) -> t.List[str]:
1073+
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1074+
1075+
metadata = []
1076+
1077+
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1078+
metadata.append(audit_name)
1079+
if audit_name in BUILT_IN_AUDITS:
1080+
for arg_name, arg_value in audit_args.items():
1081+
metadata.append(arg_name)
1082+
metadata.append(gen(arg_value))
1083+
else:
1084+
audit = self.audit_definitions[audit_name]
1085+
query = (
1086+
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1087+
or audit.query
1088+
)
1089+
metadata.extend(
1090+
[
1091+
gen(query),
1092+
audit.dialect,
1093+
str(audit.skip),
1094+
str(audit.blocking),
1095+
]
1096+
)
1097+
1098+
return metadata
1099+
10721100
@property
10731101
def metadata_hash(self) -> str:
10741102
"""
@@ -1078,8 +1106,6 @@ def metadata_hash(self) -> str:
10781106
The metadata hash for the node.
10791107
"""
10801108
if self._metadata_hash is None:
1081-
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
1082-
10831109
metadata = [
10841110
self.dialect,
10851111
self.owner,
@@ -1100,29 +1126,9 @@ def metadata_hash(self) -> str:
11001126
str(self.allow_partials),
11011127
gen(self.session_properties_) if self.session_properties_ else None,
11021128
*[gen(g) for g in self.grains],
1129+
*self._audit_metadata(),
11031130
]
11041131

1105-
for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]):
1106-
metadata.append(audit_name)
1107-
if audit_name in BUILT_IN_AUDITS:
1108-
for arg_name, arg_value in audit_args.items():
1109-
metadata.append(arg_name)
1110-
metadata.append(gen(arg_value))
1111-
else:
1112-
audit = self.audit_definitions[audit_name]
1113-
query = (
1114-
self.render_audit_query(audit, **t.cast(t.Dict[str, t.Any], audit_args))
1115-
or audit.query
1116-
)
1117-
metadata.extend(
1118-
[
1119-
gen(query),
1120-
audit.dialect,
1121-
str(audit.skip),
1122-
str(audit.blocking),
1123-
]
1124-
)
1125-
11261132
for key, value in (self.virtual_properties or {}).items():
11271133
metadata.append(key)
11281134
metadata.append(gen(value))

sqlmesh/core/plan/evaluator.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535
SnapshotTableInfo,
3636
SnapshotCreationFailedError,
3737
)
38+
from sqlmesh.core.snapshot.definition import SnapshotChangeCategory, parent_snapshots_by_name
3839
from sqlmesh.utils import CompletionStatus
3940
from sqlmesh.core.state_sync import StateSync
4041
from sqlmesh.core.state_sync.base import PromotionResult
4142
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4243
from sqlmesh.utils.errors import PlanError
4344
from sqlmesh.utils.dag import DAG
4445
from sqlmesh.utils.date import now
46+
from sqlmesh.utils.hashing import hash_data
4547

4648
logger = logging.getLogger(__name__)
4749

@@ -116,6 +118,10 @@ def evaluate(
116118
after_promote_snapshots = all_names - before_promote_snapshots
117119
deployability_index_for_evaluation = DeployabilityIndex.all_deployable()
118120

121+
self._run_audits_for_metadata_snapshots(
122+
new_snapshots, plan, deployability_index_for_evaluation
123+
)
124+
119125
execute_environment_statements(
120126
adapter=self.snapshot_evaluator.adapter,
121127
environment_statements=plan.environment_statements or [],
@@ -545,6 +551,72 @@ def _restatement_intervals_across_all_environments(
545551

546552
return set(snapshots_to_restate.values())
547553

554+
def _run_audits_for_metadata_snapshots(
555+
self,
556+
new_snapshots: t.Dict[SnapshotId, Snapshot],
557+
plan: EvaluatablePlan,
558+
deployability_index: DeployabilityIndex,
559+
) -> None:
560+
to_be_audited_snapshots = []
561+
562+
for snapshot in new_snapshots.values():
563+
if (
564+
snapshot.change_category != SnapshotChangeCategory.METADATA
565+
or not snapshot.previous_version
566+
or not snapshot.is_model
567+
):
568+
continue
569+
570+
previous_snapshot_id = snapshot.previous_version.snapshot_id(snapshot.name)
571+
previous_snapshot = self.state_sync.get_snapshots([previous_snapshot_id])[
572+
previous_snapshot_id
573+
]
574+
575+
new_audits = snapshot.model._audit_metadata()
576+
577+
# Compare the audit metadata hashes to determine if there was a change
578+
previous_audit_hash = hash_data(previous_snapshot.model._audit_metadata())
579+
current_audit_hash = hash_data(new_audits)
580+
581+
if previous_audit_hash != current_audit_hash and new_audits:
582+
to_be_audited_snapshots.append((snapshot, previous_snapshot))
583+
584+
if not to_be_audited_snapshots:
585+
return
586+
587+
scheduler = self.create_scheduler(new_snapshots.values())
588+
raise_plan_error = False
589+
for to_be_audited_snapshot, previous_snapshot in to_be_audited_snapshots:
590+
parent_snapshots = parent_snapshots_by_name(to_be_audited_snapshot, new_snapshots)
591+
592+
# The previous snapshot is the snapshot before the metadata change
593+
# and contains the latest intervals that we should use for the new audit
594+
for interval in previous_snapshot.intervals:
595+
start, end = interval
596+
597+
try:
598+
scheduler._audit_snapshot(
599+
to_be_audited_snapshot,
600+
environment_naming_info=plan.environment.naming_info,
601+
snapshots=parent_snapshots,
602+
start=start,
603+
end=end,
604+
execution_time=plan.execution_time,
605+
deployability_index=deployability_index,
606+
)
607+
except Exception as e:
608+
# Simulate a node execution failure with the audit error passed as the
609+
# cause in order to reuse log_failed_models
610+
error = NodeExecutionFailedError(
611+
(to_be_audited_snapshot.name, ((start, end), -1))
612+
)
613+
error.__cause__ = e
614+
self.console.log_failed_models([error])
615+
raise_plan_error = True
616+
617+
if raise_plan_error:
618+
raise PlanError("Plan application failed.")
619+
548620

549621
def update_intervals_for_new_snapshots(
550622
snapshots: t.Collection[Snapshot], state_sync: StateSync

sqlmesh/core/scheduler.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ def evaluate(
149149
execution_time: TimeLike,
150150
deployability_index: DeployabilityIndex,
151151
batch_index: int,
152+
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
152153
**kwargs: t.Any,
153-
) -> t.Tuple[t.List[AuditResult], t.List[AuditError]]:
154+
) -> t.List[AuditResult]:
154155
"""Evaluate a snapshot and add the processed interval to the state sync.
155156
156157
Args:
@@ -182,8 +183,9 @@ def evaluate(
182183
batch_index=batch_index,
183184
**kwargs,
184185
)
185-
audit_results = self.snapshot_evaluator.audit(
186+
audit_results = self._audit_snapshot(
186187
snapshot=snapshot,
188+
environment_naming_info=environment_naming_info,
187189
start=start,
188190
end=end,
189191
execution_time=execution_time,
@@ -193,32 +195,8 @@ def evaluate(
193195
**kwargs,
194196
)
195197

196-
audit_errors_to_raise: t.List[AuditError] = []
197-
audit_errors_to_warn: t.List[AuditError] = []
198-
for audit_result in (result for result in audit_results if result.count):
199-
error = AuditError(
200-
audit_name=audit_result.audit.name,
201-
audit_args=audit_result.audit_args,
202-
model=snapshot.model_or_none,
203-
count=t.cast(int, audit_result.count),
204-
query=t.cast(exp.Query, audit_result.query),
205-
adapter_dialect=self.snapshot_evaluator.adapter.dialect,
206-
)
207-
self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, error)
208-
if is_deployable and snapshot.node.owner:
209-
self.notification_target_manager.notify_user(
210-
NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error
211-
)
212-
if audit_result.blocking:
213-
audit_errors_to_raise.append(error)
214-
else:
215-
audit_errors_to_warn.append(error)
216-
217-
if audit_errors_to_raise:
218-
raise NodeAuditsErrors(audit_errors_to_raise)
219-
220198
self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
221-
return audit_results, audit_errors_to_warn
199+
return audit_results
222200

223201
def run(
224202
self,
@@ -465,30 +443,19 @@ def evaluate_node(node: SchedulingUnit) -> None:
465443
evaluation_duration_ms: t.Optional[int] = None
466444

467445
audit_results: t.List[AuditResult] = []
468-
audit_errors_to_warn: t.List[AuditError] = []
469446
try:
470447
assert execution_time # mypy
471448
assert deployability_index # mypy
472-
audit_results, audit_errors_to_warn = self.evaluate(
449+
audit_results = self.evaluate(
473450
snapshot=snapshot,
451+
environment_naming_info=environment_naming_info,
474452
start=start,
475453
end=end,
476454
execution_time=execution_time,
477455
deployability_index=deployability_index,
478456
batch_index=batch_idx,
479457
)
480458

481-
for audit_error in audit_errors_to_warn:
482-
display_name = snapshot.display_name(
483-
environment_naming_info,
484-
self.default_catalog,
485-
self.snapshot_evaluator.adapter.dialect,
486-
)
487-
self.console.log_warning(
488-
f"\n{display_name}: {audit_error}.",
489-
f"{audit_error}. Audit query:\n{audit_error.query.sql(audit_error.adapter_dialect)}",
490-
)
491-
492459
evaluation_duration_ms = now_timestamp() - execution_start_ts
493460
finally:
494461
num_audits = len(audit_results)
@@ -583,6 +550,69 @@ def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
583550
)
584551
return dag
585552

553+
def _audit_snapshot(
554+
self,
555+
snapshot: Snapshot,
556+
deployability_index: DeployabilityIndex,
557+
snapshots: t.Dict[str, Snapshot],
558+
start: t.Optional[TimeLike] = None,
559+
end: t.Optional[TimeLike] = None,
560+
execution_time: t.Optional[TimeLike] = None,
561+
wap_id: t.Optional[str] = None,
562+
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
563+
**kwargs: t.Any,
564+
) -> t.List[AuditResult]:
565+
is_deployable = deployability_index.is_deployable(snapshot)
566+
567+
audit_results = self.snapshot_evaluator.audit(
568+
snapshot=snapshot,
569+
start=start,
570+
end=end,
571+
execution_time=execution_time,
572+
snapshots=snapshots,
573+
deployability_index=deployability_index,
574+
wap_id=wap_id,
575+
**kwargs,
576+
)
577+
578+
audit_errors_to_raise: t.List[AuditError] = []
579+
audit_errors_to_warn: t.List[AuditError] = []
580+
for audit_result in (result for result in audit_results if result.count):
581+
error = AuditError(
582+
audit_name=audit_result.audit.name,
583+
audit_args=audit_result.audit_args,
584+
model=snapshot.model_or_none,
585+
count=t.cast(int, audit_result.count),
586+
query=t.cast(exp.Query, audit_result.query),
587+
adapter_dialect=self.snapshot_evaluator.adapter.dialect,
588+
)
589+
self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, error)
590+
if is_deployable and snapshot.node.owner:
591+
self.notification_target_manager.notify_user(
592+
NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error
593+
)
594+
if audit_result.blocking:
595+
audit_errors_to_raise.append(error)
596+
else:
597+
audit_errors_to_warn.append(error)
598+
599+
if audit_errors_to_raise:
600+
raise NodeAuditsErrors(audit_errors_to_raise)
601+
602+
if environment_naming_info:
603+
for audit_error in audit_errors_to_warn:
604+
display_name = snapshot.display_name(
605+
environment_naming_info,
606+
self.default_catalog,
607+
self.snapshot_evaluator.adapter.dialect,
608+
)
609+
self.console.log_warning(
610+
f"\n{display_name}: {audit_error}.",
611+
f"{audit_error}. Audit query:\n{audit_error.query.sql(audit_error.adapter_dialect)}",
612+
)
613+
614+
return audit_results
615+
586616

587617
def compute_interval_params(
588618
snapshots: t.Collection[Snapshot],

tests/core/test_context.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import date, timedelta
66
from tempfile import TemporaryDirectory
77
from unittest.mock import PropertyMock, call, patch
8+
from IPython.utils.capture import capture_output
89

910
import time_machine
1011
import pytest
@@ -1929,7 +1930,7 @@ def create_log_view(evaluator, view_name):
19291930
assert log_schema["my_schema"][0] == "db__dev"
19301931

19311932

1932-
def test_plan_audit_intervals(tmp_path: pathlib.Path, capsys, caplog):
1933+
def test_plan_audit_intervals(tmp_path: pathlib.Path, caplog):
19331934
ctx = Context(
19341935
paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))
19351936
)
@@ -2071,3 +2072,42 @@ def test_audit():
20712072
context.plan(no_prompts=True, auto_apply=True)
20722073

20732074
assert context.audit(models=["dummy"], start="2020-01-01", end="2020-01-01") is True
2075+
2076+
2077+
@use_terminal_console
2078+
def test_audits_running_on_metadata_changes(tmp_path: pathlib.Path):
2079+
def setup_senario(model_before: str, model_after: str):
2080+
models_dir = pathlib.Path("models")
2081+
create_temp_file(tmp_path, pathlib.Path(models_dir, "test.sql"), model_before)
2082+
2083+
# Create first snapshot
2084+
context = Context(paths=tmp_path, config=Config())
2085+
context.plan("prod", no_prompts=True, auto_apply=True)
2086+
2087+
# Create second (metadata) snapshot
2088+
create_temp_file(tmp_path, pathlib.Path(models_dir, "test.sql"), model_after)
2089+
context.load()
2090+
2091+
with capture_output() as output:
2092+
with pytest.raises(PlanError):
2093+
context.plan("prod", no_prompts=True, auto_apply=True)
2094+
2095+
assert 'Failed models\n\n "model"' in output.stdout
2096+
2097+
return output
2098+
2099+
# Ensure incorrect audits (bad data, incorrect definition etc) are evaluated immediately
2100+
output = setup_senario(
2101+
"MODEL (name model); SELECT NULL AS col",
2102+
"MODEL (name model, audits (not_null(columns=[col]))); SELECT NULL AS col",
2103+
)
2104+
assert "'not_null' audit error: 1 row failed" in output.stdout
2105+
2106+
output = setup_senario(
2107+
"MODEL (name model); SELECT NULL AS col",
2108+
"MODEL (name model, audits (not_null(columns=[this_col_does_not_exist]))); SELECT NULL AS col",
2109+
)
2110+
assert (
2111+
'Binder Error: Referenced column "this_col_does_not_exist" not found in \nFROM clause!'
2112+
in output.stdout
2113+
)

0 commit comments

Comments
 (0)