|
2 | 2 |
|
3 | 3 | import abc |
4 | 4 | import glob |
| 5 | +import itertools |
5 | 6 | import linecache |
6 | 7 | import logging |
7 | 8 | import os |
| 9 | +import re |
8 | 10 | import typing as t |
9 | 11 | from collections import Counter, defaultdict |
10 | 12 | from dataclasses import dataclass |
|
31 | 33 | from sqlmesh.core.model import model as model_registry |
32 | 34 | from sqlmesh.core.model.common import make_python_env |
33 | 35 | from sqlmesh.core.signal import signal |
| 36 | +from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns |
34 | 37 | from sqlmesh.utils import UniqueKeyDict, sys_path |
35 | 38 | from sqlmesh.utils.errors import ConfigError |
36 | 39 | from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor |
37 | 40 | from sqlmesh.utils.metaprogramming import import_python_file |
38 | | -from sqlmesh.utils.yaml import YAML |
| 41 | +from sqlmesh.utils.yaml import YAML, load as yaml_load |
| 42 | + |
39 | 43 |
|
40 | 44 | if t.TYPE_CHECKING: |
41 | 45 | from sqlmesh.core.context import GenericContext |
42 | 46 |
|
43 | 47 |
|
44 | 48 | logger = logging.getLogger(__name__) |
45 | 49 |
|
| 50 | +GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") |
| 51 | + |
46 | 52 |
|
47 | 53 | @dataclass |
48 | 54 | class LoadedProject: |
@@ -290,6 +296,12 @@ def _load_linting_rules(self) -> RuleSet: |
290 | 296 | """Loads user linting rules""" |
291 | 297 | return RuleSet() |
292 | 298 |
|
| 299 | + def load_model_tests( |
| 300 | + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None |
| 301 | + ) -> t.List[ModelTestMetadata]: |
| 302 | + """Loads YAML-based model tests""" |
| 303 | + return [] |
| 304 | + |
293 | 305 | def _glob_paths( |
294 | 306 | self, |
295 | 307 | path: Path, |
@@ -680,6 +692,61 @@ def _load_linting_rules(self) -> RuleSet: |
680 | 692 |
|
681 | 693 | return RuleSet(user_rules.values()) |
682 | 694 |
|
| 695 | + def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: |
| 696 | + """Load a single model test file.""" |
| 697 | + model_test_metadata = {} |
| 698 | + |
| 699 | + with open(path, "r", encoding="utf-8") as file: |
| 700 | + source = file.read() |
| 701 | + # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to |
| 702 | + # parse it as YAML to match the gateway name stored in the config |
| 703 | + gateway_line = GATEWAY_PATTERN.search(source) |
| 704 | + gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None |
| 705 | + |
| 706 | + contents = yaml_load(source, variables=self._get_variables(gateway)) |
| 707 | + |
| 708 | + for test_name, value in contents.items(): |
| 709 | + model_test_metadata[test_name] = ModelTestMetadata( |
| 710 | + path=path, test_name=test_name, body=value |
| 711 | + ) |
| 712 | + |
| 713 | + return model_test_metadata |
| 714 | + |
| 715 | + def load_model_tests( |
| 716 | + self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None |
| 717 | + ) -> t.List[ModelTestMetadata]: |
| 718 | + """Loads YAML-based model tests""" |
| 719 | + test_meta_list: t.List[ModelTestMetadata] = [] |
| 720 | + |
| 721 | + if tests: |
| 722 | + for test in tests: |
| 723 | + filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") |
| 724 | + |
| 725 | + test_meta = self._load_model_test_file(Path(filename)) |
| 726 | + if test_name: |
| 727 | + test_meta_list.append(test_meta[test_name]) |
| 728 | + else: |
| 729 | + test_meta_list.extend(test_meta.values()) |
| 730 | + else: |
| 731 | + search_path = Path(self.config_path) / c.TESTS |
| 732 | + |
| 733 | + for yaml_file in itertools.chain( |
| 734 | + search_path.glob("**/test*.yaml"), |
| 735 | + search_path.glob("**/test*.yml"), |
| 736 | + ): |
| 737 | + if any( |
| 738 | + yaml_file.match(ignore_pattern) |
| 739 | + for ignore_pattern in self.config.ignore_patterns or [] |
| 740 | + ): |
| 741 | + continue |
| 742 | + |
| 743 | + test_meta_list.extend(self._load_model_test_file(yaml_file).values()) |
| 744 | + |
| 745 | + if patterns: |
| 746 | + test_meta_list = filter_tests_by_patterns(test_meta_list, patterns) |
| 747 | + |
| 748 | + return test_meta_list |
| 749 | + |
683 | 750 | class _Cache(CacheBase): |
684 | 751 | def __init__(self, loader: SqlMeshLoader, config_path: Path): |
685 | 752 | self._loader = loader |
|
0 commit comments