Skip to content

Commit 2e170ae

Browse files
committed
Reuse engine adapters and close at the end
1 parent 30b1f18 commit 2e170ae

2 files changed

Lines changed: 101 additions & 72 deletions

File tree

sqlmesh/core/test/__init__.py

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import concurrent
1212
from concurrent.futures import ThreadPoolExecutor
1313

14+
from sqlmesh.core.engine_adapter import EngineAdapter
1415
from sqlmesh.core.model import Model
1516
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
1617
from sqlmesh.core.test.discovery import (
@@ -20,28 +21,13 @@
2021
load_model_test_file as load_model_test_file,
2122
)
2223
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
24+
from sqlmesh.core.test.runner import ModelTextTestRunner as ModelTextTestRunner
2325
from sqlmesh.utils import UniqueKeyDict, Verbosity
2426

2527
if t.TYPE_CHECKING:
2628
from sqlmesh.core.config.loader import C
2729

2830

29-
class ModelTextTestRunner(unittest.TextTestRunner):
30-
def __init__(
31-
self,
32-
**kwargs: t.Any,
33-
) -> None:
34-
# StringIO is used to capture the output of the tests since we'll
35-
# run them in parallel and we don't want to mix the output streams
36-
from io import StringIO
37-
38-
super().__init__(
39-
stream=StringIO(),
40-
resultclass=ModelTextTestResult,
41-
**kwargs,
42-
)
43-
44-
4531
def log_test_report(results: ModelTextTestResult, test_duration: float) -> None:
4632
# Aggregate parallel test run results
4733
tests_run = results.testsRun
@@ -65,23 +51,23 @@ def log_test_report(results: ModelTextTestResult, test_duration: float) -> None:
6551

6652
stream.write("\n")
6753

68-
if errors or failures:
54+
for test_case, err in failures:
6955
stream.writeln(unittest.TextTestResult.separator1)
70-
for failure in failures:
71-
stream.writeln(f"FAIL: {failure[0]}")
56+
stream.writeln(f"FAIL: {test_case}")
57+
stream.writeln(unittest.TextTestResult.separator2)
58+
stream.writeln(err)
7259

60+
for error in errors:
61+
stream.writeln(unittest.TextTestResult.separator1)
62+
stream.writeln(f"ERROR: {error[1]}")
7363
stream.writeln(unittest.TextTestResult.separator2)
74-
for error in errors:
75-
stream.writeln(error[1])
76-
for failure in failures:
77-
stream.writeln(failure[1])
7864

