Skip to content

Commit 11ce8c8

Browse files
authored
Fix(table_diff): Make --limit per-sample and not across all samples (#4727)
1 parent a9ef531 commit 11ce8c8

3 files changed

Lines changed: 140 additions & 28 deletions

File tree

sqlmesh/core/table_diff.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sqlmesh.core.engine_adapter import EngineAdapter
2525

2626
SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key"
27+
SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type"
2728

2829

2930
class SchemaDiff(PydanticModel, frozen=True):
@@ -389,9 +390,6 @@ def _column_expr(name: str, table: str) -> exp.Expression:
389390
for c, t in matched_columns.items()
390391
]
391392

392-
def name(e: exp.Expression) -> str:
393-
return e.args["alias"].sql(identify=True)
394-
395393
source_query = (
396394
exp.select(
397395
*(exp.column(c) for c in source_schema),
@@ -581,30 +579,19 @@ def name(e: exp.Expression) -> str:
581579
.drop(index=index_cols, errors="ignore")
582580
)
583581

584-
sample_filter_cols = ["s_exists", "t_exists", "row_joined", "row_full_match"]
585-
sample_query = (
586-
exp.select(
587-
*(sample_filter_cols),
588-
*(name(c) for c in s_selects.values()),
589-
*(name(c) for c in t_selects.values()),
590-
)
591-
.from_(table)
592-
.where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons)))
593-
.order_by(
594-
*(name(s_selects[c.name]) for c in s_index),
595-
*(name(t_selects[c.name]) for c in t_index),
596-
)
597-
.limit(self.limit)
582+
sample = self._fetch_sample(
583+
table, s_selects, s_index, t_selects, t_index, self.limit
598584
)
599-
sample = self.adapter.fetchdf(sample_query, quote_identifiers=True)
600585

601586
joined_sample_cols = [f"s__{c}" for c in s_index_names]
602587
comparison_cols = [
603588
(f"s__{c}", f"t__{c}")
604589
for c in column_stats[column_stats["pct_match"] < 100].index
605590
]
591+
606592
for cols in comparison_cols:
607593
joined_sample_cols.extend(cols)
594+
608595
joined_renamed_cols = {
609596
c: c.split("__")[1] if c.split("__")[1] in index_cols else c
610597
for c in joined_sample_cols
@@ -638,13 +625,16 @@ def name(e: exp.Expression) -> str:
638625
)
639626
for c, n in joined_renamed_cols.items()
640627
}
641-
joined_sample = sample[sample["row_joined"] == 1][joined_sample_cols]
628+
629+
joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][
630+
joined_sample_cols
631+
]
642632
joined_sample.rename(
643633
columns=joined_renamed_cols,
644634
inplace=True,
645635
)
646636

647-
s_sample = sample[(sample["s_exists"] == 1) & (sample["row_joined"] == 0)][
637+
s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][
648638
[
649639
*[f"s__{c}" for c in s_index_names],
650640
*[f"s__{c}" for c in source_schema if c not in s_index_names],
@@ -654,7 +644,7 @@ def name(e: exp.Expression) -> str:
654644
columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
655645
)
656646

657-
t_sample = sample[(sample["t_exists"] == 1) & (sample["row_joined"] == 0)][
647+
t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][
658648
[
659649
*[f"t__{c}" for c in t_index_names],
660650
*[f"t__{c}" for c in target_schema if c not in t_index_names],
@@ -665,8 +655,11 @@ def name(e: exp.Expression) -> str:
665655
)
666656

667657
sample.drop(
668-
columns=sample_filter_cols
669-
+ [f"s__{SQLMESH_JOIN_KEY_COL}", f"t__{SQLMESH_JOIN_KEY_COL}"],
658+
columns=[
659+
f"s__{SQLMESH_JOIN_KEY_COL}",
660+
f"t__{SQLMESH_JOIN_KEY_COL}",
661+
SQLMESH_SAMPLE_TYPE_COL,
662+
],
670663
inplace=True,
671664
)
672665

@@ -684,4 +677,75 @@ def name(e: exp.Expression) -> str:
684677
model_name=self.model_name,
685678
decimals=self.decimals,
686679
)
680+
687681
return self._row_diff
682+
683+
def _fetch_sample(
684+
self,
685+
sample_table: exp.Table,
686+
s_selects: t.Dict[str, exp.Alias],
687+
s_index: t.List[exp.Column],
688+
t_selects: t.Dict[str, exp.Alias],
689+
t_index: t.List[exp.Column],
690+
limit: int,
691+
) -> pd.DataFrame:
692+
rendered_data_column_names = [
693+
name(s) for s in list(s_selects.values()) + list(t_selects.values())
694+
]
695+
sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL)
696+
697+
source_only_sample = (
698+
exp.select(
699+
exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names
700+
)
701+
.from_(sample_table)
702+
.where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0)))
703+
.order_by(*(name(s_selects[c.name]) for c in s_index))
704+
.limit(limit)
705+
)
706+
707+
target_only_sample = (
708+
exp.select(
709+
exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names
710+
)
711+
.from_(sample_table)
712+
.where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0)))
713+
.order_by(*(name(t_selects[c.name]) for c in t_index))
714+
.limit(limit)
715+
)
716+
717+
common_rows_sample = (
718+
exp.select(
719+
exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names
720+
)
721+
.from_(sample_table)
722+
.where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0)))
723+
.order_by(
724+
*(name(s_selects[c.name]) for c in s_index),
725+
*(name(t_selects[c.name]) for c in t_index),
726+
)
727+
.limit(limit)
728+
)
729+
730+
query = (
731+
exp.Select()
732+
.with_("source_only", source_only_sample)
733+
.with_("target_only", target_only_sample)
734+
.with_("common_rows", common_rows_sample)
735+
.select(sample_type, *rendered_data_column_names)
736+
.from_("source_only")
737+
.union(
738+
exp.select(sample_type, *rendered_data_column_names).from_("target_only"),
739+
distinct=False,
740+
)
741+
.union(
742+
exp.select(sample_type, *rendered_data_column_names).from_("common_rows"),
743+
distinct=False,
744+
)
745+
)
746+
747+
return self.adapter.fetchdf(query, quote_identifiers=True)
748+
749+
750+
def name(e: exp.Expression) -> str:
751+
return e.args["alias"].sql(identify=True)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2358,8 +2358,8 @@ def test_table_diff_grain_check_multiple_keys(ctx: TestContext):
23582358
assert row_diff.stats["distinct_count_s"] == 7
23592359
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
23602360
assert row_diff.stats["distinct_count_t"] == 10
2361-
assert row_diff.s_sample.shape == (0, 3)
2362-
assert row_diff.t_sample.shape == (3, 3)
2361+
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
2362+
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)
23632363

