Skip to content

Commit 2785fc9

Browse files
authored
Fix: Dependency handling when converting on-run-start / on-run-end hooks in dbt projects (#4567)
1 parent 792f5b0 commit 2785fc9

7 files changed

Lines changed: 147 additions & 138 deletions

File tree

sqlmesh/dbt/loader.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from sqlmesh.utils.errors import ConfigError
2929
from sqlmesh.utils.jinja import (
3030
JinjaMacroRegistry,
31-
extract_macro_references_and_variables,
3231
make_jinja_registry,
3332
)
3433

@@ -240,11 +239,11 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
240239
def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]:
241240
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
242241

243-
environment_statements: t.List[EnvironmentStatements] = []
242+
hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
243+
project_names: t.Set[str] = set()
244244
dialect = self.config.dialect
245245
for project in self._load_projects():
246246
context = project.context
247-
hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
248247
for package_name, package in project.packages.items():
249248
context.set_and_render_variables(package.variables, package_name)
250249
on_run_start: t.List[str] = [
@@ -256,18 +255,14 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
256255
for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
257256
]
258257

259-
if statements := on_run_start + on_run_end:
260-
jinja_references, used_variables = extract_macro_references_and_variables(
261-
*statements
262-
)
258+
if on_run_start or on_run_end:
259+
dependencies = Dependencies()
260+
for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
261+
dependencies = dependencies.union(hook.dependencies)
263262

264-
statements_context = context.context_for_dependencies(
265-
Dependencies(
266-
variables=used_variables,
267-
)
268-
)
263+
statements_context = context.context_for_dependencies(dependencies)
269264
jinja_registry = make_jinja_registry(
270-
statements_context.jinja_macros, package_name, jinja_references
265+
statements_context.jinja_macros, package_name, set(dependencies.macros)
271266
)
272267
jinja_registry.add_globals(statements_context.jinja_globals)
273268

@@ -283,15 +278,15 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
283278
python_env={},
284279
jinja_macros=jinja_registry,
285280
)
286-
# Project hooks should be executed first and then rest of the packages
287-
environment_statements = [
288-
statements
289-
for _, statements in sorted(
290-
hooks_by_package_name.items(),
291-
key=lambda item: 0 if item[0] == context.project_name else 1,
292-
)
293-
]
294-
return environment_statements
281+
project_names.add(package_name)
282+
283+
return [
284+
statements
285+
for _, statements in sorted(
286+
hooks_by_package_name.items(),
287+
key=lambda item: 0 if item[0] in project_names else 1,
288+
)
289+
]
295290

296291
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:
297292
if not root.is_dir():

sqlmesh/dbt/manifest.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,24 @@ def _load_on_run_start_end(self) -> None:
292292
sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore
293293
node_name = node.name
294294
node_path = Path(node.original_file_path)
295+
296+
dependencies = Dependencies(
297+
macros=_macro_references(self._manifest, node),
298+
refs=_refs(node),
299+
sources=_sources(node),
300+
)
301+
dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name))
302+
dependencies = dependencies.union(
303+
self._flatten_dependencies_from_macros(dependencies.macros, node.package_name)
304+
)
305+
295306
if "on-run-start" in node.tags:
296307
self._on_run_start_per_package[node.package_name][node_name] = HookConfig(
297-
sql=sql, index=node.index or 0, path=node_path
308+
sql=sql, index=node.index or 0, path=node_path, dependencies=dependencies
298309
)
299310
else:
300311
self._on_run_end_per_package[node.package_name][node_name] = HookConfig(
301-
sql=sql, index=node.index or 0, path=node_path
312+
sql=sql, index=node.index or 0, path=node_path, dependencies=dependencies
302313
)
303314

304315
@property

sqlmesh/dbt/package.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class HookConfig(PydanticModel):
3434
sql: str
3535
index: int
3636
path: Path
37+
dependencies: Dependencies
3738

3839

3940
class Package(PydanticModel):

