Skip to content

Commit 76f52e6

Browse files
authored
Add iceberg support to table_diff (#4441)
1 parent d565154 commit 76f52e6

2 files changed

Lines changed: 104 additions & 1 deletion

File tree

sqlmesh/core/table_diff.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
from sqlmesh.core.dialect import to_schema
1010
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
11+
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
1112
from sqlglot import exp, parse_one
1213
from sqlglot.helper import ensure_list
1314
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1415
from sqlglot.optimizer.qualify_columns import quote_identifiers
1516
from sqlglot.optimizer.scope import find_all_in_scope
1617

1718
from sqlmesh.utils.pydantic import PydanticModel
19+
from sqlmesh.utils.errors import SQLMeshError
1820

1921
if t.TYPE_CHECKING:
2022
from sqlmesh.core._typing import TableName
@@ -431,7 +433,26 @@ def name(e: exp.Expression) -> str:
431433
schema = to_schema(temp_schema, dialect=self.dialect)
432434
temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True)
433435

434-
with self.adapter.temp_table(query, name=temp_table) as table:
436+
temp_table_kwargs = {}
437+
if isinstance(self.adapter, AthenaEngineAdapter):
438+
# Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that
439+
# the formats be the same for the source, target, and temp tables.
440+
source_table_type = self.adapter._query_table_type(self.source_table)
441+
target_table_type = self.adapter._query_table_type(self.target_table)
442+
443+
if source_table_type == "iceberg" and target_table_type == "iceberg":
444+
temp_table_kwargs["table_format"] = "iceberg"
445+
# Sets the temp table's format to Iceberg.
446+
# If neither source nor target table is Iceberg, it defaults to Hive (Athena's default).
447+
elif source_table_type == "iceberg" or target_table_type == "iceberg":
448+
raise SQLMeshError(
449+
f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' "
450+
f"do not match for Athena. Diffing between different table formats is not supported."
451+
)
452+
453+
with self.adapter.temp_table(
454+
query, name=temp_table, columns_to_types=None, **temp_table_kwargs
455+
) as table:
435456
summary_sums = [
436457
exp.func("SUM", "s_exists").as_("s_count"),
437458
exp.func("SUM", "t_exists").as_("t_count"),

tests/core/engine_adapter/test_athena.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlmesh.core.model import load_sql_based_model
1111
from sqlmesh.core.model.definition import SqlModel
1212
from sqlmesh.utils.errors import SQLMeshError
13+
from sqlmesh.core.table_diff import TableDiff
1314

1415
from tests.core.engine_adapter import to_sql_calls
1516

@@ -21,6 +22,16 @@ def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter:
2122
return make_mocked_engine_adapter(AthenaEngineAdapter)
2223

2324

25+
@pytest.fixture
26+
def table_diff(adapter: AthenaEngineAdapter) -> TableDiff:
27+
return TableDiff(
28+
adapter=adapter,
29+
source="source_table",
30+
target="target_table",
31+
on=["id"],
32+
)
33+
34+
2435
@pytest.mark.parametrize(
2536
"config_s3_warehouse_location,table_properties,table,expected_location",
2637
[
@@ -483,3 +494,74 @@ def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter):
483494
# Trino syntax - CTAS
484495
"""CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['MONTH(business_date)', 'BUCKET(colb, 4)', 'colc'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\"""",
485496
]
497+
498+
499+
@pytest.mark.parametrize(
500+
"source_format, target_format, expected_temp_format, expect_error",
501+
[
502+
("hive", "hive", None, False),
503+
("iceberg", "hive", None, True), # Expect error for mismatched formats
504+
("hive", "iceberg", None, True), # Expect error for mismatched formats
505+
("iceberg", "iceberg", "iceberg", False),
506+
(None, "iceberg", None, True), # Source doesn't exist or type unknown, target is iceberg
507+
(
508+
"iceberg",
509+
None,
510+
"iceberg",
511+
True,
512+
), # Target doesn't exist or type unknown, source is iceberg
513+
(None, "hive", None, False), # Source doesn't exist or type unknown, target is hive
514+
("hive", None, None, False), # Target doesn't exist or type unknown, source is hive
515+
(None, None, None, False), # Both don't exist or types unknown
516+
],
517+
)
518+
def test_table_diff_temp_table_format(
519+
table_diff: TableDiff,
520+
mocker: MockerFixture,
521+
source_format: t.Optional[str],
522+
target_format: t.Optional[str],
523+
expected_temp_format: t.Optional[str],
524+
expect_error: bool,
525+
):
526+
adapter = t.cast(AthenaEngineAdapter, table_diff.adapter)
527+
528+
# Mock _query_table_type to return specified formats
529+
def mock_query_table_type(table_name: exp.Table) -> t.Optional[str]:
530+
if table_name.name == "source_table":
531+
return source_format
532+
if table_name.name == "target_table":
533+
return target_format
534+
return "hive" # Default for other tables if any
535+
536+
mocker.patch.object(adapter, "_query_table_type", side_effect=mock_query_table_type)
537+
538+
# Mock temp_table to capture kwargs
539+
mock_temp_table = mocker.patch.object(adapter, "temp_table", autospec=True)
540+
mock_temp_table.return_value.__enter__.return_value = exp.to_table("diff_table")
541+
542+
# Mock fetchdf and other calls made within row_diff to avoid actual DB interaction
543+
mocker.patch.object(adapter, "fetchdf", return_value=pd.DataFrame())
544+
mocker.patch.object(adapter, "get_data_objects", return_value=[])
545+
mocker.patch.object(adapter, "columns", return_value={"id": exp.DataType.build("int")})
546+
547+
if expect_error:
548+
with pytest.raises(
549+
SQLMeshError,
550+
match="do not match for Athena. Diffing between different table formats is not supported.",
551+
):
552+
table_diff.row_diff()
553+
mock_temp_table.assert_not_called() # temp_table should not be called if formats mismatch
554+
return
555+
556+
try:
557+
table_diff.row_diff()
558+
except Exception:
559+
pass # We only care about the temp_table call args for non-error cases
560+
561+
mock_temp_table.assert_called_once()
562+
_, called_kwargs = mock_temp_table.call_args
563+
564+
if expected_temp_format:
565+
assert called_kwargs.get("table_format") == expected_temp_format
566+
else:
567+
assert "table_format" not in called_kwargs

0 commit comments

Comments
 (0)