@@ -332,8 +332,9 @@ def batch_intervals(
332332 merged_intervals : SnapshotToIntervals ,
333333 deployability_index : t .Optional [DeployabilityIndex ],
334334 environment_naming_info : EnvironmentNamingInfo ,
335+ dag : t .Optional [DAG [SnapshotId ]] = None ,
335336 ) -> t .Dict [Snapshot , Intervals ]:
336- dag = snapshots_to_dag (merged_intervals )
337+ dag = dag or snapshots_to_dag (merged_intervals )
337338
338339 snapshot_intervals : t .Dict [SnapshotId , t .Tuple [Snapshot , t .List [Interval ]]] = {
339340 snapshot .snapshot_id : (
@@ -413,6 +414,7 @@ def run_merged_intervals(
413414 start : t .Optional [TimeLike ] = None ,
414415 end : t .Optional [TimeLike ] = None ,
415416 allow_destructive_snapshots : t .Optional [t .Set [str ]] = None ,
417+ selected_snapshot_ids : t .Optional [t .Set [SnapshotId ]] = None ,
416418 run_environment_statements : bool = False ,
417419 audit_only : bool = False ,
418420 ) -> t .Tuple [t .List [NodeExecutionFailedError [SchedulingUnit ]], t .List [SchedulingUnit ]]:
@@ -427,14 +429,21 @@ def run_merged_intervals(
427429 start: The start of the run.
428430 end: The end of the run.
429431 allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed.
432+ selected_snapshot_ids: The snapshots to include in the run DAG. If None, all snapshots with missing intervals will be included.
430433
431434 Returns:
432435 A tuple of errors and skipped intervals.
433436 """
434437 execution_time = execution_time or now_timestamp ()
435438
439+ selected_snapshots = [self .snapshots [sid ] for sid in (selected_snapshot_ids or set ())]
440+ if not selected_snapshots :
441+ selected_snapshots = list (merged_intervals )
442+
443+ snapshot_dag = snapshots_to_dag (selected_snapshots )
444+
436445 batched_intervals = self .batch_intervals (
437- merged_intervals , deployability_index , environment_naming_info
446+ merged_intervals , deployability_index , environment_naming_info , dag = snapshot_dag
438447 )
439448
440449 self .console .start_evaluation_progress (
@@ -447,11 +456,13 @@ def run_merged_intervals(
447456 snapshots_to_create = {
448457 s .snapshot_id
449458 for s in self .snapshot_evaluator .get_snapshots_to_create (
450- merged_intervals . keys () , deployability_index
459+ selected_snapshots , deployability_index
451460 )
452461 }
453462
454- dag = self ._dag (batched_intervals , snapshots_to_create = snapshots_to_create )
463+ dag = self ._dag (
464+ batched_intervals , snapshot_dag = snapshot_dag , snapshots_to_create = snapshots_to_create
465+ )
455466
456467 if run_environment_statements :
457468 environment_statements = self .state_sync .get_environment_statements (
@@ -575,12 +586,14 @@ def evaluate_node(node: SchedulingUnit) -> None:
575586 def _dag (
576587 self ,
577588 batches : SnapshotToIntervals ,
589+ snapshot_dag : t .Optional [DAG [SnapshotId ]] = None ,
578590 snapshots_to_create : t .Optional [t .Set [SnapshotId ]] = None ,
579591 ) -> DAG [SchedulingUnit ]:
580592 """Builds a DAG of snapshot intervals to be evaluated.
581593
582594 Args:
583595 batches: The batches of snapshots and intervals to evaluate.
596+ snapshot_dag: The DAG of all snapshots.
584597 snapshots_to_create: The snapshots with missing physical tables.
585598
586599 Returns:
@@ -591,20 +604,24 @@ def _dag(
591604 snapshot .name : intervals for snapshot , intervals in batches .items ()
592605 }
593606 snapshots_to_create = snapshots_to_create or set ()
607+ original_snapshots_to_create = snapshots_to_create .copy ()
594608
609+ snapshot_dag = snapshot_dag or snapshots_to_dag (batches )
595610 dag = DAG [SchedulingUnit ]()
596611
597- for snapshot , intervals in batches . items () :
598- if not intervals :
599- continue
612+ for snapshot_id in snapshot_dag :
613+ snapshot = self . snapshots_by_name [ snapshot_id . name ]
614+ intervals = intervals_per_snapshot . get ( snapshot . name , [])
600615
601616 upstream_dependencies : t .List [SchedulingUnit ] = []
602617
603618 for p_sid in snapshot .parents :
604619 if p_sid in self .snapshots :
605620 p_intervals = intervals_per_snapshot .get (p_sid .name , [])
606621
607- if len (p_intervals ) > 1 :
622+ if not p_intervals and p_sid in original_snapshots_to_create :
623+ upstream_dependencies .append (CreateNode (snapshot_name = p_sid .name ))
624+ elif len (p_intervals ) > 1 :
608625 upstream_dependencies .append (DummyNode (snapshot_name = p_sid .name ))
609626 else :
610627 for i , interval in enumerate (p_intervals ):
@@ -620,14 +637,16 @@ def _dag(
620637 batch_concurrency = 1
621638
622639 create_node : t .Optional [CreateNode ] = None
623- if snapshot .snapshot_id in snapshots_to_create and (
640+ if snapshot .snapshot_id in original_snapshots_to_create and (
624641 snapshot .is_incremental_by_time_range
625642 or ((not batch_concurrency or batch_concurrency > 1 ) and batch_size )
643+ or not intervals
626644 ):
627645 # Add a separate node for table creation in case when there multiple concurrent
628- # evaluation nodes.
646+ # evaluation nodes or when there are no intervals to evaluate .
629647 create_node = CreateNode (snapshot_name = snapshot .name )
630648 dag .add (create_node , upstream_dependencies )
649+ snapshots_to_create .remove (snapshot .snapshot_id )
631650
632651 for i , interval in enumerate (intervals ):
633652 node = EvaluateNode (snapshot_name = snapshot .name , interval = interval , batch_index = i )
0 commit comments