|
27 | 27 | from sqlmesh.dbt.target import TargetConfig |
28 | 28 | from sqlmesh.utils import UniqueKeyDict |
29 | 29 | from sqlmesh.utils.errors import ConfigError |
30 | | -from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables |
| 30 | +from sqlmesh.utils.jinja import ( |
| 31 | + JinjaMacroRegistry, |
| 32 | + MacroInfo, |
| 33 | + extract_macro_references_and_variables, |
| 34 | +) |
31 | 35 |
|
32 | 36 | if sys.version_info >= (3, 12): |
33 | 37 | from importlib import metadata |
@@ -237,39 +241,41 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: |
237 | 241 | def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None: |
238 | 242 | """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" |
239 | 243 |
|
240 | | - on_run_start = [] |
241 | | - on_run_end = [] |
242 | | - |
| 244 | + on_run_start: t.List[str] = [] |
| 245 | + on_run_end: t.List[str] = [] |
| 246 | + jinja_root_macros: t.Dict[str, MacroInfo] = {} |
| 247 | + variables: t.Dict[str, t.Any] = self._get_variables() |
243 | 248 | dialect = self.config.dialect |
244 | 249 | for project in self._load_projects(): |
245 | | - if manifest := project.context._manifest: |
246 | | - if stmts := manifest._on_run_start: |
247 | | - on_run_start.extend(stmts) |
248 | | - if stmts := manifest._on_run_end: |
249 | | - on_run_end.extend(stmts) |
| 250 | + context = project.context.copy() |
| 251 | + if manifest := context._manifest: |
| 252 | + on_run_start.extend(manifest._on_run_start or []) |
| 253 | + on_run_end.extend(manifest._on_run_end or []) |
| 254 | + |
| 255 | + if root_package := context.jinja_macros.root_package_name: |
| 256 | + if root_macros := context.jinja_macros.packages.get(root_package): |
| 257 | + jinja_root_macros |= root_macros |
| 258 | + context.set_and_render_variables(context.variables, root_package) |
| 259 | + variables |= context.variables |
250 | 260 |
|
251 | 261 | if statements := on_run_start + on_run_end: |
252 | 262 | jinja_macro_references, used_variables = extract_macro_references_and_variables( |
253 | | - *(gen(e) for e in statements) |
| 263 | + *(gen(stmt) for stmt in statements) |
| 264 | + ) |
| 265 | + jinja_macros = context.jinja_macros |
| 266 | + jinja_macros.root_macros = jinja_root_macros |
| 267 | + jinja_macros = ( |
| 268 | + jinja_macros.trim(jinja_macro_references) |
| 269 | + if not jinja_macros.trimmed |
| 270 | + else jinja_macros |
254 | 271 | ) |
255 | | - |
256 | | - if jinja_macros := project.context.jinja_macros: |
257 | | - if root_package := jinja_macros.root_package_name: |
258 | | - jinja_macros.root_macros = jinja_macros.packages[root_package] |
259 | | - jinja_macros = ( |
260 | | - jinja_macros |
261 | | - if jinja_macros.trimmed |
262 | | - else jinja_macros.trim(jinja_macro_references) |
263 | | - ) |
264 | | - else: |
265 | | - jinja_macros = JinjaMacroRegistry() |
266 | 272 |
|
267 | 273 | python_env = make_python_env( |
268 | 274 | [s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)], |
269 | 275 | jinja_macro_references=jinja_macro_references, |
270 | 276 | module_path=self.config_path, |
271 | | - macros=macros or macro.get_registry(), |
272 | | - variables=self._get_variables(), |
| 277 | + macros=macros, |
| 278 | + variables=variables, |
273 | 279 | used_variables=used_variables, |
274 | 280 | path=self.config_path, |
275 | 281 | ) |
|
0 commit comments