Skip to content

Commit 9823b9e

Browse files
Fix: Support on run start / on run end hooks of dbt packages (#4222)
1 parent 5201abf commit 9823b9e

11 files changed

Lines changed: 261 additions & 87 deletions

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
610610
self._standalone_audits.update(project.standalone_audits)
611611
self._requirements.update(project.requirements)
612612
self._excluded_requirements.update(project.excluded_requirements)
613-
if project.environment_statements:
614-
self._environment_statements.append(project.environment_statements)
613+
self._environment_statements.extend(project.environment_statements)
615614

616615
config = loader.config
617616
self._linters[config.project] = Linter.from_rules(

sqlmesh/core/loader.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class LoadedProject:
6161
metrics: UniqueKeyDict[str, Metric]
6262
requirements: t.Dict[str, str]
6363
excluded_requirements: t.Set[str]
64-
environment_statements: t.Optional[EnvironmentStatements]
64+
environment_statements: t.List[EnvironmentStatements]
6565
user_rules: RuleSet
6666

6767

@@ -187,9 +187,9 @@ def _load_audits(
187187
) -> UniqueKeyDict[str, Audit]:
188188
"""Loads all audits."""
189189

190-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
190+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
191191
"""Loads environment statements."""
192-
return None
192+
return []
193193

194194
def load_materializations(self) -> None:
195195
"""Loads custom materializations."""
@@ -651,7 +651,7 @@ def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
651651

652652
return metrics
653653

654-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
654+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
655655
"""Loads environment statements."""
656656

657657
if self.config.before_all or self.config.after_all:
@@ -673,8 +673,8 @@ def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStat
673673
path=self.config_path,
674674
)
675675

676-
return EnvironmentStatements(**statements, python_env=python_env)
677-
return None
676+
return [EnvironmentStatements(**statements, python_env=python_env)]
677+
return []
678678

679679
def _load_linting_rules(self) -> RuleSet:
680680
user_rules: UniqueKeyDict[str, type[Rule]] = UniqueKeyDict("rules")

sqlmesh/dbt/loader.py

Lines changed: 55 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from sqlmesh.utils.errors import ConfigError
3030
from sqlmesh.utils.jinja import (
3131
JinjaMacroRegistry,
32-
MacroInfo,
3332
extract_macro_references_and_variables,
33+
make_jinja_registry,
3434
)
3535

3636
if sys.version_info >= (3, 12):
@@ -238,59 +238,65 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
238238

239239
return requirements, excluded_requirements
240240

241-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
241+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
242242
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
243243

244-
on_run_start: t.List[str] = []
245-
on_run_end: t.List[str] = []
246-
jinja_root_macros: t.Dict[str, MacroInfo] = {}
247-
variables: t.Dict[str, t.Any] = self._get_variables()
244+
environment_statements: t.List[EnvironmentStatements] = []
248245
dialect = self.config.dialect
249246
for project in self._load_projects():
250-
context = project.context.copy()
251-
if manifest := context._manifest:
252-
on_run_start.extend(manifest._on_run_start or [])
253-
on_run_end.extend(manifest._on_run_end or [])
254-
255-
if root_package := context.jinja_macros.root_package_name:
256-
if root_macros := context.jinja_macros.packages.get(root_package):
257-
jinja_root_macros |= root_macros
258-
context.set_and_render_variables(context.variables, root_package)
259-
variables |= context.variables
260-
261-
if statements := on_run_start + on_run_end:
262-
jinja_macro_references, used_variables = extract_macro_references_and_variables(
263-
*(gen(stmt) for stmt in statements)
264-
)
265-
jinja_macros = context.jinja_macros
266-
jinja_macros.root_macros = jinja_root_macros
267-
jinja_macros = (
268-
jinja_macros.trim(jinja_macro_references)
269-
if not jinja_macros.trimmed
270-
else jinja_macros
271-
)
272-
273-
python_env = make_python_env(
274-
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
275-
jinja_macro_references=jinja_macro_references,
276-
module_path=self.config_path,
277-
macros=macros,
278-
variables=variables,
279-
used_variables=used_variables,
280-
path=self.config_path,
281-
)
247+
context = project.context
248+
hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
249+
for package_name, package in project.packages.items():
250+
context.set_and_render_variables(package.variables, package_name)
251+
on_run_start: t.List[str] = [
252+
on_run_hook.sql
253+
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
254+
]
255+
on_run_end: t.List[str] = [
256+
on_run_hook.sql
257+
for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
258+
]
282259

283-
return EnvironmentStatements(
284-
before_all=[
285-
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or []
286-
],
287-
after_all=[
288-
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or []
289-
],
290-
python_env=python_env,
291-
jinja_macros=jinja_macros,
292-
)
293-
return None
260+
if statements := on_run_start + on_run_end:
261+
jinja_references, used_variables = extract_macro_references_and_variables(
262+
*(gen(stmt) for stmt in statements)
263+
)
264+
265+
jinja_registry = make_jinja_registry(
266+
context.jinja_macros, package_name, jinja_references
267+
)
268+
269+
python_env = make_python_env(
270+
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
271+
jinja_macro_references=jinja_references,
272+
module_path=self.config_path,
273+
macros=macros,
274+
variables=context.variables,
275+
used_variables=used_variables,
276+
path=self.config_path,
277+
)
278+
279+
hooks_by_package_name[package_name] = EnvironmentStatements(
280+
before_all=[
281+
d.jinja_statement(stmt).sql(dialect=dialect)
282+
for stmt in on_run_start or []
283+
],
284+
after_all=[
285+
d.jinja_statement(stmt).sql(dialect=dialect)
286+
for stmt in on_run_end or []
287+
],
288+
python_env=python_env,
289+
jinja_macros=jinja_registry,
290+
)
291+
# Project hooks should be executed first and then rest of the packages
292+
environment_statements = [
293+
statements
294+
for _, statements in sorted(
295+
hooks_by_package_name.items(),
296+
key=lambda item: 0 if item[0] == context.project_name else 1,
297+
)
298+
]
299+
return environment_statements
294300

295301
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
296302
if not root.is_dir():

sqlmesh/dbt/manifest.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sqlmesh.dbt.basemodel import Dependencies
2828
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
2929
from sqlmesh.dbt.model import ModelConfig
30-
from sqlmesh.dbt.package import MacroConfig
30+
from sqlmesh.dbt.package import HookConfig, MacroConfig
3131
from sqlmesh.dbt.seed import SeedConfig
3232
from sqlmesh.dbt.source import SourceConfig
3333
from sqlmesh.dbt.target import TargetConfig
@@ -54,6 +54,7 @@
5454
SeedConfigs = t.Dict[str, SeedConfig]
5555
SourceConfigs = t.Dict[str, SourceConfig]
5656
MacroConfigs = t.Dict[str, MacroConfig]
57+
HookConfigs = t.Dict[str, HookConfig]
5758

5859

5960
IGNORED_PACKAGES = {"elementary"}
@@ -94,8 +95,8 @@ def __init__(
9495
self.project_path / c.CACHE, "jinja_calls"
9596
)
9697

97-
self._on_run_start: t.Optional[t.List[str]] = None
98-
self._on_run_end: t.Optional[t.List[str]] = None
98+
self._on_run_start_per_package: t.Dict[str, HookConfigs] = defaultdict(dict)
99+
self._on_run_end_per_package: t.Dict[str, HookConfigs] = defaultdict(dict)
99100

100101
def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
101102
self._load_all()
@@ -117,6 +118,14 @@ def macros(self, package_name: t.Optional[str] = None) -> MacroConfigs:
117118
self._load_all()
118119
return self._macros_per_package[package_name or self._project_name]
119120

121+
def on_run_start(self, package_name: t.Optional[str] = None) -> HookConfigs:
122+
self._load_all()
123+
return self._on_run_start_per_package[package_name or self._project_name]
124+
125+
def on_run_end(self, package_name: t.Optional[str] = None) -> HookConfigs:
126+
self._load_all()
127+
return self._on_run_end_per_package[package_name or self._project_name]
128+
120129
@property
121130
def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]:
122131
self._load_all()
@@ -136,6 +145,7 @@ def _load_all(self) -> None:
136145
self._load_sources()
137146
self._load_tests()
138147
self._load_models_and_seeds()
148+
self._load_on_run_start_end()
139149
self._is_loaded = True
140150

