Skip to content

Commit 15866ba

Browse files
authored
Fix!: serialize blueprint variables separately to leverage AST (#4061)
1 parent 42511e1 commit 15866ba

9 files changed

Lines changed: 336 additions & 47 deletions

File tree

sqlmesh/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
DEFAULT_SCHEMA = "default"
7777

7878
SQLMESH_VARS = "__sqlmesh__vars__"
79+
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
80+
7981
VAR = "var"
8082
GATEWAY = "gateway"
8183

sqlmesh/core/context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,15 @@ def __init__(
253253
default_dialect: t.Optional[str] = None,
254254
default_catalog: t.Optional[str] = None,
255255
variables: t.Optional[t.Dict[str, t.Any]] = None,
256+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
256257
):
257258
self.snapshots = snapshots
258259
self.deployability_index = deployability_index
259260
self._engine_adapter = engine_adapter
260261
self._default_catalog = default_catalog
261262
self._default_dialect = default_dialect
262263
self._variables = variables or {}
264+
self._blueprint_variables = blueprint_variables or {}
263265

264266
@property
265267
def default_dialect(self) -> t.Optional[str]:
@@ -288,7 +290,15 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.
288290
"""Returns a variable value."""
289291
return self._variables.get(var_name.lower(), default)
290292

291-
def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext:
293+
def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
294+
"""Returns a blueprint variable value."""
295+
return self._blueprint_variables.get(var_name.lower(), default)
296+
297+
def with_variables(
298+
self,
299+
variables: t.Dict[str, t.Any],
300+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
301+
) -> ExecutionContext:
292302
"""Returns a new ExecutionContext with additional variables."""
293303
return ExecutionContext(
294304
self._engine_adapter,
@@ -297,6 +307,7 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext:
297307
self._default_dialect,
298308
self._default_catalog,
299309
variables=variables,
310+
blueprint_variables=blueprint_variables,
300311
)
301312

302313

sqlmesh/core/macros.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sqlmesh.utils.date import DatetimeRanges, to_datetime, to_date
4242
from sqlmesh.utils.errors import MacroEvalError, SQLMeshError
4343
from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja
44-
from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception
44+
from sqlmesh.utils.metaprogramming import Executable, SqlValue, prepare_env, print_exception
4545

4646
if t.TYPE_CHECKING:
4747
from sqlglot.dialects.dialect import DialectType
@@ -173,14 +173,15 @@ def __init__(
173173
"MacroEvaluator": MacroEvaluator,
174174
}
175175
self.python_env = python_env or {}
176-
self._jinja_env: t.Optional[Environment] = jinja_env
177176
self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()}
177+
self.columns_to_types_called = False
178+
self.default_catalog = default_catalog
179+
180+
self._jinja_env: t.Optional[Environment] = jinja_env
178181
self._schema = schema
179182
self._resolve_table = resolve_table
180183
self._resolve_tables = resolve_tables
181-
self.columns_to_types_called = False
182184
self._snapshots = snapshots if snapshots is not None else {}
183-
self.default_catalog = default_catalog
184185
self._path = path
185186
self._environment_naming_info = environment_naming_info
186187

@@ -191,7 +192,18 @@ def __init__(
191192
elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None):
192193
self.macros[normalize_macro_name(k)] = self.env[k]
193194
elif v.is_value:
194-
self.locals[k] = self.env[k]
195+
value = self.env[k]
196+
if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS):
197+
value = {
198+
var_name: (
199+
self.parse_one(var_value.sql)
200+
if isinstance(var_value, SqlValue)
201+
else var_value
202+
)
203+
for var_name, var_value in value.items()
204+
}
205+
206+
self.locals[k] = value
195207

