Skip to content

Commit c867283

Browse files
Fix: Include on run start / on run end hooks of dbt packages
1 parent 5201abf commit c867283

9 files changed

Lines changed: 180 additions & 84 deletions

File tree

sqlmesh/core/context.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
610610
self._standalone_audits.update(project.standalone_audits)
611611
self._requirements.update(project.requirements)
612612
self._excluded_requirements.update(project.excluded_requirements)
613-
if project.environment_statements:
614-
self._environment_statements.append(project.environment_statements)
613+
self._environment_statements.extend(project.environment_statements)
615614

616615
config = loader.config
617616
self._linters[config.project] = Linter.from_rules(

sqlmesh/core/loader.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class LoadedProject:
6161
metrics: UniqueKeyDict[str, Metric]
6262
requirements: t.Dict[str, str]
6363
excluded_requirements: t.Set[str]
64-
environment_statements: t.Optional[EnvironmentStatements]
64+
environment_statements: t.List[EnvironmentStatements]
6565
user_rules: RuleSet
6666

6767

@@ -187,9 +187,9 @@ def _load_audits(
187187
) -> UniqueKeyDict[str, Audit]:
188188
"""Loads all audits."""
189189

190-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
190+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
191191
"""Loads environment statements."""
192-
return None
192+
return []
193193

194194
def load_materializations(self) -> None:
195195
"""Loads custom materializations."""
@@ -651,7 +651,7 @@ def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
651651

652652
return metrics
653653

654-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
654+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
655655
"""Loads environment statements."""
656656

657657
if self.config.before_all or self.config.after_all:
@@ -673,8 +673,8 @@ def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStat
673673
path=self.config_path,
674674
)
675675

676-
return EnvironmentStatements(**statements, python_env=python_env)
677-
return None
676+
return [EnvironmentStatements(**statements, python_env=python_env)]
677+
return []
678678

679679
def _load_linting_rules(self) -> RuleSet:
680680
user_rules: UniqueKeyDict[str, type[Rule]] = UniqueKeyDict("rules")

sqlmesh/dbt/loader.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from sqlmesh.utils.errors import ConfigError
3030
from sqlmesh.utils.jinja import (
3131
JinjaMacroRegistry,
32-
MacroInfo,
3332
extract_macro_references_and_variables,
3433
)
3534

@@ -238,59 +237,54 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
238237

239238
return requirements, excluded_requirements
240239

241-
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
240+
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
242241
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
243242

244-
on_run_start: t.List[str] = []
245-
on_run_end: t.List[str] = []
246-
jinja_root_macros: t.Dict[str, MacroInfo] = {}
247-
variables: t.Dict[str, t.Any] = self._get_variables()
243+
environment_statements: t.List[EnvironmentStatements] = []
248244
dialect = self.config.dialect
249245
for project in self._load_projects():
250-
context = project.context.copy()
251-
if manifest := context._manifest:
252-
on_run_start.extend(manifest._on_run_start or [])
253-
on_run_end.extend(manifest._on_run_end or [])
254-
255-
if root_package := context.jinja_macros.root_package_name:
256-
if root_macros := context.jinja_macros.packages.get(root_package):
257-
jinja_root_macros |= root_macros
258-
context.set_and_render_variables(context.variables, root_package)
259-
variables |= context.variables
260-
261-
if statements := on_run_start + on_run_end:
262-
jinja_macro_references, used_variables = extract_macro_references_and_variables(
263-
*(gen(stmt) for stmt in statements)
264-
)
265-
jinja_macros = context.jinja_macros
266-
jinja_macros.root_macros = jinja_root_macros
267-
jinja_macros = (
268-
jinja_macros.trim(jinja_macro_references)
269-
if not jinja_macros.trimmed
270-
else jinja_macros
271-
)
272-
273-
python_env = make_python_env(
274-
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
275-
jinja_macro_references=jinja_macro_references,
276-
module_path=self.config_path,
277-
macros=macros,
278-
variables=variables,
279-
used_variables=used_variables,
280-
path=self.config_path,
281-
)
282-
283-
return EnvironmentStatements(
284-
before_all=[
285-
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or []
286-
],
287-
after_all=[
288-
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or []
289-
],
290-
python_env=python_env,
291-
jinja_macros=jinja_macros,
292-
)
293-
return None
246+
context = project.context
247+
for package_name, package in project.packages.items():
248+
context.set_and_render_variables(package.variables, package_name)
249+
on_run_start: t.List[str] = []
250+
on_run_end: t.List[str] = []
251+
for hook in package.on_run_start.values():
252+
on_run_start.append(hook.sql)
253+
for hook in package.on_run_end.values():
254+
on_run_end.append(hook.sql)
255+
256+
if statements := on_run_start + on_run_end:
257+
jinja_references, used_variables = extract_macro_references_and_variables(
258+
*(gen(stmt) for stmt in statements)
259+
)
260+
jinja_registry = context.jinja_macros.copy()
261+
jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
262+
jinja_registry = jinja_registry.trim(jinja_references)
263+
python_env = make_python_env(
264+
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
265+
jinja_macro_references=jinja_references,
266+
module_path=self.config_path,
267+
macros=macros,
268+
variables=context.variables,
269+
used_variables=used_variables,
270+
path=self.config_path,
271+
)
272+
273+
environment_statements.append(
274+
EnvironmentStatements(
275+
before_all=[
276+
d.jinja_statement(stmt).sql(dialect=dialect)
277+
for stmt in on_run_start or []
278+
],
279+
after_all=[
280+
d.jinja_statement(stmt).sql(dialect=dialect)
281+
for stmt in on_run_end or []
282+
],
283+
python_env=python_env,
284+
jinja_macros=jinja_registry,
285+
)
286+
)
287+
return environment_statements
294288

