Skip to content

Commit 417eb37

Browse files
committed
Fix!: serialize blueprint variables separately to leverage AST
1 parent bb9826a commit 417eb37

6 files changed

Lines changed: 106 additions & 22 deletions

File tree

sqlmesh/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
DEFAULT_SCHEMA = "default"
7777

7878
SQLMESH_VARS = "__sqlmesh__vars__"
79+
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint_vars__"
80+
7981
VAR = "var"
8082
GATEWAY = "gateway"
8183

sqlmesh/core/macros.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
self.default_catalog = default_catalog
184184
self._path = path
185185
self._environment_naming_info = environment_naming_info
186+
self._blueprint_vars = {}
186187

187188
prepare_env(self.python_env, self.env)
188189
for k, v in self.python_env.items():
@@ -191,7 +192,13 @@ def __init__(
191192
elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None):
192193
self.macros[normalize_macro_name(k)] = self.env[k]
193194
elif v.is_value:
194-
self.locals[k] = self.env[k]
195+
if k == c.SQLMESH_BLUEPRINT_VARS:
196+
self._blueprint_vars = {
197+
var: self.parse_one(value) for var, value in self.env[k].items()
198+
}
199+
self.locals.update(self._blueprint_vars)
200+
else:
201+
self.locals[k] = self.env[k]
195202

196203
def send(
197204
self, name: str, *args: t.Any, **kwargs: t.Any
@@ -465,7 +472,10 @@ def gateway(self) -> t.Optional[str]:
465472

466473
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
467474
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
468-
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)
475+
return {
476+
**(self.locals.get(c.SQLMESH_VARS) or {}),
477+
**self._blueprint_vars,
478+
}.get(var_name.lower(), default)
469479

470480
def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
471481
"""Coerces the given expression to the specified type on a best-effort basis."""

sqlmesh/core/model/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sqlmesh.utils.pydantic import ValidationInfo, field_validator
1818

1919
if t.TYPE_CHECKING:
20+
from sqlglot.dialects.dialect import DialectType
2021
from sqlmesh.utils.jinja import MacroReference
2122

2223

@@ -30,6 +31,8 @@ def make_python_env(
3031
path: t.Optional[str | Path] = None,
3132
python_env: t.Optional[t.Dict[str, Executable]] = None,
3233
strict_resolution: bool = True,
34+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
35+
dialect: DialectType = None,
3336
) -> t.Dict[str, Executable]:
3437
python_env = {} if python_env is None else python_env
3538
variables = variables or {}
@@ -86,6 +89,8 @@ def make_python_env(
8689
python_env,
8790
used_variables,
8891
variables,
92+
blueprint_variables=blueprint_variables,
93+
dialect=dialect,
8994
strict_resolution=strict_resolution,
9095
)
9196

