Skip to content

Commit ed46961

Browse files
committed
PR feedback
1 parent 7e4fd48 commit ed46961

6 files changed

Lines changed: 94 additions & 24 deletions

File tree

docs/concepts/state.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,14 @@ The state file is a simple `json` file that looks like:
7474
/* object for every Virtual Data Environment in the project. key = environment name, value = environment details */
7575
"environments": {
7676
"prod": {
77-
"..."
77+
/* information about the environment itself */
78+
"environment": {
79+
"..."
80+
},
81+
/* information about any before_all / after_all statements for this environment */
82+
"statements": [
83+
"..."
84+
]
7885
}
7986
}
8087
}

sqlmesh/core/console.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,8 @@ def stop_state_export(self, success: bool, output_file: Path) -> None:
11291129
self.state_export_progress.stop()
11301130
self.state_export_progress = None
11311131

1132+
self.log_status_update("")
1133+
11321134
if success:
11331135
self.log_success(f"State exported successfully to '{output_file.as_posix()}'")
11341136
else:
@@ -1252,6 +1254,8 @@ def stop_state_import(self, success: bool, input_file: Path) -> None:
12521254
self.state_import_progress.stop()
12531255
self.state_import_progress = None
12541256

1257+
self.log_status_update("")
1258+
12551259
if success:
12561260
self.log_success(f"State imported successfully from '{input_file.as_posix()}'")
12571261
else:

sqlmesh/core/state_sync/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from sqlmesh.core.console import Console
1010
from sqlmesh.core.dialect import schema_
11-
from sqlmesh.core.environment import Environment
11+
from sqlmesh.utils.pydantic import PydanticModel
12+
from sqlmesh.core.environment import Environment, EnvironmentStatements
1213
from sqlmesh.utils.errors import SQLMeshError
1314
from sqlmesh.core.snapshot import Snapshot
1415

