@@ -267,103 +267,98 @@ def test_model_union_query(sushi_context, assert_exp_eq):
267267
268268
269269@time_machine .travel ("1996-02-10 00:00:00 UTC" )
270- def test_model_union_if_table (sushi_context , assert_exp_eq ):
270+ @pytest .mark .parametrize (
271+ "test_id, condition, union_type, table_count, expected_result" ,
272+ [
273+ # Test case 1: Basic conditional union - True condition
274+ (
275+ "test_1" ,
276+ "@get_date() == '1996-02-10'" ,
277+ "'all'" ,
278+ 2 ,
279+ lambda expected_select : f"{ expected_select } \n UNION ALL\n { expected_select } \n " ,
280+ ),
281+ # Test case 2: False condition - should return just first table
282+ (
283+ "test_2" ,
284+ "@get_date() > '1996-02-10'" ,
285+ "'all'" ,
286+ 2 ,
287+ lambda expected_select : f"{ expected_select } \n " ,
288+ ),
289+ # Test case 3: Multiple tables in union
290+ (
291+ "test_3" ,
292+ "@get_date() == '1996-02-10'" ,
293+ "'all'" ,
294+ 3 ,
295+ lambda expected_select : f"{ expected_select } \n UNION ALL\n { expected_select } \n UNION ALL\n { expected_select } \n " ,
296+ ),
297+ # Test case 4: DISTINCT type
298+ (
299+ "test_4" ,
300+ "@get_date() == '1996-02-10'" ,
301+ "'distinct'" ,
302+ 2 ,
303+ lambda expected_select : f"{ expected_select } \n UNION\n { expected_select } \n " ,
304+ ),
305+ # Test case 5: Complex condition
306+ (
307+ "test_5" ,
308+ "@get_date() = '1996-02-10' and 1=1 or @get_date() > '1996-02-10'" ,
309+ "'distinct'" ,
310+ 2 ,
311+ lambda expected_select : f"{ expected_select } \n UNION\n { expected_select } \n " ,
312+ ),
313+ # Test case 6: Missing union type (defaults to ALL)
314+ (
315+ "test_6" ,
316+ "@get_date() == '1996-02-10'" ,
317+ "" ,
318+ 2 ,
319+ lambda expected_select : f"{ expected_select } \n UNION ALL\n { expected_select } \n " ,
320+ ),
321+ ],
322+ )
323+ def test_model_union_conditional (
324+ sushi_context , assert_exp_eq , test_id , condition , union_type , table_count , expected_result
325+ ):
271326 @macro ()
272327 def get_date (evaluator ):
273328 from sqlmesh .utils .date import now
274329
275330 return f"'{ now ().date ()} '"
276331
277- expressions = d .parse (
278- """
279- MODEL (
280- name sushi.test_1,
281- kind FULL,
282- );
283-
284- @union_if(@get_date() == '1996-02-10', 'all', sushi.marketing, sushi.marketing)
285- """
286- )
287- sushi_context .upsert_model (load_sql_based_model (expressions , default_catalog = "memory" ))
288- assert_exp_eq (
289- sushi_context .get_model ("sushi.test_1" ).render_query (),
290- """SELECT
332+ expected_select = """SELECT
291333 CAST("marketing"."customer_id" AS INT) AS "customer_id",
292334 CAST("marketing"."status" AS TEXT) AS "status",
293- CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
294- CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
295- CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
296- FROM "memory"."sushi"."marketing" AS "marketing"
297- UNION ALL
298- SELECT
299- CAST("marketing"."customer_id" AS INT) AS "customer_id",
300- CAST("marketing"."status" AS TEXT) AS "status",
301- CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
302- CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
303- CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
304- FROM "memory"."sushi"."marketing" AS "marketing"
305- """ ,
306- )
307-
308- expressions = d .parse (
309- """
310- MODEL (
311- name sushi.test_2,
312- kind FULL,
313- );
314-
315- @union_if(@get_date() > '1996-02-10', 'all', sushi.marketing, sushi.marketing)
316- """
317- )
318- sushi_context .upsert_model (load_sql_based_model (expressions , default_catalog = "memory" ))
319- assert_exp_eq (
320- sushi_context .get_model ("sushi.test_2" ).render_query (),
321- """
322- SELECT
323- CAST("marketing"."customer_id" AS INT) AS "customer_id",
324- CAST("marketing"."status" AS TEXT) AS "status",
325- CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at",
335+ CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at",
326336 CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from",
327337 CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to"
328338FROM "memory"."sushi"."marketing" AS "marketing"
329- """ ,
330- )
331-
332-
333- def test_model_union_if_query (sushi_context , assert_exp_eq ):
334- expressions = d .parse (
335- """
336- MODEL (
337- name sushi.test_query,
338- kind FULL,
339- );
339+ """
340340
341- @union_if(True, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c')
342- """
343- )
344- sushi_context .upsert_model (load_sql_based_model (expressions , default_catalog = "memory" ))
341+ # Create tables argument list based on table_count
342+ tables = ", " .join (["sushi.marketing" ] * table_count )
345343
346- assert_exp_eq (
347- sushi_context .get_model ("sushi.test_query" ).render_query (),
348- """SELECT 1 AS "c" UNION ALL SELECT 2 AS "c" UNION ALL SELECT 3 AS "c"
349- """ ,
350- )
344+ # Handle the missing union_type case
345+ union_type_arg = f", { union_type } " if union_type else ""
351346
352347 expressions = d .parse (
353- """
348+ f """
354349 MODEL (
355- name sushi.test_query ,
350+ name sushi.{ test_id } ,
356351 kind FULL,
357352 );
358353
359- @union_if(False, 'all', 'select 1 as c', 'select 2 as c', 'select 3 as c' )
354+ @union( { condition } { union_type_arg } , { tables } )
360355 """
361356 )
362357 sushi_context .upsert_model (load_sql_based_model (expressions , default_catalog = "memory" ))
363358
364359 assert_exp_eq (
365- sushi_context .get_model ("sushi.test_query " ).render_query (),
366- 'SELECT 1 AS "c"' ,
360+ sushi_context .get_model (f "sushi.{ test_id } " ).render_query (),
361+ expected_result ( expected_select ) ,
367362 )
368363
369364
0 commit comments