11from __future__ import annotations
22
3+ import sys
4+ import time
35import pathlib
6+ import threading
47import typing as t
58import unittest
69
7- from sqlmesh .core .engine_adapter import EngineAdapter
10+
11+ import concurrent
12+ from concurrent .futures import ThreadPoolExecutor
13+
814from sqlmesh .core .model import Model
915from sqlmesh .core .test .definition import ModelTest as ModelTest , generate_test as generate_test
1016from sqlmesh .core .test .discovery import (
2026 from sqlmesh .core .config .loader import C
2127
2228
29+ class ModelTextTestRunner (unittest .TextTestRunner ):
30+ def __init__ (
31+ self ,
32+ ** kwargs : t .Any ,
33+ ) -> None :
34+ # StringIO is used to capture the output of the tests since we'll
35+ # run them in parallel and we don't want to mix the output streams
36+ from io import StringIO
37+
38+ super ().__init__ (
39+ stream = StringIO (),
40+ resultclass = ModelTextTestResult ,
41+ ** kwargs ,
42+ )
43+
44+
45+ def log_test_report (results : ModelTextTestResult , test_duration : float ) -> None :
46+ # Aggregate parallel test run results
47+ tests_run = results .testsRun
48+ errors = results .errors
49+ failures = results .failures
50+ skipped = results .skipped
51+
52+ is_success = not (errors or failures )
53+
54+ # Compute test info
55+ infos = []
56+ if failures :
57+ infos .append (f"failures={ len (failures )} " )
58+ if errors :
59+ infos .append (f"errors={ len (errors )} " )
60+ if skipped :
61+ infos .append (f"skipped={ skipped } " )
62+
63+ # Report test errors
64+ stream = results .stream
65+
66+ stream .write ("\n " )
67+
68+ if errors or failures :
69+ stream .writeln (unittest .TextTestResult .separator1 )
70+ for failure in failures :
71+ stream .writeln (f"FAIL: { failure [0 ]} " )
72+
73+ stream .writeln (unittest .TextTestResult .separator2 )
74+ for error in errors :
75+ stream .writeln (error [1 ])
76+ for failure in failures :
77+ stream .writeln (failure [1 ])
78+
79+ # Test report
80+ stream .writeln (unittest .TextTestResult .separator2 )
81+ stream .writeln (
82+ f'Ran { tests_run } { "tests" if tests_run > 1 else "test" } in { test_duration :.3f} s \n '
83+ )
84+ stream .write (
85+ f'{ "OK" if is_success else "FAILED" } { " (" + ", " .join (infos ) + ")" if infos else "" } '
86+ )
87+
88+
2389def run_tests (
2490 model_test_metadata : list [ModelTestMetadata ],
2591 models : UniqueKeyDict [str , Model ],
@@ -40,22 +106,35 @@ def run_tests(
40106 verbosity: The verbosity level.
41107 preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
42108 """
43- testing_adapter_by_gateway : t .Dict [str , EngineAdapter ] = {}
44109 default_gateway = gateway or config .default_gateway_name
45110
46- try :
47- tests = []
48- for metadata in model_test_metadata :
111+ default_test_connection = config .get_test_connection (
112+ gateway_name = default_gateway ,
113+ default_catalog = default_catalog ,
114+ default_catalog_dialect = default_catalog_dialect ,
115+ )
116+
117+ lock = threading .Lock ()
118+
119+ combined_results = ModelTextTestResult (
120+ stream = unittest .runner ._WritelnDecorator (stream or sys .stderr ), # type: ignore
121+ verbosity = 2 if verbosity >= Verbosity .VERBOSE else 1 ,
122+ descriptions = None ,
123+ )
124+
125+ def _run_single_test (metadata : ModelTestMetadata ) -> ModelTextTestResult :
126+ testing_engine_adapter = None
127+
128+ try :
49129 body = metadata .body
50130 gateway = body .get ("gateway" ) or default_gateway
51- testing_engine_adapter = testing_adapter_by_gateway .get (gateway )
52- if not testing_engine_adapter :
53- testing_engine_adapter = config .get_test_connection (
54- gateway ,
55- default_catalog ,
56- default_catalog_dialect ,
57- ).create_engine_adapter (register_comments_override = False )
58- testing_adapter_by_gateway [gateway ] = testing_engine_adapter
131+
132+ # Create new connection for each test to avoid concurrency issues
133+ testing_engine_adapter = config .get_test_connection (
134+ gateway ,
135+ default_catalog ,
136+ default_catalog_dialect ,
137+ ).create_engine_adapter (register_comments_override = False )
59138
60139 test = ModelTest .create_test (
61140 body = body ,
@@ -67,22 +146,49 @@ def run_tests(
67146 default_catalog = default_catalog ,
68147 preserve_fixtures = preserve_fixtures ,
69148 )
70- if test :
71- tests .append (test )
72-
73- result = t .cast (
74- ModelTextTestResult ,
75- unittest .TextTestRunner (
76- stream = stream ,
77- verbosity = 2 if verbosity >= Verbosity .VERBOSE else 1 ,
78- resultclass = ModelTextTestResult ,
79- ).run (unittest .TestSuite (tests )),
80- )
81- finally :
82- for testing_engine_adapter in testing_adapter_by_gateway .values ():
83- testing_engine_adapter .close ()
84149
85- return result
150+ result = t .cast (
151+ ModelTextTestResult ,
152+ ModelTextTestRunner ().run (t .cast (unittest .TestCase , test )),
153+ )
154+
155+ with lock :
156+ if result .successes :
157+ combined_results .addSuccess (result .successes [0 ])
158+ elif result .errors :
159+ combined_results .addError (result .err [0 ], result .err [1 ])
160+ elif result .failures :
161+ combined_results .addFailure (result .err [0 ], result .err [1 ])
162+ elif result .skipped :
163+ skipped_args = result .skipped [0 ]
164+ combined_results .addSkip (skipped_args [0 ], skipped_args [1 ])
165+
166+ finally :
167+ if testing_engine_adapter :
168+ testing_engine_adapter .close ()
169+
170+ return result
171+
172+ test_results = []
173+
174+ workers = min (len (model_test_metadata ) or 1 , default_test_connection .concurrent_tasks )
175+
176+ start_time = time .perf_counter ()
177+ with ThreadPoolExecutor (max_workers = workers ) as pool :
178+ futures = [
179+ pool .submit (_run_single_test , metadata = metadata ) for metadata in model_test_metadata
180+ ]
181+
182+ for future in concurrent .futures .as_completed (futures ):
183+ test_results .append (future .result ())
184+
185+ end_time = time .perf_counter ()
186+
187+ combined_results .testsRun = len (test_results )
188+
189+ log_test_report (combined_results , test_duration = end_time - start_time )
190+
191+ return combined_results
86192
87193
88194def run_model_tests (
0 commit comments