Skip to content

Commit 30b1f18

Browse files
committed
Feat: Make model tests concurrent
1 parent 213e010 commit 30b1f18

2 files changed

Lines changed: 147 additions & 28 deletions

File tree

sqlmesh/core/test/__init__.py

Lines changed: 134 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from __future__ import annotations
22

3+
import sys
4+
import time
35
import pathlib
6+
import threading
47
import typing as t
58
import unittest
69

7-
from sqlmesh.core.engine_adapter import EngineAdapter
10+
11+
import concurrent
12+
from concurrent.futures import ThreadPoolExecutor
13+
814
from sqlmesh.core.model import Model
915
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
1016
from sqlmesh.core.test.discovery import (
@@ -20,6 +26,66 @@
2026
from sqlmesh.core.config.loader import C
2127

2228

29+
class ModelTextTestRunner(unittest.TextTestRunner):
30+
def __init__(
31+
self,
32+
**kwargs: t.Any,
33+
) -> None:
34+
# StringIO is used to capture the output of the tests since we'll
35+
# run them in parallel and we don't want to mix the output streams
36+
from io import StringIO
37+
38+
super().__init__(
39+
stream=StringIO(),
40+
resultclass=ModelTextTestResult,
41+
**kwargs,
42+
)
43+
44+
45+
def log_test_report(results: ModelTextTestResult, test_duration: float) -> None:
46+
# Aggregate parallel test run results
47+
tests_run = results.testsRun
48+
errors = results.errors
49+
failures = results.failures
50+
skipped = results.skipped
51+
52+
is_success = not (errors or failures)
53+
54+
# Compute test info
55+
infos = []
56+
if failures:
57+
infos.append(f"failures={len(failures)}")
58+
if errors:
59+
infos.append(f"errors={len(errors)}")
60+
if skipped:
61+
infos.append(f"skipped={skipped}")
62+
63+
# Report test errors
64+
stream = results.stream
65+
66+
stream.write("\n")
67+
68+
if errors or failures:
69+
stream.writeln(unittest.TextTestResult.separator1)
70+
for failure in failures:
71+
stream.writeln(f"FAIL: {failure[0]}")
72+
73+
stream.writeln(unittest.TextTestResult.separator2)
74+
for error in errors:
75+
stream.writeln(error[1])
76+
for failure in failures:
77+
stream.writeln(failure[1])
78+
79+
# Test report
80+
stream.writeln(unittest.TextTestResult.separator2)
81+
stream.writeln(
82+
f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n'
83+
)
84+
stream.write(
85+
f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}'
86+
)
87+
88+
2389
def run_tests(
2490
model_test_metadata: list[ModelTestMetadata],
2591
models: UniqueKeyDict[str, Model],
@@ -40,22 +106,35 @@ def run_tests(
40106
verbosity: The verbosity level.
41107
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
42108
"""
43-
testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {}
44109
default_gateway = gateway or config.default_gateway_name
45110

46-
try:
47-
tests = []
48-
for metadata in model_test_metadata:
111+
default_test_connection = config.get_test_connection(
112+
gateway_name=default_gateway,
113+
default_catalog=default_catalog,
114+
default_catalog_dialect=default_catalog_dialect,
115+
)
116+
117+
lock = threading.Lock()
118+
119+
combined_results = ModelTextTestResult(
120+
stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore
121+
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
122+
descriptions=None,
123+
)
124+
125+
def _run_single_test(metadata: ModelTestMetadata) -> ModelTextTestResult:
126+
testing_engine_adapter = None
127+
128+
try:
49129
body = metadata.body
50130
gateway = body.get("gateway") or default_gateway
51-
testing_engine_adapter = testing_adapter_by_gateway.get(gateway)
52-
if not testing_engine_adapter:
53-
testing_engine_adapter = config.get_test_connection(
54-
gateway,
55-
default_catalog,
56-
default_catalog_dialect,
57-
).create_engine_adapter(register_comments_override=False)
58-
testing_adapter_by_gateway[gateway] = testing_engine_adapter
131+
132+
# Create new connection for each test to avoid concurrency issues
133+
testing_engine_adapter = config.get_test_connection(
134+
gateway,
135+
default_catalog,
136+
default_catalog_dialect,
137+
).create_engine_adapter(register_comments_override=False)
59138

60139
test = ModelTest.create_test(
61140
body=body,
@@ -67,22 +146,49 @@ def run_tests(
67146
default_catalog=default_catalog,
68147
preserve_fixtures=preserve_fixtures,
69148
)
70-
if test:
71-
tests.append(test)
72-
73-
result = t.cast(
74-
ModelTextTestResult,
75-
unittest.TextTestRunner(
76-
stream=stream,
77-
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
78-
resultclass=ModelTextTestResult,
79-
).run(unittest.TestSuite(tests)),
80-
)
81-
finally:
82-
for testing_engine_adapter in testing_adapter_by_gateway.values():
83-
testing_engine_adapter.close()
84149

85-
return result
150+
result = t.cast(
151+
ModelTextTestResult,
152+
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
153+
)
154+
155+
with lock:
156+
if result.successes:
157+
combined_results.addSuccess(result.successes[0])
158+
elif result.errors:
159+
combined_results.addError(result.err[0], result.err[1])
160+
elif result.failures:
161+
combined_results.addFailure(result.err[0], result.err[1])
162+
elif result.skipped:
163+
skipped_args = result.skipped[0]
164+
combined_results.addSkip(skipped_args[0], skipped_args[1])
165+
166+
finally:
167+
if testing_engine_adapter:
168+
testing_engine_adapter.close()
169+
170+
return result
171+
172+
test_results = []
173+
174+
workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
175+
176+
start_time = time.perf_counter()
177+
with ThreadPoolExecutor(max_workers=workers) as pool:
178+
futures = [
179+
pool.submit(_run_single_test, metadata=metadata) for metadata in model_test_metadata
180+
]
181+
182+
for future in concurrent.futures.as_completed(futures):
183+
test_results.append(future.result())
184+
185+
end_time = time.perf_counter()
186+
187+
combined_results.testsRun = len(test_results)
188+
189+
log_test_report(combined_results, test_duration=end_time - start_time)
190+
191+
return combined_results
86192

87193

88194
def run_model_tests(

sqlmesh/core/test/result.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,21 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None:
4949
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
5050
"""
5151
exctype, value, tb = err
52+
self.err = (test, err)
5253
return super().addFailure(test, (exctype, value, None)) # type: ignore
5354

55+
def addError(self, test: unittest.TestCase, err: ErrorType) -> None:
56+
"""Called when the test case test signals an error.
57+
58+
The traceback is suppressed because it is redundant and not useful.
59+
60+
Args:
61+
test: The test case.
62+
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
63+
"""
64+
self.err = (test, err)
65+
return super().addError(test, err) # type: ignore
66+
5467
def addSuccess(self, test: unittest.TestCase) -> None:
5568
"""Called when the test case test succeeds.
5669

0 commit comments

Comments
 (0)