Skip to content

Commit 1735a20

Browse files
authored
Fix: Support coercion of SQL literal expressions into literal types for macro arguments (#4643)
1 parent 4ada5f8 commit 1735a20

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,26 @@ def _coerce(
14521452
if base is SQL and isinstance(expr, exp.Expression):
14531453
return expr.sql(dialect)
14541454

1455+
if base is t.Literal:
1456+
if not isinstance(expr, (exp.Literal, exp.Boolean)):
1457+
raise SQLMeshError(
1458+
f"{base_err_msg} Coercion to {base} requires a literal expression."
1459+
)
1460+
literal_type_args = t.get_args(typ)
1461+
try:
1462+
for literal_type_arg in literal_type_args:
1463+
expr_is_bool = isinstance(expr.this, bool)
1464+
literal_is_bool = isinstance(literal_type_arg, bool)
1465+
if (expr_is_bool and literal_is_bool and literal_type_arg == expr.this) or (
1466+
not expr_is_bool
1467+
and not literal_is_bool
1468+
and str(literal_type_arg) == str(expr.this)
1469+
):
1470+
return type(literal_type_arg)(expr.this)
1471+
except Exception:
1472+
raise SQLMeshError(base_err_msg)
1473+
raise SQLMeshError(base_err_msg)
1474+
14551475
if isinstance(expr, base):
14561476
return expr
14571477
if issubclass(base, exp.Expression):

tests/core/test_macros.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def test_default_arg_coercion(
9696
def test_select_macro(evaluator):
9797
return "SELECT 1 AS col"
9898

99+
@macro()
100+
def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b", 1, True]):
101+
if isinstance(a, exp.Expression):
102+
raise SQLMeshError("Coercion failed")
103+
return f"'{a}'"
104+
99105
return MacroEvaluator(
100106
"hive",
101107
{"test": Executable(name="test", payload="def test(_):\n return 'test'")},
@@ -1087,3 +1093,33 @@ def test_macro_with_spaces():
10871093
("d.@z", 'd.a."b c"'),
10881094
):
10891095
assert evaluator.transform(parse_one(sql)).sql() == expected
1096+
1097+
1098+
def test_macro_coerce_literal_type(macro_evaluator):
1099+
expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_a')")
1100+
assert macro_evaluator.transform(expression).sql() == "'test_literal_a'"
1101+
1102+
expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_b')")
1103+
assert macro_evaluator.transform(expression).sql() == "'test_literal_b'"
1104+
1105+
expression = d.parse_one("@TEST_LITERAL_TYPE(1)")
1106+
assert macro_evaluator.transform(expression).sql() == "'1'"
1107+
1108+
expression = d.parse_one("@TEST_LITERAL_TYPE(True)")
1109+
assert macro_evaluator.transform(expression).sql() == "'True'"
1110+
1111+
expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_c')")
1112+
with pytest.raises(MacroEvalError, match=".*Coercion failed"):
1113+
macro_evaluator.transform(expression)
1114+
1115+
expression = d.parse_one("@TEST_LITERAL_TYPE(2)")
1116+
with pytest.raises(MacroEvalError, match=".*Coercion failed"):
1117+
macro_evaluator.transform(expression)
1118+
1119+
expression = d.parse_one("@TEST_LITERAL_TYPE(False)")
1120+
with pytest.raises(MacroEvalError, match=".*Coercion failed"):
1121+
macro_evaluator.transform(expression)
1122+
1123+
expression = d.parse_one("@TEST_LITERAL_TYPE(1.0)")
1124+
with pytest.raises(MacroEvalError, match=".*Coercion failed"):
1125+
macro_evaluator.transform(expression)

0 commit comments

Comments
 (0)