Skip to content

Commit bf054e4

Browse files
committed
Treat macro refs in on_virtual_update stmts as metadata-only, fix inline audits, add another warning
1 parent f44332e commit bf054e4

6 files changed

Lines changed: 241 additions & 87 deletions

sqlmesh/core/model/common.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131

3232

3333
def make_python_env(
34-
expressions: t.Union[exp.Expression, t.List[exp.Expression]],
34+
expressions: t.Union[
35+
exp.Expression,
36+
t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]],
37+
],
3538
jinja_macro_references: t.Optional[t.Set[MacroReference]],
3639
module_path: Path,
3740
macros: MacroRegistry,
@@ -53,7 +56,12 @@ def make_python_env(
5356
used_variables = (used_variables or set()).copy()
5457

5558
expressions = ensure_list(expressions)
56-
for expression in expressions:
59+
for expression_metadata in expressions:
60+
if isinstance(expression_metadata, tuple):
61+
expression, is_metadata = expression_metadata
62+
else:
63+
expression, is_metadata = expression_metadata, None
64+
5765
if isinstance(expression, d.Jinja):
5866
continue
5967

@@ -66,7 +74,7 @@ def make_python_env(
6674
# If this macro has been seen before as a non-metadata macro, prioritize that
6775
used_macros[name] = (
6876
macros[name],
69-
(used_macros.get(name) or (None, expression.meta.get("is_metadata")))[1],
77+
(used_macros.get(name) or (None, is_metadata))[1],
7078
)
7179
if name == c.VAR:
7280
args = macro_func_or_var.this.expressions
@@ -84,7 +92,7 @@ def make_python_env(
8492
# If this macro has been seen before as a non-metadata macro, prioritize that
8593
used_macros[name] = (
8694
macros[name],
87-
(used_macros.get(name) or (None, expression.meta.get("is_metadata")))[1],
95+
(used_macros.get(name) or (None, is_metadata))[1],
8896
)
8997
elif name in variables:
9098
used_variables.add(name)

sqlmesh/core/model/definition.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,14 +2035,7 @@ def load_sql_based_model(
20352035
continue
20362036

20372037
prop_name = prop.name.lower()
2038-
if (
2039-
prop_name
2040-
in {
2041-
"signals",
2042-
"audits",
2043-
}
2044-
| PROPERTIES
2045-
):
2038+
if prop_name in {"signals", "audits"} | PROPERTIES:
20462039
unrendered_properties[prop_name] = prop.args.get("value")
20472040
elif (
20482041
prop.name.lower() == "kind"
@@ -2404,8 +2397,10 @@ def _create_model(
24042397
statements.append(kwargs["query"])
24052398
if "post_statements" in kwargs:
24062399
statements.extend(kwargs["post_statements"])
2400+
2401+
# Macros extracted from these statements need to be treated as metadata only
24072402
if "on_virtual_update" in kwargs:
2408-
statements.extend(kwargs["on_virtual_update"])
2403+
statements.extend((stmt, True) for stmt in kwargs["on_virtual_update"])
24092404

24102405
# This is done to allow variables like @gateway to be used in these properties
24112406
# since rendering shifted from load time to run time.
@@ -2444,11 +2439,15 @@ def _create_model(
24442439
raise_config_error(str(ex), location=path)
24452440
raise
24462441

2447-
audit_definitions = audit_definitions or {}
2448-
inline_audits = inline_audits or {}
2449-
audit_definitions = {**audit_definitions, **inline_audits}
2442+
audit_definitions = {
2443+
**(audit_definitions or {}),
2444+
**(inline_audits or {}),
2445+
}
24502446

2451-
used_audits = set(inline_audits)
2447+
# TODO: default_audits needs to be merged with model.audits; the former's arguments
2448+
# are silently dropped today because we add them in audit_definitions. We also need
2449+
# to check for duplicates when we implement this merging logic.
2450+
used_audits: t.Set[str] = set()
24522451
used_audits.update(audit_name for audit_name, _ in default_audits or [])
24532452
used_audits.update(audit_name for audit_name, _ in model.audits)
24542453

@@ -2460,16 +2459,15 @@ def _create_model(
24602459

24612460
model.audit_definitions.update(audit_definitions)
24622461

2463-
statements.extend(audit.query for audit in audit_definitions.values())
2462+
# Any macro referenced in audits or signals needs to be treated as metadata-only
2463+
statements.extend((audit.query, True) for audit in audit_definitions.values())
24642464
for _, audit_args in model.audits:
2465-
for audit_arg_expression in audit_args.values():
2466-
audit_arg_expression.meta["is_metadata"] = True
2467-
statements.append(audit_arg_expression)
2465+
statements.extend(
2466+
(audit_arg_expression, True) for audit_arg_expression in audit_args.values()
2467+
)
24682468

24692469
for _, kwargs in model.signals:
2470-
for signal_kwarg in kwargs.values():
2471-
signal_kwarg.meta["is_metadata"] = True
2472-
statements.append(signal_kwarg)
2470+
statements.extend((signal_kwarg, True) for signal_kwarg in kwargs.values())
24732471

24742472
python_env = python_env or {}
24752473

sqlmesh/migrations/v0078_detect_diff_caused_py_metadata_flag_propagation_and_warn.py

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
This migration script has two purposes:
3+
4+
1) Mark all python env macros referenced in audits, signals or on_virtual_update statements
5+
as metadata, unless they're referenced elsewhere in the model and they're not metadata-only.
6+
7+
2) Warn if there is both metadata and non-metadata reference in the python environment of a model.
8+
9+
The metadata status for macros and signals is now transitive, i.e. every dependency of a
10+
metadata macro or signal is also metadata, unless it is referenced by a non-metadata object.
11+
12+
This means that global references of metadata objects may now be excluded from the
13+
data hash calculation because of their new metadata status, which would lead to a
14+
diff. This script detects the possibility for such a diff and warns users ahead of time.
15+
"""
16+
17+
import json
18+
19+
from sqlglot import exp
20+
21+
import sqlmesh.core.dialect as d
22+
from sqlmesh.core.console import get_console
23+
24+
25+
def migrate(state_sync, **kwargs): # type: ignore
26+
engine_adapter = state_sync.engine_adapter
27+
schema = state_sync.schema
28+
snapshots_table = "_snapshots"
29+
if schema:
30+
snapshots_table = f"{schema}.{snapshots_table}"
31+
32+
common_msg = (
33+
"Since the metadata status is now propagated transitively, this means that the next plan "
34+
"command may detect unexpected changes and prompt about backfilling this model, or others, "
35+
"for the same reason. If this is a concern, consider running a forward-only plan instead: "
36+
"https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans.\n"
37+
)
38+
39+
for (snapshot,) in engine_adapter.fetchall(
40+
exp.select("snapshot").from_(snapshots_table), quote_identifiers=True
41+
):
42+
parsed_snapshot = json.loads(snapshot)
43+
node = parsed_snapshot["node"]
44+
45+
# Standalone audits don't have a data hash, so they're unaffected
46+
if node.get("source_type") == "audit":
47+
continue
48+
49+
name = node["name"]
50+
python_env = node.get("python_env") or {}
51+
52+
has_metadata = False
53+
has_non_metadata = False
54+
55+
for k, v in python_env.items():
56+
if v.get("is_metadata"):
57+
has_metadata = True
58+
else:
59+
has_non_metadata = True
60+
61+
if has_metadata and has_non_metadata:
62+
get_console().log_warning(
63+
f"Model '{name}' references both metadata and non-metadata functions (macros or signals). "
64+
+ common_msg
65+
)
66+
return
67+
68+
dialect = node.get("dialect")
69+
metadata_hash_statements = []
70+
71+
if on_virtual_update := node.get("on_virtual_update"):
72+
metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect))
73+
74+
for _, audit_args in func_call_validator(node.get("audits") or []):
75+
metadata_hash_statements.extend(audit_args.values())
76+
77+
for signal_name, signal_args in func_call_validator(
78+
node.get("signals") or [], is_signal=True
79+
):
80+
metadata_hash_statements.extend(signal_args.values())
81+
82+
if audit_definitions := node.get("audit_definitions"):
83+
audit_queries = [
84+
parse_expression(audit["query"], audit["dialect"])
85+
for audit in audit_definitions.values()
86+
]
87+
metadata_hash_statements.extend(audit_queries)
88+
89+
for macro_name in extract_used_macros(metadata_hash_statements):
90+
serialized_macro = python_env.get(macro_name)
91+
if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"):
92+
get_console().log_warning(
93+
f"Model '{name}' references macro '{macro_name}' which is now implicitly treated as metadata-only. "
94+
+ common_msg
95+
)
96+
return
97+
98+
99+
def extract_used_macros(expressions):
100+
used_macros = set()
101+
for expression in expressions:
102+
if isinstance(expression, d.Jinja):
103+
continue
104+
105+
for macro_func in expression.find_all(d.MacroFunc):
106+
if macro_func.__class__ is d.MacroFunc:
107+
used_macros.add(macro_func.this.name.lower())
108+
109+
return used_macros
110+
111+
112+
def func_call_validator(v, is_signal=False):
113+
assert isinstance(v, list)
114+
115+
audits = []
116+
for entry in v:
117+
if isinstance(entry, dict):
118+
args = entry
119+
name = "" if is_signal else entry.pop("name")
120+
else:
121+
assert isinstance(entry, (tuple, list))
122+
name, args = entry
123+
124+
parsed_audit = {
125+
key: d.parse_one(value) if isinstance(value, str) else value
126+
for key, value in args.items()
127+
}
128+
audits.append((name.lower(), parsed_audit))
129+
130+
return audits
131+
132+
133+
def parse_expression(v, dialect):
134+
if v is None:
135+
return None
136+
137+
if isinstance(v, list):
138+
return [d.parse_one(e, dialect=dialect) for e in v]
139+
140+
assert isinstance(v, str)
141+
return d.parse_one(v, dialect=dialect)

tests/core/test_audit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def test_load_inline_audits(assert_exp_eq):
849849
MODEL (
850850
name db.table,
851851
dialect spark,
852-
audits(does_not_exceed_threshold)
852+
audits(does_not_exceed_threshold, assert_positive_id)
853853
);
854854
855855
SELECT id FROM tbl;
@@ -871,7 +871,7 @@ def test_load_inline_audits(assert_exp_eq):
871871
)
872872

873873
model = load_sql_based_model(expressions)
874-
assert len(model.audits) == 1
874+
assert len(model.audits) == 2
875875
assert len(model.audits_with_args) == 2
876876
assert isinstance(model.audit_definitions["assert_positive_id"], ModelAudit)
877877
assert isinstance(model.audit_definitions["does_not_exceed_threshold"], ModelAudit)

0 commit comments

Comments
 (0)