tests/core/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4346,7 +4346,7 @@ def test_dbt_dialect_with_normalization_strategy(init_and_plan_context: t.Callab
43464346

43474347

43484348
@time_machine.travel("2023-01-08 15:00:00 UTC")
4349-
def test_dbt_before_all_with_var(init_and_plan_context: t.Callable):
4349+
def test_dbt_before_all_with_var_ref_source(init_and_plan_context: t.Callable):
43504350
_, plan = init_and_plan_context(
43514351
"tests/fixtures/dbt/sushi_test", config="test_config_with_normalization_strategy"
43524352
)
@@ -4356,7 +4356,7 @@ def test_dbt_before_all_with_var(init_and_plan_context: t.Callable):
43564356
assert rendered_statements[0] == [
43574357
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)",
43584358
"CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)",
4359-
'SELECT 1 AS "1"',
4359+
"SELECT 1 AS var, 'items' AS src, 'waiters' AS ref",
43604360
]
43614361

43624362

tests/dbt/test_adapter.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313

1414
from sqlmesh import Context
1515
from sqlmesh.core.dialect import schema_
16-
from sqlmesh.core.environment import EnvironmentNamingInfo
17-
from sqlmesh.core.macros import RuntimeStage
18-
from sqlmesh.core.renderer import render_statements
1916
from sqlmesh.core.snapshot import SnapshotId
2017
from sqlmesh.dbt.adapter import ParsetimeAdapter
2118
from sqlmesh.dbt.project import Project
@@ -275,114 +272,6 @@ def test_quote_as_configured():
275272
adapter.quote_as_configured("foo", "database") == "foo"
276273

277274

278-
def test_on_run_start_end(copy_to_temp_path):
279-
project_root = "tests/fixtures/dbt/sushi_test"
280-
sushi_context = Context(paths=copy_to_temp_path(project_root))
281-
assert len(sushi_context._environment_statements) == 2
282-
283-
# Root project's on run start / on run end should be first by checking the macros
284-
root_environment_statements = sushi_context._environment_statements[0]
285-
assert "create_tables" in root_environment_statements.jinja_macros.root_macros
286-
287-
# Validate order of execution to be correct
288-
assert root_environment_statements.before_all == [
289-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
290-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);\nJINJA_END;",
291-
'JINJA_STATEMENT_BEGIN;\nSELECT {{ var("yet_another_var") }}\nJINJA_END;',
292-
]
293-
assert root_environment_statements.after_all == [
294-
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
295-
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
296-
]
297-
298-
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
299-
300-
rendered_before_all = render_statements(
301-
root_environment_statements.before_all,
302-
dialect=sushi_context.default_dialect,
303-
python_env=root_environment_statements.python_env,
304-
jinja_macros=root_environment_statements.jinja_macros,
305-
runtime_stage=RuntimeStage.BEFORE_ALL,
306-
)
307-
308-
rendered_after_all = render_statements(
309-
root_environment_statements.after_all,
310-
dialect=sushi_context.default_dialect,
311-
python_env=root_environment_statements.python_env,
312-
jinja_macros=root_environment_statements.jinja_macros,
313-
snapshots=sushi_context.snapshots,
314-
runtime_stage=RuntimeStage.AFTER_ALL,
315-
environment_naming_info=EnvironmentNamingInfo(name="dev"),
316-
)
317-
318-
assert rendered_before_all == [
319-
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)",
320-
"CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)",
321-
'SELECT 1 AS "1"',
322-
]
323-
324-
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
325-
assert sorted(rendered_after_all) == sorted(
326-
[
327-
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
328-
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
329-
"DROP TABLE to_be_executed_last",
330-
]
331-
)
332-
333-
# Nested dbt_packages on run start / on run end
334-
packaged_environment_statements = sushi_context._environment_statements[1]
335-
336-
# Validate order of execution to be correct
337-
assert packaged_environment_statements.before_all == [
338-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);\nJINJA_END;",
339-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
340-
]
341-
assert packaged_environment_statements.after_all == [
342-
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_first\nJINJA_END;",
343-
"JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;",
344-
]
345-
346-
assert "packaged_tables" in packaged_environment_statements.jinja_macros.root_macros
347-
assert packaged_environment_statements.jinja_macros.root_package_name == "sushi"
348-
349-
rendered_before_all = render_statements(
350-
packaged_environment_statements.before_all,
351-
dialect=sushi_context.default_dialect,
352-
python_env=packaged_environment_statements.python_env,
353-
jinja_macros=packaged_environment_statements.jinja_macros,
354-
runtime_stage=RuntimeStage.BEFORE_ALL,
355-
)
356-
357-
rendered_after_all = render_statements(
358-
packaged_environment_statements.after_all,
359-
dialect=sushi_context.default_dialect,
360-
python_env=packaged_environment_statements.python_env,
361-
jinja_macros=packaged_environment_statements.jinja_macros,
362-
snapshots=sushi_context.snapshots,
363-
runtime_stage=RuntimeStage.AFTER_ALL,
364-
environment_naming_info=EnvironmentNamingInfo(name="dev"),
365-
)
366-
367-
# Validate order of execution to match dbt's
368-
assert rendered_before_all == [
369-
"CREATE TABLE IF NOT EXISTS to_be_executed_first (col TEXT)",
370-
"CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)",
371-
]
372-
373-
# This on run end statement should be executed first
374-
assert rendered_after_all[0] == "DROP TABLE to_be_executed_first"
375-
376-
# The table names is an indication of the rendering of the dbt_packages statements
377-
assert sorted(rendered_after_all) == sorted(
378-
[
379-
"DROP TABLE to_be_executed_first",
380-
"CREATE OR REPLACE TABLE schema_table_snapshots__dev_nested_package AS SELECT 'snapshots__dev' AS schema",
381-
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
382-
]
383-
)
384-
385-
386275
def test_adapter_get_relation_normalization(
387276
sushi_test_project: Project, runtime_renderer: t.Callable
388277
):

