Skip to content

Commit 99e6591

Browse files
committed
Move test loading logic to Loader
1 parent f8a572d commit 99e6591

8 files changed

Lines changed: 156 additions & 149 deletions

File tree

sqlmesh/core/context.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@
108108
from sqlmesh.core.table_diff import TableDiff
109109
from sqlmesh.core.test import (
110110
ModelTextTestResult,
111+
ModelTestMetadata,
111112
generate_test,
112-
load_model_tests,
113113
run_tests,
114114
)
115115
from sqlmesh.core.user import User
@@ -363,7 +363,6 @@ def __init__(
363363
self._excluded_requirements: t.Set[str] = set()
364364
self._default_catalog: t.Optional[str] = None
365365
self._linters: t.Dict[str, Linter] = {}
366-
self._variables_by_project_gateway: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {}
367366
self._loaded: bool = False
368367

369368
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
@@ -1781,12 +1780,7 @@ def test(
17811780
if verbosity >= Verbosity.VERBOSE:
17821781
pd.set_option("display.max_columns", None)
17831782

1784-
test_meta = load_model_tests(
1785-
configs=self.configs,
1786-
tests=tests,
1787-
patterns=match_patterns,
1788-
get_variables=self._get_variables,
1789-
)
1783+
test_meta = self._load_model_tests(tests=tests, patterns=match_patterns)
17901784

17911785
return run_tests(
17921786
model_test_metadata=test_meta,
@@ -2463,31 +2457,15 @@ def lint_models(
24632457
"Linter detected errors in the code. Please fix them before proceeding."
24642458
)
24652459

2466-
def _get_variables(
2467-
self, config: t.Optional[C] = None, gateway_name: t.Optional[str] = None
2468-
) -> t.Dict[str, t.Any]:
2469-
config = config or self.config
2470-
gateway_name = gateway_name or self.selected_gateway
2471-
2472-
key = (config.project, gateway_name)
2473-
if key not in self._variables_by_project_gateway:
2474-
try:
2475-
gateway = config.get_gateway(gateway_name)
2476-
except ConfigError:
2477-
from sqlmesh.core.console import get_console
2460+
def _load_model_tests(
2461+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
2462+
) -> t.List[ModelTestMetadata]:
2463+
model_tests = []
24782464

2479-
get_console().log_warning(
2480-
f"Gateway '{gateway_name}' not found in project '{config.project}'."
2481-
)
2482-
gateway = None
2483-
2484-
self._variables_by_project_gateway[key] = {
2485-
**config.variables,
2486-
**(gateway.variables if gateway else {}),
2487-
c.GATEWAY: gateway_name,
2488-
}
2465+
for loader in self._loaders:
2466+
model_tests.extend(loader._load_model_tests(tests=tests, patterns=patterns))
24892467

2490-
return self._variables_by_project_gateway[key]
2468+
return model_tests
24912469

24922470

24932471
class Context(GenericContext[Config]):

sqlmesh/core/loader.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
import glob
5+
import itertools
56
import linecache
67
import logging
78
import os
@@ -31,11 +32,13 @@
3132
from sqlmesh.core.model import model as model_registry
3233
from sqlmesh.core.model.common import make_python_env
3334
from sqlmesh.core.signal import signal
35+
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
3436
from sqlmesh.utils import UniqueKeyDict, sys_path
3537
from sqlmesh.utils.errors import ConfigError
3638
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
3739
from sqlmesh.utils.metaprogramming import import_python_file
38-
from sqlmesh.utils.yaml import YAML
40+
from sqlmesh.utils.yaml import YAML, load as yaml_load
41+
3942

4043
if t.TYPE_CHECKING:
4144
from sqlmesh.core.context import GenericContext
@@ -74,6 +77,7 @@ def __init__(self, context: GenericContext, path: Path) -> None:
7477
self.context = context
7578
self.config_path = path
7679
self.config = self.context.configs[self.config_path]
80+
self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {}
7781

7882
def load(self) -> LoadedProject:
7983
"""
@@ -289,6 +293,12 @@ def _load_linting_rules(self) -> RuleSet:
289293
"""Loads user linting rules"""
290294
return RuleSet()
291295

296+
def _load_model_tests(
297+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
298+
) -> t.List[ModelTestMetadata]:
299+
"""Loads YAML-based model tests"""
300+
return []
301+
292302
def _glob_paths(
293303
self,
294304
path: Path,
@@ -328,7 +338,26 @@ def _track_file(self, path: Path) -> None:
328338
self._path_mtimes[path] = path.stat().st_mtime
329339

330340
def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
331-
return self.context._get_variables(config=self.config, gateway_name=gateway_name)
341+
gateway_name = gateway_name or self.context.selected_gateway
342+
343+
if gateway_name not in self._variables_by_gateway:
344+
try:
345+
gateway = self.config.get_gateway(gateway_name)
346+
except ConfigError:
347+
from sqlmesh.core.console import get_console
348+
349+
get_console().log_warning(
350+
f"Gateway '{gateway_name}' not found in project '{self.config.project}'."
351+
)
352+
gateway = None
353+
354+
self._variables_by_gateway[gateway_name] = {
355+
**self.config.variables,
356+
**(gateway.variables if gateway else {}),
357+
c.GATEWAY: gateway_name,
358+
}
359+
360+
return self._variables_by_gateway[gateway_name]
332361

333362

334363
class SqlMeshLoader(Loader):
@@ -658,6 +687,53 @@ def _load_linting_rules(self) -> RuleSet:
658687

659688
return RuleSet(user_rules.values())
660689

690+
def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]:
691+
"""Load a single model test file."""
692+
model_test_metadata = {}
693+
contents = yaml_load(path, get_variables=self._get_variables)
694+
695+
for test_name, value in contents.items():
696+
model_test_metadata[test_name] = ModelTestMetadata(
697+
path=path, test_name=test_name, body=value
698+
)
699+
700+
return model_test_metadata
701+
702+
def _load_model_tests(
703+
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
704+
) -> t.List[ModelTestMetadata]:
705+
"""Loads YAML-based model tests"""
706+
test_meta: t.List[ModelTestMetadata] = []
707+
708+
if tests:
709+
for test in tests:
710+
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
711+
712+
test_file = self._load_model_test_file(Path(filename))
713+
if test_name:
714+
test_meta.append(test_file[test_name])
715+
else:
716+
test_meta.extend(test_file.values())
717+
else:
718+
search_path = Path(self.config_path) / c.TESTS
719+
720+
for yaml_file in itertools.chain(
721+
search_path.glob("**/test*.yaml"),
722+
search_path.glob("**/test*.yml"),
723+
):
724+
if any(
725+
yaml_file.match(ignore_pattern)
726+
for ignore_pattern in self.config.ignore_patterns or []
727+
):
728+
continue
729+
730+
test_meta.extend(self._load_model_test_file(yaml_file).values())
731+
732+
if patterns:
733+
test_meta = filter_tests_by_patterns(test_meta, patterns)
734+
735+
return test_meta
736+
661737
class _Cache(CacheBase):
662738
def __init__(self, loader: SqlMeshLoader, config_path: Path):
663739
self._loader = loader

sqlmesh/core/test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
44
from sqlmesh.core.test.discovery import (
55
ModelTestMetadata as ModelTestMetadata,
6-
load_model_tests as load_model_tests,
6+
filter_tests_by_patterns as filter_tests_by_patterns,
77
)
88
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult
99
from sqlmesh.core.test.runner import run_tests as run_tests

sqlmesh/core/test/discovery.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,11 @@
44
import itertools
55
import pathlib
66
import typing as t
7-
from collections.abc import Iterator
87

98
import ruamel
109

11-
from sqlmesh.core import constants as c
1210
from sqlmesh.utils import unique
1311
from sqlmesh.utils.pydantic import PydanticModel
14-
from sqlmesh.utils.yaml import load as yaml_load
15-
16-
if t.TYPE_CHECKING:
17-
from sqlmesh.core.config.loader import C
1812

1913

2014
class ModelTestMetadata(PydanticModel):
@@ -30,61 +24,6 @@ def __hash__(self) -> int:
3024
return self.fully_qualified_test_name.__hash__()
3125

3226

33-
def load_model_test_file(
34-
path: pathlib.Path,
35-
config: C,
36-
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
37-
) -> dict[str, ModelTestMetadata]:
38-
"""Load a single model test file.
39-
40-
Args:
41-
path: The path to the test file
42-
43-
returns:
44-
A list of ModelTestMetadata named tuples.
45-
"""
46-
model_test_metadata = {}
47-
contents = yaml_load(path, config=config, get_variables=get_variables)
48-
49-
for test_name, value in contents.items():
50-
model_test_metadata[test_name] = ModelTestMetadata(
51-
path=path, test_name=test_name, body=value
52-
)
53-
return model_test_metadata
54-
55-
56-
def discover_model_tests(
57-
path: pathlib.Path,
58-
config: C,
59-
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
60-
) -> Iterator[ModelTestMetadata]:
61-
"""Discover model tests.
62-
63-
Model tests are defined in YAML files and contain the inputs and outputs used to test model queries.
64-
65-
Args:
66-
path: A path to search for tests.
67-
ignore_patterns: An optional list of patterns to ignore.
68-
69-
Returns:
70-
A list of ModelTestMetadata named tuples.
71-
"""
72-
search_path = pathlib.Path(path)
73-
74-
for yaml_file in itertools.chain(
75-
search_path.glob("**/test*.yaml"),
76-
search_path.glob("**/test*.yml"),
77-
):
78-
for ignore_pattern in config.ignore_patterns or []:
79-
if yaml_file.match(ignore_pattern):
80-
break
81-
else:
82-
for model_test_metadata in load_model_test_file(
83-
yaml_file, config=config, get_variables=get_variables
84-
).values():
85-
yield model_test_metadata
86-
87-
8827
def filter_tests_by_patterns(
8928
tests: list[ModelTestMetadata], patterns: list[str]
9029
) -> list[ModelTestMetadata]:
@@ -103,47 +42,3 @@ def filter_tests_by_patterns(
10342
if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern))
10443
or pattern in test.fully_qualified_test_name
10544
)
106-
107-
108-
def load_model_tests(
109-
configs: t.Dict[pathlib.Path, C],
110-
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
111-
tests: t.Optional[t.List[str]] = None,
112-
patterns: list[str] | None = None,
113-
) -> list[ModelTestMetadata]:
114-
"""Load model tests into a list of ModelTestMetadata which will be propagated to the test runner.
115-
116-
Args:
117-
tests: A list of tests to load; If not specified, all tests are loaded
118-
patterns: A list of patterns that'll be used to filter tests by file name.
119-
configs: A dictionary of configs to use when loading all the tests.
120-
"""
121-
test_meta = []
122-
123-
if tests:
124-
for test in tests:
125-
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
126-
127-
test_file = load_model_test_file(
128-
pathlib.Path(filename),
129-
config=next(iter(configs.values())),
130-
get_variables=get_variables,
131-
)
132-
if test_name:
133-
test_meta.append(test_file[test_name])
134-
else:
135-
test_meta.extend(test_file.values())
136-
else:
137-
for path, config in configs.items():
138-
test_meta.extend(
139-
discover_model_tests(
140-
pathlib.Path(path / c.TESTS),
141-
config=config,
142-
get_variables=get_variables,
143-
)
144-
)
145-
146-
if patterns:
147-
test_meta = filter_tests_by_patterns(test_meta, patterns)
148-
149-
return test_meta

sqlmesh/core/test/result.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def log_test_report(self, test_duration: float) -> None:
107107
for _, error in errors:
108108
stream.writeln(unittest.TextTestResult.separator1)
109109
stream.writeln(f"ERROR: {error}")
110-
stream.writeln(unittest.TextTestResult.separator2)
111110

112111
# Output final report
113112
stream.writeln(unittest.TextTestResult.separator2)

sqlmesh/magics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from sqlmesh.core.context import Context
3232
from sqlmesh.core.dialect import format_model_expressions, parse
3333
from sqlmesh.core.model import load_sql_based_model
34-
from sqlmesh.core.test import ModelTestMetadata, load_model_tests
34+
from sqlmesh.core.test import ModelTestMetadata
3535
from sqlmesh.utils import sqlglot_dialects, yaml, Verbosity
3636
from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError
3737

@@ -272,7 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None
272272
if not args.test_name and not args.ls:
273273
raise MagicError("Must provide either test name or `--ls` to list tests")
274274

275-
test_meta = load_model_tests(configs=context.configs, get_variables=context._get_variables)
275+
test_meta = context._load_model_tests()
276276

277277
tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict)
278278
for model_test_metadata in test_meta:

sqlmesh/utils/yaml.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
from sqlmesh.utils.errors import SQLMeshError
1414
from sqlmesh.utils.jinja import ENVIRONMENT, create_var
1515

16-
if t.TYPE_CHECKING:
17-
from sqlmesh.core.config.loader import C
18-
19-
2016
JINJA_METHODS = {
2117
"env_var": lambda key, default=None: getenv(key, default),
2218
}
@@ -44,8 +40,7 @@ def load(
4440
render_jinja: bool = True,
4541
allow_duplicate_keys: bool = False,
4642
variables: t.Optional[t.Dict[str, t.Any]] = None,
47-
config: t.Optional[C] = None,
48-
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]] | None = None,
43+
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None,
4944
) -> t.Dict:
5045
"""Loads a YAML object from either a raw string or a file."""
5146
path: t.Optional[Path] = None
@@ -62,7 +57,7 @@ def load(
6257
gateway_line = GATEWAY_PATTERN.search(source)
6358
gateway = yaml.load(gateway_line.group(0))["gateway"] if gateway_line else None
6459

65-
variables = get_variables(config, gateway)
60+
variables = get_variables(gateway)
6661

6762
if render_jinja:
6863
source = ENVIRONMENT.from_string(source).render(

0 commit comments

Comments
 (0)