Skip to content

Commit f44332e

Browse files
committed
Fix edge case
1 parent b9e925f commit f44332e

3 files changed

Lines changed: 76 additions & 45 deletions

File tree

sqlmesh/core/model/common.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424

2525
if t.TYPE_CHECKING:
2626
from sqlglot.dialects.dialect import DialectType
27+
from sqlmesh.utils import registry_decorator
2728
from sqlmesh.utils.jinja import MacroReference
2829

30+
MacroCallable = registry_decorator
31+
2932

3033
def make_python_env(
3134
expressions: t.Union[exp.Expression, t.List[exp.Expression]],
@@ -43,56 +46,74 @@ def make_python_env(
4346
python_env = {} if python_env is None else python_env
4447
variables = variables or {}
4548
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
46-
used_macros = {}
49+
used_macros: t.Dict[
50+
str,
51+
t.Tuple[t.Union[Executable | MacroCallable], t.Optional[bool]],
52+
] = {}
4753
used_variables = (used_variables or set()).copy()
4854

4955
expressions = ensure_list(expressions)
5056
for expression in expressions:
51-
if not isinstance(expression, d.Jinja):
52-
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier):
53-
if macro_func_or_var.__class__ is d.MacroFunc:
54-
name = macro_func_or_var.this.name.lower()
55-
if name in macros:
56-
used_macro = macros[name]
57-
if callable(used_macro) and expression.meta.get("metadata_only"):
58-
setattr(used_macro.func, c.SQLMESH_METADATA, True)
59-
60-
used_macros[name] = used_macro
61-
if name == c.VAR:
62-
args = macro_func_or_var.this.expressions
63-
if len(args) < 1:
64-
raise_config_error("Macro VAR requires at least one argument", path)
65-
if not args[0].is_string:
66-
raise_config_error(
67-
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
68-
path,
69-
)
70-
used_variables.add(args[0].this.lower())
71-
elif macro_func_or_var.__class__ is d.MacroVar:
72-
name = macro_func_or_var.name.lower()
73-
if name in macros:
74-
used_macros[name] = macros[name]
75-
elif name in variables:
76-
used_variables.add(name)
77-
elif (
78-
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
79-
) and "@" in macro_func_or_var.name:
80-
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
81-
macro_func_or_var.name
82-
):
83-
var_name = braced_identifier or identifier
84-
if var_name in variables:
85-
used_variables.add(var_name)
57+
if isinstance(expression, d.Jinja):
58+
continue
59+
60+
for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier):
61+
if macro_func_or_var.__class__ is d.MacroFunc:
62+
name = macro_func_or_var.this.name.lower()
63+
if name not in macros:
64+
continue
65+
66+
# If this macro has been seen before as a non-metadata macro, prioritize that
67+
used_macros[name] = (
68+
macros[name],
69+
(used_macros.get(name) or (None, expression.meta.get("is_metadata")))[1],
70+
)
71+
if name == c.VAR:
72+
args = macro_func_or_var.this.expressions
73+
if len(args) < 1:
74+
raise_config_error("Macro VAR requires at least one argument", path)
75+
if not args[0].is_string:
76+
raise_config_error(
77+
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
78+
path,
79+
)
80+
used_variables.add(args[0].this.lower())
81+
elif macro_func_or_var.__class__ is d.MacroVar:
82+
name = macro_func_or_var.name.lower()
83+
if name in macros:
84+
# If this macro has been seen before as a non-metadata macro, prioritize that
85+
used_macros[name] = (
86+
macros[name],
87+
(used_macros.get(name) or (None, expression.meta.get("is_metadata")))[1],
88+
)
89+
elif name in variables:
90+
used_variables.add(name)
91+
elif (
92+
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
93+
) and "@" in macro_func_or_var.name:
94+
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
95+
macro_func_or_var.name
96+
):
97+
var_name = braced_identifier or identifier
98+
if var_name in variables:
99+
used_variables.add(var_name)
86100

87101
for macro_ref in jinja_macro_references or set():
88102
if macro_ref.package is None and macro_ref.name in macros:
89-
used_macros[macro_ref.name] = macros[macro_ref.name]
103+
used_macros[macro_ref.name] = (macros[macro_ref.name], None)
90104