196208
def send(
197209
self, name: str, *args: t.Any, **kwargs: t.Any
@@ -219,20 +231,23 @@ def evaluate_macros(
219231

220232
if isinstance(node, MacroVar):
221233
changed = True
222-
variables = self.locals.get(c.SQLMESH_VARS, {})
234+
variables = self.variables
235+
223236
if node.name not in self.locals and node.name.lower() not in variables:
224237
if not isinstance(node.parent, StagedFilePath):
225238
raise SQLMeshError(f"Macro variable '{node.name}' is undefined.")
226239

227240
return node
228241

242+
# Precedence order is locals (e.g. @DEF) > blueprint variables > config variables
229243
value = self.locals.get(node.name, variables.get(node.name.lower()))
230244
if isinstance(value, list):
231245
return exp.convert(
232246
tuple(
233247
self.transform(v) if isinstance(v, exp.Expression) else v for v in value
234248
)
235249
)
250+
236251
return exp.convert(
237252
self.transform(value) if isinstance(value, exp.Expression) else value
238253
)
@@ -279,17 +294,12 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str:
279294
Returns:
280295
The rendered string.
281296
"""
282-
mapping = {}
283-
284-
variables = self.locals.get(c.SQLMESH_VARS, {})
285-
286-
for k, v in chain(variables.items(), self.locals.items(), local_variables.items()):
287-
# try to convert all variables into sqlglot expressions
288-
# because they're going to be converted into strings in sql
289-
# we don't convert strings because that would result in adding quotes
290-
if k != c.SQLMESH_VARS:
291-
mapping[k] = convert_sql(v, self.dialect)
292-
297+
# We try to convert all variables into sqlglot expressions because they're going to be converted
298+
# into strings; in sql we don't convert strings because that would result in adding quotes
299+
mapping = {
300+
k: convert_sql(v, self.dialect)
301+
for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items())
302+
}
293303
return MacroStrTemplate(str(text)).safe_substitute(mapping)
294304

295305
def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None:
@@ -467,6 +477,17 @@ def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.
467477
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
468478
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)
469479

480+
def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
481+
"""Returns the value of the specified blueprint variable, or the default value if it doesn't exist."""
482+
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default)
483+
484+
@property
485+
def variables(self) -> t.Dict[str, t.Any]:
486+
return {
487+
**self.locals.get(c.SQLMESH_VARS, {}),
488+
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}),
489+
}
490+
470491
def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:
471492
"""Coerces the given expression to the specified type on a best-effort basis."""
472493
return _coerce(expr, typ, self.dialect, self._path, strict)
@@ -1054,6 +1075,19 @@ def var(
10541075
return exp.convert(evaluator.var(var_name.this, default))
10551076

10561077

1078+
@macro("BLUEPRINT_VAR")
1079+
def blueprint_var(
1080+
evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None
1081+
) -> exp.Expression:
1082+
"""Returns the value of a blueprint variable or the default value if the variable is not set."""
1083+
if not var_name.is_string:
1084+
raise SQLMeshError(
1085+
f"Invalid blueprint variable name '{var_name.sql()}'. Expected a string literal."
1086+
)
1087+
1088+
return exp.convert(evaluator.blueprint_var(var_name.this, default))
1089+
1090+
10571091
@macro()
10581092
def deduplicate(
10591093
evaluator: MacroEvaluator,

sqlmesh/core/model/common.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@
1313
from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate
1414
from sqlmesh.utils import str_to_bool
1515
from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error
16-
from sqlmesh.utils.metaprogramming import Executable, build_env, prepare_env, serialize_env
16+
from sqlmesh.utils.metaprogramming import (
17+
Executable,
18+
SqlValue,
19+
build_env,
20+
prepare_env,
21+
serialize_env,
22+
)
1723
from sqlmesh.utils.pydantic import ValidationInfo, field_validator
1824

1925
if t.TYPE_CHECKING:
26+
from sqlglot.dialects.dialect import DialectType
2027
from sqlmesh.utils.jinja import MacroReference
2128

2229

@@ -30,6 +37,8 @@ def make_python_env(
3037
path: t.Optional[str | Path] = None,
3138
python_env: t.Optional[t.Dict[str, Executable]] = None,
3239
strict_resolution: bool = True,
40+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
41+
dialect: DialectType = None,
3342
) -> t.Dict[str, Executable]:
3443
python_env = {} if python_env is None else python_env
3544
variables = variables or {}
@@ -86,6 +95,8 @@ def make_python_env(
8695
python_env,
8796
used_variables,
8897
variables,
98+
blueprint_variables=blueprint_variables,
99+
dialect=dialect,
89100
strict_resolution=strict_resolution,
90101
)
91102

@@ -95,6 +106,8 @@ def _add_variables_to_python_env(
95106
used_variables: t.Optional[t.Set[str]],
96107
variables: t.Optional[t.Dict[str, t.Any]],
97108
strict_resolution: bool = True,
109+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
110+
dialect: DialectType = None,
98111
) -> t.Dict[str, Executable]:
99112
_, python_used_variables = parse_dependencies(
100113
python_env,
@@ -107,6 +120,13 @@ def _add_variables_to_python_env(
107120
if variables:
108121
python_env[c.SQLMESH_VARS] = Executable.value(variables)
109122

123+
if blueprint_variables:
124+
blueprint_variables = {
125+
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
126+
for k, v in blueprint_variables.items()
127+
}
128+
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)
129+
110130
return python_env
111131

112132

sqlmesh/core/model/decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def model(
122122
default_catalog: t.Optional[str] = None,
123123
variables: t.Optional[t.Dict[str, t.Any]] = None,
124124
infer_names: t.Optional[bool] = False,
125+
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
125126
) -> Model:
126127
"""Get the model registered by this function."""
127128
env: t.Dict[str, t.Any] = {}
@@ -155,6 +156,7 @@ def model(
155156
path=path,
156157
dialect=dialect,
157158
default_catalog=default_catalog,
159+
blueprint_variables=blueprint_variables,
158160
)
159161

160162
rendered_name = rendered_fields["name"]
@@ -193,6 +195,7 @@ def model(
193195
"macros": macros,
194196
"jinja_macros": jinja_macros,
195197
"audit_definitions": audit_definitions,
198+
"blueprint_variables": blueprint_variables,
196199
**rendered_fields,
197200
}
198201

0 commit comments

Comments
 (0)