Skip to content

Commit deca19c

Browse files
committed
Fix: Concurrent dialect patching in model testing
1 parent fb5c52f commit deca19c

4 files changed

Lines changed: 125 additions & 22 deletions

File tree

sqlmesh/core/test/definition.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import datetime
4+
import threading
45
import typing as t
56
import unittest
67
from collections import Counter
7-
from contextlib import AbstractContextManager, nullcontext
8+
from contextlib import nullcontext, contextmanager, AbstractContextManager
89
from itertools import chain
910
from pathlib import Path
1011
from unittest.mock import patch
@@ -46,6 +47,8 @@
4647
class ModelTest(unittest.TestCase):
4748
__test__ = False
4849

50+
CONCURRENT_RENDER_LOCK = threading.Lock()
51+
4952
def __init__(
5053
self,
5154
body: t.Dict[str, t.Any],
@@ -57,6 +60,7 @@ def __init__(
5760
path: Path | None = None,
5861
preserve_fixtures: bool = False,
5962
default_catalog: str | None = None,
63+
concurrency: bool = False,
6064
) -> None:
6165
"""ModelTest encapsulates a unit test for a model.
6266
@@ -79,6 +83,7 @@ def __init__(
7983
self.preserve_fixtures = preserve_fixtures
8084
self.default_catalog = default_catalog
8185
self.dialect = dialect
86+
self.concurrency = concurrency
8287

8388
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
8489
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -310,6 +315,7 @@ def create_test(
310315
path: Path | None,
311316
preserve_fixtures: bool = False,
312317
default_catalog: str | None = None,
318+
concurrency: bool = False,
313319
) -> t.Optional[ModelTest]:
314320
"""Create a SqlModelTest or a PythonModelTest.
315321
@@ -353,6 +359,7 @@ def create_test(
353359
path,
354360
preserve_fixtures,
355361
default_catalog,
362+
concurrency,
356363
)
357364

358365
def __str__(self) -> str:
@@ -512,10 +519,40 @@ def _normalize_column_name(self, name: str) -> str:
512519

513520
return normalized_name
514521

522+
@contextmanager
523+
def _concurrent_render_context(self) -> t.Iterator[None]:
524+
"""
525+
Context manager that ensures that the tests are executed safely in a concurrent environment.
526+
This is needed in case `execution_time` is set, as we'd then have to:
527+
- Freeze time through `time_machine` (not thread safe)
528+
- Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation
529+
"""
530+
import time_machine
531+
532+
lock_ctx: AbstractContextManager = (
533+
self.CONCURRENT_RENDER_LOCK
534+
if (self.concurrency and self._execution_time)
535+
else nullcontext()
536+
)
537+
time_ctx: AbstractContextManager = nullcontext()
538+
dialect_patch_ctx: AbstractContextManager = nullcontext()
539+
540+
if self._execution_time:
541+
time_ctx = time_machine.travel(self._execution_time, tick=False)
542+
dialect_patch_ctx = patch.dict(
543+
self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
544+
)
545+
546+
with lock_ctx, time_ctx, dialect_patch_ctx:
547+
yield
548+
515549
def _execute(self, query: exp.Query) -> pd.DataFrame:
516550
"""Executes the given query using the testing engine adapter and returns a DataFrame."""
517-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
518-
return self.engine_adapter.fetchdf(query)
551+
552+
with self._concurrent_render_context():
553+
sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql)
554+
555+
return self.engine_adapter.fetchdf(sql)
519556

