Skip to content

Commit 19e24a3

Browse files
committed
Fix: Concurrent dialect patching in model testing
1 parent fb5c52f commit 19e24a3

4 files changed

Lines changed: 113 additions & 24 deletions

File tree

sqlmesh/core/test/definition.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import datetime
4+
import threading
45
import typing as t
56
import unittest
67
from collections import Counter
@@ -57,6 +58,7 @@ def __init__(
5758
path: Path | None = None,
5859
preserve_fixtures: bool = False,
5960
default_catalog: str | None = None,
61+
lock: t.Optional[threading.Lock] = None,
6062
) -> None:
6163
"""ModelTest encapsulates a unit test for a model.
6264
@@ -79,6 +81,7 @@ def __init__(
7981
self.preserve_fixtures = preserve_fixtures
8082
self.default_catalog = default_catalog
8183
self.dialect = dialect
84+
self.lock = lock
8285

8386
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
8487
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -102,6 +105,7 @@ def __init__(
102105
)
103106
self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog)
104107

108+
self._exec_time_transforms: t.Dict[type[exp.Expression], exp.Expression] = {}
105109
self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS
106110
self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "")
107111

@@ -112,20 +116,20 @@ def __init__(
112116
# When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it
113117
if self._execution_time:
114118
exec_time = exp.Literal.string(self._execution_time)
119+
120+
self._exec_time_transforms = {
121+
exp.CurrentDate: exp.cast(exec_time, "date", dialect=dialect),
122+
exp.CurrentDatetime: exp.cast(exec_time, "datetime", dialect=dialect),
123+
exp.CurrentTime: exp.cast(exec_time, "time", dialect=dialect),
124+
exp.CurrentTimestamp: exp.cast(exec_time, "timestamp", dialect=dialect),
125+
}
126+
115127
self._transforms = {
116128
**self._transforms,
117-
exp.CurrentDate: lambda self, _: self.sql(
118-
exp.cast(exec_time, "date", dialect=dialect)
119-
),
120-
exp.CurrentDatetime: lambda self, _: self.sql(
121-
exp.cast(exec_time, "datetime", dialect=dialect)
122-
),
123-
exp.CurrentTime: lambda self, _: self.sql(
124-
exp.cast(exec_time, "time", dialect=dialect)
125-
),
126-
exp.CurrentTimestamp: lambda self, _: self.sql(
127-
exp.cast(exec_time, "timestamp", dialect=dialect)
128-
),
129+
**{
130+
key: lambda self, _: self.sql(value)
131+
for key, value in self._exec_time_transforms.items()
132+
},
129133
}
130134

131135
super().__init__()
@@ -310,6 +314,7 @@ def create_test(
310314
path: Path | None,
311315
preserve_fixtures: bool = False,
312316
default_catalog: str | None = None,
317+
lock: t.Optional[threading.Lock] = None,
313318
) -> t.Optional[ModelTest]:
314319
"""Create a SqlModelTest or a PythonModelTest.
315320
@@ -353,6 +358,7 @@ def create_test(
353358
path,
354359
preserve_fixtures,
355360
default_catalog,
361+
lock=lock,
356362
)
357363

358364
def __str__(self) -> str:
@@ -514,8 +520,13 @@ def _normalize_column_name(self, name: str) -> str:
514520

515521
def _execute(self, query: exp.Query) -> pd.DataFrame:
516522
"""Executes the given query using the testing engine adapter and returns a DataFrame."""
517-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
518-
return self.engine_adapter.fetchdf(query)
523+
524+
def replace_execution_time(expression: exp.Expression) -> exp.Expression:
525+
return self._exec_time_transforms.get(type(expression), expression)
526+
527+
return self.engine_adapter.fetchdf(
528+
query.transform(replace_execution_time) if self._execution_time else query
529+
)
519530

