Skip to content

Commit 2ef4d09

Browse files
committed
Chore: Refactor state stream
1 parent 561e4fd commit 2ef4d09

4 files changed

Lines changed: 181 additions & 129 deletions

File tree

sqlmesh/core/console.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,12 +1965,16 @@ def print_environments(self, environments_summary: t.Dict[str, int]) -> None:
19651965
self.log_status_update(f"Number of SQLMesh environments are: {output_str}")
19661966

19671967
def print_connection_config(self, config: ConnectionConfig, title: str = "Connection") -> None:
1968-
engine_adapter_type = config._engine_adapter
1969-
19701968
tree = Tree(f"[b]{title}:[/b]")
19711969
tree.add(f"Type: [bold cyan]{config.type_}[/bold cyan]")
19721970
tree.add(f"Catalog: [bold cyan]{config.get_catalog()}[/bold cyan]")
1973-
tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]")
1971+
1972+
try:
1973+
engine_adapter_type = config._engine_adapter
1974+
tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]")
1975+
except NotImplementedError:
1976+
# not all ConnectionConfig's have an engine adapter associated. The CloudConnectionConfig has a HTTP client instead
1977+
pass
19741978

19751979
self._print(tree)
19761980

sqlmesh/core/state_sync/common.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import itertools
77
import abc
88

9+
from dataclasses import dataclass
10+
911
from sqlmesh.core.console import Console
1012
from sqlmesh.core.dialect import schema_
1113
from sqlmesh.utils.pydantic import PydanticModel
@@ -102,23 +104,62 @@ class EnvironmentWithStatements(PydanticModel):
102104
statements: t.List[EnvironmentStatements] = []
103105

104106

107+
@dataclass
108+
class VersionsChunk:
109+
versions: Versions
110+
111+
112+
class SnapshotsChunk:
113+
def __init__(self, items: t.Iterator[Snapshot]):
114+
self.items = items
115+
116+
def __iter__(self) -> t.Iterator[Snapshot]:
117+
return self.items
118+
119+
120+
class EnvironmentsChunk:
121+
def __init__(self, items: t.Iterator[EnvironmentWithStatements]):
122+
self.items = items
123+
124+
def __iter__(self) -> t.Iterator[EnvironmentWithStatements]:
125+
return self.items
126+
127+
128+
StateStreamContents = t.Union[VersionsChunk, SnapshotsChunk, EnvironmentsChunk]
129+
130+
105131
class StateStream(abc.ABC):
106132
"""
107133
Represents a stream of state either going into the StateSync (perhaps loaded from a file)
108134
or out of the StateSync (perhaps being dumped to a file)
135+
136+
Iterating over the stream produces the following chunks:
137+
138+
VersionsChunk: The versions of the objects contained in this StateStream
139+
SnapshotsChunk: Is itself an iterator that streams Snapshot objects. Note that they should be fully populated with any relevant Intervals
140+
EnvironmentsChunk: Is itself an iterator emitting a stream of Environments with any EnvironmentStatements attached
141+
142+
The idea here is to give some structure to the stream and ensure that callers have the opportunity to process all its components while not
143+
needing to worry about the order they are emitted in
109144
"""
110145

111-
@property
112146
@abc.abstractmethod
113-
def versions(self) -> Versions:
114-
"""The versions of the objects contained in this StateStream"""
147+
def __iter__(self) -> t.Iterator[StateStreamContents]:
148+
pass
115149

116-
@property
117-
@abc.abstractmethod
118-
def snapshots(self) -> t.Iterable[Snapshot]:
119-
"""A stream of Snapshot objects. Note that they should be fully populated with any relevant Intervals"""
150+
@classmethod
151+
def from_iterators(
152+
cls: t.Type["StateStream"],
153+
versions: Versions,
154+
snapshots: t.Iterator[Snapshot],
155+
environments: t.Iterator[EnvironmentWithStatements],
156+
) -> "StateStream":
157+
class _StateStream(cls): # type: ignore
158+
def __iter__(self) -> t.Iterator[StateStreamContents]:
159+
yield VersionsChunk(versions)
120160

121-
@property
122-
@abc.abstractmethod
123-
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
124-
"""A stream of Environments with any EnvironmentStatements attached"""
161+
yield SnapshotsChunk(snapshots)
162+
163+
yield EnvironmentsChunk(environments)
164+
165+
return _StateStream()

sqlmesh/core/state_sync/db/facade.py

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
Versions,
4949
)
5050
from sqlmesh.core.state_sync.common import (
51+
EnvironmentsChunk,
52+
SnapshotsChunk,
53+
VersionsChunk,
5154
transactional,
5255
StateStream,
5356
chunk_iterable,
@@ -466,80 +469,88 @@ def _include_snapshot(s_id: SnapshotId) -> bool:
466469
return s_id in snapshot_ids_to_export
467470
return True
468471

469-
class _DumpStateStream(StateStream):
470-
@property
471-
def versions(self) -> Versions:
472-
return state_sync.get_versions()
473-
474-
@property
475-
def snapshots(self) -> t.Iterable[Snapshot]:
476-
all_snapshot_ids = {
477-
s.snapshot_id
478-
for e in state_sync.get_environments()
479-
for s in e.snapshots
480-
if _include_snapshot(s.snapshot_id)
481-
}
482-
for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE):
483-
yield from state_sync.get_snapshots(chunk).values()
472+
def _export_snapshots() -> t.Iterator[Snapshot]:
473+
all_snapshot_ids = {
474+
s.snapshot_id
475+
for e in state_sync.get_environments()
476+
for s in e.snapshots
477+
if _include_snapshot(s.snapshot_id)
478+
}
479+
for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE):
480+
yield from state_sync.get_snapshots(chunk).values()
484481

