11from __future__ import annotations
22
33import datetime
4+ import threading
45import typing as t
56import unittest
67from collections import Counter
7- from contextlib import AbstractContextManager , nullcontext
8+ from contextlib import nullcontext , contextmanager , AbstractContextManager
89from itertools import chain
910from pathlib import Path
1011from unittest .mock import patch
4647class ModelTest (unittest .TestCase ):
4748 __test__ = False
4849
50+ CONCURRENT_RENDER_LOCK = threading .Lock ()
51+
4952 def __init__ (
5053 self ,
5154 body : t .Dict [str , t .Any ],
@@ -57,6 +60,7 @@ def __init__(
5760 path : Path | None = None ,
5861 preserve_fixtures : bool = False ,
5962 default_catalog : str | None = None ,
63+ concurrency : bool = False ,
6064 ) -> None :
6165 """ModelTest encapsulates a unit test for a model.
6266
@@ -79,6 +83,7 @@ def __init__(
7983 self .preserve_fixtures = preserve_fixtures
8084 self .default_catalog = default_catalog
8185 self .dialect = dialect
86+ self .concurrency = concurrency
8287
8388 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8489 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -310,6 +315,7 @@ def create_test(
310315 path : Path | None ,
311316 preserve_fixtures : bool = False ,
312317 default_catalog : str | None = None ,
318+ concurrency : bool = False ,
313319 ) -> t .Optional [ModelTest ]:
314320 """Create a SqlModelTest or a PythonModelTest.
315321
@@ -353,6 +359,7 @@ def create_test(
353359 path ,
354360 preserve_fixtures ,
355361 default_catalog ,
362+ concurrency ,
356363 )
357364
358365 def __str__ (self ) -> str :
@@ -512,10 +519,40 @@ def _normalize_column_name(self, name: str) -> str:
512519
513520 return normalized_name
514521
522+ @contextmanager
523+ def _concurrent_render_context (self ) -> t .Iterator [None ]:
524+ """
525+ Context manager that ensures that the tests are executed safely in a concurrent environment.
526+ This is needed in case `execution_time` is set, as we'd then have to:
527+ - Freeze time through `time_machine` (not thread safe)
528+ - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
529+ """
530+ import time_machine
531+
532+ lock_ctx : AbstractContextManager = (
533+ self .CONCURRENT_RENDER_LOCK
534+ if (self .concurrency and self ._execution_time )
535+ else nullcontext ()
536+ )
537+ time_ctx : AbstractContextManager = nullcontext ()
538+ dialect_patch_ctx : AbstractContextManager = nullcontext ()
539+
540+ if self ._execution_time :
541+ time_ctx = time_machine .travel (self ._execution_time , tick = False )
542+ dialect_patch_ctx = patch .dict (
543+ self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms
544+ )
545+
546+ with lock_ctx , time_ctx , dialect_patch_ctx :
547+ yield
548+
515549 def _execute (self , query : exp .Query ) -> pd .DataFrame :
516550 """Executes the given query using the testing engine adapter and returns a DataFrame."""
517- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
518- return self .engine_adapter .fetchdf (query )
551+
552+ with self ._concurrent_render_context ():
553+ sql = query .sql (self ._test_adapter_dialect , pretty = self .engine_adapter ._pretty_sql )
554+
555+ return self .engine_adapter .fetchdf (sql )
519556
520557 def _create_df (
521558 self ,
@@ -626,6 +663,7 @@ def __init__(
626663 path : Path | None = None ,
627664 preserve_fixtures : bool = False ,
628665 default_catalog : str | None = None ,
666+ concurrency : bool = False ,
629667 ) -> None :
630668 """PythonModelTest encapsulates a unit test for a Python model.
631669
@@ -651,6 +689,7 @@ def __init__(
651689 path ,
652690 preserve_fixtures ,
653691 default_catalog ,
692+ concurrency ,
654693 )
655694
656695 self .context = TestExecutionContext (
@@ -674,22 +713,13 @@ def runTest(self) -> None:
674713
675714 def _execute_model (self ) -> pd .DataFrame :
676715 """Executes the python model and returns a DataFrame."""
677- if self ._execution_time :
678- import time_machine
679-
680- time_ctx : AbstractContextManager = time_machine .travel (self ._execution_time , tick = False )
681- else :
682- time_ctx = nullcontext ()
716+ with self ._concurrent_render_context ():
717+ variables = self .body .get ("vars" , {}).copy ()
718+ time_kwargs = {key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables }
719+ df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
683720
684- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
685- with time_ctx :
686- variables = self .body .get ("vars" , {}).copy ()
687- time_kwargs = {
688- key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables
689- }
690- df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
691- assert not isinstance (df , exp .Expression )
692- return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
721+ assert not isinstance (df , exp .Expression )
722+ return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
693723
694724
695725def generate_test (
0 commit comments