Skip to content

Commit b304428

Browse files
authored
Fix!: Propagate gateway variables to model tests (#4102)
1 parent 3ced033 commit b304428

8 files changed

Lines changed: 297 additions & 215 deletions

File tree

sqlmesh/core/context.py

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@
109109
from sqlmesh.core.table_diff import TableDiff
110110
from sqlmesh.core.test import (
111111
ModelTextTestResult,
112+
ModelTestMetadata,
112113
generate_test,
113-
get_all_model_tests,
114-
run_model_tests,
115114
run_tests,
116115
)
117116
from sqlmesh.core.user import User
@@ -1788,47 +1787,20 @@ def test(
17881787
if verbosity >= Verbosity.VERBOSE:
17891788
pd.set_option("display.max_columns", None)
17901789

1791-
if tests:
1792-
result = run_model_tests(
1793-
tests=tests,
1794-
models=self._models,
1795-
config=self.config,
1796-
gateway=self.gateway,
1797-
dialect=self.default_dialect,
1798-
verbosity=verbosity,
1799-
patterns=match_patterns,
1800-
preserve_fixtures=preserve_fixtures,
1801-
stream=stream,
1802-
default_catalog=self.default_catalog,
1803-
default_catalog_dialect=self.engine_adapter.DIALECT,
1804-
)
1805-
else:
1806-
test_meta = []
1807-
1808-
for path, config in self.configs.items():
1809-
test_meta.extend(
1810-
get_all_model_tests(
1811-
path / c.TESTS,
1812-
patterns=match_patterns,
1813-
ignore_patterns=config.ignore_patterns,
1814-
variables=config.variables,
1815-
)
1816-
)
1790+
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
18171791

1818-
result = run_tests(
1819-
model_test_metadata=test_meta,
1820-
models=self._models,
1821-
config=self.config,
1822-
gateway=self.gateway,
1823-
dialect=self.default_dialect,
1824-
verbosity=verbosity,
1825-
preserve_fixtures=preserve_fixtures,
1826-
stream=stream,
1827-
default_catalog=self.default_catalog,
1828-
default_catalog_dialect=self.engine_adapter.DIALECT,
1829-
)
1830-
1831-
return result
1792+
return run_tests(
1793+
model_test_metadata=test_meta,
1794+
models=self._models,
1795+
config=self.config,
1796+
selected_gateway=self.selected_gateway,
1797+
dialect=self.default_dialect,
1798+
verbosity=verbosity,
1799+
preserve_fixtures=preserve_fixtures,
1800+
stream=stream,
1801+
default_catalog=self.default_catalog,
1802+
default_catalog_dialect=self.engine_adapter.DIALECT,
1803+
)
18321804

18331805
@python_api_analytics
18341806
def audit(
@@ -2504,6 +2476,19 @@ def lint_models(
25042476
"Linter detected errors in the code. Please fix them before proceeding."
25052477
)
25062478

2479+
def load_model_tests(
2480+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
2481+
) -> t.List[ModelTestMetadata]:
2482+
# If a set of specific test path(s) are provided, we can use a single loader
2483+
# since it's not required to walk every tests/ folder in each repo
2484+
loaders = [self._loaders[0]] if tests else self._loaders
2485+
2486+
model_tests = []
2487+
for loader in loaders:
2488+
model_tests.extend(loader.load_model_tests(tests=tests, patterns=patterns))
2489+
2490+
return model_tests
2491+
25072492

25082493
class Context(GenericContext[Config]):
25092494
CONFIG_TYPE = Config

sqlmesh/core/loader.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import abc
44
import glob
5+
import itertools
56
import linecache
67
import logging
78
import os
9+
import re
810
import typing as t
911
from collections import Counter, defaultdict
1012
from dataclasses import dataclass
@@ -31,18 +33,22 @@
3133
from sqlmesh.core.model import model as model_registry
3234
from sqlmesh.core.model.common import make_python_env
3335
from sqlmesh.core.signal import signal
36+
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
3437
from sqlmesh.utils import UniqueKeyDict, sys_path
3538
from sqlmesh.utils.errors import ConfigError
3639
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
3740
from sqlmesh.utils.metaprogramming import import_python_file
38-
from sqlmesh.utils.yaml import YAML
41+
from sqlmesh.utils.yaml import YAML, load as yaml_load
42+
3943

4044
if t.TYPE_CHECKING:
4145
from sqlmesh.core.context import GenericContext
4246

4347

4448
logger = logging.getLogger(__name__)
4549

50+
GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)")
51+
4652

4753
@dataclass
4854
class LoadedProject:
@@ -290,6 +296,12 @@ def _load_linting_rules(self) -> RuleSet:
290296
"""Loads user linting rules"""
291297
return RuleSet()
292298

299+
def load_model_tests(
300+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
301+
) -> t.List[ModelTestMetadata]:
302+
"""Loads YAML-based model tests"""
303+
return []
304+
293305
def _glob_paths(
294306
self,
295307
path: Path,
@@ -680,6 +692,61 @@ def _load_linting_rules(self) -> RuleSet:
680692

681693
return RuleSet(user_rules.values())
682694

695+
def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
696+
"""Load a single model test file."""
697+
model_test_metadata = {}
698+
699+
with open(path, "r", encoding="utf-8") as file:
700+
source = file.read()
701+
# If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to
702+
# parse it as YAML to match the gateway name stored in the config
703+
gateway_line = GATEWAY_PATTERN.search(source)
704+
gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None
705+
706+
contents = yaml_load(source, variables=self._get_variables(gateway))
707+
708+
for test_name, value in contents.items():
709+
model_test_metadata[test_name] = ModelTestMetadata(
710+
path=path, test_name=test_name, body=value
711+
)
712+
713+
return model_test_metadata
714+
715+
def load_model_tests(
716+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
717+
) -> t.List[ModelTestMetadata]:
718+
"""Loads YAML-based model tests"""
719+
test_meta_list: t.List[ModelTestMetadata] = []
720+
721+
if tests:
722+
for test in tests:
723+
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
724+
725+
test_meta = self._load_model_test_file(Path(filename))
726+
if test_name:
727+
test_meta_list.append(test_meta[test_name])
728+
else:
729+
test_meta_list.extend(test_meta.values())
730+
else:
731+
search_path = Path(self.config_path) / c.TESTS
732+
733+
for yaml_file in itertools.chain(
734+
search_path.glob("**/test*.yaml"),
735+
search_path.glob("**/test*.yml"),
736+
):
737+
if any(
738+
yaml_file.match(ignore_pattern)
739+
for ignore_pattern in self.config.ignore_patterns or []
740+
):
741+
continue
742+
743+
test_meta_list.extend(self._load_model_test_file(yaml_file).values())
744+
745+
if patterns:
746+
test_meta_list = filter_tests_by_patterns(test_meta_list, patterns)
747+
748+
return test_meta_list
749+
683750
class _Cache(CacheBase):
684751
def __init__(self, loader: SqlMeshLoader, config_path: Path):
685752
self._loader = loader

sqlmesh/core/test/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
from sqlmesh.core.test.discovery import (
55
ModelTestMetadata as ModelTestMetadata,
66
filter_tests_by_patterns as filter_tests_by_patterns,
7-
get_all_model_tests as get_all_model_tests,
8-
load_model_test_file as load_model_test_file,
97
)
108
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
11-
from sqlmesh.core.test.runner import (
12-
run_model_tests as run_model_tests,
13-
run_tests as run_tests,
14-
)
9+
from sqlmesh.core.test.runner import run_tests as run_tests

sqlmesh/core/test/discovery.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import itertools
55
import pathlib
66
import typing as t
7-
from collections.abc import Iterator
87

98
import ruamel
109

1110
from sqlmesh.utils import unique
1211
from sqlmesh.utils.pydantic import PydanticModel
13-
from sqlmesh.utils.yaml import load as yaml_load
1412

1513

1614
class ModelTestMetadata(PydanticModel):
@@ -26,59 +24,6 @@ def __hash__(self) -> int:
2624
return self.fully_qualified_test_name.__hash__()
2725

2826

29-
def load_model_test_file(
30-
path: pathlib.Path, variables: dict[str, t.Any] | None = None
31-
) -> dict[str, ModelTestMetadata]:
32-
"""Load a single model test file.
33-
34-
Args:
35-
path: The path to the test file
36-
37-
returns:
38-
A list of ModelTestMetadata named tuples.
39-
"""
40-
model_test_metadata = {}
41-
contents = yaml_load(path, variables=variables)
42-
43-
for test_name, value in contents.items():
44-
model_test_metadata[test_name] = ModelTestMetadata(
45-
path=path, test_name=test_name, body=value
46-
)
47-
return model_test_metadata
48-
49-
50-
def discover_model_tests(
51-
path: pathlib.Path,
52-
ignore_patterns: list[str] | None = None,
53-
variables: dict[str, t.Any] | None = None,
54-
) -> Iterator[ModelTestMetadata]:
55-
"""Discover model tests.
56-
57-
Model tests are defined in YAML files and contain the inputs and outputs used to test model queries.
58-
59-
Args:
60-
path: A path to search for tests.
61-
ignore_patterns: An optional list of patterns to ignore.
62-
63-
Returns:
64-
A list of ModelTestMetadata named tuples.
65-
"""
66-
search_path = pathlib.Path(path)
67-
68-
for yaml_file in itertools.chain(
69-
search_path.glob("**/test*.yaml"),
70-
search_path.glob("**/test*.yml"),
71-
):
72-
for ignore_pattern in ignore_patterns or []:
73-
if yaml_file.match(ignore_pattern):
74-
break
75-
else:
76-
for model_test_metadata in load_model_test_file(
77-
yaml_file, variables=variables
78-
).values():
79-
yield model_test_metadata
80-
81-
8227
def filter_tests_by_patterns(
8328
tests: list[ModelTestMetadata], patterns: list[str]
8429
) -> list[ModelTestMetadata]:
@@ -97,19 +42,3 @@ def filter_tests_by_patterns(
9742
if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern))
9843
or pattern in test.fully_qualified_test_name
9944
)
100-
101-
102-
def get_all_model_tests(
103-
*paths: pathlib.Path,
104-
patterns: list[str] | None = None,
105-
ignore_patterns: list[str] | None = None,
106-
variables: dict[str, t.Any] | None = None,
107-
) -> list[ModelTestMetadata]:
108-
model_test_metadatas = [
109-
meta
110-
for path in paths
111-
for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables)
112-
]
113-
if patterns:
114-
model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns)
115-
return model_test_metadatas

sqlmesh/core/test/result.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def log_test_report(self, test_duration: float) -> None:
107107
for _, error in errors:
108108
stream.writeln(unittest.TextTestResult.separator1)
109109
stream.writeln(f"ERROR: {error}")
110-
stream.writeln(unittest.TextTestResult.separator2)
111110

112111
# Output final report
113112
stream.writeln(unittest.TextTestResult.separator2)

0 commit comments

Comments
 (0)