141151
self._call_cache.put("", value={k: v for k, (v, used) in self._calls.items() if used})
@@ -274,6 +284,23 @@ def _load_models_and_seeds(self) -> None:
274284
**node_config,
275285
)
276286

287+
def _load_on_run_start_end(self) -> None:
288+
for node in self._manifest.nodes.values():
289+
if node.resource_type == "operation" and (
290+
set(node.tags) & {"on-run-start", "on-run-end"}
291+
):
292+
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
293+
node_name = node.name
294+
node_path = Path(node.original_file_path)
295+
if "on-run-start" in node.tags:
296+
self._on_run_start_per_package[node.package_name][node_name] = HookConfig(
297+
sql=sql, index=node.index or 0, path=node_path
298+
)
299+
else:
300+
self._on_run_end_per_package[node.package_name][node_name] = HookConfig(
301+
sql=sql, index=node.index or 0, path=node_path
302+
)
303+
277304
@property
278305
def _manifest(self) -> Manifest:
279306
if not self.__manifest:
@@ -315,11 +342,6 @@ def _load_manifest(self) -> Manifest:
315342

316343
runtime_config = RuntimeConfig.from_parts(project, profile, args)
317344

318-
if runtime_config.on_run_start:
319-
self._on_run_start = runtime_config.on_run_start
320-
if runtime_config.on_run_end:
321-
self._on_run_end = runtime_config.on_run_end
322-
323345
self._project_name = project.project_name
324346