@@ -95,6 +100,8 @@ def _add_variables_to_python_env(
95100
used_variables: t.Optional[t.Set[str]],
96101
variables: t.Optional[t.Dict[str, t.Any]],
97102
strict_resolution: bool = True,
103+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
104+
dialect: DialectType = None,
98105
) -> t.Dict[str, Executable]:
99106
_, python_used_variables = parse_dependencies(
100107
python_env,
@@ -107,6 +114,13 @@ def _add_variables_to_python_env(
107114
if variables:
108115
python_env[c.SQLMESH_VARS] = Executable.value(variables)
109116

117+
if blueprint_variables:
118+
blueprint_variables = {
119+
k: v.sql(dialect=dialect) if isinstance(v, exp.Expression) else v
120+
for k, v in blueprint_variables.items()
121+
}
122+
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)
123+
110124
return python_env
111125

112126

sqlmesh/core/model/decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def model(
122122
default_catalog: t.Optional[str] = None,
123123
variables: t.Optional[t.Dict[str, t.Any]] = None,
124124
infer_names: t.Optional[bool] = False,
125+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
125126
) -> Model:
126127
"""Get the model registered by this function."""
127128
env: t.Dict[str, t.Any] = {}
@@ -155,6 +156,7 @@ def model(
155156
path=path,
156157
dialect=dialect,
157158
default_catalog=default_catalog,
159+
blueprint_variables=blueprint_variables,
158160
)
159161

160162
rendered_name = rendered_fields["name"]
@@ -193,6 +195,7 @@ def model(
193195
"macros": macros,
194196
"jinja_macros": jinja_macros,
195197
"audit_definitions": audit_definitions,
198+
"blueprint_variables": blueprint_variables,
196199
**rendered_fields,
197200
}
198201

sqlmesh/core/model/definition.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,7 @@ def render(
17481748

17491749
variables = env.get(c.SQLMESH_VARS, {})
17501750
variables.update(kwargs.pop("variables", {}))
1751+
variables.update(env.get(c.SQLMESH_BLUEPRINT_VARS, {}))
17511752

17521753
try:
17531754
kwargs = {
@@ -1855,18 +1856,14 @@ def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]:
18551856
return [] # This is unreachable, but is done to satisfy mypy
18561857

18571858

1858-
def _extract_blueprint_variables(
1859-
blueprint: t.Any,
1860-
dialect: DialectType,
1861-
path: Path,
1862-
) -> t.Dict[str, str]:
1859+
def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.Any]:
18631860
if not blueprint:
18641861
return {}
18651862
if isinstance(blueprint, exp.Paren):
18661863
blueprint = blueprint.unnest()
1867-
return {blueprint.left.name: blueprint.right.sql(dialect=dialect)}
1864+
return {blueprint.left.name: blueprint.right}
18681865
if isinstance(blueprint, (exp.Tuple, exp.Array)):
1869-
return {e.left.name: e.right.sql(dialect=dialect) for e in blueprint.expressions}
1866+
return {e.left.name: e.right for e in blueprint.expressions}
18701867
if isinstance(blueprint, dict):
18711868
return blueprint
18721869

@@ -1889,18 +1886,18 @@ def create_models_from_blueprints(
18891886
) -> t.List[Model]:
18901887
model_blueprints: t.List[Model] = []
18911888
for blueprint in _extract_blueprints(blueprints, path):
1892-
variables = _extract_blueprint_variables(blueprint, dialect, path)
1889+
blueprint_variables = _extract_blueprint_variables(blueprint, path)
18931890

18941891
if gateway:
18951892
rendered_gateway = render_expression(
18961893
expression=exp.maybe_parse(gateway, dialect=dialect),
18971894
module_path=module_path,
18981895
macros=loader_kwargs.get("macros"),
18991896
jinja_macros=loader_kwargs.get("jinja_macros"),
1900-
variables=variables,
19011897
path=path,
19021898
dialect=dialect,
19031899
default_catalog=loader_kwargs.get("default_catalog"),
1900+
blueprint_variables=blueprint_variables,
19041901
)
19051902
gateway_name = rendered_gateway[0].name if rendered_gateway else None
19061903
else:
@@ -1911,7 +1908,8 @@ def create_models_from_blueprints(
19111908
path=path,
19121909
module_path=module_path,
19131910
dialect=dialect,
1914-
variables={**get_variables(gateway_name), **variables},
1911+
variables=get_variables(gateway_name),
1912+
blueprint_variables=blueprint_variables,
19151913
**loader_kwargs,
19161914
)
19171915
)
@@ -1983,6 +1981,7 @@ def load_sql_based_model(
19831981
default_catalog: t.Optional[str] = None,
19841982
variables: t.Optional[t.Dict[str, t.Any]] = None,
19851983
infer_names: t.Optional[bool] = False,
1984+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
19861985
**kwargs: t.Any,
19871986
) -> Model:
19881987
"""Load a model from a parsed SQLMesh model SQL file.
@@ -2059,6 +2058,7 @@ def load_sql_based_model(
20592058
path=path,
20602059
dialect=dialect,
20612060
default_catalog=default_catalog,
2061+
blueprint_variables=blueprint_variables,
20622062
)
20632063

20642064
if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1:
@@ -2143,6 +2143,7 @@ def load_sql_based_model(
21432143
variables=variables,
21442144
default_audits=default_audits,
21452145
inline_audits=inline_audits,
2146+
blueprint_variables=blueprint_variables,
21462147
**meta_fields,
21472148
)
21482149

@@ -2247,6 +2248,7 @@ def create_python_model(
22472248
module_path: Path = Path(),
22482249
depends_on: t.Optional[t.Set[str]] = None,
22492250
variables: t.Optional[t.Dict[str, t.Any]] = None,
2251+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
22502252
**kwargs: t.Any,
22512253
) -> Model:
22522254
"""Creates a Python model.
@@ -2259,6 +2261,7 @@ def create_python_model(
22592261
path: An optional path to the model definition file.
22602262
depends_on: The custom set of model's upstream dependencies.
22612263
variables: The variables to pass to the model.
2264+
blueprint_variables: The blueprint's variables to pass to the model.
22622265
"""
22632266
# Find dependencies for python models by parsing code if they are not explicitly defined
22642267
# Also remove self-references that are found
@@ -2307,6 +2310,7 @@ def create_python_model(
23072310
jinja_macros=jinja_macros,
23082311
module_path=module_path,
23092312
variables=variables,
2313+
blueprint_variables=blueprint_variables,
23102314
**kwargs,
23112315
)
23122316

@@ -2361,6 +2365,7 @@ def _create_model(
23612365
macros: t.Optional[MacroRegistry] = None,
23622366
signal_definitions: t.Optional[SignalRegistry] = None,
23632367
variables: t.Optional[t.Dict[str, t.Any]] = None,
2368+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
23642369
**kwargs: t.Any,
23652370
) -> Model:
23662371
_validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path)
@@ -2469,6 +2474,8 @@ def _create_model(
24692474
path=path,
24702475
python_env=python_env,
24712476
strict_resolution=depends_on is None,
2477+
blueprint_variables=blueprint_variables,
2478+
dialect=dialect,
24722479
)
24732480

24742481
env: t.Dict[str, t.Any] = {}
@@ -2632,6 +2639,7 @@ def render_meta_fields(
26322639
dialect: DialectType,
26332640
variables: t.Optional[t.Dict[str, t.Any]],
26342641
default_catalog: t.Optional[str],
2642+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
26352643
) -> t.Dict[str, t.Any]:
26362644
def render_field_value(value: t.Any) -> t.Any:
26372645
if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value):
@@ -2645,6 +2653,7 @@ def render_field_value(value: t.Any) -> t.Any:
26452653
path=path,
26462654
dialect=dialect,
26472655
default_catalog=default_catalog,
2656+
blueprint_variables=blueprint_variables,
26482657
)
26492658
if not rendered_expr:
26502659
raise SQLMeshError(
@@ -2752,6 +2761,7 @@ def render_expression(
27522761
dialect: DialectType = None,
27532762
variables: t.Optional[t.Dict[str, t.Any]] = None,
27542763
default_catalog: t.Optional[str] = None,
2764+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
27552765
) -> t.Optional[t.List[exp.Expression]]:
27562766
meta_python_env = make_python_env(
27572767
expressions=expression,
@@ -2760,6 +2770,7 @@ def render_expression(
27602770
macros=macros or macro.get_registry(),
27612771
variables=variables,
27622772
path=path,
2773+
blueprint_variables=blueprint_variables,
27632774
)
27642775
return ExpressionRenderer(
27652776
expression,

tests/core/test_model.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8200,6 +8200,7 @@ def identity(evaluator, value):
82008200
)
82018201
def entrypoint(context, *args, **kwargs):
82028202
x_var = context.var("x")
8203+
assert kwargs.get("blueprint")
82038204
return pd.DataFrame({"x": [x_var]})"""
82048205
)
82058206
blueprint_pysql = tmp_path / "models" / "blueprint_sql.py"
@@ -8218,6 +8219,7 @@ def entrypoint(context, *args, **kwargs):
82188219
)
82198220
def entrypoint(evaluator):
82208221
x_var = evaluator.var("x")
8222+
assert evaluator.var("blueprint")
82218223
return f'SELECT {x_var} AS x'"""
82228224
)
82238225

@@ -8236,6 +8238,9 @@ def entrypoint(evaluator):
82368238
assert model is not None
82378239
assert "blueprints" not in model.all_fields()
82388240
assert model.python_env.get(c.SQLMESH_VARS) == Executable.value({"x": gateway_no})
8241+
assert model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value(
8242+
{"blueprint": f"gw{gateway_no}"}
8243+
)
82398244
assert context.fetchdf(f"from {model.fqn}").to_dict() == {"x": {0: gateway_no}}
82408245

82418246
multi_variable_blueprint_example = tmp_path / "models" / "multi_variable_blueprint_example.sql"
@@ -8245,15 +8250,17 @@ def entrypoint(evaluator):
82458250
MODEL (
82468251
name @{customer}.my_table,
82478252
blueprints (
8248-
(customer := customer1, foo := 'bar'),
8249-
(customer := customer2, foo := qux),
8253+
(customer := customer1, customer_field := 'bar'),
8254+
(customer := customer2, customer_field := qux),
82508255
),
82518256
kind FULL
82528257
);
82538258
82548259
SELECT
8255-
@VAR('foo') AS foo,
8256-
FROM @VAR('customer').my_source
8260+
@customer_field AS foo,
8261+
@{customer_field} AS foo2,
8262+
@VAR('customer_field') AS foo3,
8263+
FROM @{customer}.my_source
82578264
"""
82588265
)
82598266