295289
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
296290
if not root.is_dir():

sqlmesh/dbt/manifest.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sqlmesh.dbt.basemodel import Dependencies
2828
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
2929
from sqlmesh.dbt.model import ModelConfig
30-
from sqlmesh.dbt.package import MacroConfig
30+
from sqlmesh.dbt.package import HookConfig, MacroConfig
3131
from sqlmesh.dbt.seed import SeedConfig
3232
from sqlmesh.dbt.source import SourceConfig
3333
from sqlmesh.dbt.target import TargetConfig
@@ -54,6 +54,7 @@
5454
SeedConfigs = t.Dict[str, SeedConfig]
5555
SourceConfigs = t.Dict[str, SourceConfig]
5656
MacroConfigs = t.Dict[str, MacroConfig]
57+
HookConfigs = t.Dict[str, HookConfig]
5758

5859

5960
IGNORED_PACKAGES = {"elementary"}
@@ -94,8 +95,8 @@ def __init__(
9495
self.project_path / c.CACHE, "jinja_calls"
9596
)
9697

97-
self._on_run_start: t.Optional[t.List[str]] = None
98-
self._on_run_end: t.Optional[t.List[str]] = None
98+
self._on_run_start_per_package: t.Dict[str, HookConfigs] = defaultdict(dict)
99+
self._on_run_end_per_package: t.Dict[str, HookConfigs] = defaultdict(dict)
99100

100101
def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
101102
self._load_all()
@@ -117,6 +118,14 @@ def macros(self, package_name: t.Optional[str] = None) -> MacroConfigs:
117118
self._load_all()
118119
return self._macros_per_package[package_name or self._project_name]
119120

121+
def on_run_start(self, package_name: t.Optional[str] = None) -> HookConfigs:
122+
self._load_all()
123+
return self._on_run_start_per_package[package_name or self._project_name]
124+
125+
def on_run_end(self, package_name: t.Optional[str] = None) -> HookConfigs:
126+
self._load_all()
127+
return self._on_run_end_per_package[package_name or self._project_name]
128+
120129
@property
121130
def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]:
122131
self._load_all()
@@ -136,6 +145,7 @@ def _load_all(self) -> None:
136145
self._load_sources()
137146
self._load_tests()
138147
self._load_models_and_seeds()
148+
self._load_on_run_start_end()
139149
self._is_loaded = True
140150

141151
self._call_cache.put("", value={k: v for k, (v, used) in self._calls.items() if used})
@@ -274,6 +284,23 @@ def _load_models_and_seeds(self) -> None:
274284
**node_config,
275285
)
276286

287+
def _load_on_run_start_end(self) -> None:
288+
for node in self._manifest.nodes.values():
289+
if node.resource_type == "operation" and (
290+
set(node.tags) & {"on-run-start", "on-run-end"}
291+
):
292+
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
293+
node_name = node.name
294+
node_path = Path(node.original_file_path)
295+
if "on-run-start" in node.tags:
296+
self._on_run_start_per_package[node.package_name][node_name] = HookConfig(
297+
sql=sql, path=node_path
298+
)
299+
else:
300+
self._on_run_end_per_package[node.package_name][node_name] = HookConfig(
301+
sql=sql, path=node_path
302+
)
303+
277304
@property
278305
def _manifest(self) -> Manifest:
279306
if not self.__manifest:
@@ -315,11 +342,6 @@ def _load_manifest(self) -> Manifest:
315342

316343
runtime_config = RuntimeConfig.from_parts(project, profile, args)
317344

318-
if runtime_config.on_run_start:
319-
self._on_run_start = runtime_config.on_run_start
320-
if runtime_config.on_run_end:
321-
self._on_run_end = runtime_config.on_run_end
322-
323345
self._project_name = project.project_name
324346

325347
if DBT_VERSION >= (1, 8):

sqlmesh/dbt/package.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ class MacroConfig(PydanticModel):
2828
path: Path
2929

3030

31+
class HookConfig(PydanticModel):
32+
"""Class to contain on run start / on run end hooks."""
33+
34+
sql: str
35+
path: Path
36+
37+
3138
class Package(PydanticModel):
3239
"""Class to contain package configuration"""
3340

@@ -38,6 +45,8 @@ class Package(PydanticModel):
3845
models: t.Dict[str, ModelConfig]
3946
variables: t.Dict[str, t.Any]
4047
macros: t.Dict[str, MacroConfig]
48+
on_run_start: t.Dict[str, HookConfig]
49+
on_run_end: t.Dict[str, HookConfig]
4150
files: t.Set[Path]
4251

