Skip to content

Commit d1c34ce

Browse files
authored
Chore: Refactor the state stream interface (#4125)
1 parent a0f7566 commit d1c34ce

4 files changed

Lines changed: 181 additions & 141 deletions

File tree

sqlmesh/core/console.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,12 +1965,16 @@ def print_environments(self, environments_summary: t.Dict[str, int]) -> None:
19651965
self.log_status_update(f"Number of SQLMesh environments are: {output_str}")
19661966

19671967
def print_connection_config(self, config: ConnectionConfig, title: str = "Connection") -> None:
1968-
engine_adapter_type = config._engine_adapter
1969-
19701968
tree = Tree(f"[b]{title}:[/b]")
19711969
tree.add(f"Type: [bold cyan]{config.type_}[/bold cyan]")
19721970
tree.add(f"Catalog: [bold cyan]{config.get_catalog()}[/bold cyan]")
1973-
tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]")
1971+
1972+
try:
1973+
engine_adapter_type = config._engine_adapter
1974+
tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]")
1975+
except NotImplementedError:
1976+
# not all ConnectionConfig's have an engine adapter associated. The CloudConnectionConfig has a HTTP client instead
1977+
pass
19741978

19751979
self._print(tree)
19761980

sqlmesh/core/state_sync/common.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import itertools
77
import abc
88

9+
from dataclasses import dataclass
10+
911
from sqlmesh.core.console import Console
1012
from sqlmesh.core.dialect import schema_
1113
from sqlmesh.utils.pydantic import PydanticModel
@@ -119,23 +121,62 @@ class EnvironmentWithStatements(PydanticModel):
119121
statements: t.List[EnvironmentStatements] = []
120122

121123

124+
@dataclass
125+
class VersionsChunk:
126+
versions: Versions
127+
128+
129+
class SnapshotsChunk:
130+
def __init__(self, items: t.Iterator[Snapshot]):
131+
self.items = items
132+
133+
def __iter__(self) -> t.Iterator[Snapshot]:
134+
return self.items
135+
136+
137+
class EnvironmentsChunk:
138+
def __init__(self, items: t.Iterator[EnvironmentWithStatements]):
139+
self.items = items
140+
141+
def __iter__(self) -> t.Iterator[EnvironmentWithStatements]:
142+
return self.items
143+
144+
145+
StateStreamContents = t.Union[VersionsChunk, SnapshotsChunk, EnvironmentsChunk]
146+
147+
122148
class StateStream(abc.ABC):
123149
"""
124150
Represents a stream of state either going into the StateSync (perhaps loaded from a file)
125151
or out of the StateSync (perhaps being dumped to a file)
152+
153+
Iterating over the stream produces the following chunks:
154+
155+
VersionsChunk: The versions of the objects contained in this StateStream
156+
SnapshotsChunk: Is itself an iterator that streams Snapshot objects. Note that they should be fully populated with any relevant Intervals
157+
EnvironmentsChunk: Is itself an iterator emitting a stream of Environments with any EnvironmentStatements attached
158+
159+
The idea here is to give some structure to the stream and ensure that callers have the opportunity to process all its components while not
160+
needing to worry about the order they are emitted in
126161
"""
127162

128-
@property
129163
@abc.abstractmethod
130-
def versions(self) -> Versions:
131-
"""The versions of the objects contained in this StateStream"""
164+
def __iter__(self) -> t.Iterator[StateStreamContents]:
165+
pass
132166

133-
@property
134-
@abc.abstractmethod
135-
def snapshots(self) -> t.Iterable[Snapshot]:
136-
"""A stream of Snapshot objects. Note that they should be fully populated with any relevant Intervals"""
167+
@classmethod
168+
def from_iterators(
169+
cls: t.Type["StateStream"],
170+
versions: Versions,
171+
snapshots: t.Iterator[Snapshot],
172+
environments: t.Iterator[EnvironmentWithStatements],
173+
) -> "StateStream":
174+
class _StateStream(cls): # type: ignore
175+
def __iter__(self) -> t.Iterator[StateStreamContents]:
176+
yield VersionsChunk(versions)
137177

138-
@property
139-
@abc.abstractmethod
140-
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
141-
"""A stream of Environments with any EnvironmentStatements attached"""
178+
yield SnapshotsChunk(snapshots)
179+
180+
yield EnvironmentsChunk(environments)
181+
182+
return _StateStream()

sqlmesh/core/state_sync/db/facade.py

Lines changed: 72 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import contextlib
2020
import logging
2121
import typing as t
22-
import itertools
2322
from pathlib import Path
2423
from datetime import datetime
2524

@@ -48,6 +47,9 @@
4847
Versions,
4948
)
5049
from sqlmesh.core.state_sync.common import (
50+
EnvironmentsChunk,
51+
SnapshotsChunk,
52+
VersionsChunk,
5153
transactional,
5254
StateStream,
5355
chunk_iterable,
@@ -448,7 +450,9 @@ def rollback(self) -> None:
448450

449451
@transactional()
450452
def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream:
451-
state_sync = self
453+
versions = self.get_versions(
454+
validate=True
455+
) # will throw if the state db hasnt been created or there is a version mismatch
452456

453457
snapshot_ids_to_export: t.Set[SnapshotId] = set()
454458
selected_environments: t.List[Environment] = []
@@ -458,89 +462,84 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
458462
if not environment:
459463
raise SQLMeshError(f"No such environment: {env_name}")
460464
selected_environments.append(environment)
465+
else:
466+
selected_environments = self.get_environments()
461467

462-
for env in selected_environments:
463-
snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []])
464-
465-
def _include_snapshot(s_id: SnapshotId) -> bool:
466-
if environment_names:
467-
return s_id in snapshot_ids_to_export
468-
return True
469-
470-
class _DumpStateStream(StateStream):
471-
@property
472-
def versions(self) -> Versions:
473-
return state_sync.get_versions()
474-
475-
@property
476-
def snapshots(self) -> t.Iterable[Snapshot]:
477-
all_snapshot_ids = {
478-
s.snapshot_id
479-
for e in state_sync.get_environments()
480-
for s in e.snapshots
481-
if _include_snapshot(s.snapshot_id)
482-
}
483-
for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE):
484-
yield from state_sync.get_snapshots(chunk).values()
468+
for env in selected_environments:
469+
snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []])
485470