520557
def _create_df(
521558
self,
@@ -626,6 +663,7 @@ def __init__(
626663
path: Path | None = None,
627664
preserve_fixtures: bool = False,
628665
default_catalog: str | None = None,
666+
concurrency: bool = False,
629667
) -> None:
630668
"""PythonModelTest encapsulates a unit test for a Python model.
631669
@@ -651,6 +689,7 @@ def __init__(
651689
path,
652690
preserve_fixtures,
653691
default_catalog,
692+
concurrency,
654693
)
655694

656695
self.context = TestExecutionContext(
@@ -674,22 +713,13 @@ def runTest(self) -> None:
674713

675714
def _execute_model(self) -> pd.DataFrame:
676715
"""Executes the python model and returns a DataFrame."""
677-
if self._execution_time:
678-
import time_machine
679-
680-
time_ctx: AbstractContextManager = time_machine.travel(self._execution_time, tick=False)
681-
else:
682-
time_ctx = nullcontext()
716+
with self._concurrent_render_context():
717+
variables = self.body.get("vars", {}).copy()
718+
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
719+
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
683720

684-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
685-
with time_ctx:
686-
variables = self.body.get("vars", {}).copy()
687-
time_kwargs = {
688-
key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables
689-
}
690-
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
691-
assert not isinstance(df, exp.Expression)
692-
return df if isinstance(df, pd.DataFrame) else df.toPandas()
721+
assert not isinstance(df, exp.Expression)
722+
return df if isinstance(df, pd.DataFrame) else df.toPandas()
693723

694724

695725
def generate_test(

sqlmesh/core/test/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def log_test_report(self, test_duration: float) -> None:
100100
for test_case, failure in failures:
101101
stream.writeln(unittest.TextTestResult.separator1)
102102
stream.writeln(f"FAIL: {test_case}")
103-
stream.writeln(f"{test_case.shortDescription()}")
103+
if test_description := test_case.shortDescription():
104+
stream.writeln(test_description)
104105
stream.writeln(unittest.TextTestResult.separator2)
105106
stream.writeln(failure)
106107

sqlmesh/core/test/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def run_tests(
120120
default_catalog_dialect=default_catalog_dialect,
121121
)
122122

123+
# Ensure workers are not greater than the number of tests
124+
num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
125+
123126
def _run_single_test(
124127
metadata: ModelTestMetadata, engine_adapter: EngineAdapter
125128
) -> t.Optional[ModelTextTestResult]:
@@ -132,6 +135,7 @@ def _run_single_test(
132135
path=metadata.path,
133136
default_catalog=default_catalog,
134137
preserve_fixtures=preserve_fixtures,
138+
concurrency=num_workers > 1,
135139
)
136140

137141
if not test:
@@ -159,9 +163,6 @@ def _run_single_test(
159163

160164
test_results = []
161165

162-
# Ensure workers are not greater than the number of tests
163-
num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
164-
165166
start_time = time.perf_counter()
166167
try:
167168
with ThreadPoolExecutor(max_workers=num_workers) as pool:

tests/core/test_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,3 +2370,74 @@ def test_number_of_tests_found(tmp_path: Path) -> None:
23702370
# Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest
23712371
results = context.test(tests=[f"{test_file}::test_example_full_model2"])
23722372
assert len(results.successes) == 1
2373+
2374+
2375+
def test_freeze_time_concurrent(tmp_path: Path) -> None:
2376+
tests_dir = tmp_path / "tests"
2377+
tests_dir.mkdir()
2378+
2379+
for model_name in ["sql_model", "py_model"]:
2380+
for i in range(5):
2381+
test_2019 = tmp_path / "tests" / f"test_2019_{model_name}_{i}.yaml"
2382+
test_2019.write_text(
2383+
f"""
2384+
test_2019_{model_name}_{i}:
2385+
model: {model_name}
2386+
vars:
2387+
execution_time: '2019-12-01'
2388+
outputs:
2389+
query:
2390+
rows:
2391+
- col_exec_ds_time: '2019-12-01'
2392+
col_current_date: '2019-12-01'
2393+
"""
2394+
)
2395+
2396+
test_2025 = tmp_path / "tests" / f"test_2025_{model_name}_{i}.yaml"
2397+
test_2025.write_text(
2398+
f"""
2399+
test_2025_{model_name}_{i}:
2400+
model: {model_name}
2401+
vars:
2402+
execution_time: '2025-12-01'
2403+
outputs:
2404+
query:
2405+
rows:
2406+
- col_exec_ds_time: '2025-12-01'
2407+
col_current_date: '2025-12-01'
2408+
"""
2409+
)
2410+
2411+
ctx = Context(
2412+
paths=tmp_path,
2413+
config=Config(default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8)),
2414+
)
2415+
2416+
@model(
2417+
"py_model",
2418+
columns={"col_exec_ds_time": "timestamp_ntz", "col_current_date": "timestamp_ntz"},
2419+
)
2420+
def execute(context, start, end, execution_time, **kwargs):
2421+
datetime_now_utc = datetime.datetime.now(tz=datetime.timezone.utc)
2422+
2423+
context.engine_adapter.execute(exp.select("CURRENT_DATE()"))
2424+
current_date = context.engine_adapter.cursor.fetchone()[0]
2425+
2426+
return pd.DataFrame(
2427+
[{"col_exec_ds_time": datetime_now_utc, "col_current_date": current_date}]
2428+
)
2429+
2430+
python_model = model.get_registry()["py_model"].model(module_path=Path("."), path=Path("."))
2431+
2432+
ctx.upsert_model(
2433+
_create_model(
2434+
meta="MODEL(NAME sql_model)",
2435+
query="SELECT @execution_ds::timestamp_ntz AS col_exec_ds_time, current_date()::date AS col_current_date",
2436+
default_catalog=ctx.default_catalog,
2437+
)
2438+
)
2439+
2440+
ctx.upsert_model(python_model)
2441+
2442+
results = ctx.test()
2443+
assert len(results.successes) == 20

0 commit comments

Comments
 (0)