|
3 | 3 | import logging |
4 | 4 | import sys |
5 | 5 | import typing as t |
| 6 | +import sqlmesh.core.dialect as d |
| 7 | +from sqlglot.optimizer.simplify import gen |
6 | 8 | from pathlib import Path |
7 | 9 | from sqlmesh.core import constants as c |
8 | 10 | from sqlmesh.core.config import ( |
|
11 | 13 | GatewayConfig, |
12 | 14 | ModelDefaultsConfig, |
13 | 15 | ) |
| 16 | +from sqlmesh.core.environment import EnvironmentStatements |
14 | 17 | from sqlmesh.core.loader import CacheBase, LoadedProject, Loader |
15 | 18 | from sqlmesh.core.macros import MacroRegistry, macro |
16 | 19 | from sqlmesh.core.model import Model, ModelCache |
| 20 | +from sqlmesh.core.model.common import make_python_env |
17 | 21 | from sqlmesh.core.signal import signal |
18 | 22 | from sqlmesh.dbt.basemodel import BMC, BaseModelConfig |
19 | 23 | from sqlmesh.dbt.context import DbtContext |
|
23 | 27 | from sqlmesh.dbt.target import TargetConfig |
24 | 28 | from sqlmesh.utils import UniqueKeyDict |
25 | 29 | from sqlmesh.utils.errors import ConfigError |
26 | | -from sqlmesh.utils.jinja import JinjaMacroRegistry |
| 30 | +from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables |
27 | 31 |
|
28 | 32 | if sys.version_info >= (3, 12): |
29 | 33 | from importlib import metadata |
@@ -230,6 +234,58 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: |
230 | 234 |
|
231 | 235 | return requirements, excluded_requirements |
232 | 236 |
|
| 237 | + def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None: |
| 238 | + """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" |
| 239 | + |
| 240 | + on_run_start = [] |
| 241 | + on_run_end = [] |
| 242 | + |
| 243 | + dialect = self.config.dialect |
| 244 | + for project in self._load_projects(): |
| 245 | + if manifest := project.context._manifest: |
| 246 | + if stmts := manifest._on_run_start: |
| 247 | + on_run_start.extend(stmts) |
| 248 | + if stmts := manifest._on_run_end: |
| 249 | + on_run_end.extend(stmts) |
| 250 | + |
| 251 | + if statements := on_run_start + on_run_end: |
| 252 | + jinja_macro_references, used_variables = extract_macro_references_and_variables( |
| 253 | + *(gen(e) for e in statements) |
| 254 | + ) |
| 255 | + |
| 256 | + if jinja_macros := project.context.jinja_macros: |
| 257 | + if root_package := jinja_macros.root_package_name: |
| 258 | + jinja_macros.root_macros = jinja_macros.packages[root_package] |
| 259 | + jinja_macros = ( |
| 260 | + jinja_macros |
| 261 | + if jinja_macros.trimmed |
| 262 | + else jinja_macros.trim(jinja_macro_references) |
| 263 | + ) |
| 264 | + else: |
| 265 | + jinja_macros = JinjaMacroRegistry() |
| 266 | + |
| 267 | + python_env = make_python_env( |
| 268 | + [s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)], |
| 269 | + jinja_macro_references=jinja_macro_references, |
| 270 | + module_path=self.config_path, |
| 271 | + macros=macros or macro.get_registry(), |
| 272 | + variables=self._get_variables(), |
| 273 | + used_variables=used_variables, |
| 274 | + path=self.config_path, |
| 275 | + ) |
| 276 | + |
| 277 | + return EnvironmentStatements( |
| 278 | + before_all=[ |
| 279 | + d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or [] |
| 280 | + ], |
| 281 | + after_all=[ |
| 282 | + d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or [] |
| 283 | + ], |
| 284 | + python_env=python_env, |
| 285 | + jinja_macros=jinja_macros, |
| 286 | + ) |
| 287 | + return None |
| 288 | + |
233 | 289 | def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]: |
234 | 290 | if not root.is_dir(): |
235 | 291 | return {} |
|
0 commit comments