|
48 | 48 | Versions, |
49 | 49 | ) |
50 | 50 | from sqlmesh.core.state_sync.common import ( |
| 51 | + EnvironmentsChunk, |
| 52 | + SnapshotsChunk, |
| 53 | + VersionsChunk, |
51 | 54 | transactional, |
52 | 55 | StateStream, |
53 | 56 | chunk_iterable, |
@@ -466,80 +469,88 @@ def _include_snapshot(s_id: SnapshotId) -> bool: |
466 | 469 | return s_id in snapshot_ids_to_export |
467 | 470 | return True |
468 | 471 |
|
469 | | - class _DumpStateStream(StateStream): |
470 | | - @property |
471 | | - def versions(self) -> Versions: |
472 | | - return state_sync.get_versions() |
473 | | - |
474 | | - @property |
475 | | - def snapshots(self) -> t.Iterable[Snapshot]: |
476 | | - all_snapshot_ids = { |
477 | | - s.snapshot_id |
478 | | - for e in state_sync.get_environments() |
479 | | - for s in e.snapshots |
480 | | - if _include_snapshot(s.snapshot_id) |
481 | | - } |
482 | | - for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE): |
483 | | - yield from state_sync.get_snapshots(chunk).values() |
| 472 | + def _export_snapshots() -> t.Iterator[Snapshot]: |
| 473 | + all_snapshot_ids = { |
| 474 | + s.snapshot_id |
| 475 | + for e in state_sync.get_environments() |
| 476 | + for s in e.snapshots |
| 477 | + if _include_snapshot(s.snapshot_id) |
| 478 | + } |
| 479 | + for chunk in chunk_iterable(all_snapshot_ids, SnapshotState.SNAPSHOT_BATCH_SIZE): |
| 480 | + yield from state_sync.get_snapshots(chunk).values() |
484 | 481 |
|
485 | | - @property |
486 | | - def environments(self) -> t.Iterable[EnvironmentWithStatements]: |
487 | | - envs = selected_environments if environment_names else state_sync.get_environments() |
| 482 | + def _export_environments() -> t.Iterator[EnvironmentWithStatements]: |
| 483 | + envs = selected_environments if environment_names else state_sync.get_environments() |
488 | 484 |
|
489 | | - for env in envs: |
490 | | - yield EnvironmentWithStatements( |
491 | | - environment=env, statements=state_sync.get_environment_statements(env.name) |
492 | | - ) |
| 485 | + for env in envs: |
| 486 | + yield EnvironmentWithStatements( |
| 487 | + environment=env, statements=state_sync.get_environment_statements(env.name) |
| 488 | + ) |
493 | 489 |
|
494 | | - return _DumpStateStream() |
| 490 | + return StateStream.from_iterators( |
| 491 | + versions=state_sync.get_versions(), |
| 492 | + snapshots=_export_snapshots(), |
| 493 | + environments=_export_environments(), |
| 494 | + ) |
495 | 495 |
|
496 | 496 | @transactional() |
497 | 497 | def import_(self, stream: StateStream, clear: bool = True) -> None: |
498 | 498 | existing_versions = self.get_versions() |
499 | 499 |
|
500 | | - # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file |
501 | | - # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption |
502 | | - # is that they dont contain any breaking changes |
503 | | - incoming_versions = stream.versions |
504 | | - if incoming_versions.minor_sqlmesh_version != existing_versions.minor_sqlmesh_version: |
505 | | - raise SQLMeshError( |
506 | | - f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n" |
507 | | - "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import." |
508 | | - ) |
| 500 | + for state_chunk in stream: |
| 501 | + if isinstance(state_chunk, VersionsChunk): |
| 502 | + # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file |
| 503 | + # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption |
| 504 | + # is that they dont contain any breaking changes |
| 505 | + incoming_versions = state_chunk.versions |
| 506 | + if ( |
| 507 | + incoming_versions.minor_sqlmesh_version |
| 508 | + != existing_versions.minor_sqlmesh_version |
| 509 | + ): |
| 510 | + raise SQLMeshError( |
| 511 | + f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n" |
| 512 | + "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import." |
| 513 | + ) |
509 | 514 |
|
510 | | - if clear: |
511 | | - self.reset(default_catalog=None) |
| 515 | + if clear: |
| 516 | + self.reset(default_catalog=None) |
512 | 517 |
|
513 | | - auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {} |
| 518 | + if isinstance(state_chunk, SnapshotsChunk): |
| 519 | + auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {} |
514 | 520 |
|
515 | | - for snapshot_chunk in chunk_iterable(stream.snapshots, SnapshotState.SNAPSHOT_BATCH_SIZE): |
516 | | - snapshot_iterator, intervals_iterator, auto_restatments_iterator = itertools.tee( |
517 | | - snapshot_chunk, 3 |
518 | | - ) |
519 | | - overwrite_existing_snapshots = ( |
520 | | - not clear |
521 | | - ) # if clear=True, all existing snapshots were dropped anyway |
522 | | - self.snapshot_state.push_snapshots( |
523 | | - snapshot_iterator, overwrite=overwrite_existing_snapshots |
524 | | - ) |
525 | | - self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator)) |
526 | | - |
527 | | - auto_restatements.update( |
528 | | - { |
529 | | - s.name_version: s.next_auto_restatement_ts |
530 | | - for s in auto_restatments_iterator |
531 | | - if s.next_auto_restatement_ts |
532 | | - } |
533 | | - ) |
| 521 | + for snapshot_chunk in chunk_iterable( |
| 522 | + state_chunk, SnapshotState.SNAPSHOT_BATCH_SIZE |
| 523 | + ): |
| 524 | + snapshot_iterator, intervals_iterator, auto_restatments_iterator = ( |
| 525 | + itertools.tee(snapshot_chunk, 3) |
| 526 | + ) |
| 527 | + overwrite_existing_snapshots = ( |
| 528 | + not clear |
| 529 | + ) # if clear=True, all existing snapshots were dropped anyway |
| 530 | + self.snapshot_state.push_snapshots( |
| 531 | + snapshot_iterator, overwrite=overwrite_existing_snapshots |
| 532 | + ) |
| 533 | + self.add_snapshots_intervals((s.snapshot_intervals for s in intervals_iterator)) |
| 534 | + |
| 535 | + auto_restatements.update( |
| 536 | + { |
| 537 | + s.name_version: s.next_auto_restatement_ts |
| 538 | + for s in auto_restatments_iterator |
| 539 | + if s.next_auto_restatement_ts |
| 540 | + } |
| 541 | + ) |
534 | 542 |
|
535 | | - for environment_with_statements in stream.environments: |
536 | | - environment = environment_with_statements.environment |
537 | | - self.environment_state.update_environment(environment) |
538 | | - self.environment_state.update_environment_statements( |
539 | | - environment.name, environment.plan_id, environment_with_statements.statements |
540 | | - ) |
| 543 | + self.update_auto_restatements(auto_restatements) |
541 | 544 |
|
542 | | - self.update_auto_restatements(auto_restatements) |
| 545 | + if isinstance(state_chunk, EnvironmentsChunk): |
| 546 | + for environment_with_statements in state_chunk: |
| 547 | + environment = environment_with_statements.environment |
| 548 | + self.environment_state.update_environment(environment) |
| 549 | + self.environment_state.update_environment_statements( |
| 550 | + environment.name, |
| 551 | + environment.plan_id, |
| 552 | + environment_with_statements.statements, |
| 553 | + ) |
543 | 554 |
|
544 | 555 | def state_type(self) -> str: |
545 | 556 | return self.engine_adapter.dialect |
|
0 commit comments