1919import contextlib
2020import logging
2121import typing as t
22- import itertools
2322from pathlib import Path
2423from datetime import datetime
2524
@@ -450,7 +449,9 @@ def rollback(self) -> None:
450449
451450 @transactional ()
452451 def export (self , environment_names : t .Optional [t .List [str ]] = None ) -> StateStream :
453- state_sync = self
452+ versions = self .get_versions (
453+ validate = True
454+ ) # will throw if the state db hasnt been created or there is a version mismatch
454455
455456 snapshot_ids_to_export : t .Set [SnapshotId ] = set ()
456457 selected_environments : t .List [Environment ] = []
@@ -460,35 +461,24 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
460461 if not environment :
461462 raise SQLMeshError (f"No such environment: { env_name } " )
462463 selected_environments .append (environment )
464+ else :
465+ selected_environments = self .get_environments ()
463466
464- for env in selected_environments :
465- snapshot_ids_to_export |= set ([s .snapshot_id for s in env .snapshots or []])
466-
467- def _include_snapshot (s_id : SnapshotId ) -> bool :
468- if environment_names :
469- return s_id in snapshot_ids_to_export
470- return True
467+ for env in selected_environments :
468+ snapshot_ids_to_export |= set ([s .snapshot_id for s in env .snapshots or []])
471469
472470 def _export_snapshots () -> t .Iterator [Snapshot ]:
473- all_snapshot_ids = {
474- s .snapshot_id
475- for e in state_sync .get_environments ()
476- for s in e .snapshots
477- if _include_snapshot (s .snapshot_id )
478- }
479- for chunk in chunk_iterable (all_snapshot_ids , SnapshotState .SNAPSHOT_BATCH_SIZE ):
480- yield from state_sync .get_snapshots (chunk ).values ()
471+ for chunk in chunk_iterable (snapshot_ids_to_export , SnapshotState .SNAPSHOT_BATCH_SIZE ):
472+ yield from self .get_snapshots (chunk ).values ()
481473
482474 def _export_environments () -> t .Iterator [EnvironmentWithStatements ]:
483- envs = selected_environments if environment_names else state_sync .get_environments ()
484-
485- for env in envs :
475+ for env in selected_environments :
486476 yield EnvironmentWithStatements (
487- environment = env , statements = state_sync .get_environment_statements (env .name )
477+ environment = env , statements = self .get_environment_statements (env .name )
488478 )
489479
490480 return StateStream .from_iterators (
491- versions = state_sync . get_versions () ,
481+ versions = versions ,
492482 snapshots = _export_snapshots (),
493483 environments = _export_environments (),
494484 )
@@ -521,21 +511,19 @@ def import_(self, stream: StateStream, clear: bool = True) -> None:
521511 for snapshot_chunk in chunk_iterable (
522512 state_chunk , SnapshotState .SNAPSHOT_BATCH_SIZE
523513 ):
524- snapshot_iterator , intervals_iterator , auto_restatments_iterator = (
525- itertools .tee (snapshot_chunk , 3 )
526- )
514+ snapshot_chunk = list (snapshot_chunk )
527515 overwrite_existing_snapshots = (
528516 not clear
529517 ) # if clear=True, all existing snapshots were dropped anyway
530518 self .snapshot_state .push_snapshots (
531- snapshot_iterator , overwrite = overwrite_existing_snapshots
519+ snapshot_chunk , overwrite = overwrite_existing_snapshots
532520 )
533- self .add_snapshots_intervals ((s .snapshot_intervals for s in intervals_iterator ))
521+ self .add_snapshots_intervals ((s .snapshot_intervals for s in snapshot_chunk ))
534522
535523 auto_restatements .update (
536524 {
537525 s .name_version : s .next_auto_restatement_ts
538- for s in auto_restatments_iterator
526+ for s in snapshot_chunk
539527 if s .next_auto_restatement_ts
540528 }
541529 )
0 commit comments