1010from sqlmesh .core .model import load_sql_based_model
1111from sqlmesh .core .model .definition import SqlModel
1212from sqlmesh .utils .errors import SQLMeshError
13+ from sqlmesh .core .table_diff import TableDiff
1314
1415from tests .core .engine_adapter import to_sql_calls
1516
@@ -21,6 +22,16 @@ def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter:
2122 return make_mocked_engine_adapter (AthenaEngineAdapter )
2223
2324
25+ @pytest .fixture
26+ def table_diff (adapter : AthenaEngineAdapter ) -> TableDiff :
27+ return TableDiff (
28+ adapter = adapter ,
29+ source = "source_table" ,
30+ target = "target_table" ,
31+ on = ["id" ],
32+ )
33+
34+
2435@pytest .mark .parametrize (
2536 "config_s3_warehouse_location,table_properties,table,expected_location" ,
2637 [
@@ -483,3 +494,74 @@ def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter):
483494 # Trino syntax - CTAS
484495 """CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['MONTH(business_date)', 'BUCKET(colb, 4)', 'colc'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\" """ ,
485496 ]
497+
498+
499+ @pytest .mark .parametrize (
500+ "source_format, target_format, expected_temp_format, expect_error" ,
501+ [
502+ ("hive" , "hive" , None , False ),
503+ ("iceberg" , "hive" , None , True ), # Expect error for mismatched formats
504+ ("hive" , "iceberg" , None , True ), # Expect error for mismatched formats
505+ ("iceberg" , "iceberg" , "iceberg" , False ),
506+ (None , "iceberg" , None , True ), # Source doesn't exist or type unknown, target is iceberg
507+ (
508+ "iceberg" ,
509+ None ,
510+ "iceberg" ,
511+ True ,
512+ ), # Target doesn't exist or type unknown, source is iceberg
513+ (None , "hive" , None , False ), # Source doesn't exist or type unknown, target is hive
514+ ("hive" , None , None , False ), # Target doesn't exist or type unknown, source is hive
515+ (None , None , None , False ), # Both don't exist or types unknown
516+ ],
517+ )
518+ def test_table_diff_temp_table_format (
519+ table_diff : TableDiff ,
520+ mocker : MockerFixture ,
521+ source_format : t .Optional [str ],
522+ target_format : t .Optional [str ],
523+ expected_temp_format : t .Optional [str ],
524+ expect_error : bool ,
525+ ):
526+ adapter = t .cast (AthenaEngineAdapter , table_diff .adapter )
527+
528+ # Mock _query_table_type to return specified formats
529+ def mock_query_table_type (table_name : exp .Table ) -> t .Optional [str ]:
530+ if table_name .name == "source_table" :
531+ return source_format
532+ if table_name .name == "target_table" :
533+ return target_format
534+ return "hive" # Default for other tables if any
535+
536+ mocker .patch .object (adapter , "_query_table_type" , side_effect = mock_query_table_type )
537+
538+ # Mock temp_table to capture kwargs
539+ mock_temp_table = mocker .patch .object (adapter , "temp_table" , autospec = True )
540+ mock_temp_table .return_value .__enter__ .return_value = exp .to_table ("diff_table" )
541+
542+ # Mock fetchdf and other calls made within row_diff to avoid actual DB interaction
543+ mocker .patch .object (adapter , "fetchdf" , return_value = pd .DataFrame ())
544+ mocker .patch .object (adapter , "get_data_objects" , return_value = [])
545+ mocker .patch .object (adapter , "columns" , return_value = {"id" : exp .DataType .build ("int" )})
546+
547+ if expect_error :
548+ with pytest .raises (
549+ SQLMeshError ,
550+ match = "do not match for Athena. Diffing between different table formats is not supported." ,
551+ ):
552+ table_diff .row_diff ()
553+ mock_temp_table .assert_not_called () # temp_table should not be called if formats mismatch
554+ return
555+
556+ try :
557+ table_diff .row_diff ()
558+ except Exception :
559+ pass # We only care about the temp_table call args for non-error cases
560+
561+ mock_temp_table .assert_called_once ()
562+ _ , called_kwargs = mock_temp_table .call_args
563+
564+ if expected_temp_format :
565+ assert called_kwargs .get ("table_format" ) == expected_temp_format
566+ else :
567+ assert "table_format" not in called_kwargs
0 commit comments