Skip to content

Commit 21a3a20

Browse files
authored
fix!: add exp and SQL to type coerce so that it works for signals (#4236)
1 parent 930dd5d commit 21a3a20

2 files changed

Lines changed: 29 additions & 12 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,24 @@ class MacroStrTemplate(Template):
7575
EXPRESSIONS_NAME_MAP = {}
7676
SQL = t.NewType("SQL", str)
7777

78-
SUPPORTED_TYPES = {
79-
"t": t,
80-
"typing": t,
81-
"List": t.List,
82-
"Tuple": t.Tuple,
83-
"Union": t.Union,
84-
"DatetimeRanges": DatetimeRanges,
85-
}
78+
79+
@lru_cache()
80+
def get_supported_types() -> t.Dict[str, t.Any]:
81+
from sqlmesh.core.context import ExecutionContext
82+
83+
return {
84+
"t": t,
85+
"typing": t,
86+
"List": t.List,
87+
"Tuple": t.Tuple,
88+
"Union": t.Union,
89+
"DatetimeRanges": DatetimeRanges,
90+
"exp": exp,
91+
"SQL": SQL,
92+
"MacroEvaluator": MacroEvaluator,
93+
"ExecutionContext": ExecutionContext,
94+
}
95+
8696

8797
for klass in sqlglot.Parser.EXPRESSION_PARSERS:
8898
name = klass if isinstance(klass, str) else klass.__name__ # type: ignore
@@ -1305,7 +1315,7 @@ def call_macro(
13051315
bound.apply_defaults()
13061316

13071317
try:
1308-
annotations = t.get_type_hints(func, localns=SUPPORTED_TYPES)
1318+
annotations = t.get_type_hints(func, localns=get_supported_types())
13091319
except (NameError, TypeError): # forward references aren't handled
13101320
annotations = {}
13111321

tests/core/test_snapshot.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sqlmesh.core.context import Context
2222
from sqlmesh.core.dialect import parse, parse_one
2323
from sqlmesh.core.environment import EnvironmentNamingInfo
24+
from sqlmesh.core.macros import SQL
2425
from sqlmesh.core.model import (
2526
FullKind,
2627
IncrementalByTimeRangeKind,
@@ -2909,8 +2910,14 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
29092910

29102911
def test_render_signal(make_snapshot, mocker):
29112912
@signal()
2912-
def check_types(batch, env: str, default: int = 0):
2913-
if env != "in_memory" or not default == 0:
2913+
def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int = 0):
2914+
if not (
2915+
env == "in_memory"
2916+
and default == 0
2917+
and isinstance(sql, list)
2918+
and isinstance(sql[0], str)
2919+
and isinstance(table, exp.Table)
2920+
):
29142921
raise
29152922
return True
29162923

@@ -2919,7 +2926,7 @@ def check_types(batch, env: str, default: int = 0):
29192926
"""
29202927
MODEL (
29212928
name test_schema.test_model,
2922-
signals check_types(env := @gateway)
2929+
signals check_types(env := @gateway, sql := [a.b], table := b.c)
29232930
);
29242931
SELECT a FROM tbl;
29252932
"""

0 commit comments

Comments
 (0)