Skip to content

Commit a5e67d0

Browse files
Fix: Adjust condition in table diff to filter in forward only models (#4340)
1 parent a837409 commit a5e67d0

2 files changed

Lines changed: 96 additions & 3 deletions

File tree

sqlmesh/core/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,9 +1677,9 @@ def table_diff(
16771677
target_snapshot = target_snapshots_to_name.get(model.fqn)
16781678

16791679
if target_snapshot and source_snapshot:
1680-
if (
1681-
source_snapshot.fingerprint.data_hash
1682-
!= target_snapshot.fingerprint.data_hash
1680+
if (source_snapshot.fingerprint != target_snapshot.fingerprint) and (
1681+
(source_snapshot.version != target_snapshot.version)
1682+
or (source_snapshot.is_forward_only or target_snapshot.is_forward_only)
16831683
):
16841684
# Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point
16851685
# to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews

tests/core/test_table_diff.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,3 +695,96 @@ def test_data_diff_multiple_models(sushi_context_fixed_date, capsys, caplog):
695695
skip_grain_check=False,
696696
)
697697
assert len(diffs) == 0
698+
699+
700+
@pytest.mark.slow
701+
def test_data_diff_forward_only(sushi_context_fixed_date, capsys, caplog):
702+
expressions = d.parse(
703+
"""
704+
MODEL (name memory.sushi.full_1, kind full, grain(key),);
705+
SELECT
706+
key,
707+
value,
708+
FROM
709+
(VALUES
710+
(1, 3),
711+
(2, 4),
712+
) AS t (key, value)
713+
"""
714+
)
715+
model_s = load_sql_based_model(expressions, dialect="snowflake")
716+
sushi_context_fixed_date.upsert_model(model_s)
717+
718+
# Create second analytics model sourcing from first
719+
expressions_2 = d.parse(
720+
"""
721+
MODEL (name memory.sushi.full_2, kind full, grain(key),);
722+
SELECT
723+
key,
724+
value as amount,
725+
FROM
726+
memory.sushi.full_1
727+
"""
728+
)
729+
model_s2 = load_sql_based_model(expressions_2, dialect="snowflake")
730+
sushi_context_fixed_date.upsert_model(model_s2)
731+
732+
sushi_context_fixed_date.plan(
733+
"target_dev",
734+
no_prompts=True,
735+
auto_apply=True,
736+
skip_tests=True,
737+
start="2023-01-31",
738+
end="2023-01-31",
739+
)
740+
741+
model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."FULL_1"']
742+
modified_model = model.dict()
743+
modified_model["query"] = exp.select("*").from_("(VALUES (12, 6),(5,3),) AS t (key, value)")
744+
modified_sqlmodel = SqlModel(**modified_model)
745+
sushi_context_fixed_date.upsert_model(modified_sqlmodel)
746+
747+
sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig(
748+
sql=AutoCategorizationMode.FULL
749+
)
750+
751+
plan_builder = sushi_context_fixed_date.plan_builder(
752+
"source_dev", skip_tests=True, forward_only=True
753+
)
754+
plan = plan_builder.build()
755+
756+
sushi_context_fixed_date.apply(plan)
757+
758+
# Get diffs for both models
759+
selector = {"*full*"}
760+
diffs = sushi_context_fixed_date.table_diff(
761+
source="source_dev",
762+
target="target_dev",
763+
on=["key"],
764+
select_models=selector,
765+
skip_grain_check=False,
766+
)
767+
768+
# Both models should be diffed
769+
assert len(diffs) == 2
770+
771+
# Check full_1 diff
772+
diff1 = next(d for d in diffs if "FULL_1" in d.source)
773+
row_diff1 = diff1.row_diff()
774+
diff2 = next(d for d in diffs if "FULL_2" in d.source)
775+
row_diff2 = diff2.row_diff()
776+
777+
# Both diffs should show the same matches
778+
for row_diff in [row_diff1, row_diff2]:
779+
assert row_diff.full_match_count == 0
780+
assert row_diff.full_match_pct == 0.0
781+
assert row_diff.s_only_count == 2
782+
assert row_diff.t_only_count == 2
783+
assert row_diff.stats["join_count"] == 0
784+
assert row_diff.stats["null_grain_count"] == 0
785+
assert row_diff.stats["s_count"] == 2
786+
assert row_diff.stats["distinct_count_s"] == 2
787+
assert row_diff.stats["t_count"] == 2
788+
assert row_diff.stats["distinct_count_t"] == 2
789+
assert row_diff.s_sample.shape == (2, 2)
790+
assert row_diff.t_sample.shape == (2, 2)

0 commit comments

Comments
 (0)