|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import typing as t |
| 4 | +import json |
4 | 5 | from collections import Counter |
5 | 6 | from datetime import timedelta |
6 | 7 | from unittest import mock |
@@ -3995,6 +3996,59 @@ def test_empty_backfill_new_model(init_and_plan_context: t.Callable): |
3995 | 3996 | assert snapshot.intervals[-1][1] <= to_timestamp("2023-01-08") |
3996 | 3997 |
|
3997 | 3998 |
|
| 3999 | +@time_machine.travel("2023-01-08 15:00:00 UTC") |
| 4000 | +@pytest.mark.parametrize("forward_only", [False, True]) |
| 4001 | +def test_plan_repairs_unrenderable_snapshot_state( |
| 4002 | + init_and_plan_context: t.Callable, forward_only: bool |
| 4003 | +): |
| 4004 | + context, plan = init_and_plan_context("examples/sushi") |
| 4005 | + context.apply(plan) |
| 4006 | + |
| 4007 | + target_snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") |
| 4008 | + assert target_snapshot |
| 4009 | + |
| 4010 | + # Manually corrupt the snapshot's query |
| 4011 | + raw_snapshot = context.state_sync.state_sync.engine_adapter.fetchone( |
| 4012 | + f"SELECT snapshot FROM sqlmesh._snapshots WHERE name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'" |
| 4013 | + )[0] # type: ignore |
| 4014 | + parsed_snapshot = json.loads(raw_snapshot) |
| 4015 | + parsed_snapshot["node"]["query"] = "SELECT @missing_macro()" |
| 4016 | + context.state_sync.state_sync.engine_adapter.update_table( |
| 4017 | + "sqlmesh._snapshots", |
| 4018 | + {"snapshot": json.dumps(parsed_snapshot)}, |
| 4019 | + f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'", |
| 4020 | + ) |
| 4021 | + |
| 4022 | + context.clear_caches() |
| 4023 | + |
| 4024 | + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ |
| 4025 | + target_snapshot.snapshot_id |
| 4026 | + ] |
| 4027 | + with pytest.raises(Exception): |
| 4028 | + target_snapshot_in_state.model.render_query_or_raise() |
| 4029 | + |
| 4030 | + # Repair the snapshot by creating a new version of it |
| 4031 | + context.upsert_model(target_snapshot.model.name, stamp="repair") |
| 4032 | + target_snapshot = context.get_snapshot(target_snapshot.name) |
| 4033 | + |
| 4034 | + plan_builder = context.plan_builder("prod", forward_only=forward_only) |
| 4035 | + plan = plan_builder.build() |
| 4036 | + assert plan.directly_modified == {target_snapshot.snapshot_id} |
| 4037 | + if not forward_only: |
| 4038 | + assert {i.snapshot_id for i in plan.missing_intervals} == {target_snapshot.snapshot_id} |
| 4039 | + plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) |
| 4040 | + plan = plan_builder.build() |
| 4041 | + |
| 4042 | + context.apply(plan) |
| 4043 | + |
| 4044 | + context.clear_caches() |
| 4045 | + assert context.get_snapshot(target_snapshot.name).model.render_query_or_raise() |
| 4046 | + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ |
| 4047 | + target_snapshot.snapshot_id |
| 4048 | + ] |
| 4049 | + assert target_snapshot_in_state.model.render_query_or_raise() |
| 4050 | + |
| 4051 | + |
3998 | 4052 | @time_machine.travel("2023-01-08 15:00:00 UTC") |
3999 | 4053 | def test_dbt_requirements(sushi_dbt_context: Context): |
4000 | 4054 | assert set(sushi_dbt_context.requirements) == {"dbt-core", "dbt-duckdb"} |
|
0 commit comments