Skip to content

Commit fc61b1c

Browse files
committed
Fix: Create a transaction and a session when migrating a snapshot
1 parent 82f2447 commit fc61b1c

2 files changed

Lines changed: 41 additions & 34 deletions

File tree

sqlmesh/core/snapshot/evaluator.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -911,39 +911,40 @@ def _migrate_snapshot(
911911
):
912912
return
913913

914+
deployability_index = DeployabilityIndex.all_deployable()
915+
render_kwargs: t.Dict[str, t.Any] = dict(
916+
engine_adapter=adapter,
917+
snapshots=parent_snapshots_by_name(snapshot, snapshots),
918+
runtime_stage=RuntimeStage.CREATING,
919+
deployability_index=deployability_index,
920+
)
914921
target_table_name = snapshot.table_name()
915-
if adapter.table_exists(target_table_name):
916-
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
917-
tmp_table_name = snapshot.table_name(is_deployable=False)
918-
logger.info(
919-
"Migrating table schema from '%s' to '%s'",
920-
tmp_table_name,
921-
target_table_name,
922-
)
923-
evaluation_strategy.migrate(
924-
target_table_name=target_table_name,
925-
source_table_name=tmp_table_name,
926-
snapshot=snapshot,
927-
snapshots=parent_snapshots_by_name(snapshot, snapshots),
928-
allow_destructive_snapshots=allow_destructive_snapshots,
929-
)
930-
else:
931-
logger.info(
932-
"Creating table '%s' for the snapshot of the forward-only model %s",
933-
target_table_name,
934-
snapshot.snapshot_id,
935-
)
936-
deployability_index = DeployabilityIndex.all_deployable()
937-
render_kwargs: t.Dict[str, t.Any] = dict(
938-
engine_adapter=adapter,
939-
snapshots=parent_snapshots_by_name(snapshot, snapshots),
940-
runtime_stage=RuntimeStage.CREATING,
941-
deployability_index=deployability_index,
942-
)
943-
with (
944-
adapter.transaction(),
945-
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
946-
):
922+
923+
with (
924+
adapter.transaction(),
925+
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
926+
):
927+
if adapter.table_exists(target_table_name):
928+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
929+
tmp_table_name = snapshot.table_name(is_deployable=False)
930+
logger.info(
931+
"Migrating table schema from '%s' to '%s'",
932+
tmp_table_name,
933+
target_table_name,
934+
)
935+
evaluation_strategy.migrate(
936+
target_table_name=target_table_name,
937+
source_table_name=tmp_table_name,
938+
snapshot=snapshot,
939+
snapshots=parent_snapshots_by_name(snapshot, snapshots),
940+
allow_destructive_snapshots=allow_destructive_snapshots,
941+
)
942+
else:
943+
logger.info(
944+
"Creating table '%s' for the snapshot of the forward-only model %s",
945+
target_table_name,
946+
snapshot.snapshot_id,
947+
)
947948
self._execute_create(
948949
snapshot=snapshot,
949950
table_name=target_table_name,

tests/core/test_snapshot_evaluator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,16 @@ def date_kwargs() -> t.Dict[str, str]:
9292

9393
@pytest.fixture
9494
def adapter_mock(mocker: MockerFixture):
95+
def mock_exit(self, exc_type, exc_value, traceback):
96+
pass
97+
9598
transaction_mock = mocker.Mock()
9699
transaction_mock.__enter__ = mocker.Mock()
97-
transaction_mock.__exit__ = mocker.Mock()
100+
transaction_mock.__exit__ = mock_exit
98101

99102
session_mock = mocker.Mock()
100103
session_mock.__enter__ = mocker.Mock()
101-
session_mock.__exit__ = mocker.Mock()
104+
session_mock.__exit__ = mock_exit
102105

103106
adapter_mock = mocker.Mock()
104107
adapter_mock.transaction.return_value = transaction_mock
@@ -1160,6 +1163,7 @@ def test_migrate(mocker: MockerFixture, make_snapshot):
11601163
cursor_mock = mocker.Mock()
11611164
connection_mock.cursor.return_value = cursor_mock
11621165
adapter = EngineAdapter(lambda: connection_mock, "")
1166+
session_spy = mocker.spy(adapter, "session")
11631167

11641168
current_table = "sqlmesh__test_schema.test_schema__test_model__1"
11651169

@@ -1201,6 +1205,8 @@ def columns(table_name):
12011205
]
12021206
)
12031207

1208+
session_spy.assert_called_once()
1209+
12041210

12051211
def test_migrate_missing_table(mocker: MockerFixture, make_snapshot):
12061212
connection_mock = mocker.NonCallableMock()

0 commit comments

Comments
 (0)