Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 87 additions & 23 deletions sqlmesh/core/table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlmesh.core.engine_adapter import EngineAdapter

SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key"
SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type"


class SchemaDiff(PydanticModel, frozen=True):
Expand Down Expand Up @@ -389,9 +390,6 @@ def _column_expr(name: str, table: str) -> exp.Expression:
for c, t in matched_columns.items()
]

def name(e: exp.Expression) -> str:
return e.args["alias"].sql(identify=True)

source_query = (
exp.select(
*(exp.column(c) for c in source_schema),
Expand Down Expand Up @@ -581,30 +579,19 @@ def name(e: exp.Expression) -> str:
.drop(index=index_cols, errors="ignore")
)

sample_filter_cols = ["s_exists", "t_exists", "row_joined", "row_full_match"]
sample_query = (
exp.select(
*(sample_filter_cols),
*(name(c) for c in s_selects.values()),
*(name(c) for c in t_selects.values()),
)
.from_(table)
.where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons)))
.order_by(
*(name(s_selects[c.name]) for c in s_index),
*(name(t_selects[c.name]) for c in t_index),
)
.limit(self.limit)
sample = self._fetch_sample(
table, s_selects, s_index, t_selects, t_index, self.limit
)
sample = self.adapter.fetchdf(sample_query, quote_identifiers=True)

joined_sample_cols = [f"s__{c}" for c in s_index_names]
comparison_cols = [
(f"s__{c}", f"t__{c}")
for c in column_stats[column_stats["pct_match"] < 100].index
]

for cols in comparison_cols:
joined_sample_cols.extend(cols)

joined_renamed_cols = {
c: c.split("__")[1] if c.split("__")[1] in index_cols else c
for c in joined_sample_cols
Expand Down Expand Up @@ -638,13 +625,16 @@ def name(e: exp.Expression) -> str:
)
for c, n in joined_renamed_cols.items()
}
joined_sample = sample[sample["row_joined"] == 1][joined_sample_cols]

joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][
joined_sample_cols
]
joined_sample.rename(
columns=joined_renamed_cols,
inplace=True,
)

s_sample = sample[(sample["s_exists"] == 1) & (sample["row_joined"] == 0)][
s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][
[
*[f"s__{c}" for c in s_index_names],
*[f"s__{c}" for c in source_schema if c not in s_index_names],
Expand All @@ -654,7 +644,7 @@ def name(e: exp.Expression) -> str:
columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True
)

t_sample = sample[(sample["t_exists"] == 1) & (sample["row_joined"] == 0)][
t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][
[
*[f"t__{c}" for c in t_index_names],
*[f"t__{c}" for c in target_schema if c not in t_index_names],
Expand All @@ -665,8 +655,11 @@ def name(e: exp.Expression) -> str:
)

sample.drop(
columns=sample_filter_cols
+ [f"s__{SQLMESH_JOIN_KEY_COL}", f"t__{SQLMESH_JOIN_KEY_COL}"],
columns=[
f"s__{SQLMESH_JOIN_KEY_COL}",
f"t__{SQLMESH_JOIN_KEY_COL}",
SQLMESH_SAMPLE_TYPE_COL,
],
inplace=True,
)

Expand All @@ -684,4 +677,75 @@ def name(e: exp.Expression) -> str:
model_name=self.model_name,
decimals=self.decimals,
)

return self._row_diff

def _fetch_sample(
self,
sample_table: exp.Table,
s_selects: t.Dict[str, exp.Alias],
s_index: t.List[exp.Column],
t_selects: t.Dict[str, exp.Alias],
t_index: t.List[exp.Column],
limit: int,
) -> pd.DataFrame:
rendered_data_column_names = [
name(s) for s in list(s_selects.values()) + list(t_selects.values())
]
sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL)

source_only_sample = (
exp.select(
exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names
)
.from_(sample_table)
.where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0)))
.order_by(*(name(s_selects[c.name]) for c in s_index))
.limit(limit)
)

target_only_sample = (
exp.select(
exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names
)
.from_(sample_table)
.where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0)))
.order_by(*(name(t_selects[c.name]) for c in t_index))
.limit(limit)
)

common_rows_sample = (
exp.select(
exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names
)
.from_(sample_table)
.where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0)))
.order_by(
*(name(s_selects[c.name]) for c in s_index),
*(name(t_selects[c.name]) for c in t_index),
)
.limit(limit)
)

query = (
exp.Select()
.with_("source_only", source_only_sample)
.with_("target_only", target_only_sample)
.with_("common_rows", common_rows_sample)
.select(sample_type, *rendered_data_column_names)
.from_("source_only")
.union(
exp.select(sample_type, *rendered_data_column_names).from_("target_only"),
distinct=False,
)
.union(
exp.select(sample_type, *rendered_data_column_names).from_("common_rows"),
distinct=False,
)
)

return self.adapter.fetchdf(query, quote_identifiers=True)


def name(e: exp.Expression) -> str:
return e.args["alias"].sql(identify=True)
4 changes: 2 additions & 2 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2358,8 +2358,8 @@ def test_table_diff_grain_check_multiple_keys(ctx: TestContext):
assert row_diff.stats["distinct_count_s"] == 7
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
assert row_diff.stats["distinct_count_t"] == 10
assert row_diff.s_sample.shape == (0, 3)
assert row_diff.t_sample.shape == (3, 3)
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)


def test_table_diff_arbitrary_condition(ctx: TestContext):
Expand Down
54 changes: 51 additions & 3 deletions tests/core/test_table_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def test_grain_check(sushi_context_fixed_date):
assert row_diff.stats["distinct_count_s"] == 7
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
assert row_diff.stats["distinct_count_t"] == 10
assert row_diff.s_sample.shape == (0, 3)
assert row_diff.t_sample.shape == (3, 3)
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)


def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture):
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1'
sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20'
sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"'
drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'

# make with_log_level() return the current instance of engine_adapter so we can still spy on _execute
Expand Down Expand Up @@ -1032,3 +1032,51 @@ def test_schema_diff_ignore_case():
exp.DataType.build("date"),
) # notice: source casing used on output
}


def test_data_diff_sample_limit():
engine_adapter = DuckDBConnectionConfig().create_engine_adapter()

columns_to_types = {"id": exp.DataType.build("int"), "name": exp.DataType.build("varchar")}

engine_adapter.create_table("src", columns_to_types)
engine_adapter.create_table("target", columns_to_types)

common_records = {}
src_only_records = {}
target_only_records = {}

for i in range(0, 10):
common_records[i] = f"common_{i}"
src_only_records[i + 20] = f"src_{i}"
target_only_records[i + 40] = f"target_{i}"

src_records = {**common_records, **src_only_records}
target_records = {**common_records, **target_only_records}

# changes
src_records[1] = "modified_source_1"
src_records[3] = "modified_source_3"
target_records[2] = "modified_target_2"
target_records[7] = "modified_target_7"

src_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in src_records.items()])
target_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in target_records.items()])

engine_adapter.insert_append("src", src_df)
engine_adapter.insert_append("target", target_df)

table_diff = TableDiff(
adapter=engine_adapter, source="src", target="target", on=["id"], limit=3
)

diff = table_diff.row_diff()

assert diff.s_only_count == 10
assert diff.t_only_count == 10
assert diff.join_count == 10

# each sample should contain :limit records
assert len(diff.s_sample) == 3
assert len(diff.t_sample) == 3
assert len(diff.joined_sample) == 3