Skip to content

Commit e6b8cab

Browse files
Refactor to handle multiple projects; add multi_repo dbt test
1 parent 0ff6643 commit e6b8cab

5 files changed

Lines changed: 56 additions & 23 deletions

File tree

examples/multi_dbt/bronze/dbt_project.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
1919
models:
2020
start: "2024-01-01"
2121
+materialized: table
22+
23+
on-run-start:
24+
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'

examples/multi_dbt/silver/dbt_project.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"]
1919
models:
2020
start: "2024-01-01"
2121
+materialized: table
22+
23+
on-run-end:
24+
- '{{ store_schemas(schemas) }}'
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{% macro store_schemas(schemas) %}
2+
create or replace table schema_table as select {{schemas}} as all_schemas;
3+
{% endmacro %}

sqlmesh/dbt/loader.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@
2727
from sqlmesh.dbt.target import TargetConfig
2828
from sqlmesh.utils import UniqueKeyDict
2929
from sqlmesh.utils.errors import ConfigError
30-
from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables
30+
from sqlmesh.utils.jinja import (
31+
JinjaMacroRegistry,
32+
MacroInfo,
33+
extract_macro_references_and_variables,
34+
)
3135

3236
if sys.version_info >= (3, 12):
3337
from importlib import metadata
@@ -237,39 +241,41 @@ def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]:
237241
def _load_environment_statements(self, macros: MacroRegistry) -> EnvironmentStatements | None:
238242
"""Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively."""
239243

240-
on_run_start = []
241-
on_run_end = []
242-
244+
on_run_start: t.List[str] = []
245+
on_run_end: t.List[str] = []
246+
jinja_root_macros: t.Dict[str, MacroInfo] = {}
247+
variables: t.Dict[str, t.Any] = self._get_variables()
243248
dialect = self.config.dialect
244249
for project in self._load_projects():
245-
if manifest := project.context._manifest:
246-
if stmts := manifest._on_run_start:
247-
on_run_start.extend(stmts)
248-
if stmts := manifest._on_run_end:
249-
on_run_end.extend(stmts)
250+
context = project.context.copy()
251+
if manifest := context._manifest:
252+
on_run_start.extend(manifest._on_run_start or [])
253+
on_run_end.extend(manifest._on_run_end or [])
254+
255+
if root_package := context.jinja_macros.root_package_name:
256+
if root_macros := context.jinja_macros.packages.get(root_package):
257+
jinja_root_macros |= root_macros
258+
context.set_and_render_variables(context.variables, root_package)
259+
variables |= context.variables
250260

251261
if statements := on_run_start + on_run_end:
252262
jinja_macro_references, used_variables = extract_macro_references_and_variables(
253-
*(gen(e) for e in statements)
263+
*(gen(stmt) for stmt in statements)
264+
)
265+
jinja_macros = context.jinja_macros
266+
jinja_macros.root_macros = jinja_root_macros
267+
jinja_macros = (
268+
jinja_macros.trim(jinja_macro_references)
269+
if not jinja_macros.trimmed
270+
else jinja_macros
254271
)
255-
256-
if jinja_macros := project.context.jinja_macros:
257-
if root_package := jinja_macros.root_package_name:
258-
jinja_macros.root_macros = jinja_macros.packages[root_package]
259-
jinja_macros = (
260-
jinja_macros
261-
if jinja_macros.trimmed
262-
else jinja_macros.trim(jinja_macro_references)
263-
)
264-
else:
265-
jinja_macros = JinjaMacroRegistry()
266272

267273
python_env = make_python_env(
268274
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
269275
jinja_macro_references=jinja_macro_references,
270276
module_path=self.config_path,
271-
macros=macros or macro.get_registry(),
272-
variables=self._get_variables(),
277+
macros=macros,
278+
variables=variables,
273279
used_variables=used_variables,
274280
path=self.config_path,
275281
)

tests/core/test_integration.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4502,6 +4502,24 @@ def test_multi_dbt(mocker):
45024502
context.apply(plan)
45034503
validate_apply_basics(context, c.PROD, plan.snapshots.values())
45044504

4505+
environment_statements = context.state_sync.get_environment_statements(c.PROD)
4506+
assert len(environment_statements) == 2
4507+
bronze_statements = environment_statements[0]
4508+
assert bronze_statements.before_all == [
4509+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
4510+
]
4511+
assert not bronze_statements.after_all
4512+
silver_statements = environment_statements[1]
4513+
assert not silver_statements.before_all
4514+
assert silver_statements.after_all == [
4515+
"JINJA_STATEMENT_BEGIN;\n{{ store_schemas(schemas) }}\nJINJA_END;"
4516+
]
4517+
assert "store_schemas" in silver_statements.jinja_macros.root_macros
4518+
analytics_table = context.fetchdf("select * from analytic_stats;")
4519+
assert sorted(analytics_table.columns) == sorted(["physical_table", "evaluation_time"])
4520+
schema_table = context.fetchdf("select * from schema_table;")
4521+
assert sorted(schema_table.all_schemas[0]) == sorted(["bronze", "silver"])
4522+
45054523

45064524
def test_multi_hybrid(mocker):
45074525
context = Context(

0 commit comments

Comments
 (0)