11from __future__ import annotations
22
33import datetime
4+ import threading
45import typing as t
56import unittest
67from collections import Counter
@@ -57,6 +58,7 @@ def __init__(
5758 path : Path | None = None ,
5859 preserve_fixtures : bool = False ,
5960 default_catalog : str | None = None ,
61+ lock : t .Optional [threading .Lock ] = None ,
6062 ) -> None :
6163 """ModelTest encapsulates a unit test for a model.
6264
@@ -79,6 +81,7 @@ def __init__(
7981 self .preserve_fixtures = preserve_fixtures
8082 self .default_catalog = default_catalog
8183 self .dialect = dialect
84+ self .lock = lock
8285
8386 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8487 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -102,6 +105,7 @@ def __init__(
102105 )
103106 self ._qualified_fixture_schema = schema_ (self ._fixture_schema , self ._fixture_catalog )
104107
108+ self ._exec_time_transforms : t .Dict [type [exp .Expression ], exp .Expression ] = {}
105109 self ._transforms = self ._test_adapter_dialect .generator_class .TRANSFORMS
106110 self ._execution_time = str (self .body .get ("vars" , {}).get ("execution_time" ) or "" )
107111
@@ -112,20 +116,20 @@ def __init__(
112116 # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
113117 if self ._execution_time :
114118 exec_time = exp .Literal .string (self ._execution_time )
119+
120+ self ._exec_time_transforms = {
121+ exp .CurrentDate : exp .cast (exec_time , "date" , dialect = dialect ),
122+ exp .CurrentDatetime : exp .cast (exec_time , "datetime" , dialect = dialect ),
123+ exp .CurrentTime : exp .cast (exec_time , "time" , dialect = dialect ),
124+ exp .CurrentTimestamp : exp .cast (exec_time , "timestamp" , dialect = dialect ),
125+ }
126+
115127 self ._transforms = {
116128 ** self ._transforms ,
117- exp .CurrentDate : lambda self , _ : self .sql (
118- exp .cast (exec_time , "date" , dialect = dialect )
119- ),
120- exp .CurrentDatetime : lambda self , _ : self .sql (
121- exp .cast (exec_time , "datetime" , dialect = dialect )
122- ),
123- exp .CurrentTime : lambda self , _ : self .sql (
124- exp .cast (exec_time , "time" , dialect = dialect )
125- ),
126- exp .CurrentTimestamp : lambda self , _ : self .sql (
127- exp .cast (exec_time , "timestamp" , dialect = dialect )
128- ),
129+ ** {
130+ key : lambda self , _ : self .sql (value )
131+ for key , value in self ._exec_time_transforms .items ()
132+ },
129133 }
130134
131135 super ().__init__ ()
@@ -310,6 +314,7 @@ def create_test(
310314 path : Path | None ,
311315 preserve_fixtures : bool = False ,
312316 default_catalog : str | None = None ,
317+ lock : t .Optional [threading .Lock ] = None ,
313318 ) -> t .Optional [ModelTest ]:
314319 """Create a SqlModelTest or a PythonModelTest.
315320
@@ -353,6 +358,7 @@ def create_test(
353358 path ,
354359 preserve_fixtures ,
355360 default_catalog ,
361+ lock = lock ,
356362 )
357363
358364 def __str__ (self ) -> str :
@@ -514,8 +520,13 @@ def _normalize_column_name(self, name: str) -> str:
514520
515521 def _execute (self , query : exp .Query ) -> pd .DataFrame :
516522 """Executes the given query using the testing engine adapter and returns a DataFrame."""
517- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
518- return self .engine_adapter .fetchdf (query )
523+
524+ def replace_execution_time (expression : exp .Expression ) -> exp .Expression :
525+ return self ._exec_time_transforms .get (type (expression ), expression )
526+
527+ return self .engine_adapter .fetchdf (
528+ query .transform (replace_execution_time ) if self ._execution_time else query
529+ )
519530
520531 def _create_df (
521532 self ,
@@ -626,6 +637,7 @@ def __init__(
626637 path : Path | None = None ,
627638 preserve_fixtures : bool = False ,
628639 default_catalog : str | None = None ,
640+ lock : t .Optional [threading .Lock ] = None ,
629641 ) -> None :
630642 """PythonModelTest encapsulates a unit test for a Python model.
631643
@@ -651,6 +663,7 @@ def __init__(
651663 path ,
652664 preserve_fixtures ,
653665 default_catalog ,
666+ lock ,
654667 )
655668
656669 self .context = TestExecutionContext (
@@ -681,15 +694,18 @@ def _execute_model(self) -> pd.DataFrame:
681694 else :
682695 time_ctx = nullcontext ()
683696
684- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
685- with time_ctx :
686- variables = self .body .get ("vars" , {}).copy ()
687- time_kwargs = {
688- key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables
689- }
690- df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
691- assert not isinstance (df , exp .Expression )
692- return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
697+ with self .lock or nullcontext ():
698+ with patch .dict (
699+ self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms
700+ ):
701+ with time_ctx :
702+ variables = self .body .get ("vars" , {}).copy ()
703+ time_kwargs = {
704+ key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables
705+ }
706+ df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
707+ assert not isinstance (df , exp .Expression )
708+ return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
693709
694710
695711def generate_test (
0 commit comments