7965
# Test report
8066
stream.writeln(unittest.TextTestResult.separator2)
8167
stream.writeln(
8268
f'Ran {tests_run} {"tests" if tests_run > 1 else "test"} in {test_duration:.3f}s \n'
8369
)
84-
stream.write(
70+
stream.writeln(
8571
f'{"OK" if is_success else "FAILED"}{" (" + ", ".join(infos) + ")" if infos else ""}'
8672
)
8773

@@ -106,6 +92,7 @@ def run_tests(
10692
verbosity: The verbosity level.
10793
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
10894
"""
95+
testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {}
10996
default_gateway = gateway or config.default_gateway_name
11097

11198
default_test_connection = config.get_test_connection(
@@ -122,65 +109,85 @@ def run_tests(
122109
descriptions=None,
123110
)
124111

125-
def _run_single_test(metadata: ModelTestMetadata) -> ModelTextTestResult:
126-
testing_engine_adapter = None
127-
128-
try:
129-
body = metadata.body
130-
gateway = body.get("gateway") or default_gateway
131-
132-
# Create new connection for each test to avoid concurrency issues
133-
testing_engine_adapter = config.get_test_connection(
134-
gateway,
135-
default_catalog,
136-
default_catalog_dialect,
137-
).create_engine_adapter(register_comments_override=False)
138-
139-
test = ModelTest.create_test(
140-
body=body,
141-
test_name=metadata.test_name,
142-
models=models,
143-
engine_adapter=testing_engine_adapter,
144-
dialect=dialect,
145-
path=metadata.path,
146-
default_catalog=default_catalog,
147-
preserve_fixtures=preserve_fixtures,
148-
)
112+
worker_payload = []
149113

150-
result = t.cast(
151-
ModelTextTestResult,
152-
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
114+
for metadata in model_test_metadata:
115+
gateway = metadata.body.get("gateway") or default_gateway
116+
test_connection = config.get_test_connection(
117+
gateway, default_catalog, default_catalog_dialect
118+
)
119+
120+
concurrent_tasks = test_connection.concurrent_tasks
121+
122+
from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig
123+
124+
is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig)
125+
126+
engine_adapter = None
127+
if is_duckdb_connection:
128+
# Ensure DuckDB connections are fully isolated from each other
129+
# by forcing the creation of a new adapter with SingletonConnectionPool
130+
test_connection.concurrent_tasks = 1
131+
engine_adapter = test_connection.create_engine_adapter(register_comments_override=False)
132+
test_connection.concurrent_tasks = concurrent_tasks
133+
elif gateway not in testing_adapter_by_gateway:
134+
# All other engines can share connections between threads
135+
testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter(
136+
register_comments_override=False
153137
)
154138

155-
with lock:
156-
if result.successes:
157-
combined_results.addSuccess(result.successes[0])
158-
elif result.errors:
159-
combined_results.addError(result.err[0], result.err[1])
160-
elif result.failures:
161-
combined_results.addFailure(result.err[0], result.err[1])
162-
elif result.skipped:
163-
skipped_args = result.skipped[0]
164-
combined_results.addSkip(skipped_args[0], skipped_args[1])
165-
166-
finally:
167-
if testing_engine_adapter:
168-
testing_engine_adapter.close()
139+
engine_adapter = engine_adapter or testing_adapter_by_gateway[gateway]
140+
worker_payload.append((metadata, engine_adapter))
141+
142+
def _run_single_test(
143+
metadata: ModelTestMetadata, engine_adapter: EngineAdapter
144+
) -> ModelTextTestResult:
145+
test = ModelTest.create_test(
146+
body=metadata.body,
147+
test_name=metadata.test_name,
148+
models=models,
149+
engine_adapter=engine_adapter,
150+
dialect=dialect,
151+
path=metadata.path,
152+
default_catalog=default_catalog,
153+
preserve_fixtures=preserve_fixtures,
154+
)
155+
156+
result = t.cast(
157+
ModelTextTestResult,
158+
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
159+
)
169160

161+
with lock:
162+
if result.successes:
163+
combined_results.addSuccess(result.successes[0])
164+
elif result.errors:
165+
combined_results.addError(result.err[0], result.err[1])
166+
elif result.failures:
167+
combined_results.addFailure(result.err[0], result.err[1])
168+
elif result.skipped:
169+
skipped_args = result.skipped[0]
170+
combined_results.addSkip(skipped_args[0], skipped_args[1])
170171
return result
171172

172173
test_results = []
173174

174175
workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
175176

176177
start_time = time.perf_counter()
177-
with ThreadPoolExecutor(max_workers=workers) as pool:
178-
futures = [
179-
pool.submit(_run_single_test, metadata=metadata) for metadata in model_test_metadata
180-
]
181-
182-
for future in concurrent.futures.as_completed(futures):
183-
test_results.append(future.result())
178+
try:
179+
with ThreadPoolExecutor(max_workers=workers) as pool:
180+
futures = [
181+
pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter)
182+
for metadata, engine_adapter in worker_payload
183+
]
184+
185+
for future in concurrent.futures.as_completed(futures):
186+
test_results.append(future.result())
187+
finally:
188+
for _, engine_adapter in worker_payload:
189+
if engine_adapter:
190+
engine_adapter.close()
184191

185192
end_time = time.perf_counter()
186193

sqlmesh/core/test/runner.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
import unittest
5+
6+
from sqlmesh.core.test.result import ModelTextTestResult
7+
8+
9+
class ModelTextTestRunner(unittest.TextTestRunner):
10+
def __init__(
11+
self,
12+
**kwargs: t.Any,
13+
) -> None:
14+
# StringIO is used to capture the output of the tests since we'll
15+
# run them in parallel and we don't want to mix the output streams
16+
from io import StringIO
17+
18+
super().__init__(
19+
stream=StringIO(),
20+
resultclass=ModelTextTestResult,
21+
**kwargs,
22+
)

0 commit comments

Comments
 (0)