Skip to content

Commit 100d76b

Browse files
committed
PR feedback
1 parent cf8ea2a commit 100d76b

1 file changed

Lines changed: 16 additions & 28 deletions

File tree

sqlmesh/core/state_sync/db/facade.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import contextlib
2020
import logging
2121
import typing as t
22-
import itertools
2322
from pathlib import Path
2423
from datetime import datetime
2524

@@ -450,7 +449,9 @@ def rollback(self) -> None:
450449

451450
@transactional()
452451
def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream:
453-
state_sync = self
452+
versions = self.get_versions(
453+
validate=True
454+
) # will throw if the state db hasnt been created or there is a version mismatch
454455

455456
snapshot_ids_to_export: t.Set[SnapshotId] = set()
456457
selected_environments: t.List[Environment] = []
@@ -460,35 +461,24 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
460461
if not environment:
461462
raise SQLMeshError(f"No such environment: {env_name}")
462463
selected_environments.append(environment)
464+
else:
465+
selected_environments = self.get_environments()
463466

464-
for env in selected_environments:
465-
snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []])
466-
467-
def _include_snapshot(s_id: SnapshotId) -> bool:
468-
if environment_names:
469-
return s_id in snapshot_ids_to_export
470-
return True
467+
for env in selected_environments:
468+
snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []])
471469

472470
def _export_snapshots() -> t.Iterator[Snapshot]:
473-
all_snapshot_ids = {
474-
s.snapshot_id
475-
for e in state_sync.get_environments()
476-
for s in e.snapshots
477-
if _include_snapshot(s.snapshot_id)
478-
}
479-
for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE):
480-
yield from state_sync.get_snapshots(chunk).values()
471+
for chunk in chunk_iterable(snapshot_ids_to_export, SnapshotState.SNAPSHOT_BATCH_SIZE):
472+
yield from self.get_snapshots(chunk).values()
481473

482474
def _export_environments() -> t.Iterator[EnvironmentWithStatements]:
483-
envs = selected_environments if environment_names else state_sync.get_environments()
484-
485-
for env in envs:
475+
for env in selected_environments:
486476
yield EnvironmentWithStatements(
487-
environment=env, statements=state_sync.get_environment_statements(env.name)
477+
environment=env, statements=self.get_environment_statements(env.name)
488478
)
489479

490480
return StateStream.from_iterators(
491-
versions=state_sync.get_versions(),
481+
versions=versions,
492482
snapshots=_export_snapshots(),
493483
environments=_export_environments(),
494484
)
@@ -521,21 +511,19 @@ def import_(self, stream: StateStream, clear: bool = True) -> None:
521511
for snapshot_chunk in chunk_iterable(
522512
state_chunk, SnapshotState.SNAPSHOT_BATCH_SIZE
523513
):
524-
snapshot_iterator, intervals_iterator, auto_restatments_iterator = (
525-
itertools.tee(snapshot_chunk, 3)
526-
)
514+
snapshot_chunk = list(snapshot_chunk)
527515
overwrite_existing_snapshots = (
528516
not clear
529517
) # if clear=True, all existing snapshots were dropped anyway
530518
self.snapshot_state.push_snapshots(
531-
snapshot_iterator, overwrite=overwrite_existing_snapshots
519+
snapshot_chunk, overwrite=overwrite_existing_snapshots
532520
)
533-
self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator))
521+
self.add_snapshots_intervals((s.snapshot_intervals for s in snapshot_chunk))
534522

535523
auto_restatements.update(
536524
{
537525
s.name_version: s.next_auto_restatement_ts
538-
for s in auto_restatments_iterator
526+
for s in snapshot_chunk
539527
if s.next_auto_restatement_ts
540528
}
541529
)

0 commit comments

Comments
 (0)