11from __future__ import annotations
22
33import datetime
4+ import threading
45import typing as t
56import unittest
67from collections import Counter
7- from contextlib import AbstractContextManager , nullcontext
8+ from contextlib import nullcontext , contextmanager , AbstractContextManager
89from itertools import chain
910from pathlib import Path
1011from unittest .mock import patch
4647class ModelTest (unittest .TestCase ):
4748 __test__ = False
4849
50+ CONCURRENT_RENDER_LOCK = threading .Lock ()
51+
4952 def __init__ (
5053 self ,
5154 body : t .Dict [str , t .Any ],
@@ -57,6 +60,7 @@ def __init__(
5760 path : Path | None = None ,
5861 preserve_fixtures : bool = False ,
5962 default_catalog : str | None = None ,
63+ concurrency : bool = False ,
6064 ) -> None :
6165 """ModelTest encapsulates a unit test for a model.
6266
@@ -79,6 +83,7 @@ def __init__(
7983 self .preserve_fixtures = preserve_fixtures
8084 self .default_catalog = default_catalog
8185 self .dialect = dialect
86+ self .concurrency = concurrency
8287
8388 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8489 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -310,6 +315,7 @@ def create_test(
310315 path : Path | None ,
311316 preserve_fixtures : bool = False ,
312317 default_catalog : str | None = None ,
318+ concurrency : bool = False ,
313319 ) -> t .Optional [ModelTest ]:
314320 """Create a SqlModelTest or a PythonModelTest.
315321
@@ -353,6 +359,7 @@ def create_test(
353359 path ,
354360 preserve_fixtures ,
355361 default_catalog ,
362+ concurrency ,
356363 )
357364
358365 def __str__ (self ) -> str :
@@ -512,10 +519,34 @@ def _normalize_column_name(self, name: str) -> str:
512519
513520 return normalized_name
514521
515- def _execute (self , query : exp .Query ) -> pd .DataFrame :
522+ @contextmanager
523+ def _concurrent_render_context (self ) -> t .Iterator [None ]:
524+ """
525+ Context manager that ensures that the tests are executed safely in a concurrent environment.
526+ This is needed in case `execution_time` is set, as we'd then have to:
527+ - Freeze time through `time_machine` (not thread safe)
528+ - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
529+ """
530+ import time_machine
531+
532+ lock_ctx : AbstractContextManager = (
533+ self .CONCURRENT_RENDER_LOCK if self .concurrency else nullcontext ()
534+ )
535+ time_ctx : AbstractContextManager = nullcontext ()
536+ dialect_patch_ctx : AbstractContextManager = nullcontext ()
537+
538+ if self ._execution_time :
539+ time_ctx = time_machine .travel (self ._execution_time , tick = False )
540+ dialect_patch_ctx = patch .dict (
541+ self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms
542+ )
543+
544+ with lock_ctx , time_ctx , dialect_patch_ctx :
545+ yield
546+
547+ def _execute (self , query : exp .Query | str ) -> pd .DataFrame :
516548 """Executes the given query using the testing engine adapter and returns a DataFrame."""
517- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
518- return self .engine_adapter .fetchdf (query )
549+ return self .engine_adapter .fetchdf (query )
519550
520551 def _create_df (
521552 self ,
@@ -570,13 +601,25 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False)
570601 for alias , cte in ctes .items ():
571602 cte_query = cte_query .with_ (alias , cte .this , recursive = recursive )
572603
573- actual = self ._execute (cte_query )
604+ with self ._concurrent_render_context ():
605+ # Similar to the model's query, we render the CTE query under the locked context
606+ # so that the execution (fetchdf) can continue concurrently between the threads
607+ sql = cte_query .sql (
608+ self ._test_adapter_dialect , pretty = self .engine_adapter ._pretty_sql
609+ )
610+
611+ actual = self ._execute (sql )
574612 expected = self ._create_df (values , columns = cte_query .named_selects , partial = partial )
575613
576614 self .assert_equal (expected , actual , sort = sort , partial = partial )
577615
578616 def runTest (self ) -> None :
579- query = self ._render_model_query ()
617+ with self ._concurrent_render_context ():
618+ # Render the model's query and generate the SQL under the locked context so that
619+ # execution (fetchdf) can continue concurrently between the threads
620+ query = self ._render_model_query ()
621+ sql = query .sql (self ._test_adapter_dialect , pretty = self .engine_adapter ._pretty_sql )
622+
580623 with_clause = query .args .get ("with" )
581624
582625 if with_clause :
@@ -593,7 +636,7 @@ def runTest(self) -> None:
593636 partial = values .get ("partial" )
594637 sort = query .args .get ("order" ) is None
595638
596- actual = self ._execute (query )
639+ actual = self ._execute (sql )
597640 expected = self ._create_df (values , columns = self .model .columns_to_types , partial = partial )
598641
599642 self .assert_equal (expected , actual , sort = sort , partial = partial )
@@ -626,6 +669,7 @@ def __init__(
626669 path : Path | None = None ,
627670 preserve_fixtures : bool = False ,
628671 default_catalog : str | None = None ,
672+ concurrency : bool = False ,
629673 ) -> None :
630674 """PythonModelTest encapsulates a unit test for a Python model.
631675
@@ -651,6 +695,7 @@ def __init__(
651695 path ,
652696 preserve_fixtures ,
653697 default_catalog ,
698+ concurrency ,
654699 )
655700
656701 self .context = TestExecutionContext (
@@ -674,22 +719,13 @@ def runTest(self) -> None:
674719
675720 def _execute_model (self ) -> pd .DataFrame :
676721 """Executes the python model and returns a DataFrame."""
677- if self ._execution_time :
678- import time_machine
679-
680- time_ctx : AbstractContextManager = time_machine .travel (self ._execution_time , tick = False )
681- else :
682- time_ctx = nullcontext ()
722+ with self ._concurrent_render_context ():
723+ variables = self .body .get ("vars" , {}).copy ()
724+ time_kwargs = {key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables }
725+ df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
683726
684- with patch .dict (self ._test_adapter_dialect .generator_class .TRANSFORMS , self ._transforms ):
685- with time_ctx :
686- variables = self .body .get ("vars" , {}).copy ()
687- time_kwargs = {
688- key : variables .pop (key ) for key in TIME_KWARG_KEYS if key in variables
689- }
690- df = next (self .model .render (context = self .context , ** time_kwargs , ** variables ))
691- assert not isinstance (df , exp .Expression )
692- return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
727+ assert not isinstance (df , exp .Expression )
728+ return df if isinstance (df , pd .DataFrame ) else df .toPandas ()
693729
694730
695731def generate_test (
0 commit comments