Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,39 +911,40 @@ def _migrate_snapshot(
):
return

deployability_index = DeployabilityIndex.all_deployable()
render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
runtime_stage=RuntimeStage.CREATING,
deployability_index=deployability_index,
)
target_table_name = snapshot.table_name()
if adapter.table_exists(target_table_name):
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
tmp_table_name = snapshot.table_name(is_deployable=False)
logger.info(
"Migrating table schema from '%s' to '%s'",
tmp_table_name,
target_table_name,
)
evaluation_strategy.migrate(
target_table_name=target_table_name,
source_table_name=tmp_table_name,
snapshot=snapshot,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
allow_destructive_snapshots=allow_destructive_snapshots,
)
else:
logger.info(
"Creating table '%s' for the snapshot of the forward-only model %s",
target_table_name,
snapshot.snapshot_id,
)
deployability_index = DeployabilityIndex.all_deployable()
render_kwargs: t.Dict[str, t.Any] = dict(
engine_adapter=adapter,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
runtime_stage=RuntimeStage.CREATING,
deployability_index=deployability_index,
)
with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
):

with (
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
):
if adapter.table_exists(target_table_name):
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
tmp_table_name = snapshot.table_name(is_deployable=False)
logger.info(
"Migrating table schema from '%s' to '%s'",
tmp_table_name,
target_table_name,
)
evaluation_strategy.migrate(
target_table_name=target_table_name,
source_table_name=tmp_table_name,
snapshot=snapshot,
snapshots=parent_snapshots_by_name(snapshot, snapshots),
allow_destructive_snapshots=allow_destructive_snapshots,
)
else:
logger.info(
"Creating table '%s' for the snapshot of the forward-only model %s",
target_table_name,
snapshot.snapshot_id,
)
self._execute_create(
snapshot=snapshot,
table_name=target_table_name,
Expand Down
19 changes: 14 additions & 5 deletions tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
SnapshotTableCleanupTask,
)
from sqlmesh.core.snapshot.definition import to_view_mapping
from sqlmesh.core.snapshot.evaluator import CustomMaterialization
from sqlmesh.core.snapshot.evaluator import CustomMaterialization, SnapshotCreationFailedError
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.date import to_timestamp
from sqlmesh.utils.errors import ConfigError, SQLMeshError, DestructiveChangeError
Expand Down Expand Up @@ -92,13 +92,16 @@ def date_kwargs() -> t.Dict[str, str]:

@pytest.fixture
def adapter_mock(mocker: MockerFixture):
def mock_exit(self, exc_type, exc_value, traceback):
pass

transaction_mock = mocker.Mock()
transaction_mock.__enter__ = mocker.Mock()
transaction_mock.__exit__ = mocker.Mock()
transaction_mock.__exit__ = mock_exit

session_mock = mocker.Mock()
session_mock.__enter__ = mocker.Mock()
session_mock.__exit__ = mocker.Mock()
session_mock.__exit__ = mock_exit

adapter_mock = mocker.Mock()
adapter_mock.transaction.return_value = transaction_mock
Expand Down Expand Up @@ -1160,6 +1163,7 @@ def test_migrate(mocker: MockerFixture, make_snapshot):
cursor_mock = mocker.Mock()
connection_mock.cursor.return_value = cursor_mock
adapter = EngineAdapter(lambda: connection_mock, "")
session_spy = mocker.spy(adapter, "session")

current_table = "sqlmesh__test_schema.test_schema__test_model__1"

Expand Down Expand Up @@ -1201,6 +1205,8 @@ def columns(table_name):
]
)

session_spy.assert_called_once()


def test_migrate_missing_table(mocker: MockerFixture, make_snapshot):
connection_mock = mocker.NonCallableMock()
Expand Down Expand Up @@ -1596,7 +1602,8 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m
),
]

evaluator.create([snapshot], {})
with pytest.raises(SnapshotCreationFailedError):
evaluator.create([snapshot], {})

adapter_mock.clone_table.assert_called_once_with(
f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev",
Expand Down Expand Up @@ -2537,7 +2544,9 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

evaluator = SnapshotEvaluator(adapter_mock)
evaluator.create([snapshot], {})

with pytest.raises(SnapshotCreationFailedError):
evaluator.create([snapshot], {})

adapter_mock.replace_query.assert_called_once_with(
f"sqlmesh__db.db__seed__{snapshot.version}",
Expand Down