@@ -96,6 +97,11 @@ def chunk_iterable(iterable: t.Iterable[T], size: int = 10) -> t.Iterable[t.Iter
9697
yield itertools.chain([first], itertools.islice(iterator, size - 1))
9798

9899

100+
class EnvironmentWithStatements(PydanticModel):
101+
environment: Environment
102+
statements: t.List[EnvironmentStatements] = []
103+
104+
99105
class StateStream(abc.ABC):
100106
"""
101107
Represents a stream of state either going into the StateSync (perhaps loaded from a file)
@@ -114,5 +120,5 @@ def snapshots(self) -> t.Iterable[Snapshot]:
114120

115121
@property
116122
@abc.abstractmethod
117-
def environments(self) -> t.Iterable[Environment]:
118-
"""A stream of Environment objects"""
123+
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
124+
"""A stream of Environments with any EnvironmentStatements attached"""

sqlmesh/core/state_sync/db/facade.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
transactional,
5252
StateStream,
5353
chunk_iterable,
54+
EnvironmentWithStatements,
5455
)
5556
from sqlmesh.core.state_sync.db.interval import IntervalState
5657
from sqlmesh.core.state_sync.db.environment import EnvironmentState
@@ -482,11 +483,13 @@ def snapshots(self) -> t.Iterable[Snapshot]:
482483
yield from state_sync.get_snapshots(chunk).values()
483484

484485
@property
485-
def environments(self) -> t.Iterable[Environment]:
486-
if environment_names:
487-
yield from selected_environments
488-
else:
489-
yield from state_sync.get_environments()
486+
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
487+
envs = selected_environments if environment_names else state_sync.get_environments()
488+
489+
for env in envs:
490+
yield EnvironmentWithStatements(
491+
environment=env, statements=state_sync.get_environment_statements(env.name)
492+
)
490493

491494
return _DumpStateStream()
492495

@@ -515,7 +518,7 @@ def import_(self, stream: StateStream, clear: bool = True) -> None:
515518
)
516519
overwrite_existing_snapshots = (
517520
not clear
518-
) # if clear=True, all existing snapshjots were dropped anyway
521+
) # if clear=True, all existing snapshots were dropped anyway
519522
self.snapshot_state.push_snapshots(
520523
snapshot_iterator, overwrite=overwrite_existing_snapshots
521524
)
@@ -529,12 +532,12 @@ def import_(self, stream: StateStream, clear: bool = True) -> None:
529532
}
530533
)
531534

532-
existing_environments = set(self.get_environments_summary().keys()) if not clear else set()
533-
for environment in stream.environments:
534-
if not clear and environment.name in existing_environments:
535-
self.environment_state.update_environment(environment)
536-
else:
537-
self.promote(environment)
535+
for environment_with_statements in stream.environments:
536+
environment = environment_with_statements.environment
537+
self.environment_state.update_environment(environment)
538+
self.environment_state.update_environment_statements(
539+
environment.name, environment.plan_id, environment_with_statements.statements
540+
)
538541

539542
self.update_auto_restatements(auto_restatements)
540543

sqlmesh/core/state_sync/export_import.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from sqlmesh.core.state_sync import StateSync
44
from sqlmesh.core.snapshot import Snapshot
55
from sqlmesh.utils.date import now, to_tstz
6-
from sqlmesh.core.environment import Environment
76
from sqlmesh.utils.pydantic import _expression_encoder
87
from sqlmesh.core.state_sync import Versions
9-
from sqlmesh.core.state_sync.common import StateStream
8+
from sqlmesh.core.state_sync.common import StateStream, EnvironmentWithStatements
109
from sqlmesh.core.console import Console
1110
from pathlib import Path
1211
from sqlmesh.core.console import NoopConsole
@@ -45,7 +44,7 @@ def snapshots(self) -> t.Iterable[Snapshot]:
4544
return iter(snapshots.values())
4645

4746
@property
48-
def environments(self) -> t.Iterable[Environment]:
47+
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
4948
return []
5049

5150
return _LocalStateStream()
@@ -71,11 +70,11 @@ def _dump_snapshots(
7170

7271
@streamable_dict
7372
def _dump_environments(
74-
environment_stream: t.Iterable[Environment],
73+
environment_stream: t.Iterable[EnvironmentWithStatements],
7574
) -> t.Iterator[t.Tuple[str, t.Any]]:
7675
console.update_state_export_progress(environment_count=0)
7776
for idx, env in enumerate(environment_stream):
78-
yield env.name, _dump_pydantic_model(env)
77+
yield env.environment.name, _dump_pydantic_model(env)
7978
console.update_state_export_progress(environment_count=idx + 1)
8079

8180
@streamable_dict
@@ -129,12 +128,14 @@ def snapshots(self) -> t.Iterable[Snapshot]:
129128
console.update_state_import_progress(snapshots_complete=True)
130129

131130
@property
132-
def environments(self) -> t.Iterable[Environment]:
131+
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
133132
stream = data()["environments"]
134133

135134
console.update_state_import_progress(environment_count=0)
136135
for idx, (_, raw_environment) in enumerate(stream.items()):
137-
environment = Environment.model_validate(to_standard_types(raw_environment))
136+
environment = EnvironmentWithStatements.model_validate(
137+
to_standard_types(raw_environment)
138+
)
138139
yield environment
139140
console.update_state_import_progress(environment_count=idx + 1)
140141

tests/core/state_sync/test_export_import.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from pathlib import Path
3-
from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync
3+
from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync, CachingStateSync
44
from sqlmesh.core.state_sync.export_import import export_state, import_state
55
from sqlmesh.utils.errors import SQLMeshError
66
from sqlmesh.core import constants as c
@@ -566,3 +566,52 @@ def test_roundtrip_includes_auto_restatements(
566566

567567
plan = context.plan(skip_tests=True)
568568
assert not plan.has_changes
569+
570+
571+
def test_roundtrip_includes_environment_statements(tmp_path: Path) -> None:
572+
config = Config(
573+
gateways={
574+
"main": GatewayConfig(
575+
connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")),
576+
state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")),
577+
)
578+
},
579+
default_gateway="main",
580+
model_defaults=ModelDefaultsConfig(
581+
dialect="duckdb",
582+
),
583+
before_all=["select 1 as before_all"],
584+
after_all=["select 2 as after_all"],
585+
)
586+
587+
context = Context(paths=tmp_path, config=config)
588+
context.plan(auto_apply=True)
589+
590+
state_file = tmp_path / "state_dump.json"
591+
context.export_state(state_file)
592+
593+
environments = json.loads(state_file.read_text(encoding="utf8"))["environments"]
594+
595+
assert environments["prod"]["statements"][0]["before_all"][0] == "select 1 as before_all"
596+
assert environments["prod"]["statements"][0]["after_all"][0] == "select 2 as after_all"
597+
598+
assert not context.plan().has_changes
599+
600+
state_sync = context.state_sync
601+
assert isinstance(state_sync, CachingStateSync)
602+
assert isinstance(state_sync.state_sync, EngineAdapterStateSync)
603+
604+
# show state destroyed
605+
state_sync.state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) # type: ignore
606+
with pytest.raises(SQLMeshError, match=r"Please run a migration"):
607+
state_sync.get_versions(validate=True)
608+
609+
state_sync.migrate(default_catalog=None)
610+
import_state(state_sync, state_file)
611+
612+
assert not context.plan().has_changes
613+
614+
environment_statements = state_sync.get_environment_statements("prod")
615+
assert len(environment_statements) == 1
616+
assert environment_statements[0].before_all[0] == "select 1 as before_all"
617+
assert environment_statements[0].after_all[0] == "select 2 as after_all"

0 commit comments

Comments
 (0)