11from __future__ import annotations
2+ from dataclasses import dataclass
23import logging
34import typing as t
45import time
3940from sqlmesh .utils .date import (
4041 TimeLike ,
4142 now_timestamp ,
42- to_timestamp ,
4343 validate_date_range ,
4444)
4545from sqlmesh .utils .errors import (
5555
5656logger = logging .getLogger (__name__ )
5757SnapshotToIntervals = t .Dict [Snapshot , Intervals ]
58- # we store snapshot name instead of snapshots/snapshotids because pydantic
59- # is extremely slow to hash. snapshot names should be unique within a dag run
60- SchedulingUnit = t .Tuple [str , t .Tuple [Interval , int ]]
58+
59+
60+ class BaseNode :
61+ snapshot_name : str
62+
63+ def __lt__ (self , other : BaseNode ) -> bool :
64+ return (self .__class__ .__name__ , self .snapshot_name ) < (
65+ other .__class__ .__name__ ,
66+ other .snapshot_name ,
67+ )
68+
69+
70+ @dataclass (frozen = True )
71+ class EvaluateNode (BaseNode ):
72+ snapshot_name : str
73+ interval : Interval
74+ batch_index : int
75+
76+ def __lt__ (self , other : BaseNode ) -> bool :
77+ if not isinstance (other , EvaluateNode ):
78+ return super ().__lt__ (other )
79+ return (self .__class__ .__name__ , self .snapshot_name , self .interval , self .batch_index ) < (
80+ other .__class__ .__name__ ,
81+ other .snapshot_name ,
82+ other .interval ,
83+ other .batch_index ,
84+ )
85+
86+
87+ @dataclass (frozen = True )
88+ class CreateNode (BaseNode ):
89+ snapshot_name : str
90+
91+
92+ @dataclass (frozen = True )
93+ class DummyNode (BaseNode ):
94+ snapshot_name : str
95+
96+
97+ SchedulingUnit = t .Union [EvaluateNode , CreateNode , DummyNode ]
6198
6299
63100class Scheduler :
@@ -162,6 +199,7 @@ def evaluate(
162199 batch_index : int ,
163200 environment_naming_info : t .Optional [EnvironmentNamingInfo ] = None ,
164201 allow_destructive_snapshots : t .Optional [t .Set [str ]] = None ,
202+ target_table_exists : t .Optional [bool ] = None ,
165203 ** kwargs : t .Any ,
166204 ) -> t .List [AuditResult ]:
167205 """Evaluate a snapshot and add the processed interval to the state sync.
@@ -175,6 +213,7 @@ def evaluate(
175213 deployability_index: Determines snapshots that are deployable in the context of this evaluation.
176214 batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it
177215 auto_restatement_enabled: Whether to enable auto restatements.
216+ target_table_exists: Whether the target table exists. If None, the table will be checked for existence.
178217 kwargs: Additional kwargs to pass to the renderer.
179218
180219 Returns:
@@ -195,6 +234,7 @@ def evaluate(
195234 allow_destructive_snapshots = allow_destructive_snapshots ,
196235 deployability_index = deployability_index ,
197236 batch_index = batch_index ,
237+ target_table_exists = target_table_exists ,
198238 ** kwargs ,
199239 )
200240 audit_results = self ._audit_snapshot (
@@ -404,7 +444,14 @@ def run_merged_intervals(
404444 audit_only = audit_only ,
405445 )
406446
407- dag = self ._dag (batched_intervals )
447+ snapshots_to_create = {
448+ s .snapshot_id
449+ for s in self .snapshot_evaluator .get_snapshots_to_create (
450+ merged_intervals .keys (), deployability_index
451+ )
452+ }
453+
454+ dag = self ._dag (batched_intervals , snapshots_to_create = snapshots_to_create )
408455
409456 if run_environment_statements :
410457 environment_statements = self .state_sync .get_environment_statements (
@@ -425,55 +472,63 @@ def run_merged_intervals(
425472 def evaluate_node (node : SchedulingUnit ) -> None :
426473 if circuit_breaker and circuit_breaker ():
427474 raise CircuitBreakerError ()
428-
429- snapshot_name , ((start , end ), batch_idx ) = node
430- if batch_idx == - 1 :
475+ if isinstance (node , DummyNode ):
431476 return
432- snapshot = self .snapshots_by_name [snapshot_name ]
433-
434- self .console .start_snapshot_evaluation_progress (snapshot )
435-
436- execution_start_ts = now_timestamp ()
437- evaluation_duration_ms : t .Optional [int ] = None
438477
439- audit_results : t .List [AuditResult ] = []
440- try :
441- assert execution_time # mypy
442- assert deployability_index # mypy
443-
444- if audit_only :
445- audit_results = self ._audit_snapshot (
446- snapshot = snapshot ,
447- environment_naming_info = environment_naming_info ,
448- deployability_index = deployability_index ,
449- snapshots = self .snapshots_by_name ,
450- start = start ,
451- end = end ,
452- execution_time = execution_time ,
453- )
454- else :
455- audit_results = self .evaluate (
456- snapshot = snapshot ,
457- environment_naming_info = environment_naming_info ,
458- start = start ,
459- end = end ,
460- execution_time = execution_time ,
461- deployability_index = deployability_index ,
462- batch_index = batch_idx ,
463- allow_destructive_snapshots = allow_destructive_snapshots ,
478+ snapshot = self .snapshots_by_name [node .snapshot_name ]
479+
480+ if isinstance (node , EvaluateNode ):
481+ self .console .start_snapshot_evaluation_progress (snapshot )
482+ execution_start_ts = now_timestamp ()
483+ evaluation_duration_ms : t .Optional [int ] = None
484+ start , end = node .interval
485+
486+ audit_results : t .List [AuditResult ] = []
487+ try :
488+ assert execution_time # mypy
489+ assert deployability_index # mypy
490+
491+ if audit_only :
492+ audit_results = self ._audit_snapshot (
493+ snapshot = snapshot ,
494+ environment_naming_info = environment_naming_info ,
495+ deployability_index = deployability_index ,
496+ snapshots = self .snapshots_by_name ,
497+ start = start ,
498+ end = end ,
499+ execution_time = execution_time ,
500+ )
501+ else :
502+ audit_results = self .evaluate (
503+ snapshot = snapshot ,
504+ environment_naming_info = environment_naming_info ,
505+ start = start ,
506+ end = end ,
507+ execution_time = execution_time ,
508+ deployability_index = deployability_index ,
509+ batch_index = node .batch_index ,
510+ allow_destructive_snapshots = allow_destructive_snapshots ,
511+ target_table_exists = snapshot .snapshot_id not in snapshots_to_create ,
512+ )
513+
514+ evaluation_duration_ms = now_timestamp () - execution_start_ts
515+ finally :
516+ num_audits = len (audit_results )
517+ num_audits_failed = sum (1 for result in audit_results if result .count )
518+ self .console .update_snapshot_evaluation_progress (
519+ snapshot ,
520+ batched_intervals [snapshot ][node .batch_index ],
521+ node .batch_index ,
522+ evaluation_duration_ms ,
523+ num_audits - num_audits_failed ,
524+ num_audits_failed ,
464525 )
465-
466- evaluation_duration_ms = now_timestamp () - execution_start_ts
467- finally :
468- num_audits = len (audit_results )
469- num_audits_failed = sum (1 for result in audit_results if result .count )
470- self .console .update_snapshot_evaluation_progress (
471- snapshot ,
472- batched_intervals [snapshot ][batch_idx ],
473- batch_idx ,
474- evaluation_duration_ms ,
475- num_audits - num_audits_failed ,
476- num_audits_failed ,
526+ elif isinstance (node , CreateNode ):
527+ self .snapshot_evaluator .create_snapshot (
528+ snapshot = snapshot ,
529+ snapshots = self .snapshots_by_name ,
530+ deployability_index = deployability_index ,
531+ allow_destructive_snapshots = allow_destructive_snapshots or set (),
477532 )
478533
479534 try :
@@ -486,7 +541,9 @@ def evaluate_node(node: SchedulingUnit) -> None:
486541 )
487542 self .console .stop_evaluation_progress (success = not errors )
488543
489- skipped_snapshots = {i [0 ] for i in skipped_intervals }
544+ skipped_snapshots = {
545+ i .snapshot_name for i in skipped_intervals if isinstance (i , EvaluateNode )
546+ }
490547 self .console .log_skipped_models (skipped_snapshots )
491548 for skipped in skipped_snapshots :
492549 logger .info (f"SKIPPED snapshot { skipped } \n " )
@@ -515,11 +572,16 @@ def evaluate_node(node: SchedulingUnit) -> None:
515572
516573 self .state_sync .recycle ()
517574
518- def _dag (self , batches : SnapshotToIntervals ) -> DAG [SchedulingUnit ]:
575+ def _dag (
576+ self ,
577+ batches : SnapshotToIntervals ,
578+ snapshots_to_create : t .Optional [t .Set [SnapshotId ]] = None ,
579+ ) -> DAG [SchedulingUnit ]:
519580 """Builds a DAG of snapshot intervals to be evaluated.
520581
521582 Args:
522583 batches: The batches of snapshots and intervals to evaluate.
584+ snapshots_to_create: The snapshots with missing physical tables.
523585
524586 Returns:
525587 A DAG of snapshot intervals to be evaluated.
@@ -528,46 +590,64 @@ def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
528590 intervals_per_snapshot = {
529591 snapshot .name : intervals for snapshot , intervals in batches .items ()
530592 }
593+ snapshots_to_create = snapshots_to_create or set ()
531594
532595 dag = DAG [SchedulingUnit ]()
533- terminal_node = ((to_timestamp (0 ), to_timestamp (0 )), - 1 )
534596
535597 for snapshot , intervals in batches .items ():
536598 if not intervals :
537599 continue
538600
539- upstream_dependencies = []
601+ upstream_dependencies : t . List [ SchedulingUnit ] = []
540602
541603 for p_sid in snapshot .parents :
542604 if p_sid in self .snapshots :
543605 p_intervals = intervals_per_snapshot .get (p_sid .name , [])
544606
545607 if len (p_intervals ) > 1 :
546- upstream_dependencies .append (( p_sid .name , terminal_node ))
608+ upstream_dependencies .append (DummyNode ( snapshot_name = p_sid .name ))
547609 else :
548610 for i , interval in enumerate (p_intervals ):
549- upstream_dependencies .append ((p_sid .name , (interval , i )))
611+ upstream_dependencies .append (
612+ EvaluateNode (
613+ snapshot_name = p_sid .name , interval = interval , batch_index = i
614+ )
615+ )
550616
551617 batch_concurrency = snapshot .node .batch_concurrency
552618 if snapshot .depends_on_past :
553619 batch_concurrency = 1
554620
621+ create_node : t .Optional [CreateNode ] = None
622+ if (
623+ batch_concurrency
624+ and batch_concurrency > 1
625+ and snapshot .snapshot_id in snapshots_to_create
626+ ):
627+ # Add a separate node for table creation in case when there multiple concurrent
628+ # evaluation nodes.
629+ create_node = CreateNode (snapshot_name = snapshot .name )
630+
555631 for i , interval in enumerate (intervals ):
556- node = ( snapshot .name , ( interval , i ) )
632+ node = EvaluateNode ( snapshot_name = snapshot .name , interval = interval , batch_index = i )
557633 dag .add (node , upstream_dependencies )
558634
559635 if len (intervals ) > 1 :
560- dag .add ((snapshot .name , terminal_node ), [node ])
636+ dag .add (DummyNode (snapshot_name = snapshot .name ), [node ])
637+
638+ if create_node :
639+ dag .add (node , [create_node ])
561640
562641 if batch_concurrency and i >= batch_concurrency :
563642 batch_idx_to_wait_for = i - batch_concurrency
564643 dag .add (
565644 node ,
566645 [
567- (
568- snapshot .name ,
569- (intervals [batch_idx_to_wait_for ], batch_idx_to_wait_for ),
570- )
646+ EvaluateNode (
647+ snapshot_name = snapshot .name ,
648+ interval = intervals [batch_idx_to_wait_for ],
649+ batch_index = batch_idx_to_wait_for ,
650+ ),
571651 ],
572652 )
573653 return dag
0 commit comments