diff --git a/examples/sushi/models/orders.py b/examples/sushi/models/orders.py index aa0e04559f..8d8718a3e3 100644 --- a/examples/sushi/models/orders.py +++ b/examples/sushi/models/orders.py @@ -36,6 +36,7 @@ "end_ts": "int", "event_date": "date", }, + signals=[("test_signal", {"arg": 1})], ) def execute( context: ExecutionContext, diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 7f90c0de63..41bb8a0aef 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -672,6 +672,7 @@ def _load_python_models( default_catalog=self.context.default_catalog, infer_names=self.config.model_naming.infer_names, audit_definitions=audits, + signal_definitions=signals, default_catalog_per_gateway=self.context.default_catalog_per_gateway, ): if model.enabled: diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 0151e9ec76..3b78efc636 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import DialectType from sqlmesh.core.macros import MacroRegistry +from sqlmesh.core.signal import SignalRegistry from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.core import constants as c from sqlmesh.core.dialect import MacroFunc, parse_one @@ -48,23 +49,24 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: self.kwargs = kwargs # Make sure that argument values are expressions in order to pass validation in ModelMeta. - calls = self.kwargs.pop("audits", []) - self.kwargs["audits"] = [ - ( - (call, {}) - if isinstance(call, str) - else ( - call[0], - { - arg_key: exp.convert( - tuple(arg_value) if isinstance(arg_value, list) else arg_value - ) - for arg_key, arg_value in call[1].items() - }, + for function_call_attribute in ("audits", "signals"): + calls = self.kwargs.pop(function_call_attribute, []) + self.kwargs[function_call_attribute] = [ + ( + (call, {}) + if isinstance(call, str) + else ( + call[0], + { + arg_key: exp.convert( + tuple(arg_value) if isinstance(arg_value, list) else arg_value + ) + for arg_key, arg_value in call[1].items() + }, + ) ) - ) - for call in calls - ] + for call in calls + ] if "default_catalog" in kwargs: raise ConfigError("`default_catalog` cannot be set on a per-model basis.") @@ -142,6 +144,7 @@ def model( defaults: t.Optional[t.Dict[str, t.Any]] = None, macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, + signal_definitions: t.Optional[SignalRegistry] = None, audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, dialect: t.Optional[str] = None, time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, @@ -223,6 +226,7 @@ def model( "macros": macros, "jinja_macros": jinja_macros, "audit_definitions": audit_definitions, + "signal_definitions": signal_definitions, "blueprint_variables": blueprint_variables, **rendered_fields, } diff --git a/tests/core/test_context.py b/tests/core/test_context.py index e08f5346ea..213c4cec2b 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -2122,7 +2122,7 @@ def test_check_intervals(sushi_context, mocker): intervals = sushi_context.check_intervals(environment=None, no_signals=False, select_models=[]) min_intervals = 19 - assert spy.call_count == 1 + assert spy.call_count == 2 assert len(intervals) >= min_intervals for i in intervals.values(): diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 7c65f25889..a1f9034481 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -5303,6 +5303,30 @@ def my_signal(batch): ) +def test_load_python_model_with_signals(): + @signal() + def always_true(batch): + return True + + @model( + name="model_with_signal", + kind="full", + columns={'"COL"': "int"}, + signals=[("always_true", {})], + ) + def model_with_signal(context, **kwargs): + return pd.DataFrame([{"COL": 1}]) + + models = model.get_registry()["model_with_signal"].models( + get_variables=lambda _: {}, + path=Path("."), + module_path=Path("."), + signal_definitions=signal.get_registry(), + ) + assert len(models) == 1 + assert models[0].signals == [("always_true", {})] + + def test_null_column_type(): expressions = d.parse( """