1919import contextlib
2020import logging
2121import typing as t
22- import itertools
2322from pathlib import Path
2423from datetime import datetime
2524
4847 Versions ,
4948)
5049from sqlmesh .core .state_sync .common import (
50+ EnvironmentsChunk ,
51+ SnapshotsChunk ,
52+ VersionsChunk ,
5153 transactional ,
5254 StateStream ,
5355 chunk_iterable ,
@@ -448,7 +450,9 @@ def rollback(self) -> None:
448450
449451 @transactional ()
450452 def export (self , environment_names : t .Optional [t .List [str ]] = None ) -> StateStream :
451- state_sync = self
453+ versions = self .get_versions (
454+ validate = True
455+ ) # will throw if the state db hasnt been created or there is a version mismatch
452456
453457 snapshot_ids_to_export : t .Set [SnapshotId ] = set ()
454458 selected_environments : t .List [Environment ] = []
@@ -458,89 +462,84 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
458462 if not environment :
459463 raise SQLMeshError (f"No such environment: { env_name } " )
460464 selected_environments .append (environment )
465+ else :
466+ selected_environments = self .get_environments ()
461467
462- for env in selected_environments :
463- snapshot_ids_to_export |= set ([s .snapshot_id for s in env .snapshots or []])
464-
465- def _include_snapshot (s_id : SnapshotId ) -> bool :
466- if environment_names :
467- return s_id in snapshot_ids_to_export
468- return True
469-
470- class _DumpStateStream (StateStream ):
471- @property
472- def versions (self ) -> Versions :
473- return state_sync .get_versions ()
474-
475- @property
476- def snapshots (self ) -> t .Iterable [Snapshot ]:
477- all_snapshot_ids = {
478- s .snapshot_id
479- for e in state_sync .get_environments ()
480- for s in e .snapshots
481- if _include_snapshot (s .snapshot_id )
482- }
483- for chunk in chunk_iterable (all_snapshot_ids , SnapshotState .SNAPSHOT_BATCH_SIZE ):
484- yield from state_sync .get_snapshots (chunk ).values ()
468+ for env in selected_environments :
469+ snapshot_ids_to_export |= set ([s .snapshot_id for s in env .snapshots or []])
485470
486- @ property
487- def environments ( self ) -> t . Iterable [ EnvironmentWithStatements ] :
488- envs = selected_environments if environment_names else state_sync . get_environments ()
471+ def _export_snapshots () -> t . Iterator [ Snapshot ]:
472+ for chunk in chunk_iterable ( snapshot_ids_to_export , SnapshotState . SNAPSHOT_BATCH_SIZE ) :
473+ yield from self . get_snapshots ( chunk ). values ()
489474
490- for env in envs :
491- yield EnvironmentWithStatements (
492- environment = env , statements = state_sync .get_environment_statements (env .name )
493- )
475+ def _export_environments () -> t .Iterator [EnvironmentWithStatements ]:
476+ for env in selected_environments :
477+ yield EnvironmentWithStatements (
478+ environment = env , statements = self .get_environment_statements (env .name )
479+ )
494480
495- return _DumpStateStream ()
481+ return StateStream .from_iterators (
482+ versions = versions ,
483+ snapshots = _export_snapshots (),
484+ environments = _export_environments (),
485+ )
496486
497487 @transactional ()
498488 def import_ (self , stream : StateStream , clear : bool = True ) -> None :
499489 existing_versions = self .get_versions ()
500490
501- # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
502- # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
503- # is that they dont contain any breaking changes
504- incoming_versions = stream .versions
505- if incoming_versions .minor_sqlmesh_version != existing_versions .minor_sqlmesh_version :
506- raise SQLMeshError (
507- f"SQLMesh version mismatch. You are running '{ existing_versions .sqlmesh_version } ' but the state file was created with '{ incoming_versions .sqlmesh_version } '.\n "
508- "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
509- )
510-
511- if clear :
512- self .reset (default_catalog = None )
513-
514- auto_restatements : t .Dict [SnapshotNameVersion , t .Optional [int ]] = {}
515-
516- for snapshot_chunk in chunk_iterable (stream .snapshots , SnapshotState .SNAPSHOT_BATCH_SIZE ):
517- snapshot_iterator , intervals_iterator , auto_restatments_iterator = itertools .tee (
518- snapshot_chunk , 3
519- )
520- overwrite_existing_snapshots = (
521- not clear
522- ) # if clear=True, all existing snapshots were dropped anyway
523- self .snapshot_state .push_snapshots (
524- snapshot_iterator , overwrite = overwrite_existing_snapshots
525- )
526- self .add_snapshots_intervals ((s .snapshot_intervals for s in intervals_iterator ))
491+ for state_chunk in stream :
492+ if isinstance (state_chunk , VersionsChunk ):
493+ # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file
494+ # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption
495+ # is that they dont contain any breaking changes
496+ incoming_versions = state_chunk .versions
497+ if (
498+ incoming_versions .minor_sqlmesh_version
499+ != existing_versions .minor_sqlmesh_version
500+ ):
501+ raise SQLMeshError (
502+ f"SQLMesh version mismatch. You are running '{ existing_versions .sqlmesh_version } ' but the state file was created with '{ incoming_versions .sqlmesh_version } '.\n "
503+ "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import."
504+ )
527505
528- auto_restatements .update (
529- {
530- s .name_version : s .next_auto_restatement_ts
531- for s in auto_restatments_iterator
532- if s .next_auto_restatement_ts
533- }
534- )
506+ if clear :
507+ self .reset (default_catalog = None )
508+
509+ if isinstance (state_chunk , SnapshotsChunk ):
510+ auto_restatements : t .Dict [SnapshotNameVersion , t .Optional [int ]] = {}
511+
512+ for snapshot_chunk in chunk_iterable (
513+ state_chunk , SnapshotState .SNAPSHOT_BATCH_SIZE
514+ ):
515+ snapshot_chunk = list (snapshot_chunk )
516+ overwrite_existing_snapshots = (
517+ not clear
518+ ) # if clear=True, all existing snapshots were dropped anyway
519+ self .snapshot_state .push_snapshots (
520+ snapshot_chunk , overwrite = overwrite_existing_snapshots
521+ )
522+ self .add_snapshots_intervals ((s .snapshot_intervals for s in snapshot_chunk ))
523+
524+ auto_restatements .update (
525+ {
526+ s .name_version : s .next_auto_restatement_ts
527+ for s in snapshot_chunk
528+ if s .next_auto_restatement_ts
529+ }
530+ )
535531
536- for environment_with_statements in stream .environments :
537- environment = environment_with_statements .environment
538- self .environment_state .update_environment (environment )
539- self .environment_state .update_environment_statements (
540- environment .name , environment .plan_id , environment_with_statements .statements
541- )
532+ self .update_auto_restatements (auto_restatements )
542533
543- self .update_auto_restatements (auto_restatements )
534+ if isinstance (state_chunk , EnvironmentsChunk ):
535+ for environment_with_statements in state_chunk :
536+ environment = environment_with_statements .environment
537+ self .environment_state .update_environment (environment )
538+ self .environment_state .update_environment_statements (
539+ environment .name ,
540+ environment .plan_id ,
541+ environment_with_statements .statements ,
542+ )
544543
545544 def state_type (self ) -> str :
546545 return self .engine_adapter .dialect
0 commit comments