325347
if DBT_VERSION >= (1, 8):

sqlmesh/dbt/package.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ class MacroConfig(PydanticModel):
2828
path: Path
2929

3030

31+
class HookConfig(PydanticModel):
32+
"""Class to contain on run start / on run end hooks."""
33+
34+
sql: str
35+
index: int
36+
path: Path
37+
38+
3139
class Package(PydanticModel):
3240
"""Class to contain package configuration"""
3341

@@ -38,6 +46,8 @@ class Package(PydanticModel):
3846
models: t.Dict[str, ModelConfig]
3947
variables: t.Dict[str, t.Any]
4048
macros: t.Dict[str, MacroConfig]
49+
on_run_start: t.Dict[str, HookConfig]
50+
on_run_end: t.Dict[str, HookConfig]
4151
files: t.Set[Path]
4252

4353
@property
@@ -83,6 +93,8 @@ def load(self, package_root: Path) -> Package:
8393
models = _fix_paths(self._context.manifest.models(package_name), package_root)
8494
seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root)
8595
macros = _fix_paths(self._context.manifest.macros(package_name), package_root)
96+
on_run_start = _fix_paths(self._context.manifest.on_run_start(package_name), package_root)
97+
on_run_end = _fix_paths(self._context.manifest.on_run_end(package_name), package_root)
8698
sources = self._context.manifest.sources(package_name)
8799

88100
config_paths = {
@@ -102,10 +114,12 @@ def load(self, package_root: Path) -> Package:
102114
variables=package_variables,
103115
macros=macros,
104116
files=config_paths,
117+
on_run_start=on_run_start,
118+
on_run_end=on_run_end,
105119
)
106120

107121

108-
T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig)
122+
T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig, HookConfig)
109123

110124

111125
def _fix_paths(configs: t.Dict[str, T], package_root: Path) -> t.Dict[str, T]:

sqlmesh/utils/jinja.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,31 @@ def create_builtin_globals(
608608
c.GATEWAY: lambda: variables.get(c.GATEWAY, None),
609609
**global_vars,
610610
}
611+
612+
613+
def make_jinja_registry(
614+
jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference]
615+
) -> JinjaMacroRegistry:
616+
"""
617+
Creates a Jinja macro registry for a specific package.
618+
619+
This function takes an existing Jinja macro registry and returns a new
620+
registry that includes only the macros associated with the specified
621+
package and trims the registry to include only the macros referenced
622+
in the provided set of macro references.
623+
624+
Args:
625+
jinja_macros: The original Jinja macro registry containing all macros.
626+
package_name: The name of the package for which to create the registry.
627+
jinja_references: A set of macro references to retain in the new registry.
628+
629+
Returns:
630+
A new JinjaMacroRegistry containing only the macros for the specified
631+
package and the referenced macros.
632+
"""
633+
634+
jinja_registry = jinja_macros.copy()
635+
jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
636+
jinja_registry = jinja_registry.trim(jinja_references)
637+
638+
return jinja_registry

0 commit comments

Comments
 (0)