Skip to content

Commit 5334251

Browse files
revise logic to extend current union; add unit tests
1 parent 696f3d8 commit 5334251

2 files changed

Lines changed: 100 additions & 130 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -970,93 +970,68 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr
970970
@macro()
971971
def union(
972972
evaluator: MacroEvaluator,
973-
type_: exp.Literal = exp.Literal.string("ALL"),
974-
*tables: exp.Table,
973+
*args: exp.Expression,
975974
) -> exp.Query:
976975
"""Returns a UNION of the given tables. Only choosing columns that have the same name and type.
977976
977+
Args:
978+
evaluator: MacroEvaluator that invoked the macro
979+
args: Variable arguments that can be:
980+
- First argument can be a condition (exp.Condition)
981+
- A union type ('ALL' or 'DISTINCT') as exp.Literal
982+
- Tables (exp.Table)
983+
978984
Example:
979985
>>> from sqlglot import parse_one
980986
>>> from sqlglot.schema import MappingSchema
981987
>>> from sqlmesh.core.macros import MacroEvaluator
982988
>>> sql = "@UNION('distinct', foo, bar)"
983989
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql()
984990
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
985-
"""
986-
kind = type_.name.upper()
987-
if kind not in ("ALL", "DISTINCT"):
988-
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
989-
990-
columns = {
991-
column
992-
for column, _ in reduce(
993-
lambda a, b: a & b, # type: ignore
994-
(evaluator.columns_to_types(table).items() for table in tables),
995-
)
996-
}
997-
998-
projections = [
999-
exp.cast(column, type_, dialect=evaluator.dialect).as_(column)
1000-
for column, type_ in evaluator.columns_to_types(tables[0]).items()
1001-
if column in columns
1002-
]
1003-
1004-
return reduce(
1005-
lambda a, b: a.union(b, distinct=kind == "DISTINCT"), # type: ignore
1006-
[exp.select(*projections).from_(t) for t in tables],
1007-
)
1008-
1009-
1010-
@macro()
1011-
def union_if(
1012-
evaluator: MacroEvaluator,
1013-
condition: exp.Expression,
1014-
type_: exp.Literal = exp.Literal.string("ALL"),
1015-
*tables: exp.Table | exp.Query,
1016-
) -> exp.Query:
1017-
"""Returns a UNION of the given tables or queries if the condition is true, otherwise returns just the first.
1018-
1019-
Example:
1020-
>>> from sqlglot import parse_one
1021-
>>> from sqlglot.schema import MappingSchema
1022-
>>> from sqlmesh.core.macros import MacroEvaluator
1023-
>>> sql = "@UNION_IF(True, 'distinct', foo, bar)"
991+
>>> sql = "@UNION(True, 'distinct', foo, bar)"
1024992
>>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql()
1025993
'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar'
1026994
"""
995+
996+
if not args:
997+
raise SQLMeshError("At least one table is required for UNION.")
998+
999+
arg_idx = 0
1000+
# Check for condition
1001+
condition = evaluator.eval_expression(args[arg_idx])
1002+
if isinstance(condition, bool):
1003+
arg_idx += 1
1004+
if arg_idx >= len(args):
1005+
raise SQLMeshError("Expected more arguments after condition.")
1006+
1007+
# Check for union type
1008+
type_ = exp.Literal.string("ALL")
1009+
if arg_idx < len(args) and isinstance(args[arg_idx], exp.Literal):
1010+
type_ = args[arg_idx] # type: ignore
1011+
arg_idx += 1
10271012
kind = type_.name.upper()
10281013
if kind not in ("ALL", "DISTINCT"):
10291014
raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.")
10301015

1031-
result = evaluator.eval_expression(condition)
1032-
1033-
if isinstance(tables[0], exp.Query):
1034-
if not result:
1035-
# If condition is false, return just the first query
1036-
return tables[0]
1037-
return tables[0].union(*tables[1:], distinct=kind == "DISTINCT")
1016+
# Remaining args should be tables
1017+
tables = args[arg_idx:]
10381018

10391019
columns = {
10401020
column
10411021
for column, _ in reduce(
10421022
lambda a, b: a & b, # type: ignore
1043-
(
1044-
evaluator.columns_to_types(table).items()
1045-
for table in tables
1046-
if isinstance(table, exp.Table) # for mypy
1047-
),
1023+
(evaluator.columns_to_types(table).items() for table in tables), # type: ignore
10481024
)
10491025
}
10501026

10511027
projections = [
10521028
exp.cast(column, type_, dialect=evaluator.dialect).as_(column)
1053-
for column, type_ in evaluator.columns_to_types(tables[0]).items()
1029+
for column, type_ in evaluator.columns_to_types(tables[0]).items() # type: ignore
10541030
if column in columns
10551031
]
10561032

1057-
result = evaluator.eval_expression(condition)
1058-
if not result:
1059-
# If condition is false, return just the first table with proper column casting
1033+
# Skip the union if condition is False
1034+
if condition == False:
10601035
return exp.select(*projections).from_(tables[0])
10611036

