Skip to content

Commit 52da5f5

Browse files
Feat(dbt): Add support for on-run-start and on-run-end hooks
1 parent 98decf5 commit 52da5f5

10 files changed

Lines changed: 172 additions & 6 deletions

File tree

docs/integrations/dbt.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ The dbt jinja methods that are not currently supported are:
324324
* selected_sources
325325
* adapter.expand_target_column_types
326326
* adapter.rename_relation
327-
* schemas
328327
* graph.nodes.values
329328
* graph.metrics.values
330329

sqlmesh/core/environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot
1515
from sqlmesh.utils import word_characters_only
1616
from sqlmesh.utils.date import TimeLike, now_timestamp
17+
from sqlmesh.utils.jinja import JinjaMacroRegistry
1718
from sqlmesh.utils.metaprogramming import Executable
1819
from sqlmesh.utils.pydantic import PydanticModel, field_validator
1920

@@ -218,6 +219,7 @@ class EnvironmentStatements(PydanticModel):
218219
before_all: t.List[str]
219220
after_all: t.List[str]
220221
python_env: t.Dict[str, Executable]
222+
jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry()
221223

222224

223225
def execute_environment_statements(
@@ -239,6 +241,7 @@ def execute_environment_statements(
239241
dialect=adapter.dialect,
240242
default_catalog=default_catalog,
241243
python_env=statements.python_env,
244+
jinja_macros=statements.jinja_macros,
242245
snapshots=snapshots,
243246
start=start,
244247
end=end,

sqlmesh/core/renderer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,15 @@ def _render(
107107

108108
if environment_naming_info := kwargs.get("environment_naming_info", None):
109109
kwargs["this_env"] = getattr(environment_naming_info, "name")
110+
if snapshots and (
111+
schemas := set(
112+
[
113+
s.qualified_view_name.schema_for_environment(environment_naming_info)
114+
for s in snapshots.values()
115+
]
116+
)
117+
):
118+
kwargs["schemas"] = list(schemas)
110119

111120
this_model = kwargs.pop("this_model", None)
112121

@@ -411,19 +420,21 @@ def render(
411420

412421
def render_statements(
413422
statements: t.List[str],
414-
dialect: DialectType = None,
423+
dialect: str,
415424
default_catalog: t.Optional[str] = None,
416425
python_env: t.Optional[t.Dict[str, Executable]] = None,
426+
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
417427
**render_kwargs: t.Any,
418428
) -> t.List[str]:
419429
rendered_statements: t.List[str] = []
420430
for statement in statements:
421-
for expression in parse(statement, dialect=dialect):
431+
for expression in d.parse(statement, default_dialect=dialect):
422432
if expression:
423433
rendered = ExpressionRenderer(
424434
expression,
425435
dialect,
426436
[],
437+
jinja_macro_registry=jinja_macros or JinjaMacroRegistry(),
427438
python_env=python_env,
428439
default_catalog=default_catalog,
429440
quote_identifiers=False,

sqlmesh/dbt/loader.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
import sys
55
import typing as t
6+
import sqlmesh.core.dialect as d
7+
from sqlglot.optimizer.simplify import gen
68
from pathlib import Path
79
from sqlmesh.core import constants as c
810
from sqlmesh.core.config import (
@@ -11,9 +13,11 @@
1113
GatewayConfig,
1214
ModelDefaultsConfig,
1315
)
16+
from sqlmesh.core.environment import EnvironmentStatements
1417
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
1518
from sqlmesh.core.macros import MacroRegistry, macro
1619
from sqlmesh.core.model import Model, ModelCache
20+
from sqlmesh.core.model.common import make_python_env
1721
from sqlmesh.core.signal import signal
1822
from sqlmesh.dbt.basemodel import BMC, BaseModelConfig
1923
from sqlmesh.dbt.context import DbtContext
@@ -23,7 +27,7 @@
2327
from sqlmesh.dbt.target import TargetConfig
2428
from sqlmesh.utils import UniqueKeyDict
2529
from sqlmesh.utils.errors import ConfigError
26-
from sqlmesh.utils.jinja import JinjaMacroRegistry
30+
from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables
2731

2832
if sys.version_info >= (3, 12):
2933
from importlib import metadata
@@ -230,6 +234,58 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
230234

231235
return requirements, excluded_requirements
232236

237+
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
238+
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
239+
240+
on_run_start = []
241+
on_run_end = []
242+
243+
dialect = self.config.dialect
244+
for project in self._load_projects():
245+
if manifest := project.context._manifest:
246+
if stmts := manifest._on_run_start:
247+
on_run_start.extend(stmts)
248+
if stmts := manifest._on_run_end:
249+
on_run_end.extend(stmts)
250+
251+
if statements := on_run_start + on_run_end:
252+
jinja_macro_references, used_variables = extract_macro_references_and_variables(
253+
*(gen(e) for e in statements)
254+
)
255+
256+
if jinja_macros := project.context.jinja_macros:
257+
if root_package := jinja_macros.root_package_name:
258+
jinja_macros.root_macros = jinja_macros.packages[root_package]
259+
jinja_macros = (
260+
jinja_macros
261+
if jinja_macros.trimmed
262+
else jinja_macros.trim(jinja_macro_references)
263+
)
264+
else:
265+
jinja_macros = JinjaMacroRegistry()
266+
267+
python_env = make_python_env(
268+
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
269+
jinja_macro_references=jinja_macro_references,
270+
module_path=self.config_path,
271+
macros=macros or macro.get_registry(),
272+
variables=self._get_variables(),
273+
used_variables=used_variables,
274+
path=self.config_path,
275+
)
276+
277+
return EnvironmentStatements(
278+
before_all=[
279+
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_start or []
280+
],
281+
after_all=[
282+
d.jinja_statement(stmt).sql(dialect=dialect) for stmt in on_run_end or []
283+
],
284+
python_env=python_env,
285+
jinja_macros=jinja_macros,
286+
)
287+
return None
288+
233289
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
234290
if not root.is_dir():
235291
return {}

sqlmesh/dbt/manifest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def __init__(
9494
self.project_path / c.CACHE, "jinja_calls"
9595
)
9696

97+
self._on_run_start: t.Optional[t.List[str]] = None
98+
self._on_run_end: t.Optional[t.List[str]] = None
99+
97100
def tests(self, package_name: t.Optional[str] = None) -> TestConfigs:
98101
self._load_all()
99102
return self._tests_per_package[package_name or self._project_name]
@@ -312,6 +315,11 @@ def _load_manifest(self) -> Manifest:
312315

313316
runtime_config = RuntimeConfig.from_parts(project, profile, args)
314317

318+
if runtime_config.on_run_start:
319+
self._on_run_start = runtime_config.on_run_start
320+
if runtime_config.on_run_end:
321+
self._on_run_end = runtime_config.on_run_end
322+
315323
self._project_name = project.project_name
316324

317325
if DBT_VERSION >= (1, 8):

tests/core/test_context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ def test_environment_statements(tmp_path: pathlib.Path):
14171417
after_all=[
14181418
"@grant_schema_usage()",
14191419
"@grant_select_privileges()",
1420+
"@grant_usage_role(@schemas, 'admin')",
14201421
],
14211422
)
14221423

@@ -1481,6 +1482,22 @@ def grant_schema_usage(evaluator):
14811482
""",
14821483
)
14831484

1485+
create_temp_file(
1486+
tmp_path,
1487+
pathlib.Path(macros_dir, "grant_usage_file.py"),
1488+
"""
1489+
from sqlmesh import macro
1490+
1491+
@macro()
1492+
def grant_usage_role(evaluator, schemas, role):
1493+
if evaluator._environment_naming_info:
1494+
return [
1495+
f"GRANT USAGE ON SCHEMA {schema} TO {role};"
1496+
for schema in schemas
1497+
]
1498+
""",
1499+
)
1500+
14841501
context = Context(paths=tmp_path, config=config)
14851502
snapshots = {s.name: s for s in context.snapshots.values()}
14861503

@@ -1515,6 +1532,7 @@ def grant_schema_usage(evaluator):
15151532
assert after_all_rendered == [
15161533
"GRANT USAGE ON SCHEMA db TO user_role",
15171534
"GRANT SELECT ON VIEW memory.db.test_after_model TO ROLE admin_role",
1535+
'GRANT USAGE ON SCHEMA "db" TO "admin"',
15181536
]
15191537

15201538
after_all_rendered_dev = render_statements(
@@ -1529,6 +1547,7 @@ def grant_schema_usage(evaluator):
15291547
assert after_all_rendered_dev == [
15301548
"GRANT USAGE ON SCHEMA db__dev TO user_role",
15311549
"GRANT SELECT ON VIEW memory.db__dev.test_after_model TO ROLE admin_role",
1550+
'GRANT USAGE ON SCHEMA "db__dev" TO "admin"',
15321551
]
15331552

15341553

tests/dbt/test_adapter.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from sqlmesh import Context
1515
from sqlmesh.core.dialect import schema_
16+
from sqlmesh.core.environment import EnvironmentNamingInfo
17+
from sqlmesh.core.macros import RuntimeStage
18+
from sqlmesh.core.renderer import render_statements
1619
from sqlmesh.core.snapshot import SnapshotId
1720
from sqlmesh.dbt.adapter import ParsetimeAdapter
1821
from sqlmesh.dbt.project import Project
@@ -270,3 +273,49 @@ def test_quote_as_configured():
270273
adapter.quote_as_configured("foo", "identifier") == '"foo"'
271274
adapter.quote_as_configured("foo", "schema") == "foo"
272275
adapter.quote_as_configured("foo", "database") == "foo"
276+
277+
278+
def test_on_run_start_end(copy_to_temp_path):
279+
project_root = "tests/fixtures/dbt/sushi_test"
280+
sushi_context = Context(paths=copy_to_temp_path(project_root))
281+
assert len(sushi_context._environment_statements) == 1
282+
environment_statements = sushi_context._environment_statements[0]
283+
284+
assert environment_statements.before_all == [
285+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
286+
]
287+
assert environment_statements.after_all == [
288+
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;"
289+
]
290+
assert "create_tables" in environment_statements.jinja_macros.root_macros
291+
292+
rendered_before_all = render_statements(
293+
environment_statements.before_all,
294+
dialect=sushi_context.default_dialect,
295+
python_env=environment_statements.python_env,
296+
jinja_macros=environment_statements.jinja_macros,
297+
runtime_stage=RuntimeStage.BEFORE_ALL,
298+
)
299+
300+
rendered_after_all = render_statements(
301+
environment_statements.after_all,
302+
dialect=sushi_context.default_dialect,
303+
python_env=environment_statements.python_env,
304+
jinja_macros=environment_statements.jinja_macros,
305+
snapshots=sushi_context.snapshots,
306+
runtime_stage=RuntimeStage.AFTER_ALL,
307+
environment_naming_info=EnvironmentNamingInfo(name="dev"),
308+
)
309+
310+
assert rendered_before_all == [
311+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)"
312+
]
313+
314+
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
315+
assert sorted(rendered_after_all) == sorted(
316+
[
317+
"CREATE OR REPLACE TABLE schema_table_raw__dev AS SELECT 'raw__dev' AS schema",
318+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
319+
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
320+
]
321+
)

tests/dbt/test_transformation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,16 @@ def test_dbt_version(sushi_test_project: Project):
997997
assert context.render("{{ dbt_version }}").startswith("1.")
998998

999999

1000+
@pytest.mark.xdist_group("dbt_manifest")
1001+
def test_dbt_on_run_start_end(sushi_test_project: Project):
1002+
context = sushi_test_project.context
1003+
assert context._manifest
1004+
assert context._manifest._on_run_start == [
1005+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);"
1006+
]
1007+
assert context._manifest._on_run_end == ["{{ create_tables(schemas) }}"]
1008+
1009+
10001010
@pytest.mark.xdist_group("dbt_manifest")
10011011
def test_parsetime_adapter_call(
10021012
assert_exp_eq, sushi_test_project: Project, sushi_test_dbt_context: Context

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ models:
2525
+materialized: table
2626
+pre-hook:
2727
- '{{ log("pre-hook") }}'
28-
+post-hook:
28+
+post-hook:
2929
- '{{ log("post-hook") }}'
3030

3131
seeds:
3232
sushi:
3333
+pre-hook:
3434
- '{{ log("pre-hook") }}'
35-
+post-hook:
35+
+post-hook:
3636
- '{{ log("post-hook") }}'
3737

3838
vars:
@@ -57,3 +57,9 @@ vars:
5757
value: 1
5858
- name: 'item2'
5959
value: 2
60+
61+
62+
on-run-start:
63+
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'
64+
on-run-end:
65+
- '{{ create_tables(schemas) }}'
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{% macro create_tables(schemas) %}
2+
{% for schema in schemas %}
3+
create or replace table schema_table_{{schema}} as select '{{schema}}' as schema;
4+
{% endfor%}
5+
{% endmacro %}

0 commit comments

Comments
 (0)