Skip to content

Commit a1dc94a

Browse files
committed
Fix!: serialize blueprint variables separately to leverage AST
1 parent 3ea0493 commit a1dc94a

8 files changed

Lines changed: 168 additions & 43 deletions

File tree

sqlmesh/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
DEFAULT_SCHEMA = "default"
7777

7878
SQLMESH_VARS = "__sqlmesh__vars__"
79+
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
80+
7981
VAR = "var"
8082
GATEWAY = "gateway"
8183

sqlmesh/core/macros.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,16 @@ def __init__(
173173
"MacroEvaluator": MacroEvaluator,
174174
}
175175
self.python_env = python_env or {}
176-
self._jinja_env: t.Optional[Environment] = jinja_env
177176
self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()}
177+
self.columns_to_types_called = False
178+
self.default_catalog = default_catalog
179+
self.variables: t.Dict[str, t.Any] = {}
180+
181+
self._jinja_env: t.Optional[Environment] = jinja_env
178182
self._schema = schema
179183
self._resolve_table = resolve_table
180184
self._resolve_tables = resolve_tables
181-
self.columns_to_types_called = False
182185
self._snapshots = snapshots if snapshots is not None else {}
183-
self.default_catalog = default_catalog
184186
self._path = path
185187
self._environment_naming_info = environment_naming_info
186188

@@ -191,7 +193,12 @@ def __init__(
191193
elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None):
192194
self.macros[normalize_macro_name(k)] = self.env[k]
193195
elif v.is_value:
194-
self.locals[k] = self.env[k]
196+
value = self.env[k]
197+
if k == c.SQLMESH_VARS:
198+
self.variables = {**value, **(value.get(c.SQLMESH_BLUEPRINT_VARS) or {})}
199+
self.variables.pop(c.SQLMESH_BLUEPRINT_VARS, None)
200+
else:
201+
self.locals[k] = value
195202