485-
@property
486-
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
487-
envs = selected_environments if environment_names else state_sync.get_environments()
482+
def _export_environments() -> t.Iterator[EnvironmentWithStatements]:
483+
envs = selected_environments if environment_names else state_sync.get_environments()
488484

489-
for env in envs:
490-
yield EnvironmentWithStatements(
491-
environment=env, statements=state_sync.get_environment_statements(env.name)
492-
)
485+
for env in envs:
486+
yield EnvironmentWithStatements(
487+
environment=env, statements=state_sync.get_environment_statements(env.name)
488+
)
493489

494-
return _DumpStateStream()
490+
return StateStream.from_iterators(
491+
versions=state_sync.get_versions(),
492+
snapshots=_export_snapshots(),
493+
environments=_export_environments(),
494+
)
495495

496496
@transactional()
497497
def import_(self, stream: StateStream, clear: bool = True) -> None:
498498
existing_versions = self.get_versions()
499499

500-
# SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
501-
# is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
502-
# is that they dont contain any breaking changes
503-
incoming_versions = stream.versions
504-
if incoming_versions.minor_sqlmesh_version != existing_versions.minor_sqlmesh_version:
505-
raise SQLMeshError(
506-
f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n"
507-
"Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
508-
)
500+
for state_chunk in stream:
501+
if isinstance(state_chunk, VersionsChunk):
502+
# SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
503+
# is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
504+
# is that they dont contain any breaking changes
505+
incoming_versions = state_chunk.versions
506+
if (
507+
incoming_versions.minor_sqlmesh_version
508+
!= existing_versions.minor_sqlmesh_version
509+
):
510+
raise SQLMeshError(
511+
f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n"
512+
"Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
513+
)
509514

510-
if clear:
511-
self.reset(default_catalog=None)
515+
if clear:
516+
self.reset(default_catalog=None)
512517

513-
auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {}
518+
if isinstance(state_chunk, SnapshotsChunk):
519+
auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {}
514520

515-
for snapshot_chunk in chunk_iterable(stream.snapshots, SnapshotState.SNAPSHOT_BATCH_SIZE):
516-
snapshot_iterator, intervals_iterator, auto_restatments_iterator = itertools.tee(
517-
snapshot_chunk, 3
518-
)
519-
overwrite_existing_snapshots = (
520-
not clear
521-
) # if clear=True, all existing snapshots were dropped anyway
522-
self.snapshot_state.push_snapshots(
523-
snapshot_iterator, overwrite=overwrite_existing_snapshots
524-
)
525-
self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator))
526-
527-
auto_restatements.update(
528-
{
529-
s.name_version: s.next_auto_restatement_ts
530-
for s in auto_restatments_iterator
531-
if s.next_auto_restatement_ts
532-
}
533-
)
521+
for snapshot_chunk in chunk_iterable(
522+
state_chunk, SnapshotState.SNAPSHOT_BATCH_SIZE
523+
):
524+
snapshot_iterator, intervals_iterator, auto_restatments_iterator = (
525+
itertools.tee(snapshot_chunk, 3)
526+
)
527+
overwrite_existing_snapshots = (
528+
not clear
529+
) # if clear=True, all existing snapshots were dropped anyway
530+
self.snapshot_state.push_snapshots(
531+
snapshot_iterator, overwrite=overwrite_existing_snapshots
532+
)
533+
self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator))
534+
535+
auto_restatements.update(
536+
{
537+
s.name_version: s.next_auto_restatement_ts
538+
for s in auto_restatments_iterator
539+
if s.next_auto_restatement_ts
540+
}
541+
)
534542

535-
for environment_with_statements in stream.environments:
536-
environment = environment_with_statements.environment
537-
self.environment_state.update_environment(environment)
538-
self.environment_state.update_environment_statements(
539-
environment.name, environment.plan_id, environment_with_statements.statements
540-
)
543+
self.update_auto_restatements(auto_restatements)
541544

542-
self.update_auto_restatements(auto_restatements)
545+
if isinstance(state_chunk, EnvironmentsChunk):
546+
for environment_with_statements in state_chunk:
547+
environment = environment_with_statements.environment
548+
self.environment_state.update_environment(environment)
549+
self.environment_state.update_environment_statements(
550+
environment.name,
551+
environment.plan_id,
552+
environment_with_statements.statements,
553+
)
543554

544555
def state_type(self) -> str:
545556
return self.engine_adapter.dialect

0 commit comments

Comments
 (0)