diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index f2042583d0..d33748630d 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -911,39 +911,40 @@ def _migrate_snapshot( ): return + deployability_index = DeployabilityIndex.all_deployable() + render_kwargs: t.Dict[str, t.Any] = dict( + engine_adapter=adapter, + snapshots=parent_snapshots_by_name(snapshot, snapshots), + runtime_stage=RuntimeStage.CREATING, + deployability_index=deployability_index, + ) target_table_name = snapshot.table_name() - if adapter.table_exists(target_table_name): - evaluation_strategy = _evaluation_strategy(snapshot, adapter) - tmp_table_name = snapshot.table_name(is_deployable=False) - logger.info( - "Migrating table schema from '%s' to '%s'", - tmp_table_name, - target_table_name, - ) - evaluation_strategy.migrate( - target_table_name=target_table_name, - source_table_name=tmp_table_name, - snapshot=snapshot, - snapshots=parent_snapshots_by_name(snapshot, snapshots), - allow_destructive_snapshots=allow_destructive_snapshots, - ) - else: - logger.info( - "Creating table '%s' for the snapshot of the forward-only model %s", - target_table_name, - snapshot.snapshot_id, - ) - deployability_index = DeployabilityIndex.all_deployable() - render_kwargs: t.Dict[str, t.Any] = dict( - engine_adapter=adapter, - snapshots=parent_snapshots_by_name(snapshot, snapshots), - runtime_stage=RuntimeStage.CREATING, - deployability_index=deployability_index, - ) - with ( - adapter.transaction(), - adapter.session(snapshot.model.render_session_properties(**render_kwargs)), - ): + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_kwargs)), + ): + if adapter.table_exists(target_table_name): + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + tmp_table_name = snapshot.table_name(is_deployable=False) + logger.info( + "Migrating table schema from '%s' to '%s'", + tmp_table_name, + target_table_name, + ) + evaluation_strategy.migrate( + target_table_name=target_table_name, + source_table_name=tmp_table_name, + snapshot=snapshot, + snapshots=parent_snapshots_by_name(snapshot, snapshots), + allow_destructive_snapshots=allow_destructive_snapshots, + ) + else: + logger.info( + "Creating table '%s' for the snapshot of the forward-only model %s", + target_table_name, + snapshot.snapshot_id, + ) self._execute_create( snapshot=snapshot, table_name=target_table_name, diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index d572fc7e11..d131e6aa95 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -54,7 +54,7 @@ SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.definition import to_view_mapping -from sqlmesh.core.snapshot.evaluator import CustomMaterialization +from sqlmesh.core.snapshot.evaluator import CustomMaterialization, SnapshotCreationFailedError from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.date import to_timestamp from sqlmesh.utils.errors import ConfigError, SQLMeshError, DestructiveChangeError @@ -92,13 +92,16 @@ def date_kwargs() -> t.Dict[str, str]: @pytest.fixture def adapter_mock(mocker: MockerFixture): + def mock_exit(self, exc_type, exc_value, traceback): + pass + transaction_mock = mocker.Mock() transaction_mock.__enter__ = mocker.Mock() - transaction_mock.__exit__ = mocker.Mock() + transaction_mock.__exit__ = mock_exit session_mock = mocker.Mock() session_mock.__enter__ = mocker.Mock() - session_mock.__exit__ = mocker.Mock() + session_mock.__exit__ = mock_exit adapter_mock = mocker.Mock() adapter_mock.transaction.return_value = transaction_mock @@ -1160,6 +1163,7 @@ def test_migrate(mocker: MockerFixture, make_snapshot): cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock adapter = EngineAdapter(lambda: connection_mock, "") + session_spy = mocker.spy(adapter, "session") current_table = "sqlmesh__test_schema.test_schema__test_model__1" @@ -1201,6 +1205,8 @@ def columns(table_name): ] ) + session_spy.assert_called_once() + def test_migrate_missing_table(mocker: MockerFixture, make_snapshot): connection_mock = mocker.NonCallableMock() @@ -1596,7 +1602,8 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m ), ] - evaluator.create([snapshot], {}) + with pytest.raises(SnapshotCreationFailedError): + evaluator.create([snapshot], {}) adapter_mock.clone_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", @@ -2537,7 +2544,9 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot snapshot.categorize_as(SnapshotChangeCategory.BREAKING) evaluator = SnapshotEvaluator(adapter_mock) - evaluator.create([snapshot], {}) + + with pytest.raises(SnapshotCreationFailedError): + evaluator.create([snapshot], {}) adapter_mock.replace_query.assert_called_once_with( f"sqlmesh__db.db__seed__{snapshot.version}",