@@ -8266,21 +8273,25 @@ def entrypoint(evaluator):
82668273
customer1_model = models.get('"db"."customer1"."my_table"')
82678274

82688275
assert customer1_model is not None
8269-
assert customer1_model.python_env.get(c.SQLMESH_VARS) == Executable.value(
8270-
{"customer": "customer1", "foo": "'bar'"}
8276+
8277+
assert customer1_model.python_env.get(c.SQLMESH_VARS) is None
8278+
assert customer1_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value(
8279+
{"customer": "customer1", "customer_field": "'bar'"}
82718280
)
82728281
assert t.cast(exp.Expression, customer1_model.render_query()).sql() == (
8273-
"""SELECT '''bar''' AS "foo" FROM "db"."customer1"."my_source" AS "my_source\""""
8282+
"""SELECT 'bar' AS "foo", "bar" AS "foo2", 'bar' AS "foo3" FROM "db"."customer1"."my_source" AS "my_source\""""
82748283
)
82758284

82768285
customer2_model = models.get('"db"."customer2"."my_table"')
82778286

82788287
assert customer2_model is not None
8279-
assert customer2_model.python_env.get(c.SQLMESH_VARS) == Executable.value(
8280-
{"customer": "customer2", "foo": "qux"}
8288+
8289+
assert customer2_model.python_env.get(c.SQLMESH_VARS) is None
8290+
assert customer2_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value(
8291+
{"customer": "customer2", "customer_field": "qux"}
82818292
)
82828293
assert t.cast(exp.Expression, customer2_model.render_query()).sql() == (
8283-
'''SELECT 'qux' AS "foo" FROM "db"."customer2"."my_source" AS "my_source"'''
8294+
'''SELECT "qux" AS "foo", "qux" AS "foo2", "qux" AS "foo3" FROM "db"."customer2"."my_source" AS "my_source"'''
82848295
)
82858296

82868297

@@ -8327,6 +8338,39 @@ def gen_blueprints(evaluator):
83278338
assert '"memory"."customer2"."some_table"' in ctx.models
83288339

83298340

8341+
def test_blueprinting_with_quotes(tmp_path: Path) -> None:
8342+
init_example_project(tmp_path, dialect="duckdb", template=ProjectTemplate.EMPTY)
8343+
8344+
template_with_quoted_vars = tmp_path / "models/template_with_quoted_vars.sql"
8345+
template_with_quoted_vars.parent.mkdir(parents=True, exist_ok=True)
8346+
template_with_quoted_vars.write_text(
8347+
"""
8348+
MODEL (
8349+
name m.@{bp_var},
8350+
blueprints (
8351+
(bp_var := "a b"),
8352+
(bp_var := 'c d'),
8353+
),
8354+
);
8355+
8356+
SELECT @bp_var AS c1, @{bp_var} AS c2
8357+
"""
8358+
)
8359+
8360+
ctx = Context(
8361+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path
8362+
)
8363+
assert len(ctx.models) == 2
8364+
8365+
m1 = ctx.get_model('m."a b"', raise_if_missing=True)
8366+
m2 = ctx.get_model('m."c d"', raise_if_missing=True)
8367+
8368+
assert m1.name == 'm."a b"'
8369+
assert m2.name == 'm."c d"'
8370+
assert t.cast(exp.Query, m1.render_query()).sql() == '''SELECT "a b" AS "c1", "a b" AS "c2"'''
8371+
assert t.cast(exp.Query, m2.render_query()).sql() == '''SELECT 'c d' AS "c1", "c d" AS "c2"'''
8372+
8373+
83308374
@time_machine.travel("2020-01-01 00:00:00 UTC")
83318375
def test_dynamic_date_spine_model(assert_exp_eq):
83328376
@macro()

0 commit comments

Comments
 (0)