4352
@property
@@ -83,6 +92,8 @@ def load(self, package_root: Path) -> Package:
8392
models = _fix_paths(self._context.manifest.models(package_name), package_root)
8493
seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root)
8594
macros = _fix_paths(self._context.manifest.macros(package_name), package_root)
95+
on_run_start = _fix_paths(self._context.manifest.on_run_start(package_name), package_root)
96+
on_run_end = _fix_paths(self._context.manifest.on_run_end(package_name), package_root)
8697
sources = self._context.manifest.sources(package_name)
8798

8899
config_paths = {
@@ -102,10 +113,12 @@ def load(self, package_root: Path) -> Package:
102113
variables=package_variables,
103114
macros=macros,
104115
files=config_paths,
116+
on_run_start=on_run_start,
117+
on_run_end=on_run_end,
105118
)
106119

107120

108-
T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig)
121+
T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig, HookConfig)
109122

110123

111124
def _fix_paths(configs: t.Dict[str, T], package_root: Path) -> t.Dict[str, T]:

tests/dbt/test_adapter.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,30 +278,33 @@ def test_quote_as_configured():
278278
def test_on_run_start_end(copy_to_temp_path):
279279
project_root = "tests/fixtures/dbt/sushi_test"
280280
sushi_context = Context(paths=copy_to_temp_path(project_root))
281-
assert len(sushi_context._environment_statements) == 1
282-
environment_statements = sushi_context._environment_statements[0]
281+
assert len(sushi_context._environment_statements) == 2
283282

284-
assert environment_statements.before_all == [
283+
# Root project on run start / on run end
284+
root_environment_statements = sushi_context._environment_statements[0]
285+
assert root_environment_statements.before_all == [
285286
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
286287
]
287-
assert environment_statements.after_all == [
288+
assert root_environment_statements.after_all == [
288289
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;"
289290
]
290-
assert "create_tables" in environment_statements.jinja_macros.root_macros
291+
292+
assert "create_tables" in root_environment_statements.jinja_macros.root_macros
293+
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
291294

292295
rendered_before_all = render_statements(
293-
environment_statements.before_all,
296+
root_environment_statements.before_all,
294297
dialect=sushi_context.default_dialect,
295-
python_env=environment_statements.python_env,
296-
jinja_macros=environment_statements.jinja_macros,
298+
python_env=root_environment_statements.python_env,
299+
jinja_macros=root_environment_statements.jinja_macros,
297300
runtime_stage=RuntimeStage.BEFORE_ALL,
298301
)
299302

300303
rendered_after_all = render_statements(
301-
environment_statements.after_all,
304+
root_environment_statements.after_all,
302305
dialect=sushi_context.default_dialect,
303-
python_env=environment_statements.python_env,
304-
jinja_macros=environment_statements.jinja_macros,
306+
python_env=root_environment_statements.python_env,
307+
jinja_macros=root_environment_statements.jinja_macros,
305308
snapshots=sushi_context.snapshots,
306309
runtime_stage=RuntimeStage.AFTER_ALL,
307310
environment_naming_info=EnvironmentNamingInfo(name="dev"),
@@ -318,3 +321,46 @@ def test_on_run_start_end(copy_to_temp_path):
318321
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
319322
]
320323
)
324+
325+
# Nested dbt_packages on run start / on run end
326+
packaged_environment_statements = sushi_context._environment_statements[1]
327+
328+
assert packaged_environment_statements.before_all == [
329+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
330+
]
331+
assert packaged_environment_statements.after_all == [
332+
"JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;"
333+
]
334+
335+
assert "packaged_tables" in packaged_environment_statements.jinja_macros.root_macros
336+
assert packaged_environment_statements.jinja_macros.root_package_name == "sushi"
337+
338+
rendered_before_all = render_statements(
339+
packaged_environment_statements.before_all,
340+
dialect=sushi_context.default_dialect,
341+
python_env=packaged_environment_statements.python_env,
342+
jinja_macros=packaged_environment_statements.jinja_macros,
343+
runtime_stage=RuntimeStage.BEFORE_ALL,
344+
)
345+
346+
rendered_after_all = render_statements(
347+
packaged_environment_statements.after_all,
348+
dialect=sushi_context.default_dialect,
349+
python_env=packaged_environment_statements.python_env,
350+
jinja_macros=packaged_environment_statements.jinja_macros,
351+
snapshots=sushi_context.snapshots,
352+
runtime_stage=RuntimeStage.AFTER_ALL,
353+
environment_naming_info=EnvironmentNamingInfo(name="dev"),
354+
)
355+
356+
assert rendered_before_all == [
357+
"CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)"
358+
]
359+
360+
# The table names is an indication of the rendering of the dbt_packages statements
361+
assert sorted(rendered_after_all) == sorted(
362+
[
363+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev_nested_package AS SELECT 'snapshots__dev' AS schema",
364+
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
365+
]
366+
)

0 commit comments

Comments
 (0)