10621037
return reduce(

tests/core/test_model.py

Lines changed: 68 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -267,103 +267,98 @@ def test_model_union_query(sushi_context, assert_exp_eq):
267267

268268

269269
@time_machine.travel("1996-02-10 00:00:00 UTC")
270-
def test_model_union_if_table(sushi_context, assert_exp_eq):
270+
@pytest.mark.parametrize(
271+
"test_id, condition, union_type, table_count, expected_result",
272+
[
273+
# Test case 1: Basic conditional union - True condition
274+
(
275+
"test_1",
276+
"@get_date() == '1996-02-10'",
277+
"'all'",
278+
2,
279+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n",
280+
),
281+
# Test case 2: False condition - should return just first table
282+
(
283+
"test_2",
284+
"@get_date() > '1996-02-10'",
285+
"'all'",
286+
2,
287+
lambda expected_select: f"{expected_select}\n",
288+
),
289+
# Test case 3: Multiple tables in union
290+
(
291+
"test_3",
292+
"@get_date() == '1996-02-10'",
293+
"'all'",
294+
3,
295+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n",
296+
),
297+
# Test case 4: DISTINCT type
298+
(
299+
"test_4",
300+
"@get_date() == '1996-02-10'",
301+
"'distinct'",
302+
2,
303+
lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n",
304+
),
305+
# Test case 5: Complex condition
306+
(
307+
"test_5",
308+
"@get_date() = '1996-02-10' and 1=1 or @get_date() > '1996-02-10'",
309+
"'distinct'",
310+
2,
311+
lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n",
312+
),
313+
# Test case 6: Missing union type (defaults to ALL)
314+
(
315+
"test_6",
316+
"@get_date() == '1996-02-10'",
317+
"",
318+
2,
319+
lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n",
320+
),
321+
],
322+
)
323+
def test_model_union_conditional(
324+
sushi_context, assert_exp_eq, test_id, condition, union_type, table_count, expected_result
325+
):
271326
@macro()
272327
def get_date(evaluator):
273328
from sqlmesh.utils.date import now
274329

275330
return f"'{now().date()}'"
276331

277-
expressions = d.parse(
278-
"""
279-
MODEL (
280-
name sushi.test_1,
281-
kind FULL,
282-
);
283-
284-
@union_if(@get_date() == '1996-02-10', 'all', sushi.marketing, sushi.marketing)
285-
"""
286-
)
287-
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
288-
assert_exp_eq(
289-
sushi_context.get_model("sushi.test_1").render_query(),
290-
"""SELECT
332+
expected_select = """SELECT
291333
CAST("marketing"."customer_id" AS INT) AS "customer_id",
292334
CAST("marketing"."status" AS TEXT) AS "status",
293-
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
294-
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
295-
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
296-
FROM "memory"."sushi"."marketing" AS "marketing"
297-
UNION ALL
298-
SELECT
299-
CAST("marketing"."customer_id" AS INT) AS "customer_id",
300-
CAST("marketing"."status" AS TEXT) AS "status",
301-
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
302-
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
303-
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
304-
FROM "memory"."sushi"."marketing" AS "marketing"
305-
""",
306-
)
307-
308-
expressions = d.parse(
309-
"""
310-
MODEL (
311-
name sushi.test_2,
312-
kind FULL,
313-
);
314-
315-
@union_if(@get_date() > '1996-02-10', 'all', sushi.marketing, sushi.marketing)
316-
"""
317-
)
318-
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
319-
assert_exp_eq(
320-
sushi_context.get_model("sushi.test_2").render_query(),
321-
"""
322-
SELECT
323-
CAST("marketing"."customer_id" AS INT) AS "customer_id",
324-
CAST("marketing"."status" AS TEXT) AS "status",
325-
CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
335+
CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at",
326336
CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
327337
CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
328338
FROM "memory"."sushi"."marketing" AS "marketing"
329-
""",
330-
)
331-
332-
333-
def test_model_union_if_query(sushi_context, assert_exp_eq):
334-
expressions = d.parse(
335-
"""
336-
MODEL (
337-
name sushi.test_query,
338-
kind FULL,
339-
);
339+
"""
340340

341-
@union_if(True, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c')
342-
"""
343-
)
344-
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
341+
# Create tables argument list based on table_count
342+
tables = ", ".join(["sushi.marketing"] * table_count)
345343

346-
assert_exp_eq(
347-
sushi_context.get_model("sushi.test_query").render_query(),
348-
"""SELECT 1 AS "c" UNION ALL SELECT 2 AS "c" UNION ALL SELECT 3 AS "c"
349-
""",
350-
)
344+
# Handle the missing union_type case
345+
union_type_arg = f", {union_type}" if union_type else ""
351346

352347
expressions = d.parse(
353-
"""
348+
f"""
354349
MODEL (
355-
name sushi.test_query,
350+
name sushi.{test_id},
356351
kind FULL,
357352
);
358353
359-
@union_if(False, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c')
354+
@union({condition}{union_type_arg}, {tables})
360355
"""
361356
)
362357
sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
363358

364359
assert_exp_eq(
365-
sushi_context.get_model("sushi.test_query").render_query(),
366-
'SELECT 1 AS "c"',
360+
sushi_context.get_model(f"sushi.{test_id}").render_query(),
361+
expected_result(expected_select),
367362
)
368363

369364

0 commit comments

Comments
 (0)