91-
for name, used_macro in used_macros.items():
105+
for name, (used_macro, is_metadata) in used_macros.items():
92106
if isinstance(used_macro, Executable):
93107
python_env[name] = used_macro
94108
elif not hasattr(used_macro, c.SQLMESH_BUILTIN) and name not in python_env:
109+
used_macro_func = used_macro.func
110+
previous_is_metadata = getattr(used_macro_func, c.SQLMESH_METADATA, None)
111+
112+
if is_metadata:
113+
setattr(used_macro_func, c.SQLMESH_METADATA, is_metadata)
114+
95115
build_env(used_macro.func, env=env, name=name, path=module_path)
116+
setattr(used_macro_func, c.SQLMESH_METADATA, previous_is_metadata)
96117

97118
python_env.update(serialize_env(env, path=module_path))
98119
return _add_variables_to_python_env(

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2463,12 +2463,12 @@ def _create_model(
24632463
statements.extend(audit.query for audit in audit_definitions.values())
24642464
for _, audit_args in model.audits:
24652465
for audit_arg_expression in audit_args.values():
2466-
audit_arg_expression.meta["metadata_only"] = True
2466+
audit_arg_expression.meta["is_metadata"] = True
24672467
statements.append(audit_arg_expression)
24682468

24692469
for _, kwargs in model.signals:
24702470
for signal_kwarg in kwargs.values():
2471-
signal_kwarg.meta["metadata_only"] = True
2471+
signal_kwarg.meta["is_metadata"] = True
24722472
statements.append(signal_kwarg)
24732473

24742474
python_env = python_env or {}

tests/core/test_model.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8811,15 +8811,16 @@ def test_macros_referenced_in_audits_or_signals_are_metadata_only(tmp_path: Path
88118811
name test_model,
88128812
kind FULL,
88138813
signals (
8814-
test_signal_always_true(arg := @m1())
8814+
test_signal_always_true(arg1 := @m1(), arg2 := @non_metadata_macro())
88158815
),
88168816
audits (
88178817
unique_values(columns := @m2())
88188818
),
88198819
);
88208820
88218821
SELECT
8822-
1 AS c
8822+
1 AS c1,
8823+
@non_metadata_macro() AS c2,
88238824
"""
88248825
)
88258826

@@ -8837,7 +8838,11 @@ def m1(evaluator):
88378838
88388839
@macro()
88398840
def m2(evaluator):
8840-
return exp.column("c")"""
8841+
return exp.column("c")
8842+
8843+
@macro()
8844+
def non_metadata_macro(evaluator):
8845+
return 1"""
88418846

88428847
test_macros = tmp_path / "macros/test_macros.py"
88438848
test_macros.parent.mkdir(parents=True, exist_ok=True)
@@ -8852,7 +8857,7 @@ def bar():
88528857
pass
88538858
88548859
@signal()
8855-
def test_signal_always_true(batch, arg):
8860+
def test_signal_always_true(batch, arg1, arg2):
88568861
bar()
88578862
return True"""
88588863

@@ -8869,14 +8874,19 @@ def test_signal_always_true(batch, arg):
88698874

88708875
python_env = model.python_env
88718876

8872-
assert len(python_env) == 6
8877+
assert len(python_env) == 7
88738878
assert (python_env.get("test_signal_always_true") or empty_executable).is_metadata
88748879
assert (python_env.get("bar") or empty_executable).is_metadata
88758880
assert (python_env.get("m1") or empty_executable).is_metadata
88768881
assert (python_env.get("baz") or empty_executable).is_metadata
88778882
assert (python_env.get("m2") or empty_executable).is_metadata
88788883
assert (python_env.get("exp") or empty_executable).is_metadata
88798884

8885+
# non_metadata_macro is referenced in the signal, which makes that reference "metadata only",
8886+
# but it's also referenced in the model's query and the macro itself is not "metadata only",
8887+
# so the corresponding executable needs to be included in the data hash calculation
8888+
assert not (python_env.get("non_metadata_macro") or empty_executable).is_metadata
8889+
88808890

88818891
def test_scd_type_2_full_history_restatement():
88828892
assert ModelKindName.SCD_TYPE_2.full_history_restatement_only is True

0 commit comments

Comments
 (0)