Skip to content

Commit 7f98b34

Browse files
committed
PR Feedback 3
1 parent 205aa3d commit 7f98b34

4 files changed

Lines changed: 304 additions & 201 deletions

File tree

sqlmesh/core/test/__init__.py

Lines changed: 2 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
11
from __future__ import annotations
22

3-
import sys
4-
import time
5-
import pathlib
6-
import threading
7-
import typing as t
8-
import unittest
9-
10-
11-
import concurrent
12-
from concurrent.futures import ThreadPoolExecutor
13-
14-
from sqlmesh.core.engine_adapter import EngineAdapter
15-
from sqlmesh.core.model import Model
163
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
174
from sqlmesh.core.test.discovery import (
185
ModelTestMetadata as ModelTestMetadata,
@@ -22,185 +9,6 @@
229
)
2310
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
2411
from sqlmesh.core.test.runner import (
25-
ModelTextTestRunner,
26-
log_test_report,
12+
run_model_tests as run_model_tests,
13+
run_tests as run_tests,
2714
)
28-
from sqlmesh.utils import UniqueKeyDict, Verbosity
29-
30-
if t.TYPE_CHECKING:
31-
from sqlmesh.core.config.loader import C
32-
33-
34-
def run_tests(
35-
model_test_metadata: list[ModelTestMetadata],
36-
models: UniqueKeyDict[str, Model],
37-
config: C,
38-
gateway: t.Optional[str] = None,
39-
dialect: str | None = None,
40-
verbosity: Verbosity = Verbosity.DEFAULT,
41-
preserve_fixtures: bool = False,
42-
stream: t.TextIO | None = None,
43-
default_catalog: str | None = None,
44-
default_catalog_dialect: str = "",
45-
) -> ModelTextTestResult:
46-
"""Create a test suite of ModelTest objects and run it.
47-
48-
Args:
49-
model_test_metadata: A list of ModelTestMetadata named tuples.
50-
models: All models to use for expansion and mapping of physical locations.
51-
verbosity: The verbosity level.
52-
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
53-
"""
54-
testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {}
55-
default_gateway = gateway or config.default_gateway_name
56-
57-
default_test_connection = config.get_test_connection(
58-
gateway_name=default_gateway,
59-
default_catalog=default_catalog,
60-
default_catalog_dialect=default_catalog_dialect,
61-
)
62-
63-
lock = threading.Lock()
64-
65-
combined_results = ModelTextTestResult(
66-
stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore
67-
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
68-
descriptions=None,
69-
)
70-
71-
worker_payload = []
72-
73-
for metadata in model_test_metadata:
74-
gateway = metadata.body.get("gateway") or default_gateway
75-
test_connection = config.get_test_connection(
76-
gateway, default_catalog, default_catalog_dialect
77-
)
78-
79-
concurrent_tasks = test_connection.concurrent_tasks
80-
81-
from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig
82-
83-
is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig)
84-
85-
engine_adapter = None
86-
if is_duckdb_connection:
87-
# Ensure DuckDB connections are fully isolated from each other
88-
# by forcing the creation of a new adapter with SingletonConnectionPool
89-
test_connection.concurrent_tasks = 1
90-
engine_adapter = test_connection.create_engine_adapter(register_comments_override=False)
91-
test_connection.concurrent_tasks = concurrent_tasks
92-
elif gateway not in testing_adapter_by_gateway:
93-
# All other engines can share connections between threads
94-
testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter(
95-
register_comments_override=False
96-
)
97-
98-
engine_adapter = engine_adapter or testing_adapter_by_gateway[gateway]
99-
worker_payload.append((metadata, engine_adapter))
100-
101-
def _run_single_test(
102-
metadata: ModelTestMetadata, engine_adapter: EngineAdapter
103-
) -> ModelTextTestResult:
104-
test = ModelTest.create_test(
105-
body=metadata.body,
106-
test_name=metadata.test_name,
107-
models=models,
108-
engine_adapter=engine_adapter,
109-
dialect=dialect,
110-
path=metadata.path,
111-
default_catalog=default_catalog,
112-
preserve_fixtures=preserve_fixtures,
113-
)
114-
115-
result = t.cast(
116-
ModelTextTestResult,
117-
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
118-
)
119-
120-
with lock:
121-
if result.successes:
122-
combined_results.addSuccess(result.successes[0])
123-
elif result.errors:
124-
combined_results.addError(result.err[0], result.err[1])
125-
elif result.failures:
126-
combined_results.addFailure(result.err[0], result.err[1])
127-
elif result.skipped:
128-
skipped_args = result.skipped[0]
129-
combined_results.addSkip(skipped_args[0], skipped_args[1])
130-
return result
131-
132-
test_results = []
133-
134-
workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks)
135-
136-
start_time = time.perf_counter()
137-
try:
138-
with ThreadPoolExecutor(max_workers=workers) as pool:
139-
futures = [
140-
pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter)
141-
for metadata, engine_adapter in worker_payload
142-
]
143-
144-
for future in concurrent.futures.as_completed(futures):
145-
test_results.append(future.result())
146-
finally:
147-
for _, engine_adapter in worker_payload:
148-
if engine_adapter:
149-
engine_adapter.close()
150-
151-
end_time = time.perf_counter()
152-
153-
combined_results.testsRun = len(test_results)
154-
155-
log_test_report(combined_results, test_duration=end_time - start_time)
156-
157-
return combined_results
158-
159-
160-
def run_model_tests(
161-
tests: list[str],
162-
models: UniqueKeyDict[str, Model],
163-
config: C,
164-
gateway: t.Optional[str] = None,
165-
dialect: str | None = None,
166-
verbosity: Verbosity = Verbosity.DEFAULT,
167-
patterns: list[str] | None = None,
168-
preserve_fixtures: bool = False,
169-
stream: t.TextIO | None = None,
170-
default_catalog: t.Optional[str] = None,
171-
default_catalog_dialect: str = "",
172-
) -> ModelTextTestResult:
173-
"""Load and run tests.
174-
175-
Args:
176-
tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order]
177-
models: All models to use for expansion and mapping of physical locations.
178-
verbosity: The verbosity level.
179-
patterns: A list of patterns to match against.
180-
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
181-
"""
182-
loaded_tests = []
183-
for test in tests:
184-
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
185-
path = pathlib.Path(filename)
186-
187-
if test_name:
188-
loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name])
189-
else:
190-
loaded_tests.extend(load_model_test_file(path, variables=config.variables).values())
191-
192-
if patterns:
193-
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)
194-
195-
return run_tests(
196-
loaded_tests,
197-
models,
198-
config,
199-
gateway=gateway,
200-
dialect=dialect,
201-
verbosity=verbosity,
202-
preserve_fixtures=preserve_fixtures,
203-
stream=stream,
204-
default_catalog=default_catalog,
205-
default_catalog_dialect=default_catalog_dialect,
206-
)

sqlmesh/core/test/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None:
4949
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
5050
"""
5151
exctype, value, tb = err
52-
self.err = (test, err)
52+
self.original_err = (test, err)
5353
return super().addFailure(test, (exctype, value, None)) # type: ignore
5454

5555
def addError(self, test: unittest.TestCase, err: ErrorType) -> None:
@@ -61,7 +61,7 @@ def addError(self, test: unittest.TestCase, err: ErrorType) -> None:
6161
test: The test case.
6262
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
6363
"""
64-
self.err = (test, err)
64+
self.original_err = (test, err)
6565
return super().addError(test, err) # type: ignore
6666

6767
def addSuccess(self, test: unittest.TestCase) -> None:

0 commit comments

Comments
 (0)