Skip to content

Commit 8f8b58a

Browse files
committed
Fix: Propagate gateway variables to model tests
1 parent 8db5700 commit 8f8b58a

5 files changed

Lines changed: 190 additions & 127 deletions

File tree

sqlmesh/core/context.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@
109109
from sqlmesh.core.test import (
110110
ModelTextTestResult,
111111
generate_test,
112-
get_all_model_tests,
113-
run_model_tests,
112+
load_model_tests,
114113
run_tests,
115114
)
116115
from sqlmesh.core.user import User
@@ -1781,47 +1780,30 @@ def test(
17811780
if verbosity >= Verbosity.VERBOSE:
17821781
pd.set_option("display.max_columns", None)
17831782

1784-
if tests:
1785-
result = run_model_tests(
1786-
tests=tests,
1787-
models=self._models,
1788-
config=self.config,
1789-
gateway=self.gateway,
1790-
dialect=self.default_dialect,
1791-
verbosity=verbosity,
1792-
patterns=match_patterns,
1793-
preserve_fixtures=preserve_fixtures,
1794-
stream=stream,
1795-
default_catalog=self.default_catalog,
1796-
default_catalog_dialect=self.engine_adapter.DIALECT,
1797-
)
1798-
else:
1799-
test_meta = []
1800-
1801-
for path, config in self.configs.items():
1802-
test_meta.extend(
1803-
get_all_model_tests(
1804-
path / c.TESTS,
1805-
patterns=match_patterns,
1806-
ignore_patterns=config.ignore_patterns,
1807-
variables=config.variables,
1808-
)
1809-
)
1783+
default_gateway = self.gateway or self.config.default_gateway_name
18101784

1811-
result = run_tests(
1812-
model_test_metadata=test_meta,
1813-
models=self._models,
1814-
config=self.config,
1815-
gateway=self.gateway,
1816-
dialect=self.default_dialect,
1817-
verbosity=verbosity,
1818-
preserve_fixtures=preserve_fixtures,
1819-
stream=stream,
1820-
default_catalog=self.default_catalog,
1821-
default_catalog_dialect=self.engine_adapter.DIALECT,
1822-
)
1785+
# Merge the root variables with the gateway's variables
1786+
variables = {**self.config.variables, **self.config.get_gateway(default_gateway).variables}
18231787

1824-
return result
1788+
test_meta = load_model_tests(
1789+
configs=self.configs,
1790+
tests=tests,
1791+
patterns=match_patterns,
1792+
variables=variables,
1793+
)
1794+
1795+
return run_tests(
1796+
model_test_metadata=test_meta,
1797+
models=self._models,
1798+
config=self.config,
1799+
default_gateway=default_gateway,
1800+
dialect=self.default_dialect,
1801+
verbosity=verbosity,
1802+
preserve_fixtures=preserve_fixtures,
1803+
stream=stream,
1804+
default_catalog=self.default_catalog,
1805+
default_catalog_dialect=self.engine_adapter.DIALECT,
1806+
)
18251807

18261808
@python_api_analytics
18271809
def audit(

sqlmesh/core/test/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
filter_tests_by_patterns as filter_tests_by_patterns,
77
get_all_model_tests as get_all_model_tests,
88
load_model_test_file as load_model_test_file,
9+
load_model_tests as load_model_tests,
910
)
1011
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
11-
from sqlmesh.core.test.runner import (
12-
run_model_tests as run_model_tests,
13-
run_tests as run_tests,
14-
)
12+
from sqlmesh.core.test.runner import run_tests as run_tests

sqlmesh/core/test/discovery.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
import ruamel
1010

11+
from sqlmesh.core import constants as c
1112
from sqlmesh.utils import unique
1213
from sqlmesh.utils.pydantic import PydanticModel
1314
from sqlmesh.utils.yaml import load as yaml_load
1415

16+
if t.TYPE_CHECKING:
17+
from sqlmesh.core.config.loader import C
18+
1519

1620
class ModelTestMetadata(PydanticModel):
1721
path: pathlib.Path
@@ -113,3 +117,46 @@ def get_all_model_tests(
113117
if patterns:
114118
model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns)
115119
return model_test_metadatas
120+
121+
122+
def load_model_tests(
123+
configs: dict[pathlib.Path, C],
124+
tests: t.Optional[t.List[str]] = None,
125+
patterns: list[str] | None = None,
126+
variables: dict[str, t.Any] | None = None,
127+
) -> list[ModelTestMetadata]:
128+
"""Load model tests into a list of ModelTestMetadata which will be propagated to the test runner.
129+
130+
Args:
131+
tests: A list of tests to load; If not specified, all tests are loaded
132+
patterns: A list of patterns to match against.
133+
variables: A dictionary of variables to use when loading the tests.
134+
configs: A dictionary of configs to use when loading all the tests.
135+
"""
136+
test_meta = []
137+
138+
if tests:
139+
for test in tests:
140+
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
141+
142+
test_file = load_model_test_file(pathlib.Path(filename), variables=variables)
143+
if test_name:
144+
test_meta.append(test_file[test_name])
145+
else:
146+
test_meta.extend(test_file.values())
147+
148+
if patterns:
149+
test_meta = filter_tests_by_patterns(test_meta, patterns)
150+
151+
else:
152+
for path, config in configs.items():
153+
test_meta.extend(
154+
get_all_model_tests(
155+
path / c.TESTS,
156+
patterns=patterns,
157+
ignore_patterns=config.ignore_patterns,
158+
variables=variables,
159+
)
160+
)
161+
162+
return test_meta

sqlmesh/core/test/runner.py

Lines changed: 33 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import sys
44
import time
5-
import pathlib
65
import threading
76
import typing as t
87
import unittest
@@ -86,7 +85,7 @@ def run_tests(
8685
model_test_metadata: list[ModelTestMetadata],
8786
models: UniqueKeyDict[str, Model],
8887
config: C,
89-
gateway: t.Optional[str] = None,
88+
default_gateway: str,
9089
dialect: str | None = None,
9190
verbosity: Verbosity = Verbosity.DEFAULT,
9291
preserve_fixtures: bool = False,
@@ -102,8 +101,6 @@ def run_tests(
102101
verbosity: The verbosity level.
103102
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
104103
"""
105-
default_gateway = gateway or config.default_gateway_name
106-
107104
default_test_connection = config.get_test_connection(
108105
gateway_name=default_gateway,
109106
default_catalog=default_catalog,
@@ -128,34 +125,40 @@ def run_tests(
128125

129126
def _run_single_test(
130127
metadata: ModelTestMetadata, engine_adapter: EngineAdapter
131-
) -> ModelTextTestResult:
132-
test = ModelTest.create_test(
133-
body=metadata.body,
134-
test_name=metadata.test_name,
135-
models=models,
136-
engine_adapter=engine_adapter,
137-
dialect=dialect,
138-
path=metadata.path,
139-
default_catalog=default_catalog,
140-
preserve_fixtures=preserve_fixtures,
141-
)
128+
) -> t.Optional[ModelTextTestResult]:
129+
result: t.Optional[ModelTextTestResult] = None
130+
try:
131+
test = ModelTest.create_test(
132+
body=metadata.body,
133+
test_name=metadata.test_name,
134+
models=models,
135+
engine_adapter=engine_adapter,
136+
dialect=dialect,
137+
path=metadata.path,
138+
default_catalog=default_catalog,
139+
preserve_fixtures=preserve_fixtures,
140+
)
142141

143-
result = t.cast(
144-
ModelTextTestResult,
145-
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
146-
)
142+
result = t.cast(
143+
ModelTextTestResult,
144+
ModelTextTestRunner().run(t.cast(unittest.TestCase, test)),
145+
)
146+
147+
with lock:
148+
if result.successes:
149+
combined_results.addSuccess(result.successes[0])
150+
elif result.errors:
151+
combined_results.addError(result.original_err[0], result.original_err[1])
152+
elif result.failures:
153+
combined_results.addFailure(result.original_err[0], result.original_err[1])
154+
elif result.skipped:
155+
skipped_args = result.skipped[0]
156+
combined_results.addSkip(skipped_args[0], skipped_args[1])
147157

148-
with lock:
149-
if result.successes:
150-
combined_results.addSuccess(result.successes[0])
151-
elif result.errors:
152-
combined_results.addError(result.original_err[0], result.original_err[1])
153-
elif result.failures:
154-
combined_results.addFailure(result.original_err[0], result.original_err[1])
155-
elif result.skipped:
156-
skipped_args = result.skipped[0]
157-
combined_results.addSkip(skipped_args[0], skipped_args[1])
158-
return result
158+
combined_results.testsRun += 1
159+
160+
finally:
161+
return result
159162

160163
test_results = []
161164

@@ -180,57 +183,6 @@ def _run_single_test(
180183

181184
end_time = time.perf_counter()
182185

183-
combined_results.testsRun = len(test_results)
184-
185186
combined_results.log_test_report(test_duration=end_time - start_time)
186187

187188
return combined_results
188-
189-
190-
def run_model_tests(
191-
tests: list[str],
192-
models: UniqueKeyDict[str, Model],
193-
config: C,
194-
gateway: t.Optional[str] = None,
195-
dialect: str | None = None,
196-
verbosity: Verbosity = Verbosity.DEFAULT,
197-
patterns: list[str] | None = None,
198-
preserve_fixtures: bool = False,
199-
stream: t.TextIO | None = None,
200-
default_catalog: t.Optional[str] = None,
201-
default_catalog_dialect: str = "",
202-
) -> ModelTextTestResult:
203-
"""Load and run tests.
204-
205-
Args:
206-
tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order]
207-
models: All models to use for expansion and mapping of physical locations.
208-
verbosity: The verbosity level.
209-
patterns: A list of patterns to match against.
210-
preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging.
211-
"""
212-
loaded_tests = []
213-
for test in tests:
214-
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
215-
path = pathlib.Path(filename)
216-
217-
if test_name:
218-
loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name])
219-
else:
220-
loaded_tests.extend(load_model_test_file(path, variables=config.variables).values())
221-
222-
if patterns:
223-
loaded_tests = filter_tests_by_patterns(loaded_tests, patterns)
224-
225-
return run_tests(
226-
loaded_tests,
227-
models,
228-
config,
229-
gateway=gateway,
230-
dialect=dialect,
231-
verbosity=verbosity,
232-
preserve_fixtures=preserve_fixtures,
233-
stream=stream,
234-
default_catalog=default_catalog,
235-
default_catalog_dialect=default_catalog_dialect,
236-
)

0 commit comments

Comments
 (0)