Skip to content

Commit 90f8dd6

Browse files
committed
Chore: detect invalid signal refs and raise early
1 parent ebf8f47 commit 90f8dd6

2 files changed

Lines changed: 36 additions & 8 deletions

File tree

sqlmesh/core/model/definition.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,22 +2488,27 @@ def _create_model(
24882488

24892489
model.audit_definitions.update(audit_definitions)
24902490

2491-
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
2491+
# Any macro referenced in audits or signals needs to be treated as metadata-only
2492+
statements.extend((audit.query, True) for audit in audit_definitions.values())
24922493

24932494
# Ensure that all audits referenced in the model are defined
2495+
from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS
2496+
24942497
available_audits = BUILT_IN_AUDITS.keys() | model.audit_definitions.keys()
2495-
for referenced_audit, *_ in model.audits:
2498+
for referenced_audit, audit_args in model.audits:
24962499
if referenced_audit not in available_audits:
24972500
raise_config_error(f"Audit '{referenced_audit}' is undefined", location=path)
24982501

2499-
# Any macro referenced in audits or signals needs to be treated as metadata-only
2500-
statements.extend((audit.query, True) for audit in audit_definitions.values())
2501-
for _, audit_args in model.audits:
25022502
statements.extend(
25032503
(audit_arg_expression, True) for audit_arg_expression in audit_args.values()
25042504
)
25052505

2506-
for _, kwargs in model.signals:
2506+
signal_definitions = signal_definitions or UniqueKeyDict("signals")
2507+
2508+
for referenced_signal, kwargs in model.signals:
2509+
if referenced_signal and referenced_signal not in signal_definitions:
2510+
raise_config_error(f"Signal '{referenced_signal}' is undefined", location=path)
2511+
25072512
statements.extend((signal_kwarg, True) for signal_kwarg in kwargs.values())
25082513

25092514
python_env = make_python_env(
@@ -2523,7 +2528,7 @@ def _create_model(
25232528
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
25242529

25252530
for signal_name, _ in model.signals:
2526-
if signal_definitions and signal_name in signal_definitions:
2531+
if signal_name and signal_name in signal_definitions:
25272532
func = signal_definitions[signal_name].func
25282533
setattr(func, c.SQLMESH_METADATA, True)
25292534
build_env(func, env=env, name=signal_name, path=module_path)

tests/core/test_model.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4919,6 +4919,10 @@ def test_signals():
49194919
model = load_sql_based_model(expressions)
49204920
assert model.signals[0][1] == {"arg": exp.Literal.number(1)}
49214921

4922+
@signal()
4923+
def my_signal(batch):
4924+
return True
4925+
49224926
expressions = d.parse(
49234927
"""
49244928
MODEL (
@@ -4946,7 +4950,10 @@ def test_signals():
49464950
"""
49474951
)
49484952

4949-
model = load_sql_based_model(expressions)
4953+
model = load_sql_based_model(
4954+
expressions,
4955+
signal_definitions={"my_signal": signal.get_registry()["my_signal"]},
4956+
)
49504957
assert model.signals == [
49514958
(
49524959
"my_signal",
@@ -9837,6 +9844,22 @@ def test_invalid_audit_reference():
98379844
load_sql_based_model(expressions)
98389845

98399846

9847+
def test_invalid_signal_reference():
9848+
sql = """
9849+
MODEL (
9850+
name test,
9851+
signals (s())
9852+
);
9853+
9854+
SELECT
9855+
1 AS id
9856+
"""
9857+
expressions = d.parse(sql)
9858+
9859+
with pytest.raises(ConfigError, match="Signal 's' is undefined"):
9860+
load_sql_based_model(expressions)
9861+
9862+
98409863
def test_scd_time_data_type_does_not_cause_diff_after_deserialization() -> None:
98419864
for dialect in (
98429865
"athena",

0 commit comments

Comments
 (0)