Skip to content

Commit aafdbf4

Browse files
Feat: Add union if macro for conditional union of tables
1 parent 51fe510 commit aafdbf4

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,53 @@ def union(
10071007
)
10081008

10091009

1010+
@macro()
1011+
def union_if(
1012+
evaluator: MacroEvaluator,
1013+
condition: exp.Expression,
1014+
type_: exp.Literal = exp.Literal.string("ALL"),
1015+
*tables: exp.Table,
1016+
) -> exp.Query:
1017+
"""Returns a UNION of the given tables if the condition is true, otherwise returns just the first table.
1018+
The behaviour remains of only choosing columns that have the same name and type.
1019+
1020+
Example:
1021+
>>> from sqlglot import parse_one
1022+
>>> from sqlglot.schema import MappingSchema
1023+
>>> from sqlmesh.core.macros import MacroEvaluator
1024+
>>> sql = "@UNION_IF(@start_ts = '2025-01-01 00:00:00', 'distinct', foo, bar)"
1025+
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql()
1026+
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
1027+
"""
1028+
kind = type_.name.upper()
1029+
if kind not in ("ALL", "DISTINCT"):
1030+
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
1031+
1032+
columns = {
1033+
column
1034+
for column, _ in reduce(
1035+
lambda a, b: a & b, # type: ignore
1036+
(evaluator.columns_to_types(table).items() for table in tables),
1037+
)
1038+
}
1039+
1040+
projections = [
1041+
exp.cast(column, type_, dialect=evaluator.dialect).as_(column)
1042+
for column, type_ in evaluator.columns_to_types(tables[0]).items()
1043+
if column in columns
1044+
]
1045+
1046+
result = evaluator.eval_expression(condition)
1047+
if not result:
1048+
# If condition is false, return just the first table with proper column casting
1049+
return exp.select(*projections).from_(tables[0])
1050+
1051+
return reduce(
1052+
lambda a, b: a.union(b, distinct=kind == "DISTINCT"), # type: ignore
1053+
[exp.select(*projections).from_(t) for t in tables],
1054+
)
1055+
1056+
10101057
@macro()
10111058
def haversine_distance(
10121059
_: MacroEvaluator,

tests/core/test_model.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,70 @@ def test_model_union_query(sushi_context, assert_exp_eq):
266266
)
267267

268268

269+
@time_machine.travel("1996-02-10 00:00:00 UTC")
270+
def test_model_union_if_query(sushi_context, assert_exp_eq):
271+
@macro()
272+
def get_date(evaluator):
273+
from sqlmesh.utils.date import now
274+
275+
return f"'{now().date()}'"
276+
277+
expressions = d.parse(
278+
"""
279+
MODEL (
280+
name sushi.test_1,
281+
kind FULL,
282+
);
283+
284+
@union_if(@get_date() == '1996-02-10', 'all', sushi.marketing, sushi.marketing)
285+
"""
286+
)
287+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
288+
assert_exp_eq(
289+
sushi_context.get_model("sushi.test_1").render_query(),
290+
"""SELECT
291+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
292+
CAST("marketing"."status" AS TEXT) AS "status",
293+
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
294+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
295+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
296+
FROM "memory"."sushi"."marketing" AS "marketing"
297+
UNION ALL
298+
SELECT
299+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
300+
CAST("marketing"."status" AS TEXT) AS "status",
301+
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
302+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
303+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
304+
FROM "memory"."sushi"."marketing" AS "marketing"
305+
""",
306+
)
307+
308+
expressions = d.parse(
309+
"""
310+
MODEL (
311+
name sushi.test_2,
312+
kind FULL,
313+
);
314+
315+
@union_if(@get_date() > '1996-02-10', 'all', sushi.marketing, sushi.marketing)
316+
"""
317+
)
318+
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
319+
assert_exp_eq(
320+
sushi_context.get_model("sushi.test_2").render_query(),
321+
"""
322+
SELECT
323+
CAST("marketing"."customer_id" AS INT) AS "customer_id",
324+
CAST("marketing"."status" AS TEXT) AS "status",
325+
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
326+
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
327+
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
328+
FROM "memory"."sushi"."marketing" AS "marketing"
329+
""",
330+
)
331+
332+
269333
def test_model_validation_union_query():
270334
expressions = d.parse(
271335
"""

0 commit comments

Comments
 (0)