Skip to content

Commit fc7fe1b

Browse files
authored
Fix!: make metadata_only a transitive property in Python objects (#4080)
1 parent e0f922f commit fc7fe1b

8 files changed

Lines changed: 584 additions & 111 deletions

File tree

sqlmesh/core/model/common.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,17 @@
2424

2525
if t.TYPE_CHECKING:
2626
from sqlglot.dialects.dialect import DialectType
27+
from sqlmesh.utils import registry_decorator
2728
from sqlmesh.utils.jinja import MacroReference
2829

30+
MacroCallable = registry_decorator
31+
2932

3033
def make_python_env(
31-
expressions: t.Union[exp.Expression, t.List[exp.Expression]],
34+
expressions: t.Union[
35+
exp.Expression,
36+
t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]],
37+
],
3238
jinja_macro_references: t.Optional[t.Set[MacroReference]],
3339
module_path: Path,
3440
macros: MacroRegistry,
@@ -42,53 +48,79 @@ def make_python_env(
4248
) -> t.Dict[str, Executable]:
4349
python_env = {} if python_env is None else python_env
4450
variables = variables or {}
45-
env: t.Dict[str, t.Any] = {}
46-
used_macros = {}
51+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
52+
used_macros: t.Dict[
53+
str,
54+
t.Tuple[t.Union[Executable | MacroCallable], t.Optional[bool]],
55+
] = {}
4756
used_variables = (used_variables or set()).copy()
4857

4958
expressions = ensure_list(expressions)
50-
for expression in expressions:
51-
if not isinstance(expression, d.Jinja):
52-
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier):
53-
if macro_func_or_var.__class__ is d.MacroFunc:
54-
name = macro_func_or_var.this.name.lower()
55-
if name in macros:
56-
used_macros[name] = macros[name]
57-
if name == c.VAR:
58-
args = macro_func_or_var.this.expressions
59-
if len(args) < 1:
60-
raise_config_error("Macro VAR requires at least one argument", path)
61-
if not args[0].is_string:
62-
raise_config_error(
63-
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
64-
path,
65-
)
66-
used_variables.add(args[0].this.lower())
67-
elif macro_func_or_var.__class__ is d.MacroVar:
68-
name = macro_func_or_var.name.lower()
69-
if name in macros:
70-
used_macros[name] = macros[name]
71-
elif name in variables:
72-
used_variables.add(name)
73-
elif (
74-
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
75-
) and "@" in macro_func_or_var.name:
76-
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
77-
macro_func_or_var.name
78-
):
79-
var_name = braced_identifier or identifier
80-
if var_name in variables:
81-
used_variables.add(var_name)
59+
for expression_metadata in expressions:
60+
if isinstance(expression_metadata, tuple):
61+
expression, is_metadata = expression_metadata
62+
else:
63+
expression, is_metadata = expression_metadata, None
64+
65+
if isinstance(expression, d.Jinja):
66+
continue
67+
68+
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier):
69+
if macro_func_or_var.__class__ is d.MacroFunc:
70+
name = macro_func_or_var.this.name.lower()
71+
if name not in macros:
72+
continue
73+
74+
# If this macro has been seen before as a non-metadata macro, prioritize that
75+
used_macros[name] = (
76+
macros[name],
77+
used_macros.get(name, (None, is_metadata))[1],
78+
)
79+
if name == c.VAR:
80+
args = macro_func_or_var.this.expressions
81+
if len(args) < 1:
82+
raise_config_error("Macro VAR requires at least one argument", path)
83+
if not args[0].is_string:
84+
raise_config_error(
85+
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
86+
path,
87+
)
88+
used_variables.add(args[0].this.lower())
89+
elif macro_func_or_var.__class__ is d.MacroVar:
90+
name = macro_func_or_var.name.lower()
91+
if name in macros:
92+
# If this macro has been seen before as a non-metadata macro, prioritize that
93+
used_macros[name] = (
94+
macros[name],
95+
used_macros.get(name, (None, is_metadata))[1],
96+
)
97+
elif name in variables:
98+
used_variables.add(name)
99+
elif (
100+
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
101+
) and "@" in macro_func_or_var.name:
102+
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
103+
macro_func_or_var.name
104+
):
105+
var_name = braced_identifier or identifier
106+
if var_name in variables:
107+
used_variables.add(var_name)
82108

83109
for macro_ref in jinja_macro_references or set():
84110
if macro_ref.package is None and macro_ref.name in macros:
85-
used_macros[macro_ref.name] = macros[macro_ref.name]
111+
used_macros[macro_ref.name] = (macros[macro_ref.name], None)
86112