23642364

23652365
def test_table_diff_arbitrary_condition(ctx: TestContext):

tests/core/test_table_diff.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ def test_grain_check(sushi_context_fixed_date):
306306
assert row_diff.stats["distinct_count_s"] == 7
307307
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
308308
assert row_diff.stats["distinct_count_t"] == 10
309-
assert row_diff.s_sample.shape == (0, 3)
310-
assert row_diff.t_sample.shape == (3, 3)
309+
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
310+
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)
311311

312312

313313
def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture):
@@ -338,7 +338,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
338338
query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
339339
summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
340340
compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1'
341-
sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20'
341+
sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"'
342342
drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
343343

344344
# make with_log_level() return the current instance of engine_adapter so we can still spy on _execute
@@ -1032,3 +1032,51 @@ def test_schema_diff_ignore_case():
10321032
exp.DataType.build("date"),
10331033
) # notice: source casing used on output
10341034
}
1035+
1036+
1037+
def test_data_diff_sample_limit():
1038+
engine_adapter = DuckDBConnectionConfig().create_engine_adapter()
1039+
1040+
columns_to_types = {"id": exp.DataType.build("int"), "name": exp.DataType.build("varchar")}
1041+
1042+
engine_adapter.create_table("src", columns_to_types)
1043+
engine_adapter.create_table("target", columns_to_types)
1044+
1045+
common_records = {}
1046+
src_only_records = {}
1047+
target_only_records = {}
1048+
1049+
for i in range(0, 10):
1050+
common_records[i] = f"common_{i}"
1051+
src_only_records[i + 20] = f"src_{i}"
1052+
target_only_records[i + 40] = f"target_{i}"
1053+
1054+
src_records = {**common_records, **src_only_records}
1055+
target_records = {**common_records, **target_only_records}
1056+
1057+
# changes
1058+
src_records[1] = "modified_source_1"
1059+
src_records[3] = "modified_source_3"
1060+
target_records[2] = "modified_target_2"
1061+
target_records[7] = "modified_target_7"
1062+
1063+
src_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in src_records.items()])
1064+
target_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in target_records.items()])
1065+
1066+
engine_adapter.insert_append("src", src_df)
1067+
engine_adapter.insert_append("target", target_df)
1068+
1069+
table_diff = TableDiff(
1070+
adapter=engine_adapter, source="src", target="target", on=["id"], limit=3
1071+
)
1072+
1073+
diff = table_diff.row_diff()
1074+
1075+
assert diff.s_only_count == 10
1076+
assert diff.t_only_count == 10
1077+
assert diff.join_count == 10
1078+
1079+
# each sample should contain :limit records
1080+
assert len(diff.s_sample) == 3
1081+
assert len(diff.t_sample) == 3
1082+
assert len(diff.joined_sample) == 3

0 commit comments

Comments
 (0)