196203
def send(
197204
self, name: str, *args: t.Any, **kwargs: t.Any
@@ -219,7 +226,7 @@ def evaluate_macros(
219226

220227
if isinstance(node, MacroVar):
221228
changed = True
222-
variables = self.locals.get(c.SQLMESH_VARS, {})
229+
variables = self.variables
223230
if node.name not in self.locals and node.name.lower() not in variables:
224231
if not isinstance(node.parent, StagedFilePath):
225232
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")
@@ -279,17 +286,12 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
279286
Returns:
280287
The rendered string.
281288
"""
282-
mapping = {}
283-
284-
variables = self.locals.get(c.SQLMESH_VARS, {})
285-
286-
for k, v in chain(variables.items(), self.locals.items(), local_variables.items()):
287-
# try to convert all variables into sqlglot expressions
288-
# because they're going to be converted into strings in sql
289-
# we don't convert strings because that would result in adding quotes
290-
if k != c.SQLMESH_VARS:
291-
mapping[k] = convert_sql(v, self.dialect)
292-
289+
# We try to convert all variables into sqlglot expressions because they're going to be converted
290+
# into strings; in sql we don't convert strings because that would result in adding quotes
291+
mapping = {
292+
k: convert_sql(v, self.dialect)
293+
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
294+
}
293295
return MacroStrTemplate(str(text)).safe_substitute(mapping)
294296

295297
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
@@ -465,7 +467,7 @@ def gateway(self) -> t.Optional[str]:
465467

466468
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
467469
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
468-
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)
470+
return self.variables.get(var_name.lower(), default)
469471

470472
def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
471473
"""Coerces the given expression to the specified type on a best-effort basis."""

sqlmesh/core/model/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sqlmesh.utils.pydantic import ValidationInfo, field_validator
1818

1919
if t.TYPE_CHECKING:
20+
from sqlglot.dialects.dialect import DialectType
2021
from sqlmesh.utils.jinja import MacroReference
2122

2223

@@ -30,6 +31,8 @@ def make_python_env(
3031
path: t.Optional[str | Path] = None,
3132
python_env: t.Optional[t.Dict[str, Executable]] = None,
3233
strict_resolution: bool = True,
34+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
35+
dialect: DialectType = None,
3336
) -> t.Dict[str, Executable]:
3437
python_env = {} if python_env is None else python_env
3538
variables = variables or {}
@@ -86,6 +89,8 @@ def make_python_env(
8689
python_env,
8790
used_variables,
8891
variables,
92+
blueprint_variables=blueprint_variables,
93+
dialect=dialect,
8994
strict_resolution=strict_resolution,
9095
)
9196

@@ -95,6 +100,8 @@ def _add_variables_to_python_env(
95100
used_variables: t.Optional[t.Set[str]],
96101
variables: t.Optional[t.Dict[str, t.Any]],
97102
strict_resolution: bool = True,
103+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
104+
dialect: DialectType = None,
98105
) -> t.Dict[str, Executable]:
99106
_, python_used_variables = parse_dependencies(
100107
python_env,
@@ -104,6 +111,20 @@ def _add_variables_to_python_env(
104111
used_variables = (used_variables or set()) | python_used_variables
105112

106113
variables = {k: v for k, v in (variables or {}).items() if k in used_variables}
114+
115+
serialized_blueprint_variables = {}
116+
for k, v in (blueprint_variables or {}).items():
117+
if isinstance(v, exp.Expression):
118+
serialized_blueprint_variables[k] = {
119+
"value": v.sql(dialect=dialect),
120+
"dialect": str(dialect or ""),
121+
}
122+
else:
123+
serialized_blueprint_variables[k] = {"value": v}
124+
125+
if serialized_blueprint_variables:
126+
variables[c.SQLMESH_BLUEPRINT_VARS] = serialized_blueprint_variables
127+
107128
if variables:
108129
python_env[c.SQLMESH_VARS] = Executable.value(variables)
109130

sqlmesh/core/model/decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def model(
122122
default_catalog: t.Optional[str] = None,
123123
variables: t.Optional[t.Dict[str, t.Any]] = None,
124124
infer_names: t.Optional[bool] = False,
125+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
125126
) -> Model:
126127
"""Get the model registered by this function."""
127128
env: t.Dict[str, t.Any] = {}
@@ -155,6 +156,7 @@ def model(
155156
path=path,
156157
dialect=dialect,
157158
default_catalog=default_catalog,
159+
blueprint_variables=blueprint_variables,
158160
)
159161

160162
rendered_name = rendered_fields["name"]
@@ -193,6 +195,7 @@ def model(
193195
"macros": macros,
194196
"jinja_macros": jinja_macros,
195197
"audit_definitions": audit_definitions,
198+
"blueprint_variables": blueprint_variables,
196199
**rendered_fields,
197200
}
198201

sqlmesh/core/model/definition.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,7 @@ def render(
17481748

17491749
variables = env.get(c.SQLMESH_VARS, {})
17501750
variables.update(kwargs.pop("variables", {}))
1751+
variables.update(variables.get(c.SQLMESH_BLUEPRINT_VARS) or {})
17511752

17521753
try:
17531754
kwargs = {
@@ -1855,18 +1856,14 @@ def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]:
18551856
return [] # This is unreachable, but is done to satisfy mypy
18561857

18571858

1858-
def _extract_blueprint_variables(
1859-
blueprint: t.Any,
1860-
dialect: DialectType,
1861-
path: Path,
1862-
) -> t.Dict[str, str]:
1859+
def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.Any]:
18631860
if not blueprint:
18641861
return {}
18651862
if isinstance(blueprint, exp.Paren):
18661863
blueprint = blueprint.unnest()
1867-
return {blueprint.left.name: blueprint.right.sql(dialect=dialect)}
1864+
return {blueprint.left.name: blueprint.right}
18681865
if isinstance(blueprint, (exp.Tuple, exp.Array)):
1869-
return {e.left.name: e.right.sql(dialect=dialect) for e in blueprint.expressions}
1866+
return {e.left.name: e.right for e in blueprint.expressions}
18701867
if isinstance(blueprint, dict):
18711868
return blueprint
18721869

@@ -1889,18 +1886,18 @@ def create_models_from_blueprints(
18891886
) -> t.List[Model]:
18901887
model_blueprints: t.List[Model] = []
18911888
for blueprint in _extract_blueprints(blueprints, path):
1892-
variables = _extract_blueprint_variables(blueprint, dialect, path)
1889+
blueprint_variables = _extract_blueprint_variables(blueprint, path)
18931890

18941891
if gateway:
18951892
rendered_gateway = render_expression(
18961893
expression=exp.maybe_parse(gateway, dialect=dialect),
18971894
module_path=module_path,
18981895
macros=loader_kwargs.get("macros"),
18991896
jinja_macros=loader_kwargs.get("jinja_macros"),
1900-
variables=variables,
19011897
path=path,
19021898
dialect=dialect,
19031899
default_catalog=loader_kwargs.get("default_catalog"),
1900+
blueprint_variables=blueprint_variables,
19041901
)
19051902
gateway_name = rendered_gateway[0].name if rendered_gateway else None
19061903
else:
@@ -1911,7 +1908,8 @@ def create_models_from_blueprints(
19111908
path=path,
19121909
module_path=module_path,
19131910
dialect=dialect,
1914-
variables={**get_variables(gateway_name), **variables},
1911+
variables=get_variables(gateway_name),
1912+
blueprint_variables=blueprint_variables,
19151913
**loader_kwargs,
19161914
)
19171915
)
@@ -1983,6 +1981,7 @@ def load_sql_based_model(
19831981
default_catalog: t.Optional[str] = None,
19841982
variables: t.Optional[t.Dict[str, t.Any]] = None,
19851983
infer_names: t.Optional[bool] = False,
1984+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
19861985
**kwargs: t.Any,
19871986
) -> Model:
19881987
"""Load a model from a parsed SQLMesh model SQL file.
@@ -2059,6 +2058,7 @@ def load_sql_based_model(
20592058
path=path,
20602059
dialect=dialect,
20612060
default_catalog=default_catalog,
2061+
blueprint_variables=blueprint_variables,
20622062
)
20632063

20642064
if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1:
@@ -2143,6 +2143,7 @@ def load_sql_based_model(
21432143
variables=variables,
21442144
default_audits=default_audits,
21452145
inline_audits=inline_audits,
2146+
blueprint_variables=blueprint_variables,
21462147
**meta_fields,
21472148
)
21482149

@@ -2247,6 +2248,7 @@ def create_python_model(
22472248
module_path: Path = Path(),
22482249
depends_on: t.Optional[t.Set[str]] = None,
22492250
variables: t.Optional[t.Dict[str, t.Any]] = None,
2251+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
22502252
**kwargs: t.Any,
22512253
) -> Model:
22522254
"""Creates a Python model.
@@ -2259,6 +2261,7 @@ def create_python_model(
22592261
path: An optional path to the model definition file.
22602262
depends_on: The custom set of model's upstream dependencies.
22612263
variables: The variables to pass to the model.
2264+
blueprint_variables: The blueprint's variables to pass to the model.
22622265
"""
22632266
# Find dependencies for python models by parsing code if they are not explicitly defined
22642267
# Also remove self-references that are found
@@ -2307,6 +2310,7 @@ def create_python_model(
23072310
jinja_macros=jinja_macros,
23082311
module_path=module_path,
23092312
variables=variables,
2313+
blueprint_variables=blueprint_variables,
23102314
**kwargs,
23112315
)
23122316

@@ -2361,6 +2365,7 @@ def _create_model(
23612365
macros: t.Optional[MacroRegistry] = None,
23622366
signal_definitions: t.Optional[SignalRegistry] = None,
23632367
variables: t.Optional[t.Dict[str, t.Any]] = None,
2368+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
23642369
**kwargs: t.Any,
23652370
) -> Model:
23662371
_validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path)
@@ -2469,6 +2474,8 @@ def _create_model(
24692474
path=path,
24702475
python_env=python_env,
24712476
strict_resolution=depends_on is None,
2477+
blueprint_variables=blueprint_variables,
2478+
dialect=dialect,
24722479
)
24732480

24742481
env: t.Dict[str, t.Any] = {}
@@ -2632,6 +2639,7 @@ def render_meta_fields(
26322639
dialect: DialectType,
26332640
variables: t.Optional[t.Dict[str, t.Any]],
26342641
default_catalog: t.Optional[str],
2642+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
26352643
) -> t.Dict[str, t.Any]:
26362644
def render_field_value(value: t.Any) -> t.Any:
26372645
if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value):
@@ -2645,6 +2653,7 @@ def render_field_value(value: t.Any) -> t.Any:
26452653
path=path,
26462654
dialect=dialect,
26472655
default_catalog=default_catalog,
2656+
blueprint_variables=blueprint_variables,
26482657
)
26492658
if not rendered_expr:
26502659
raise SQLMeshError(
@@ -2752,6 +2761,7 @@ def render_expression(
27522761
dialect: DialectType = None,
27532762
variables: t.Optional[t.Dict[str, t.Any]] = None,
27542763
default_catalog: t.Optional[str] = None,
2764+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
27552765
) -> t.Optional[t.List[exp.Expression]]:
27562766
meta_python_env = make_python_env(
27572767
expressions=expression,
@@ -2760,6 +2770,7 @@ def render_expression(
27602770
macros=macros or macro.get_registry(),
27612771
variables=variables,
27622772
path=path,
2773+
blueprint_variables=blueprint_variables,
27632774
)
27642775
return ExpressionRenderer(
27652776
expression,

sqlmesh/core/renderer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def _render(
221221
macro_evaluator.locals.update(render_kwargs)
222222

223223
if variables:
224-
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
224+
macro_evaluator.variables.update(variables)
225225

226226
resolved_expressions: t.List[t.Optional[exp.Expression]] = []
227227

sqlmesh/utils/metaprogramming.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from astor import to_source
1919

2020
from sqlmesh.core import constants as c
21+
from sqlmesh.core.dialect import parse_one
2122
from sqlmesh.utils import format_exception, unique
2223
from sqlmesh.utils.errors import SQLMeshError
2324
from sqlmesh.utils.pydantic import PydanticModel
@@ -491,11 +492,33 @@ def prepare_env(
491492
python_env.items(), key=lambda item: 0 if item[1].is_import else 1
492493
):
493494
if executable.is_value:
494-
env[name] = ast.literal_eval(executable.payload)
495+
literal = ast.literal_eval(executable.payload)
496+
if (
497+
isinstance(literal, dict)
498+
and c.SQLMESH_VARS == name
499+
and c.SQLMESH_BLUEPRINT_VARS in literal
500+
):
501+
bueprint_variables = {}
502+
for var, value_metadata in literal[c.SQLMESH_BLUEPRINT_VARS].items():
503+
assert isinstance(value_metadata, dict)
504+
assert "value" in value_metadata
505+
506+
value = value_metadata["value"]
507+
dialect = value_metadata.get("dialect")
508+
509+
# The presence of dialect signals that the original blueprint value was an AST
510+
bueprint_variables[var] = (
511+
value if dialect is None else parse_one(value, dialect=dialect)
512+
)
513+
514+
literal[c.SQLMESH_BLUEPRINT_VARS] = bueprint_variables
515+
516+
env[name] = literal
495517
else:
496518
exec(executable.payload, env)
497519
if executable.alias and executable.name:
498520
env[executable.alias] = env[executable.name]
521+
499522
return env
500523

501524

0 commit comments

Comments
 (0)