520531
def _create_df(
521532
self,
@@ -626,6 +637,7 @@ def __init__(
626637
path: Path | None = None,
627638
preserve_fixtures: bool = False,
628639
default_catalog: str | None = None,
640+
lock: t.Optional[threading.Lock] = None,
629641
) -> None:
630642
"""PythonModelTest encapsulates a unit test for a Python model.
631643
@@ -651,6 +663,7 @@ def __init__(
651663
path,
652664
preserve_fixtures,
653665
default_catalog,
666+
lock,
654667
)
655668

656669
self.context = TestExecutionContext(
@@ -681,15 +694,18 @@ def _execute_model(self) -> pd.DataFrame:
681694
else:
682695
time_ctx = nullcontext()
683696

684-
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
685-
with time_ctx:
686-
variables = self.body.get("vars", {}).copy()
687-
time_kwargs = {
688-
key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables
689-
}
690-
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
691-
assert not isinstance(df, exp.Expression)
692-
return df if isinstance(df, pd.DataFrame) else df.toPandas()
697+
with self.lock or nullcontext():
698+
with patch.dict(
699+
self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms
700+
):
701+
with time_ctx:
702+
variables = self.body.get("vars", {}).copy()
703+
time_kwargs = {
704+
key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables
705+
}
706+
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
707+
assert not isinstance(df, exp.Expression)
708+
return df if isinstance(df, pd.DataFrame) else df.toPandas()
693709

694710

695711
def generate_test(

sqlmesh/core/test/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def log_test_report(self, test_duration: float) -> None:
100100
for test_case, failure in failures:
101101
stream.writeln(unittest.TextTestResult.separator1)
102102
stream.writeln(f"FAIL: {test_case}")
103-
stream.writeln(f"{test_case.shortDescription()}")
103+
if test_description := test_case.shortDescription():
104+
stream.writeln(test_description)
104105
stream.writeln(unittest.TextTestResult.separator2)
105106
stream.writeln(failure)
106107

sqlmesh/core/test/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _run_single_test(
132132
path=metadata.path,
133133
default_catalog=default_catalog,
134134
preserve_fixtures=preserve_fixtures,
135+
lock=lock if num_workers > 1 else None,
135136
)
136137

137138
if not test:

tests/core/test_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,3 +2370,74 @@ def test_number_of_tests_found(tmp_path: Path) -> None:
23702370
# Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest
23712371
results = context.test(tests=[f"{test_file}::test_example_full_model2"])
23722372
assert len(results.successes) == 1
2373+
2374+
2375+
def test_freeze_time_concurrent(tmp_path: Path) -> None:
2376+
tests_dir = tmp_path / "tests"
2377+
tests_dir.mkdir()
2378+
2379+
for model_name in ["sql_model", "py_model"]:
2380+
for i in range(5):
2381+
test_2019 = tmp_path / "tests" / f"test_2019_{model_name}_{i}.yaml"
2382+
test_2019.write_text(
2383+
f"""
2384+
test_2019_{model_name}_{i}:
2385+
model: {model_name}
2386+
vars:
2387+
execution_time: '2019-12-01'
2388+
outputs:
2389+
query:
2390+
rows:
2391+
- col_exec_ds_time: '2019-12-01'
2392+
col_current_date: '2019-12-01'
2393+
"""
2394+
)
2395+
2396+
test_2025 = tmp_path / "tests" / f"test_2025_{model_name}_{i}.yaml"
2397+
test_2025.write_text(
2398+
f"""
2399+
test_2025_{model_name}_{i}:
2400+
model: {model_name}
2401+
vars:
2402+
execution_time: '2025-12-01'
2403+
outputs:
2404+
query:
2405+
rows:
2406+
- col_exec_ds_time: '2025-12-01'
2407+
col_current_date: '2025-12-01'
2408+
"""
2409+
)
2410+
2411+
ctx = Context(
2412+
paths=tmp_path,
2413+
config=Config(default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8)),
2414+
)
2415+
2416+
@model(
2417+
"py_model",
2418+
columns={"col_exec_ds_time": "timestamp_ntz", "col_current_date": "timestamp_ntz"},
2419+
)
2420+
def execute(context, start, end, execution_time, **kwargs):
2421+
datetime_now_utc = datetime.datetime.now(tz=datetime.timezone.utc)
2422+
2423+
context.engine_adapter.execute(exp.select("CURRENT_DATE()"))
2424+
current_date = context.engine_adapter.cursor.fetchone()[0]
2425+
2426+
return pd.DataFrame(
2427+
[{"col_exec_ds_time": datetime_now_utc, "col_current_date": current_date}]
2428+
)
2429+
2430+
python_model = model.get_registry()["py_model"].model(module_path=Path("."), path=Path("."))
2431+
2432+
ctx.upsert_model(
2433+
_create_model(
2434+
meta="MODEL(NAME sql_model)",
2435+
query="SELECT @execution_ds::timestamp_ntz AS col_exec_ds_time, current_date()::date AS col_current_date",
2436+
default_catalog=ctx.default_catalog,
2437+
)
2438+
)
2439+
2440+
ctx.upsert_model(python_model)
2441+
2442+
results = ctx.test()
2443+
assert len(results.successes) == 20

0 commit comments

Comments
 (0)