|
1 | 1 | import pytest |
2 | 2 | from pathlib import Path |
3 | | -from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync |
| 3 | +from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync, CachingStateSync |
4 | 4 | from sqlmesh.core.state_sync.export_import import export_state, import_state |
5 | 5 | from sqlmesh.utils.errors import SQLMeshError |
6 | 6 | from sqlmesh.core import constants as c |
@@ -566,3 +566,52 @@ def test_roundtrip_includes_auto_restatements( |
566 | 566 |
|
567 | 567 | plan = context.plan(skip_tests=True) |
568 | 568 | assert not plan.has_changes |
| 569 | + |
| 570 | + |
| 571 | +def test_roundtrip_includes_environment_statements(tmp_path: Path) -> None: |
| 572 | + config = Config( |
| 573 | + gateways={ |
| 574 | + "main": GatewayConfig( |
| 575 | + connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")), |
| 576 | + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), |
| 577 | + ) |
| 578 | + }, |
| 579 | + default_gateway="main", |
| 580 | + model_defaults=ModelDefaultsConfig( |
| 581 | + dialect="duckdb", |
| 582 | + ), |
| 583 | + before_all=["select 1 as before_all"], |
| 584 | + after_all=["select 2 as after_all"], |
| 585 | + ) |
| 586 | + |
| 587 | + context = Context(paths=tmp_path, config=config) |
| 588 | + context.plan(auto_apply=True) |
| 589 | + |
| 590 | + state_file = tmp_path / "state_dump.json" |
| 591 | + context.export_state(state_file) |
| 592 | + |
| 593 | + environments = json.loads(state_file.read_text(encoding="utf8"))["environments"] |
| 594 | + |
| 595 | + assert environments["prod"]["statements"][0]["before_all"][0] == "select 1 as before_all" |
| 596 | + assert environments["prod"]["statements"][0]["after_all"][0] == "select 2 as after_all" |
| 597 | + |
| 598 | + assert not context.plan().has_changes |
| 599 | + |
| 600 | + state_sync = context.state_sync |
| 601 | + assert isinstance(state_sync, CachingStateSync) |
| 602 | + assert isinstance(state_sync.state_sync, EngineAdapterStateSync) |
| 603 | + |
| 604 | + # show state destroyed |
| 605 | + state_sync.state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) # type: ignore |
| 606 | + with pytest.raises(SQLMeshError, match=r"Please run a migration"): |
| 607 | + state_sync.get_versions(validate=True) |
| 608 | + |
| 609 | + state_sync.migrate(default_catalog=None) |
| 610 | + import_state(state_sync, state_file) |
| 611 | + |
| 612 | + assert not context.plan().has_changes |
| 613 | + |
| 614 | + environment_statements = state_sync.get_environment_statements("prod") |
| 615 | + assert len(environment_statements) == 1 |
| 616 | + assert environment_statements[0].before_all[0] == "select 1 as before_all" |
| 617 | + assert environment_statements[0].after_all[0] == "select 2 as after_all" |
0 commit comments