tests/dbt/test_transformation.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from pytest_mock.plugin import MockerFixture
1414
from sqlglot import exp, parse_one
1515
from sqlmesh.core import dialect as d
16+
from sqlmesh.core.environment import EnvironmentNamingInfo
17+
from sqlmesh.core.macros import RuntimeStage
18+
from sqlmesh.core.renderer import render_statements
1619
from sqlmesh.core.audit import StandaloneAudit
1720
from sqlmesh.core.context import Context
1821
from sqlmesh.core.console import get_console
@@ -1551,3 +1554,112 @@ def test_grain():
15511554

15521555
model.grain = "id_a"
15531556
assert model.to_sqlmesh(context).grains == [exp.to_column("id_a")]
1557+
1558+
1559+
def test_on_run_start_end(copy_to_temp_path):
1560+
project_root = "tests/fixtures/dbt/sushi_test"
1561+
sushi_context = Context(paths=copy_to_temp_path(project_root))
1562+
assert len(sushi_context._environment_statements) == 2
1563+
1564+
# Root project's on run start / on run end should be first by checking the macros
1565+
root_environment_statements = sushi_context._environment_statements[0]
1566+
assert "create_tables" in root_environment_statements.jinja_macros.root_macros
1567+
1568+
# Validate order of execution to be correct
1569+
assert root_environment_statements.before_all == [
1570+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
1571+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);\nJINJA_END;",
1572+
"""JINJA_STATEMENT_BEGIN;\nSELECT {{ var("yet_another_var") }} AS var, '{{ source("raw", "items").identifier }}' AS src, '{{ ref("waiters").identifier }}' AS ref;\nJINJA_END;""",
1573+
"JINJA_STATEMENT_BEGIN;\n{{ log_value('on-run-start') }}\nJINJA_END;",
1574+
]
1575+
assert root_environment_statements.after_all == [
1576+
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
1577+
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
1578+
]
1579+
1580+
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
1581+
1582+
rendered_before_all = render_statements(
1583+
root_environment_statements.before_all,
1584+
dialect=sushi_context.default_dialect,
1585+
python_env=root_environment_statements.python_env,
1586+
jinja_macros=root_environment_statements.jinja_macros,
1587+
runtime_stage=RuntimeStage.BEFORE_ALL,
1588+
)
1589+
1590+
rendered_after_all = render_statements(
1591+
root_environment_statements.after_all,
1592+
dialect=sushi_context.default_dialect,
1593+
python_env=root_environment_statements.python_env,
1594+
jinja_macros=root_environment_statements.jinja_macros,
1595+
snapshots=sushi_context.snapshots,
1596+
runtime_stage=RuntimeStage.AFTER_ALL,
1597+
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1598+
)
1599+
1600+
assert rendered_before_all == [
1601+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)",
1602+
"CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)",
1603+
"SELECT 1 AS var, 'items' AS src, 'waiters' AS ref",
1604+
]
1605+
1606+
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
1607+
assert sorted(rendered_after_all) == sorted(
1608+
[
1609+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
1610+
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
1611+
"DROP TABLE to_be_executed_last",
1612+
]
1613+
)
1614+
1615+
# Nested dbt_packages on run start / on run end
1616+
packaged_environment_statements = sushi_context._environment_statements[1]
1617+
1618+
# Validate order of execution to be correct
1619+
assert packaged_environment_statements.before_all == [
1620+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);\nJINJA_END;",
1621+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
1622+
]
1623+
assert packaged_environment_statements.after_all == [
1624+
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_first\nJINJA_END;",
1625+
"JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;",
1626+
]
1627+
1628+
assert "packaged_tables" in packaged_environment_statements.jinja_macros.root_macros
1629+
assert packaged_environment_statements.jinja_macros.root_package_name == "sushi"
1630+
1631+
rendered_before_all = render_statements(
1632+
packaged_environment_statements.before_all,
1633+
dialect=sushi_context.default_dialect,
1634+
python_env=packaged_environment_statements.python_env,
1635+
jinja_macros=packaged_environment_statements.jinja_macros,
1636+
runtime_stage=RuntimeStage.BEFORE_ALL,
1637+
)
1638+
1639+
rendered_after_all = render_statements(
1640+
packaged_environment_statements.after_all,
1641+
dialect=sushi_context.default_dialect,
1642+
python_env=packaged_environment_statements.python_env,
1643+
jinja_macros=packaged_environment_statements.jinja_macros,
1644+
snapshots=sushi_context.snapshots,
1645+
runtime_stage=RuntimeStage.AFTER_ALL,
1646+
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1647+
)
1648+
1649+
# Validate order of execution to match dbt's
1650+
assert rendered_before_all == [
1651+
"CREATE TABLE IF NOT EXISTS to_be_executed_first (col TEXT)",
1652+
"CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)",
1653+
]
1654+
1655+
# This on run end statement should be executed first
1656+
assert rendered_after_all[0] == "DROP TABLE to_be_executed_first"
1657+
1658+
# The table names is an indication of the rendering of the dbt_packages statements
1659+
assert sorted(rendered_after_all) == sorted(
1660+
[
1661+
"DROP TABLE to_be_executed_first",
1662+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev_nested_package AS SELECT 'snapshots__dev' AS schema",
1663+
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
1664+
]
1665+
)

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ vars:
6262
on-run-start:
6363
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'
6464
- 'CREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);'
65-
- 'SELECT {{ var("yet_another_var") }}'
65+
- SELECT {{ var("yet_another_var") }} AS var, '{{ source("raw", "items").identifier }}' AS src, '{{ ref("waiters").identifier }}' AS ref;
66+
- "{{ log_value('on-run-start') }}"
6667
on-run-end:
6768
- '{{ create_tables(schemas) }}'
6869
- 'DROP TABLE to_be_executed_last;'

0 commit comments

Comments
 (0)