Skip to content

Commit d460f6f

Browse files
committed
Load variables from gateway specified in YAML
1 parent 7ebd460 commit d460f6f

7 files changed

Lines changed: 104 additions & 77 deletions

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,17 +1780,10 @@ def test(
17801780
if verbosity >= Verbosity.VERBOSE:
17811781
pd.set_option("display.max_columns", None)
17821782

1783-
# Merge the root variables with the gateway's variables
1784-
variables = {
1785-
**self.config.variables,
1786-
**self.config.get_gateway(self.selected_gateway).variables,
1787-
}
1788-
17891783
test_meta = load_model_tests(
1790-
configs=self.configs,
1784+
loaders=self._loaders,
17911785
tests=tests,
17921786
patterns=match_patterns,
1793-
variables=variables,
17941787
)
17951788

17961789
return run_tests(

sqlmesh/core/test/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
44
from sqlmesh.core.test.discovery import (
55
ModelTestMetadata as ModelTestMetadata,
6-
filter_tests_by_patterns as filter_tests_by_patterns,
7-
get_all_model_tests as get_all_model_tests,
8-
load_model_test_file as load_model_test_file,
96
load_model_tests as load_model_tests,
107
)
118
from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult

sqlmesh/core/test/discovery.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlmesh.utils.yaml import load as yaml_load
1515

1616
if t.TYPE_CHECKING:
17-
from sqlmesh.core.config.loader import C
17+
from sqlmesh.core.loader import Loader
1818

1919

2020
class ModelTestMetadata(PydanticModel):
@@ -31,7 +31,8 @@ def __hash__(self) -> int:
3131

3232

3333
def load_model_test_file(
34-
path: pathlib.Path, variables: dict[str, t.Any] | None = None
34+
path: pathlib.Path,
35+
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
3536
) -> dict[str, ModelTestMetadata]:
3637
"""Load a single model test file.
3738
@@ -42,7 +43,7 @@ def load_model_test_file(
4243
A list of ModelTestMetadata named tuples.
4344
"""
4445
model_test_metadata = {}
45-
contents = yaml_load(path, variables=variables)
46+
contents = yaml_load(path, get_variables=get_variables)
4647

4748
for test_name, value in contents.items():
4849
model_test_metadata[test_name] = ModelTestMetadata(
@@ -53,8 +54,8 @@ def load_model_test_file(
5354

5455
def discover_model_tests(
5556
path: pathlib.Path,
57+
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]],
5658
ignore_patterns: list[str] | None = None,
57-
variables: dict[str, t.Any] | None = None,
5859
) -> Iterator[ModelTestMetadata]:
5960
"""Discover model tests.
6061
@@ -78,7 +79,7 @@ def discover_model_tests(
7879
break
7980
else:
8081
for model_test_metadata in load_model_test_file(
81-
yaml_file, variables=variables
82+
yaml_file, get_variables=get_variables
8283
).values():
8384
yield model_test_metadata
8485

@@ -103,34 +104,16 @@ def filter_tests_by_patterns(
103104
)
104105

105106

106-
def get_all_model_tests(
107-
*paths: pathlib.Path,
108-
patterns: list[str] | None = None,
109-
ignore_patterns: list[str] | None = None,
110-
variables: dict[str, t.Any] | None = None,
111-
) -> list[ModelTestMetadata]:
112-
model_test_metadatas = [
113-
meta
114-
for path in paths
115-
for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables)
116-
]
117-
if patterns:
118-
model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns)
119-
return model_test_metadatas
120-
121-
122107
def load_model_tests(
123-
configs: dict[pathlib.Path, C],
108+
loaders: list[Loader],
124109
tests: t.Optional[t.List[str]] = None,
125110
patterns: list[str] | None = None,
126-
variables: dict[str, t.Any] | None = None,
127111
) -> list[ModelTestMetadata]:
128112
"""Load model tests into a list of ModelTestMetadata which will be propagated to the test runner.
129113
130114
Args:
131115
tests: A list of tests to load; If not specified, all tests are loaded
132-
patterns: A list of patterns to match against.
133-
variables: A dictionary of variables to use when loading the tests.
116+
patterns: A list of patterns that'll be used to filter tests by file name.
134117
configs: A dictionary of configs to use when loading all the tests.
135118
"""
136119
test_meta = []
@@ -139,24 +122,27 @@ def load_model_tests(
139122
for test in tests:
140123
filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "")
141124

142-
test_file = load_model_test_file(pathlib.Path(filename), variables=variables)
125+
test_file = load_model_test_file(
126+
pathlib.Path(filename), get_variables=loaders[0]._get_variables
127+
)
143128
if test_name:
144129
test_meta.append(test_file[test_name])
145130
else:
146131
test_meta.extend(test_file.values())
147-
148-
if patterns:
149-
test_meta = filter_tests_by_patterns(test_meta, patterns)
150-
151132
else:
152-
for path, config in configs.items():
133+
for loader in loaders:
153134
test_meta.extend(
154-
get_all_model_tests(
155-
path / c.TESTS,
156-
patterns=patterns,
157-
ignore_patterns=config.ignore_patterns,
158-
variables=variables,
159-
)
135+
[
136+
meta
137+
for meta in discover_model_tests(
138+
pathlib.Path(loader.config_path / c.TESTS),
139+
ignore_patterns=loader.config.ignore_patterns, # type: ignore
140+
get_variables=loader._get_variables,
141+
)
142+
]
160143
)
161144

145+
if patterns:
146+
test_meta = filter_tests_by_patterns(test_meta, patterns)
147+
162148
return test_meta

sqlmesh/core/test/runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test
1515
from sqlmesh.core.test.discovery import (
1616
ModelTestMetadata as ModelTestMetadata,
17-
filter_tests_by_patterns as filter_tests_by_patterns,
18-
get_all_model_tests as get_all_model_tests,
19-
load_model_test_file as load_model_test_file,
2017
)
2118
from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig
2219

sqlmesh/magics.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626
from rich.jupyter import JupyterRenderable
2727
from sqlmesh.cli.example_project import ProjectTemplate, init_example_project
2828
from sqlmesh.core import analytics
29-
from sqlmesh.core import constants as c
3029
from sqlmesh.core.config import load_configs
3130
from sqlmesh.core.console import create_console, set_console, configure_console
3231
from sqlmesh.core.context import Context
3332
from sqlmesh.core.dialect import format_model_expressions, parse
3433
from sqlmesh.core.model import load_sql_based_model
35-
from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests
34+
from sqlmesh.core.test import ModelTestMetadata, load_model_tests
3635
from sqlmesh.utils import sqlglot_dialects, yaml, Verbosity
3736
from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError
3837

@@ -273,15 +272,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None
273272
if not args.test_name and not args.ls:
274273
raise MagicError("Must provide either test name or `--ls` to list tests")
275274

276-
test_meta = []
277-
278-
for path, config in context.configs.items():
279-
test_meta.extend(
280-
get_all_model_tests(
281-
path / c.TESTS,
282-
ignore_patterns=config.ignore_patterns,
283-
)
284-
)
275+
test_meta = load_model_tests(loaders=context._loaders)
285276

286277
tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict)
287278
for model_test_metadata in test_meta:

sqlmesh/utils/yaml.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from decimal import Decimal
66
from os import getenv
77
from pathlib import Path
8+
import re
89

910
from ruamel import yaml
1011

@@ -17,6 +18,9 @@
1718
}
1819

1920

21+
gateway_pattern = re.compile(r"gateway:\s*([^\s]+)")
22+
23+
2024
def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML:
2125
yaml_obj = yaml.YAML(typ=typ)
2226

@@ -36,6 +40,7 @@ def load(
3640
render_jinja: bool = True,
3741
allow_duplicate_keys: bool = False,
3842
variables: t.Optional[t.Dict[str, t.Any]] = None,
43+
get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]] | None = None,
3944
) -> t.Dict:
4045
"""Loads a YAML object from either a raw string or a file."""
4146
path: t.Optional[Path] = None
@@ -45,6 +50,10 @@ def load(
4550
with open(source, "r", encoding="utf-8") as file:
4651
source = file.read()
4752

53+
if get_variables:
54+
gateway = gateway_pattern.search(source)
55+
variables = get_variables(gateway.group(1) if gateway else None)
56+
4857
if render_jinja:
4958
source = ENVIRONMENT.from_string(source).render(
5059
{

tests/core/test_test.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,26 +1563,32 @@ def test_variable_usage(tmp_path: Path) -> None:
15631563
init_example_project(tmp_path, dialect="duckdb")
15641564

15651565
variables = {"gold": "gold_db", "silver": "silver_db"}
1566-
1567-
# Case 1: Test root variables
1568-
config = Config(
1569-
default_connection=DuckDBConnectionConfig(),
1570-
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1571-
variables=variables,
1572-
)
1573-
context = Context(paths=tmp_path, config=config)
1566+
incorrect_variables = {"gold": "foo", "silver": "bar"}
15741567

15751568
parent = _create_model(
15761569
"SELECT 1 AS id, '2022-01-02'::DATE AS ds, @start_ts AS start_ts",
15771570
meta="MODEL (name silver_db.sch.b, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))",
15781571
)
1579-
parent = t.cast(SqlModel, context.upsert_model(parent))
15801572

15811573
child = _create_model(
15821574
"SELECT ds, @IF(@VAR('myvar'), id, id + 1) AS id FROM silver_db.sch.b WHERE ds BETWEEN @start_ds and @end_ds",
15831575
meta="MODEL (name gold_db.sch.a, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))",
15841576
)
1585-
child = t.cast(SqlModel, context.upsert_model(child))
1577+
1578+
def init_context(config: Config, **kwargs):
1579+
context = Context(paths=tmp_path, config=config, **kwargs)
1580+
context.upsert_model(parent)
1581+
context.upsert_model(child)
1582+
return context
1583+
1584+
# Case 1: Test root variables
1585+
config = Config(
1586+
default_connection=DuckDBConnectionConfig(),
1587+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1588+
variables=variables,
1589+
)
1590+
1591+
context = init_context(config)
15861592

15871593
test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml"
15881594
test_file.write_text(
@@ -1617,28 +1623,76 @@ def test_variable_usage(tmp_path: Path) -> None:
16171623
assert len(results.successes) == 2
16181624

16191625
# Case 2: Test gateway variables
1620-
context.config = Config(
1626+
config = Config(
16211627
gateways={
1622-
"": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables),
1628+
"main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables),
16231629
},
16241630
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
16251631
)
16261632

1633+
context = init_context(config, gateway="main")
1634+
16271635
results = context.test()
16281636

16291637
assert not results.failures
16301638
assert not results.errors
16311639
assert len(results.successes) == 2
16321640

16331641
# Case 3: Test gateway variables overriding root variables
1634-
context.config = Config(
1642+
config = Config(
16351643
gateways={
1636-
"": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables),
1644+
"main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables),
16371645
},
16381646
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1639-
variables={"gold": "foo", "silver": "bar"},
1647+
variables=incorrect_variables,
16401648
)
16411649

1650+
context = init_context(config, gateway="main")
1651+
1652+
results = context.test()
1653+
1654+
assert not results.failures
1655+
assert not results.errors
1656+
assert len(results.successes) == 2
1657+
1658+
# Case 4: Use variable from the defined gateway
1659+
test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml"
1660+
test_file.write_text(
1661+
"""
1662+
test_parameterized_model_names:
1663+
model: {{ var('gold') }}.sch.a
1664+
gateway: secondary
1665+
vars:
1666+
myvar: True
1667+
start_ds: 2022-01-01
1668+
end_ds: 2022-01-03
1669+
inputs:
1670+
{{ var('silver') }}.sch.b:
1671+
- ds: 2022-01-01
1672+
id: 1
1673+
- ds: 2022-01-01
1674+
id: 2
1675+
outputs:
1676+
query:
1677+
- ds: 2022-01-01
1678+
id: 1
1679+
- ds: 2022-01-01
1680+
id: 2
1681+
"""
1682+
)
1683+
1684+
config = Config(
1685+
gateways={
1686+
"main": GatewayConfig(
1687+
connection=DuckDBConnectionConfig(), variables=incorrect_variables
1688+
),
1689+
"secondary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables),
1690+
},
1691+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1692+
)
1693+
1694+
context = init_context(config, gateway="main")
1695+
16421696
results = context.test()
16431697

16441698
assert not results.failures

0 commit comments

Comments
 (0)