1111import concurrent
1212from concurrent .futures import ThreadPoolExecutor
1313
14+ from sqlmesh .core .engine_adapter import EngineAdapter
1415from sqlmesh .core .model import Model
1516from sqlmesh .core .test .definition import ModelTest as ModelTest , generate_test as generate_test
1617from sqlmesh .core .test .discovery import (
2021 load_model_test_file as load_model_test_file ,
2122)
2223from sqlmesh .core .test .result import ModelTextTestResult as ModelTextTestResult
24+ from sqlmesh .core .test .runner import ModelTextTestRunner as ModelTextTestRunner
2325from sqlmesh .utils import UniqueKeyDict , Verbosity
2426
2527if t .TYPE_CHECKING :
2628 from sqlmesh .core .config .loader import C
2729
2830
29- class ModelTextTestRunner (unittest .TextTestRunner ):
30- def __init__ (
31- self ,
32- ** kwargs : t .Any ,
33- ) -> None :
34- # StringIO is used to capture the output of the tests since we'll
35- # run them in parallel and we don't want to mix the output streams
36- from io import StringIO
37-
38- super ().__init__ (
39- stream = StringIO (),
40- resultclass = ModelTextTestResult ,
41- ** kwargs ,
42- )
43-
44-
4531def log_test_report (results : ModelTextTestResult , test_duration : float ) -> None :
4632 # Aggregate parallel test run results
4733 tests_run = results .testsRun
@@ -65,23 +51,23 @@ def log_test_report(results: ModelTextTestResult, test_duration: float) -> None:
6551
6652 stream .write ("\n " )
6753
68- if errors or failures :
54+ for test_case , err in failures :
6955 stream .writeln (unittest .TextTestResult .separator1 )
70- for failure in failures :
71- stream .writeln (f"FAIL: { failure [0 ]} " )
56+ stream .writeln (f"FAIL: { test_case } " )
57+ stream .writeln (unittest .TextTestResult .separator2 )
58+ stream .writeln (err )
7259
60+ for error in errors :
61+ stream .writeln (unittest .TextTestResult .separator1 )
62+ stream .writeln (f"ERROR: { error [1 ]} " )
7363 stream .writeln (unittest .TextTestResult .separator2 )
74- for error in errors :
75- stream .writeln (error [1 ])
76- for failure in failures :
77- stream .writeln (failure [1 ])
7864
7965 # Test report
8066 stream .writeln (unittest .TextTestResult .separator2 )
8167 stream .writeln (
8268 f'Ran { tests_run } { "tests" if tests_run > 1 else "test" } in { test_duration :.3f} s \n '
8369 )
84- stream .write (
70+ stream .writeln (
8571 f'{ "OK" if is_success else "FAILED" } { " (" + ", " .join (infos ) + ")" if infos else "" } '
8672 )
8773
@@ -106,6 +92,7 @@ def run_tests(
10692 verbosity: The verbosity level.
10793 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
10894 """
95+ testing_adapter_by_gateway : t .Dict [str , EngineAdapter ] = {}
10996 default_gateway = gateway or config .default_gateway_name
11097
11198 default_test_connection = config .get_test_connection (
@@ -122,65 +109,85 @@ def run_tests(
122109 descriptions = None ,
123110 )
124111
125- def _run_single_test (metadata : ModelTestMetadata ) -> ModelTextTestResult :
126- testing_engine_adapter = None
127-
128- try :
129- body = metadata .body
130- gateway = body .get ("gateway" ) or default_gateway
131-
132- # Create new connection for each test to avoid concurrency issues
133- testing_engine_adapter = config .get_test_connection (
134- gateway ,
135- default_catalog ,
136- default_catalog_dialect ,
137- ).create_engine_adapter (register_comments_override = False )
138-
139- test = ModelTest .create_test (
140- body = body ,
141- test_name = metadata .test_name ,
142- models = models ,
143- engine_adapter = testing_engine_adapter ,
144- dialect = dialect ,
145- path = metadata .path ,
146- default_catalog = default_catalog ,
147- preserve_fixtures = preserve_fixtures ,
148- )
112+ worker_payload = []
149113
150- result = t .cast (
151- ModelTextTestResult ,
152- ModelTextTestRunner ().run (t .cast (unittest .TestCase , test )),
114+ for metadata in model_test_metadata :
115+ gateway = metadata .body .get ("gateway" ) or default_gateway
116+ test_connection = config .get_test_connection (
117+ gateway , default_catalog , default_catalog_dialect
118+ )
119+
120+ concurrent_tasks = test_connection .concurrent_tasks
121+
122+ from sqlmesh .core .config .connection import BaseDuckDBConnectionConfig
123+
124+ is_duckdb_connection = isinstance (test_connection , BaseDuckDBConnectionConfig )
125+
126+ engine_adapter = None
127+ if is_duckdb_connection :
128+ # Ensure DuckDB connections are fully isolated from each other
129+ # by forcing the creation of a new adapter with SingletonConnectionPool
130+ test_connection .concurrent_tasks = 1
131+ engine_adapter = test_connection .create_engine_adapter (register_comments_override = False )
132+ test_connection .concurrent_tasks = concurrent_tasks
133+ elif gateway not in testing_adapter_by_gateway :
134+ # All other engines can share connections between threads
135+ testing_adapter_by_gateway [gateway ] = test_connection .create_engine_adapter (
136+ register_comments_override = False
153137 )
154138
155- with lock :
156- if result .successes :
157- combined_results .addSuccess (result .successes [0 ])
158- elif result .errors :
159- combined_results .addError (result .err [0 ], result .err [1 ])
160- elif result .failures :
161- combined_results .addFailure (result .err [0 ], result .err [1 ])
162- elif result .skipped :
163- skipped_args = result .skipped [0 ]
164- combined_results .addSkip (skipped_args [0 ], skipped_args [1 ])
165-
166- finally :
167- if testing_engine_adapter :
168- testing_engine_adapter .close ()
139+ engine_adapter = engine_adapter or testing_adapter_by_gateway [gateway ]
140+ worker_payload .append ((metadata , engine_adapter ))
141+
142+ def _run_single_test (
143+ metadata : ModelTestMetadata , engine_adapter : EngineAdapter
144+ ) -> ModelTextTestResult :
145+ test = ModelTest .create_test (
146+ body = metadata .body ,
147+ test_name = metadata .test_name ,
148+ models = models ,
149+ engine_adapter = engine_adapter ,
150+ dialect = dialect ,
151+ path = metadata .path ,
152+ default_catalog = default_catalog ,
153+ preserve_fixtures = preserve_fixtures ,
154+ )
155+
156+ result = t .cast (
157+ ModelTextTestResult ,
158+ ModelTextTestRunner ().run (t .cast (unittest .TestCase , test )),
159+ )
169160
161+ with lock :
162+ if result .successes :
163+ combined_results .addSuccess (result .successes [0 ])
164+ elif result .errors :
165+ combined_results .addError (result .err [0 ], result .err [1 ])
166+ elif result .failures :
167+ combined_results .addFailure (result .err [0 ], result .err [1 ])
168+ elif result .skipped :
169+ skipped_args = result .skipped [0 ]
170+ combined_results .addSkip (skipped_args [0 ], skipped_args [1 ])
170171 return result
171172
172173 test_results = []
173174
174175 workers = min (len (model_test_metadata ) or 1 , default_test_connection .concurrent_tasks )
175176
176177 start_time = time .perf_counter ()
177- with ThreadPoolExecutor (max_workers = workers ) as pool :
178- futures = [
179- pool .submit (_run_single_test , metadata = metadata ) for metadata in model_test_metadata
180- ]
181-
182- for future in concurrent .futures .as_completed (futures ):
183- test_results .append (future .result ())
178+ try :
179+ with ThreadPoolExecutor (max_workers = workers ) as pool :
180+ futures = [
181+ pool .submit (_run_single_test , metadata = metadata , engine_adapter = engine_adapter )
182+ for metadata , engine_adapter in worker_payload
183+ ]
184+
185+ for future in concurrent .futures .as_completed (futures ):
186+ test_results .append (future .result ())
187+ finally :
188+ for _ , engine_adapter in worker_payload :
189+ if engine_adapter :
190+ engine_adapter .close ()
184191
185192 end_time = time .perf_counter ()
186193
0 commit comments