Skip to content

Commit cb452e6

Browse files
committed
Move get_variables up to Context
1 parent d460f6f commit cb452e6

5 files changed

Lines changed: 59 additions & 45 deletions

File tree

sqlmesh/core/context.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def __init__(
363363
self._excluded_requirements: t.Set[str] = set()
364364
self._default_catalog: t.Optional[str] = None
365365
self._linters: t.Dict[str, Linter] = {}
366+
self._variables_by_project_gateway: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {}
366367
self._loaded: bool = False
367368

368369
self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items())))
@@ -1781,9 +1782,10 @@ def test(
17811782
pd.set_option("display.max_columns", None)
17821783

17831784
test_meta = load_model_tests(
1784-
loaders=self._loaders,
1785+
configs=self.configs,
17851786
tests=tests,
17861787
patterns=match_patterns,
1788+
get_variables=self._get_variables,
17871789
)
17881790

17891791
return run_tests(
@@ -2461,6 +2463,32 @@ def lint_models(
24612463
"Linter detected errors in the code. Please fix them before proceeding."
24622464
)
24632465

2466+
def _get_variables(
2467+
self, config: t.Optional[C] = None, gateway_name: t.Optional[str] = None
2468+
) -> t.Dict[str, t.Any]:
2469+
config = config or self.config
2470+
gateway_name = gateway_name or self.selected_gateway
2471+
2472+
key = (config.project, gateway_name)
2473+
if key not in self._variables_by_project_gateway:
2474+
try:
2475+
gateway = config.get_gateway(gateway_name)
2476+
except ConfigError:
2477+
from sqlmesh.core.console import get_console
2478+
2479+
get_console().log_warning(
2480+
f"Gateway '{gateway_name}' not found in project '{config.project}'."
2481+
)
2482+
gateway = None
2483+
2484+
self._variables_by_project_gateway[key] = {
2485+
**config.variables,
2486+
**(gateway.variables if gateway else {}),
2487+
c.GATEWAY: gateway_name,
2488+
}
2489+
2490+
return self._variables_by_project_gateway[key]
2491+
24642492

24652493
class Context(GenericContext[Config]):
24662494
CONFIG_TYPE = Config

sqlmesh/core/loader.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def __init__(self, context: GenericContext, path: Path) -> None:
7474
self.context = context
7575
self.config_path = path
7676
self.config = self.context.configs[self.config_path]
77-
self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {}
7877

7978
def load(self) -> LoadedProject:
8079
"""
@@ -329,26 +328,7 @@ def _track_file(self, path: Path) -> None:
329328
self._path_mtimes[path] = path.stat().st_mtime
330329

331330
def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]:
332-
gateway_name = gateway_name or self.context.selected_gateway
333-
334-
if gateway_name not in self._variables_by_gateway:
335-
try:
336-
gateway = self.config.get_gateway(gateway_name)
337-
except ConfigError:
338-
from sqlmesh.core.console import get_console
339-
340-
get_console().log_warning(
341-
f"Gateway '{gateway_name}' not found in project '{self.config.project}'."
342-
)
343-
gateway = None
344-
345-
self._variables_by_gateway[gateway_name] = {
346-
**self.config.variables,
347-
**(gateway.variables if gateway else {}),
348-
c.GATEWAY: gateway_name,
349-
}
350-
351-
return self._variables_by_gateway[gateway_name]
331+
return self.context._get_variables(config=self.config, gateway_name=gateway_name)
352332

353333

354334
class SqlMeshLoader(Loader):

sqlmesh/core/test/discovery.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlmesh.utils.yaml import load as yaml_load
1515

1616
if t.TYPE_CHECKING:
17-
from sqlmesh.core.loader import Loader
17+
from sqlmesh.core.config.loader import C
1818

1919

2020
class ModelTestMetadata(PydanticModel):
@@ -32,7 +32,8 @@ def __hash__(self) -> int:
3232

3333
def load_model_test_file(
3434
path: pathlib.Path,
35-
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
35+
config: C,
36+
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
3637
) -> dict[str, ModelTestMetadata]:
3738
"""Load a single model test file.
3839
@@ -43,7 +44,7 @@ def load_model_test_file(
4344
A list of ModelTestMetadata named tuples.
4445
"""
4546
model_test_metadata = {}
46-
contents = yaml_load(path, get_variables=get_variables)
47+
contents = yaml_load(path, config=config, get_variables=get_variables)
4748

4849
for test_name, value in contents.items():
4950
model_test_metadata[test_name] = ModelTestMetadata(
@@ -54,8 +55,8 @@ def load_model_test_file(
5455

5556
def discover_model_tests(
5657
path: pathlib.Path,
57-
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
58-
ignore_patterns: list[str] | None = None,
58+
config: C,
59+
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
5960
) -> Iterator[ModelTestMetadata]:
6061
"""Discover model tests.
6162
@@ -74,12 +75,12 @@ def discover_model_tests(
7475
search_path.glob("**/test*.yaml"),
7576
search_path.glob("**/test*.yml"),
7677
):
77-
for ignore_pattern in ignore_patterns or []:
78+
for ignore_pattern in config.ignore_patterns or []:
7879
if yaml_file.match(ignore_pattern):
7980
break
8081
else:
8182
for model_test_metadata in load_model_test_file(
82-
yaml_file, get_variables=get_variables
83+
yaml_file, config=config, get_variables=get_variables
8384
).values():
8485
yield model_test_metadata
8586

@@ -105,7 +106,8 @@ def filter_tests_by_patterns(
105106

106107

107108
def load_model_tests(
108-
loaders: list[Loader],
109+
configs: t.Dict[pathlib.Path, C],
110+
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]],
109111
tests: t.Optional[t.List[str]] = None,
110112
patterns: list[str] | None = None,
111113
) -> list[ModelTestMetadata]:
@@ -123,23 +125,22 @@ def load_model_tests(
123125
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
124126

125127
test_file = load_model_test_file(
126-
pathlib.Path(filename), get_variables=loaders[0]._get_variables
128+
pathlib.Path(filename),
129+
config=next(iter(configs.values())),
130+
get_variables=get_variables,
127131
)
128132
if test_name:
129133
test_meta.append(test_file[test_name])
130134
else:
131135
test_meta.extend(test_file.values())
132136
else:
133-
for loader in loaders:
137+
for path, config in configs.items():
134138
test_meta.extend(
135-
[
136-
meta
137-
for meta in discover_model_tests(
138-
pathlib.Path(loader.config_path / c.TESTS),
139-
ignore_patterns=loader.config.ignore_patterns, # type: ignore
140-
get_variables=loader._get_variables,
141-
)
142-
]
139+
discover_model_tests(
140+
pathlib.Path(path / c.TESTS),
141+
config=config,
142+
get_variables=get_variables,
143+
)
143144
)
144145

145146
if patterns:

sqlmesh/magics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None
272272
if not args.test_name and not args.ls:
273273
raise MagicError("Must provide either test name or `--ls` to list tests")
274274

275-
test_meta = load_model_tests(loaders=context._loaders)
275+
test_meta = load_model_tests(configs=context.configs, get_variables=context._get_variables)
276276

277277
tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict)
278278
for model_test_metadata in test_meta:

sqlmesh/utils/yaml.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
from sqlmesh.utils.errors import SQLMeshError
1414
from sqlmesh.utils.jinja import ENVIRONMENT, create_var
1515

16+
if t.TYPE_CHECKING:
17+
from sqlmesh.core.config.loader import C
18+
19+
1620
JINJA_METHODS = {
1721
"env_var": lambda key, default=None: getenv(key, default),
1822
}
1923

2024

21-
gateway_pattern = re.compile(r"gateway:\s*([^\s]+)")
25+
GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)")
2226

2327

2428
def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML:
@@ -40,7 +44,8 @@ def load(
4044
render_jinja: bool = True,
4145
allow_duplicate_keys: bool = False,
4246
variables: t.Optional[t.Dict[str, t.Any]] = None,
43-
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None,
47+
config: t.Optional[C] = None,
48+
get_variables: t.Callable[[t.Optional[C], t.Optional[str]], t.Dict[str, str]] | None = None,
4449
) -> t.Dict:
4550
"""Loads a YAML object from either a raw string or a file."""
4651
path: t.Optional[Path] = None
@@ -51,8 +56,8 @@ def load(
5156
source = file.read()
5257

5358
if get_variables:
54-
gateway = gateway_pattern.search(source)
55-
variables = get_variables(gateway.group(1) if gateway else None)
59+
gateway = GATEWAY_PATTERN.search(source)
60+
variables = get_variables(config, gateway.group(1) if gateway else None)
5661

5762
if render_jinja:
5863
source = ENVIRONMENT.from_string(source).render(

0 commit comments

Comments
 (0)