Skip to content

Commit 0fc89dd

Browse files
authored
Feat!: Make model tests run concurrently (#4047)
1 parent 96ed132 commit 0fc89dd

4 files changed

Lines changed: 373 additions & 124 deletions

File tree

sqlmesh/core/test/__init__.py

Lines changed: 4 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
from __future__ import annotations
22

3-
import pathlib
4-
import typing as t
5-
import unittest
6-
7-
from sqlmesh.core.engine_adapter import EngineAdapter
8-
from sqlmesh.core.model import Model
93
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
104
from sqlmesh.core.test.discovery import (
115
ModelTestMetadata as ModelTestMetadata,
@@ -14,121 +8,7 @@
148
load_model_test_file as load_model_test_file,
159
)
1610
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
17-
from sqlmesh.utils import UniqueKeyDict, Verbosity
18-
19-
if t.TYPE_CHECKING:
20-
from sqlmesh.core.config.loader import C
21-
22-
23-
def run_tests(
24-
model_test_metadata: list[ModelTestMetadata],
25-
models: UniqueKeyDict[str, Model],
26-
config: C,
27-
gateway: t.Optional[str] = None,
28-
dialect: str | None = None,
29-
verbosity: Verbosity = Verbosity.DEFAULT,
30-
preserve_fixtures: bool = False,
31-
stream: t.TextIO | None = None,
32-
default_catalog: str | None = None,
33-
default_catalog_dialect: str = "",
34-
) -> ModelTextTestResult:
35-
"""Create a test suite of ModelTest objects and run it.
36-
37-
Args:
38-
model_test_metadata: A list of ModelTestMetadata named tuples.
39-
models: All models to use for expansion and mapping of physical locations.
40-
verbosity: The verbosity level.
41-
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
42-
"""
43-
testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {}
44-
default_gateway = gateway or config.default_gateway_name
45-
46-
try:
47-
tests = []
48-
for metadata in model_test_metadata:
49-
body = metadata.body
50-
gateway = body.get("gateway") or default_gateway
51-
testing_engine_adapter = testing_adapter_by_gateway.get(gateway)
52-
if not testing_engine_adapter:
53-
testing_engine_adapter = config.get_test_connection(
54-
gateway,
55-
default_catalog,
56-
default_catalog_dialect,
57-
).create_engine_adapter(register_comments_override=False)
58-
testing_adapter_by_gateway[gateway] = testing_engine_adapter
59-
60-
test = ModelTest.create_test(
61-
body=body,
62-
test_name=metadata.test_name,
63-
models=models,
64-
engine_adapter=testing_engine_adapter,
65-
dialect=dialect,
66-
path=metadata.path,
67-
default_catalog=default_catalog,
68-
preserve_fixtures=preserve_fixtures,
69-
)
70-
if test:
71-
tests.append(test)
72-
73-
result = t.cast(
74-
ModelTextTestResult,
75-
unittest.TextTestRunner(
76-
stream=stream,
77-
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
78-
resultclass=ModelTextTestResult,
79-
).run(unittest.TestSuite(tests)),
80-
)
81-
finally:
82-
for testing_engine_adapter in testing_adapter_by_gateway.values():
83-
testing_engine_adapter.close()
84-
85-
return result
86-
87-
88-
def run_model_tests(
89-
tests: list[str],
90-
models: UniqueKeyDict[str, Model],
91-
config: C,
92-
gateway: t.Optional[str] = None,
93-
dialect: str | None = None,
94-
verbosity: Verbosity = Verbosity.DEFAULT,
95-
patterns: list[str] | None = None,
96-
preserve_fixtures: bool = False,
97-
stream: t.TextIO | None = None,
98-
default_catalog: t.Optional[str] = None,
99-
default_catalog_dialect: str = "",
100-
) -> ModelTextTestResult:
101-
"""Load and run tests.
102-
103-
Args:
104-
tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order]
105-
models: All models to use for expansion and mapping of physical locations.
106-
verbosity: The verbosity level.
107-
patterns: A list of patterns to match against.
108-
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
109-
"""
110-
loaded_tests = []
111-
for test in tests:
112-
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
113-
path = pathlib.Path(filename)
114-
115-
if test_name:
116-
loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name])
117-
else:
118-
loaded_tests.extend(load_model_test_file(path, variables=config.variables).values())
119-
120-
if patterns:
121-
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)
122-
123-
return run_tests(
124-
loaded_tests,
125-
models,
126-
config,
127-
gateway=gateway,
128-
dialect=dialect,
129-
verbosity=verbosity,
130-
preserve_fixtures=preserve_fixtures,
131-
stream=stream,
132-
default_catalog=default_catalog,
133-
default_catalog_dialect=default_catalog_dialect,
134-
)
11+
from sqlmesh.core.test.runner import (
12+
run_model_tests as run_model_tests,
13+
run_tests as run_tests,
14+
)

sqlmesh/core/test/result.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,19 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None:
4949
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
5050
"""
5151
exctype, value, tb = err
52+
self.original_err = (test, err)
5253
return super().addFailure(test, (exctype, value, None)) # type: ignore
5354

55+
def addError(self, test: unittest.TestCase, err: ErrorType) -> None:
56+
"""Called when the test case test signals an error.
57+
58+
Args:
59+
test: The test case.
60+
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
61+
"""
62+
self.original_err = (test, err)
63+
return super().addError(test, err) # type: ignore
64+
5465
def addSuccess(self, test: unittest.TestCase) -> None:
5566
"""Called when the test case test succeeds.
5667
@@ -59,3 +70,50 @@ def addSuccess(self, test: unittest.TestCase) -> None:
5970
"""
6071
super().addSuccess(test)
6172
self.successes.append(test)
73+
74+
def log_test_report(self, test_duration: float) -> None:
75+
"""
76+
Log the test report following unittest's conventions.
77+
78+
Args:
79+
test_duration: The duration of the tests.
80+
"""
81+
tests_run = self.testsRun
82+
errors = self.errors
83+
failures = self.failures
84+
skipped = self.skipped
85+
86+
is_success = not (errors or failures)
87+
88+
infos = []
89+
if failures:
90+
infos.append(f"failures={len(failures)}")
91+
if errors:
92+
infos.append(f"errors={len(errors)}")
93+
if skipped:
94+
infos.append(f"skipped={skipped}")
95+
96+
stream = self.stream
97+
98+
stream.write("\n")
99+
100+
for test_case, failure in failures:
101+
stream.writeln(unittest.TextTestResult.separator1)
102+
stream.writeln(f"FAIL: {test_case}")
103+
stream.writeln(f"{test_case.shortDescription()}")
104+
stream.writeln(unittest.TextTestResult.separator2)
105+
stream.writeln(failure)
106+
107+
for _, error in errors:
108+
stream.writeln(unittest.TextTestResult.separator1)
109+
stream.writeln(f"ERROR: {error}")
110+
stream.writeln(unittest.TextTestResult.separator2)
111+
112+
# Output final report
113+
stream.writeln(unittest.TextTestResult.separator2)
114+
stream.writeln(
115+
f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n'
116+
)
117+
stream.writeln(
118+
f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}'
119+
)

0 commit comments

Comments
 (0)