Skip to content

Commit fe7ed31

Browse files
authored
Feat!: First-class support for calling python macros from jinja context (#4211)
1 parent aa4dfe8 commit fe7ed31

4 files changed

Lines changed: 91 additions & 39 deletions

File tree

sqlmesh/core/macros.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,9 @@ def __init__(
150150
self,
151151
dialect: DialectType = "",
152152
python_env: t.Optional[t.Dict[str, Executable]] = None,
153-
jinja_env: t.Optional[Environment] = None,
154153
schema: t.Optional[MappingSchema] = None,
155154
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
156-
resolve_table: t.Optional[t.Callable[[str | exp.Expression], str]] = None,
155+
resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None,
157156
resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None,
158157
snapshots: t.Optional[t.Dict[str, Snapshot]] = None,
159158
default_catalog: t.Optional[str] = None,
@@ -177,7 +176,7 @@ def __init__(
177176
self.columns_to_types_called = False
178177
self.default_catalog = default_catalog
179178

180-
self._jinja_env: t.Optional[Environment] = jinja_env
179+
self._jinja_env: t.Optional[Environment] = None
181180
self._schema = schema
182181
self._resolve_table = resolve_table
183182
self._resolve_tables = resolve_tables
@@ -431,7 +430,7 @@ def get_snapshot(self, model_name: TableName | exp.Column) -> t.Optional[Snapsho
431430
)
432431
)
433432

434-
def resolve_table(self, table: str | exp.Expression) -> str:
433+
def resolve_table(self, table: str | exp.Table) -> str:
435434
"""Gets the physical table name for a given model."""
436435
if not self._resolve_table:
437436
raise SQLMeshError(

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,9 +2476,7 @@ def _create_model(
24762476
for _, kwargs in model.signals:
24772477
statements.extend((signal_kwarg, True) for signal_kwarg in kwargs.values())
24782478

2479-
python_env = python_env or {}
2480-
2481-
make_python_env(
2479+
python_env = make_python_env(
24822480
statements,
24832481
jinja_macro_references,
24842482
module_path,

sqlmesh/core/renderer.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import typing as t
55
from contextlib import contextmanager
6+
from functools import partial
67
from pathlib import Path
78

89
from sqlglot import exp, parse
@@ -134,7 +135,35 @@ def _render(
134135
if this_snapshot and (kind := this_snapshot.model_kind_name):
135136
kwargs["model_kind_name"] = kind.name
136137

137-
expressions = [self._expression]
138+
def _resolve_table(table: str | exp.Table) -> str:
139+
return self._resolve_table(
140+
d.normalize_model_name(table, self._default_catalog, self._dialect),
141+
snapshots=snapshots,
142+
table_mapping=table_mapping,
143+
deployability_index=deployability_index,
144+
).sql(dialect=self._dialect, identify=True, comments=False)
145+
146+
macro_evaluator = MacroEvaluator(
147+
self._dialect,
148+
python_env=self._python_env,
149+
schema=self.schema,
150+
runtime_stage=runtime_stage,
151+
resolve_table=_resolve_table,
152+
resolve_tables=lambda e: self._resolve_tables(
153+
e,
154+
snapshots=snapshots,
155+
table_mapping=table_mapping,
156+
deployability_index=deployability_index,
157+
start=start,
158+
end=end,
159+
execution_time=execution_time,
160+
runtime_stage=runtime_stage,
161+
),
162+
snapshots=snapshots,
163+
default_catalog=self._default_catalog,
164+
path=self._path,
165+
environment_naming_info=environment_naming_info,
166+
)
138167

139168
start_time, end_time = (
140169
make_inclusive(start or c.EPOCH, end or c.EPOCH, self._dialect)
@@ -153,18 +182,17 @@ def _render(
153182

154183
variables = kwargs.pop("variables", {})
155184
jinja_env_kwargs = {
156-
**{**render_kwargs, **prepare_env(self._python_env), **variables},
185+
**{
186+
**render_kwargs,
187+
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
188+
**variables,
189+
},
157190
"snapshots": snapshots or {},
158191
"table_mapping": table_mapping,
159192
"deployability_index": deployability_index,
160193
"default_catalog": self._default_catalog,
161194
"runtime_stage": runtime_stage.value,
162-
"resolve_table": lambda table: self._resolve_table(
163-
d.normalize_model_name(table, self._default_catalog, self._dialect),
164-
snapshots=snapshots,
165-
table_mapping=table_mapping,
166-
deployability_index=deployability_index,
167-
).sql(dialect=self._dialect, identify=True, comments=False),
195+
"resolve_table": _resolve_table,
168196
}
169197
if this_model:
170198
render_kwargs["this_model"] = this_model
@@ -174,6 +202,7 @@ def _render(
174202

175203
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
176204

205+
expressions = [self._expression]
177206
if isinstance(self._expression, d.Jinja):
178207
try:
179208
expressions = []
@@ -190,29 +219,6 @@ def _render(
190219
f"Could not render or parse jinja at '{self._path}'.\n{ex}"
191220
) from ex
192221

193-
macro_evaluator = MacroEvaluator(
194-
self._dialect,
195-
python_env=self._python_env,
196-
jinja_env=jinja_env,
197-
schema=self.schema,
198-
runtime_stage=runtime_stage,
199-
resolve_table=jinja_env.globals["resolve_table"], # type: ignore
200-
resolve_tables=lambda e: self._resolve_tables(
201-
e,
202-
snapshots=snapshots,
203-
table_mapping=table_mapping,
204-
deployability_index=deployability_index,
205-
start=start,
206-
end=end,
207-
execution_time=execution_time,
208-
runtime_stage=runtime_stage,
209-
),
210-
snapshots=snapshots,
211-
default_catalog=self._default_catalog,
212-
path=self._path,
213-
environment_naming_info=environment_naming_info,
214-
)
215-
216222
macro_evaluator.locals.update(render_kwargs)
217223

218224
if variables:
@@ -637,3 +643,15 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query:
637643
annotate_types(select)
638644

639645
return query
646+
647+
648+
def _prepare_python_env_for_jinja(
649+
evaluator: MacroEvaluator,
650+
python_env: t.Dict[str, Executable],
651+
) -> t.Dict[str, t.Any]:
652+
prepared_env = prepare_env(python_env)
653+
# Pass the evaluator to all macro functions
654+
return {
655+
key: partial(value, evaluator) if callable(value) else value
656+
for key, value in prepared_env.items()
657+
}

tests/core/test_model.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2803,7 +2803,7 @@ def test_parse_expression_list_with_jinja():
28032803

28042804
def test_no_depends_on_runtime_jinja_query():
28052805
@macro()
2806-
def runtime_macro(**kwargs) -> None:
2806+
def runtime_macro(evaluator, **kwargs) -> None:
28072807
from sqlmesh.utils.errors import ParsetimeAdapterCallError
28082808

28092809
raise ParsetimeAdapterCallError("")
@@ -9051,3 +9051,40 @@ def test_formatting_flag_serde():
90519051

90529052
deserialized_model = SqlModel.parse_raw(model_json)
90539053
assert deserialized_model.dict() == model.dict()
9054+
9055+
9056+
def test_call_python_macro_from_jinja():
9057+
def noop() -> None:
9058+
print("noop")
9059+
9060+
@macro()
9061+
def test_runtime_stage(evaluator):
9062+
noop()
9063+
return evaluator.runtime_stage
9064+
9065+
expressions = d.parse(
9066+
"""
9067+
MODEL (
9068+
name db.table,
9069+
dialect spark,
9070+
owner owner_name,
9071+
);
9072+
9073+
JINJA_QUERY_BEGIN;
9074+
SELECT '{{ test_runtime_stage() }}' AS a, '{{ test_runtime_stage_jinja('bla') }}' AS b;
9075+
JINJA_END;
9076+
"""
9077+
)
9078+
9079+
jinja_macros = JinjaMacroRegistry(
9080+
root_macros={
9081+
"test_runtime_stage_jinja": MacroInfo(
9082+
definition="{% macro test_runtime_stage_jinja(value) %}{{ test_runtime_stage() }}_{{ value }}{% endmacro %}",
9083+
depends_on=[],
9084+
)
9085+
}
9086+
)
9087+
9088+
model = load_sql_based_model(expressions, jinja_macros=jinja_macros)
9089+
assert model.render_query().sql() == "SELECT 'loading' AS a, 'loading_bla' AS b"
9090+
assert set(model.python_env) == {"noop", "test_runtime_stage"}

0 commit comments

Comments
 (0)