486-
@property
487-
def environments(self) -> t.Iterable[EnvironmentWithStatements]:
488-
envs = selected_environments if environment_names else state_sync.get_environments()
471+
def _export_snapshots() -> t.Iterator[Snapshot]:
472+
for chunk in chunk_iterable(snapshot_ids_to_export, SnapshotState.SNAPSHOT_BATCH_SIZE):
473+
yield from self.get_snapshots(chunk).values()
489474

490-
for env in envs:
491-
yield EnvironmentWithStatements(
492-
environment=env, statements=state_sync.get_environment_statements(env.name)
493-
)
475+
def _export_environments() -> t.Iterator[EnvironmentWithStatements]:
476+
for env in selected_environments:
477+
yield EnvironmentWithStatements(
478+
environment=env, statements=self.get_environment_statements(env.name)
479+
)
494480

495-
return _DumpStateStream()
481+
return StateStream.from_iterators(
482+
versions=versions,
483+
snapshots=_export_snapshots(),
484+
environments=_export_environments(),
485+
)
496486

497487
@transactional()
498488
def import_(self, stream: StateStream, clear: bool = True) -> None:
499489
existing_versions = self.get_versions()
500490

501-
# SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
502-
# is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
503-
# is that they dont contain any breaking changes
504-
incoming_versions = stream.versions
505-
if incoming_versions.minor_sqlmesh_version != existing_versions.minor_sqlmesh_version:
506-
raise SQLMeshError(
507-
f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n"
508-
"Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
509-
)
510-
511-
if clear:
512-
self.reset(default_catalog=None)
513-
514-
auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {}
515-
516-
for snapshot_chunk in chunk_iterable(stream.snapshots, SnapshotState.SNAPSHOT_BATCH_SIZE):
517-
snapshot_iterator, intervals_iterator, auto_restatments_iterator = itertools.tee(
518-
snapshot_chunk, 3
519-
)
520-
overwrite_existing_snapshots = (
521-
not clear
522-
) # if clear=True, all existing snapshots were dropped anyway
523-
self.snapshot_state.push_snapshots(
524-
snapshot_iterator, overwrite=overwrite_existing_snapshots
525-
)
526-
self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator))
491+
for state_chunk in stream:
492+
if isinstance(state_chunk, VersionsChunk):
493+
# SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
494+
# is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
495+
# is that they dont contain any breaking changes
496+
incoming_versions = state_chunk.versions
497+
if (
498+
incoming_versions.minor_sqlmesh_version
499+
!= existing_versions.minor_sqlmesh_version
500+
):
501+
raise SQLMeshError(
502+
f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n"
503+
"Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
504+
)
527505

528-
auto_restatements.update(
529-
{
530-
s.name_version: s.next_auto_restatement_ts
531-
for s in auto_restatments_iterator
532-
if s.next_auto_restatement_ts
533-
}
534-
)
506+
if clear:
507+
self.reset(default_catalog=None)
508+
509+
if isinstance(state_chunk, SnapshotsChunk):
510+
auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {}
511+
512+
for snapshot_chunk in chunk_iterable(
513+
state_chunk, SnapshotState.SNAPSHOT_BATCH_SIZE
514+
):
515+
snapshot_chunk = list(snapshot_chunk)
516+
overwrite_existing_snapshots = (
517+
not clear
518+
) # if clear=True, all existing snapshots were dropped anyway
519+
self.snapshot_state.push_snapshots(
520+
snapshot_chunk, overwrite=overwrite_existing_snapshots
521+
)
522+
self.add_snapshots_intervals((s.snapshot_intervals for s in snapshot_chunk))
523+
524+
auto_restatements.update(
525+
{
526+
s.name_version: s.next_auto_restatement_ts
527+
for s in snapshot_chunk
528+
if s.next_auto_restatement_ts
529+
}
530+
)
535531

536-
for environment_with_statements in stream.environments:
537-
environment = environment_with_statements.environment
538-
self.environment_state.update_environment(environment)
539-
self.environment_state.update_environment_statements(
540-
environment.name, environment.plan_id, environment_with_statements.statements
541-
)
532+
self.update_auto_restatements(auto_restatements)
542533

543-
self.update_auto_restatements(auto_restatements)
534+
if isinstance(state_chunk, EnvironmentsChunk):
535+
for environment_with_statements in state_chunk:
536+
environment = environment_with_statements.environment
537+
self.environment_state.update_environment(environment)
538+
self.environment_state.update_environment_statements(
539+
environment.name,
540+
environment.plan_id,
541+
environment_with_statements.statements,
542+
)
544543

545544
def state_type(self) -> str:
546545
return self.engine_adapter.dialect

0 commit comments

Comments
 (0)