Skip to content

Commit 178d59d

Browse files
add support for queries
1 parent 08e7c3b commit 178d59d

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,9 @@ def union_if(
10121012
evaluator: MacroEvaluator,
10131013
condition: exp.Expression,
10141014
type_: exp.Literal = exp.Literal.string("ALL"),
1015-
*tables: exp.Table,
1015+
*tables: exp.Table | exp.Query,
10161016
) -> exp.Query:
1017-
"""Returns a UNION of the given tables if the condition is true, otherwise returns just the first table.
1018-
The behaviour remains of only choosing columns that have the same name and type.
1017+
"""Returns a UNION of the given tables or queries if the condition is true, otherwise returns just the first.
10191018
10201019
Example:
10211020
>>> from sqlglot import parse_one
@@ -1029,11 +1028,23 @@ def union_if(
10291028
if kind not in ("ALL", "DISTINCT"):
10301029
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
10311030

1031+
result = evaluator.eval_expression(condition)
1032+
1033+
if isinstance(tables[0], exp.Query):
1034+
if not result:
1035+
# If condition is false, return just the first query
1036+
return tables[0]
1037+
return tables[0].union(*tables[1:], distinct=kind == "DISTINCT")
1038+
10321039
columns = {
10331040
column
10341041
for column, _ in reduce(
10351042
lambda a, b: a & b, # type: ignore
1036-
(evaluator.columns_to_types(table).items() for table in tables),
1043+
(
1044+
evaluator.columns_to_types(table).items()
1045+
for table in tables
1046+
if isinstance(table, exp.Table) # for mypy
1047+
),
10371048
)
10381049
}
10391050

tests/core/test_model.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def test_model_union_query(sushi_context, assert_exp_eq):
267267

268268

269269
@time_machine.travel("1996-02-10 00:00:00 UTC")
270-
def test_model_union_if_query(sushi_context, assert_exp_eq):
270+
def test_model_union_if_table(sushi_context, assert_exp_eq):
271271
@macro()
272272
def get_date(evaluator):
273273
from sqlmesh.utils.date import now
@@ -330,6 +330,43 @@ def get_date(evaluator):
330330
)
331331

332332

333+
def test_model_union_if_query(sushi_context, assert_exp_eq):
334+
expressions = d.parse(
335+
"""
336+
MODEL (
337+
name sushi.test_query,
338+
kind FULL,
339+
);
340+
341+
@union_if(True, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c')
342+
"""
343+
)
344+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
345+
346+
assert_exp_eq(
347+
sushi_context.get_model("sushi.test_query").render_query(),
348+
"""SELECT 1 AS "c" UNION ALL SELECT 2 AS "c" UNION ALL SELECT 3 AS "c"
349+
""",
350+
)
351+
352+
expressions = d.parse(
353+
"""
354+
MODEL (
355+
name sushi.test_query,
356+
kind FULL,
357+
);
358+
359+
@union_if(False, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c')
360+
"""
361+
)
362+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
363+
364+
assert_exp_eq(
365+
sushi_context.get_model("sushi.test_query").render_query(),
366+
'SELECT 1 AS "c"',
367+
)
368+
369+
333370
def test_model_validation_union_query():
334371
expressions = d.parse(
335372
"""

0 commit comments

Comments
 (0)