diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index ac4c9c71cc..6a91b22dfb 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -24,6 +24,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key" +SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type" class SchemaDiff(PydanticModel, frozen=True): @@ -389,9 +390,6 @@ def _column_expr(name: str, table: str) -> exp.Expression: for c, t in matched_columns.items() ] - def name(e: exp.Expression) -> str: - return e.args["alias"].sql(identify=True) - source_query = ( exp.select( *(exp.column(c) for c in source_schema), @@ -581,30 +579,19 @@ def name(e: exp.Expression) -> str: .drop(index=index_cols, errors="ignore") ) - sample_filter_cols = ["s_exists", "t_exists", "row_joined", "row_full_match"] - sample_query = ( - exp.select( - *(sample_filter_cols), - *(name(c) for c in s_selects.values()), - *(name(c) for c in t_selects.values()), - ) - .from_(table) - .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons))) - .order_by( - *(name(s_selects[c.name]) for c in s_index), - *(name(t_selects[c.name]) for c in t_index), - ) - .limit(self.limit) + sample = self._fetch_sample( + table, s_selects, s_index, t_selects, t_index, self.limit ) - sample = self.adapter.fetchdf(sample_query, quote_identifiers=True) joined_sample_cols = [f"s__{c}" for c in s_index_names] comparison_cols = [ (f"s__{c}", f"t__{c}") for c in column_stats[column_stats["pct_match"] < 100].index ] + for cols in comparison_cols: joined_sample_cols.extend(cols) + joined_renamed_cols = { c: c.split("__")[1] if c.split("__")[1] in index_cols else c for c in joined_sample_cols @@ -638,13 +625,16 @@ def name(e: exp.Expression) -> str: ) for c, n in joined_renamed_cols.items() } - joined_sample = sample[sample["row_joined"] == 1][joined_sample_cols] + + joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][ + joined_sample_cols + ] joined_sample.rename( columns=joined_renamed_cols, inplace=True, ) - s_sample = sample[(sample["s_exists"] == 1) & (sample["row_joined"] == 0)][ + s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][ [ *[f"s__{c}" for c in s_index_names], *[f"s__{c}" for c in source_schema if c not in s_index_names], @@ -654,7 +644,7 @@ def name(e: exp.Expression) -> str: columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True ) - t_sample = sample[(sample["t_exists"] == 1) & (sample["row_joined"] == 0)][ + t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][ [ *[f"t__{c}" for c in t_index_names], *[f"t__{c}" for c in target_schema if c not in t_index_names], @@ -665,8 +655,11 @@ def name(e: exp.Expression) -> str: ) sample.drop( - columns=sample_filter_cols - + [f"s__{SQLMESH_JOIN_KEY_COL}", f"t__{SQLMESH_JOIN_KEY_COL}"], + columns=[ + f"s__{SQLMESH_JOIN_KEY_COL}", + f"t__{SQLMESH_JOIN_KEY_COL}", + SQLMESH_SAMPLE_TYPE_COL, + ], inplace=True, ) @@ -684,4 +677,75 @@ def name(e: exp.Expression) -> str: model_name=self.model_name, decimals=self.decimals, ) + return self._row_diff + + def _fetch_sample( + self, + sample_table: exp.Table, + s_selects: t.Dict[str, exp.Alias], + s_index: t.List[exp.Column], + t_selects: t.Dict[str, exp.Alias], + t_index: t.List[exp.Column], + limit: int, + ) -> pd.DataFrame: + rendered_data_column_names = [ + name(s) for s in list(s_selects.values()) + list(t_selects.values()) + ] + sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL) + + source_only_sample = ( + exp.select( + exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0))) + .order_by(*(name(s_selects[c.name]) for c in s_index)) + .limit(limit) + ) + + target_only_sample = ( + exp.select( + exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0))) + .order_by(*(name(t_selects[c.name]) for c in t_index)) + .limit(limit) + ) + + common_rows_sample = ( + exp.select( + exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0))) + .order_by( + *(name(s_selects[c.name]) for c in s_index), + *(name(t_selects[c.name]) for c in t_index), + ) + .limit(limit) + ) + + query = ( + exp.Select() + .with_("source_only", source_only_sample) + .with_("target_only", target_only_sample) + .with_("common_rows", common_rows_sample) + .select(sample_type, *rendered_data_column_names) + .from_("source_only") + .union( + exp.select(sample_type, *rendered_data_column_names).from_("target_only"), + distinct=False, + ) + .union( + exp.select(sample_type, *rendered_data_column_names).from_("common_rows"), + distinct=False, + ) + ) + + return self.adapter.fetchdf(query, quote_identifiers=True) + + +def name(e: exp.Expression) -> str: + return e.args["alias"].sql(identify=True) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 6509eb447c..873d25547f 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2358,8 +2358,8 @@ def test_table_diff_grain_check_multiple_keys(ctx: TestContext): assert row_diff.stats["distinct_count_s"] == 7 assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"] assert row_diff.stats["distinct_count_t"] == 10 - assert row_diff.s_sample.shape == (0, 3) - assert row_diff.t_sample.shape == (3, 3) + assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) + assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) def test_table_diff_arbitrary_condition(ctx: TestContext): diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index bf491d77a7..a9b56650c7 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -306,8 +306,8 @@ def test_grain_check(sushi_context_fixed_date): assert row_diff.stats["distinct_count_s"] == 7 assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"] assert row_diff.stats["distinct_count_t"] == 10 - assert row_diff.s_sample.shape == (0, 3) - assert row_diff.t_sample.shape == (3, 3) + assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) + assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture): @@ -338,7 +338,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1' - sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20' + sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"' drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' # make with_log_level() return the current instance of engine_adapter so we can still spy on _execute @@ -1032,3 +1032,51 @@ def test_schema_diff_ignore_case(): exp.DataType.build("date"), ) # notice: source casing used on output } + + +def test_data_diff_sample_limit(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = {"id": exp.DataType.build("int"), "name": exp.DataType.build("varchar")} + + engine_adapter.create_table("src", columns_to_types) + engine_adapter.create_table("target", columns_to_types) + + common_records = {} + src_only_records = {} + target_only_records = {} + + for i in range(0, 10): + common_records[i] = f"common_{i}" + src_only_records[i + 20] = f"src_{i}" + target_only_records[i + 40] = f"target_{i}" + + src_records = {**common_records, **src_only_records} + target_records = {**common_records, **target_only_records} + + # changes + src_records[1] = "modified_source_1" + src_records[3] = "modified_source_3" + target_records[2] = "modified_target_2" + target_records[7] = "modified_target_7" + + src_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in src_records.items()]) + target_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in target_records.items()]) + + engine_adapter.insert_append("src", src_df) + engine_adapter.insert_append("target", target_df) + + table_diff = TableDiff( + adapter=engine_adapter, source="src", target="target", on=["id"], limit=3 + ) + + diff = table_diff.row_diff() + + assert diff.s_only_count == 10 + assert diff.t_only_count == 10 + assert diff.join_count == 10 + + # each sample should contain :limit records + assert len(diff.s_sample) == 3 + assert len(diff.t_sample) == 3 + assert len(diff.joined_sample) == 3