87-
for name, used_macro in used_macros.items():
113+
for name, (used_macro, is_metadata) in used_macros.items():
88114
if isinstance(used_macro, Executable):
89115
python_env[name] = used_macro
90116
elif not hasattr(used_macro, c.SQLMESH_BUILTIN) and name not in python_env:
91-
build_env(used_macro.func, env=env, name=name, path=module_path)
117+
build_env(
118+
used_macro.func,
119+
env=env,
120+
name=name,
121+
path=module_path,
122+
is_metadata_obj=is_metadata,
123+
)
92124

93125
python_env.update(serialize_env(env, path=module_path))
94126
return _add_variables_to_python_env(

sqlmesh/core/model/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def model(
125125
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
126126
) -> Model:
127127
"""Get the model registered by this function."""
128-
env: t.Dict[str, t.Any] = {}
128+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
129129
entrypoint = self.func.__name__
130130

131131
if not self.name_provided and not infer_names:

sqlmesh/core/model/definition.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,14 +2035,7 @@ def load_sql_based_model(
20352035
continue
20362036

20372037
prop_name = prop.name.lower()
2038-
if (
2039-
prop_name
2040-
in {
2041-
"signals",
2042-
"audits",
2043-
}
2044-
| PROPERTIES
2045-
):
2038+
if prop_name in {"signals", "audits"} | PROPERTIES:
20462039
unrendered_properties[prop_name] = prop.args.get("value")
20472040
elif (
20482041
prop.name.lower() == "kind"
@@ -2404,14 +2397,17 @@ def _create_model(
24042397
statements.append(kwargs["query"])
24052398
if "post_statements" in kwargs:
24062399
statements.extend(kwargs["post_statements"])
2400+
2401+
# Macros extracted from these statements need to be treated as metadata only
24072402
if "on_virtual_update" in kwargs:
2408-
statements.extend(kwargs["on_virtual_update"])
2403+
statements.extend((stmt, True) for stmt in kwargs["on_virtual_update"])
24092404

2410-
# to allow variables like @gateway to be used in these properties
2411-
# since rendering shifted from load time to run time
2405+
# This is done to allow variables like @gateway to be used in these properties
2406+
# since rendering shifted from load time to run time.
2407+
# Note: we check for Tuple since that's what we expect from _resolve_properties
24122408
for property_name in PROPERTIES:
2413-
if property_values := kwargs.get(property_name):
2414-
statements.extend(property_values)
2409+
if isinstance(property_values := kwargs.get(property_name), exp.Tuple):
2410+
statements.extend(property_values.expressions)
24152411

24162412
jinja_macro_references, used_variables = extract_macro_references_and_variables(
24172413
*(gen(e) for e in statements)
@@ -2443,11 +2439,15 @@ def _create_model(
24432439
raise_config_error(str(ex), location=path)
24442440
raise
24452441

2446-
audit_definitions = audit_definitions or {}
2447-
inline_audits = inline_audits or {}
2448-
audit_definitions = {**audit_definitions, **inline_audits}
2442+
audit_definitions = {
2443+
**(audit_definitions or {}),
2444+
**(inline_audits or {}),
2445+
}
24492446

2450-
used_audits = set(inline_audits)
2447+
# TODO: default_audits needs to be merged with model.audits; the former's arguments
2448+
# are silently dropped today because we add them in audit_definitions. We also need
2449+
# to check for duplicates when we implement this merging logic.
2450+
used_audits: t.Set[str] = set()
24512451
used_audits.update(audit_name for audit_name, _ in default_audits or [])
24522452
used_audits.update(audit_name for audit_name, _ in model.audits)
24532453

@@ -2459,12 +2459,15 @@ def _create_model(
24592459

24602460
model.audit_definitions.update(audit_definitions)
24612461

2462-
statements.extend(audit.query for audit in audit_definitions.values())
2462+
# Any macro referenced in audits or signals needs to be treated as metadata-only
2463+
statements.extend((audit.query, True) for audit in audit_definitions.values())
24632464
for _, audit_args in model.audits:
2464-
statements.extend(audit_args.values())
2465+
statements.extend(
2466+
(audit_arg_expression, True) for audit_arg_expression in audit_args.values()
2467+
)
24652468

24662469
for _, kwargs in model.signals:
2467-
statements.extend(kwargs.values())
2470+
statements.extend((signal_kwarg, True) for signal_kwarg in kwargs.values())
24682471

24692472
python_env = python_env or {}
24702473

@@ -2482,7 +2485,7 @@ def _create_model(
24822485
dialect=dialect,
24832486
)
24842487

2485-
env: t.Dict[str, t.Any] = {}
2488+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
24862489

24872490
for signal_name, _ in model.signals:
24882491
if signal_definitions and signal_name in signal_definitions:
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""
2+
This script's goal is to warn users if there is both a metadata and non-metadata reference in
3+
the python environment of a model. Additionally, it warns them if there's a macro referenced
4+
in a used audit's query, in the argument list of the audits and signals properties, or in an
5+
on_virtual_update statement.
6+
7+
Context:
8+
9+
The metadata status for macros and signals is now transitive, i.e. every dependency of a
10+
metadata macro or signal is also metadata, unless it is referenced by a non-metadata object.
11+
12+
This means that global references of metadata objects may now be excluded from the data hash
13+
calculation because of their new metadata status, which would lead to a diff.
14+
15+
Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only",
16+
even though they may not be marked as such by a user. This may also lead to a diff.
17+
"""
18+
19+
import json
20+
21+
from sqlglot import exp
22+
23+
import sqlmesh.core.dialect as d
24+
from sqlmesh.core.console import get_console
25+
26+
27+
def migrate(state_sync, **kwargs): # type: ignore
28+
engine_adapter = state_sync.engine_adapter
29+
schema = state_sync.schema
30+
snapshots_table = "_snapshots"
31+
if schema:
32+
snapshots_table = f"{schema}.{snapshots_table}"
33+
34+
warning = (
35+
"SQLMesh detected that it may not be able to fully migrate the state database. This should not impact "
36+
"the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` "
37+
"command. Please run `sqlmesh diff prod` after the migration has completed, before making any new "
38+
"changes. If any unexpected changes are reported, consider running a forward-only plan to apply these "
39+
"changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. "
40+
"See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n"
41+
)
42+
43+
for (snapshot,) in engine_adapter.fetchall(
44+
exp.select("snapshot").from_(snapshots_table), quote_identifiers=True
45+
):
46+
parsed_snapshot = json.loads(snapshot)
47+
node = parsed_snapshot["node"]
48+
49+
# Standalone audits don't have a data hash, so they're unaffected
50+
if node.get("source_type") == "audit":
51+
continue
52+
53+
python_env = node.get("python_env") or {}
54+
55+
has_metadata = False
56+
has_non_metadata = False
57+
58+
for k, v in python_env.items():
59+
if v.get("is_metadata"):
60+
has_metadata = True
61+
else:
62+
has_non_metadata = True
63+
64+
if has_metadata and has_non_metadata:
65+
get_console().log_warning(warning)
66+
return
67+
68+
dialect = node.get("dialect")
69+
metadata_hash_statements = []
70+
71+
# We use try-except here as a conservative measure to avoid any unexpected exceptions
72+
try:
73+
if on_virtual_update := node.get("on_virtual_update"):
74+
metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect))
75+
76+
for _, audit_args in func_call_validator(node.get("audits") or []):
77+
metadata_hash_statements.extend(audit_args.values())
78+
79+
for signal_name, signal_args in func_call_validator(
80+
node.get("signals") or [], is_signal=True
81+
):
82+
metadata_hash_statements.extend(signal_args.values())
83+
84+
if audit_definitions := node.get("audit_definitions"):
85+
audit_queries = [
86+
parse_expression(audit["query"], audit["dialect"])
87+
for audit in audit_definitions.values()
88+
]
89+
metadata_hash_statements.extend(audit_queries)
90+
91+
for macro_name in extract_used_macros(metadata_hash_statements):
92+
serialized_macro = python_env.get(macro_name)
93+
if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"):
94+
get_console().log_warning(warning)
95+
return
96+
except Exception:
97+
pass
98+
99+
100+
def extract_used_macros(expressions):
101+
used_macros = set()
102+
for expression in expressions:
103+
if isinstance(expression, d.Jinja):
104+
continue
105+
106+
for macro_func in expression.find_all(d.MacroFunc):
107+
if macro_func.__class__ is d.MacroFunc:
108+
used_macros.add(macro_func.this.name.lower())
109+
110+
return used_macros
111+
112+
113+
def func_call_validator(v, is_signal=False):
114+
assert isinstance(v, list)
115+
116+
audits = []
117+
for entry in v:
118+
if isinstance(entry, dict):
119+
args = entry
120+
name = "" if is_signal else entry.pop("name")
121+
else:
122+
assert isinstance(entry, (tuple, list))
123+
name, args = entry
124+
125+
parsed_audit = {
126+
key: d.parse_one(value) if isinstance(value, str) else value
127+
for key, value in args.items()
128+
}
129+
audits.append((name.lower(), parsed_audit))
130+
131+
return audits
132+
133+
134+
def parse_expression(v, dialect):
135+
if v is None:
136+
return None
137+
138+
if isinstance(v, list):
139+
return [d.parse_one(e, dialect=dialect) for e in v]
140+
141+
assert isinstance(v, str)
142+
return d.parse_one(v, dialect=dialect)

0 commit comments

Comments
 (0)