Skip to content

Commit 99c23ac

Browse files
refactor make jinja registry in a method; ensure correct execution order
1 parent c867283 commit 99c23ac

8 files changed

Lines changed: 113 additions & 35 deletions

File tree

sqlmesh/dbt/loader.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sqlmesh.utils.jinja import (
3131
JinjaMacroRegistry,
3232
extract_macro_references_and_variables,
33+
make_jinja_registry,
3334
)
3435

3536
if sys.version_info >= (3, 12):
@@ -244,22 +245,27 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
244245
dialect = self.config.dialect
245246
for project in self._load_projects():
246247
context = project.context
248+
hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {}
247249
for package_name, package in project.packages.items():
248250
context.set_and_render_variables(package.variables, package_name)
249-
on_run_start: t.List[str] = []
250-
on_run_end: t.List[str] = []
251-
for hook in package.on_run_start.values():
252-
on_run_start.append(hook.sql)
253-
for hook in package.on_run_end.values():
254-
on_run_end.append(hook.sql)
251+
on_run_start: t.List[str] = [
252+
on_run_hook.sql
253+
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
254+
]
255+
on_run_end: t.List[str] = [
256+
on_run_hook.sql
257+
for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index)
258+
]
255259

256260
if statements := on_run_start + on_run_end:
257261
jinja_references, used_variables = extract_macro_references_and_variables(
258262
*(gen(stmt) for stmt in statements)
259263
)
260-
jinja_registry = context.jinja_macros.copy()
261-
jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
262-
jinja_registry = jinja_registry.trim(jinja_references)
264+
265+
jinja_registry = make_jinja_registry(
266+
context.jinja_macros, package_name, jinja_references
267+
)
268+
263269
python_env = make_python_env(
264270
[s for stmt in statements for s in d.parse(stmt, default_dialect=dialect)],
265271
jinja_macro_references=jinja_references,
@@ -270,20 +276,26 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
270276
path=self.config_path,
271277
)
272278

273-
environment_statements.append(
274-
EnvironmentStatements(
275-
before_all=[
276-
d.jinja_statement(stmt).sql(dialect=dialect)
277-
for stmt in on_run_start or []
278-
],
279-
after_all=[
280-
d.jinja_statement(stmt).sql(dialect=dialect)
281-
for stmt in on_run_end or []
282-
],
283-
python_env=python_env,
284-
jinja_macros=jinja_registry,
285-
)
279+
hooks_by_package_name[package_name] = EnvironmentStatements(
280+
before_all=[
281+
d.jinja_statement(stmt).sql(dialect=dialect)
282+
for stmt in on_run_start or []
283+
],
284+
after_all=[
285+
d.jinja_statement(stmt).sql(dialect=dialect)
286+
for stmt in on_run_end or []
287+
],
288+
python_env=python_env,
289+
jinja_macros=jinja_registry,
286290
)
291+
# Project hooks should be executed first and then rest of the packages
292+
environment_statements = [
293+
statements
294+
for _, statements in sorted(
295+
hooks_by_package_name.items(),
296+
key=lambda item: 0 if item[0] == context.project_name else 1,
297+
)
298+
]
287299
return environment_statements
288300

289301
def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]:

sqlmesh/dbt/manifest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,11 @@ def _load_on_run_start_end(self) -> None:
294294
node_path = Path(node.original_file_path)
295295
if "on-run-start" in node.tags:
296296
self._on_run_start_per_package[node.package_name][node_name] = HookConfig(
297-
sql=sql, path=node_path
297+
sql=sql, index=node.index or 0, path=node_path
298298
)
299299
else:
300300
self._on_run_end_per_package[node.package_name][node_name] = HookConfig(
301-
sql=sql, path=node_path
301+
sql=sql, index=node.index or 0, path=node_path
302302
)
303303

304304
@property

sqlmesh/dbt/package.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class HookConfig(PydanticModel):
3232
"""Class to contain on run start / on run end hooks."""
3333

3434
sql: str
35+
index: int
3536
path: Path
3637

3738

sqlmesh/utils/jinja.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,31 @@ def create_builtin_globals(
608608
c.GATEWAY: lambda: variables.get(c.GATEWAY, None),
609609
**global_vars,
610610
}
611+
612+
613+
def make_jinja_registry(
614+
jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference]
615+
) -> JinjaMacroRegistry:
616+
"""
617+
Creates a Jinja macro registry for a specific package.
618+
619+
This function takes an existing Jinja macro registry and returns a new
620+
registry that includes only the macros associated with the specified
621+
package and trims the registry to include only the macros referenced
622+
in the provided set of macro references.
623+
624+
Args:
625+
jinja_macros: The original Jinja macro registry containing all macros.
626+
package_name: The name of the package for which to create the registry.
627+
jinja_references: A set of macro references to retain in the new registry.
628+
629+
Returns:
630+
A new JinjaMacroRegistry containing only the macros for the specified
631+
package and the referenced macros.
632+
"""
633+
634+
jinja_registry = jinja_macros.copy()
635+
jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {}
636+
jinja_registry = jinja_registry.trim(jinja_references)
637+
638+
return jinja_registry

tests/dbt/test_adapter.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,20 @@ def test_on_run_start_end(copy_to_temp_path):
280280
sushi_context = Context(paths=copy_to_temp_path(project_root))
281281
assert len(sushi_context._environment_statements) == 2
282282

283-
# Root project on run start / on run end
283+
# Root project's on run start / on run end should be first by checking the macros
284284
root_environment_statements = sushi_context._environment_statements[0]
285+
assert "create_tables" in root_environment_statements.jinja_macros.root_macros
286+
287+
# Validate order of execution to be correct
285288
assert root_environment_statements.before_all == [
286-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
289+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
290+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);\nJINJA_END;",
287291
]
288292
assert root_environment_statements.after_all == [
289-
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;"
293+
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
294+
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
290295
]
291296

292-
assert "create_tables" in root_environment_statements.jinja_macros.root_macros
293297
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
294298

295299
rendered_before_all = render_statements(
@@ -311,25 +315,30 @@ def test_on_run_start_end(copy_to_temp_path):
311315
)
312316

313317
assert rendered_before_all == [
314-
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)"
318+
"CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)",
319+
"CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)",
315320
]
316321

317322
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
318323
assert sorted(rendered_after_all) == sorted(
319324
[
320325
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
321326
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
327+
"DROP TABLE to_be_executed_last",
322328
]
323329
)
324330

325331
# Nested dbt_packages on run start / on run end
326332
packaged_environment_statements = sushi_context._environment_statements[1]
327333

334+
# Validate order of execution to be correct
328335
assert packaged_environment_statements.before_all == [
329-
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;"
336+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);\nJINJA_END;",
337+
"JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;",
330338
]
331339
assert packaged_environment_statements.after_all == [
332-
"JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;"
340+
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_first\nJINJA_END;",
341+
"JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;",
333342
]
334343

335344
assert "packaged_tables" in packaged_environment_statements.jinja_macros.root_macros
@@ -353,13 +362,19 @@ def test_on_run_start_end(copy_to_temp_path):
353362
environment_naming_info=EnvironmentNamingInfo(name="dev"),
354363
)
355364

365+
# Validate order of execution to match dbt's
356366
assert rendered_before_all == [
357-
"CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)"
367+
"CREATE TABLE IF NOT EXISTS to_be_executed_first (col TEXT)",
368+
"CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)",
358369
]
359370

371+
# This on run end statement should be executed first
372+
assert rendered_after_all[0] == "DROP TABLE to_be_executed_first"
373+
360374
# The table names is an indication of the rendering of the dbt_packages statements
361375
assert sorted(rendered_after_all) == sorted(
362376
[
377+
"DROP TABLE to_be_executed_first",
363378
"CREATE OR REPLACE TABLE schema_table_snapshots__dev_nested_package AS SELECT 'snapshots__dev' AS schema",
364379
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
365380
]

tests/dbt/test_transformation.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,15 +1003,33 @@ def test_dbt_version(sushi_test_project: Project):
10031003

10041004
@pytest.mark.xdist_group("dbt_manifest")
10051005
def test_dbt_on_run_start_end(sushi_test_project: Project):
1006+
# Validate perservation of dbt's order of execution
1007+
assert sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-0"].index == 0
1008+
assert sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-1"].index == 1
1009+
assert sushi_test_project.packages["sushi"].on_run_end["sushi-on-run-end-0"].index == 0
1010+
assert sushi_test_project.packages["sushi"].on_run_end["sushi-on-run-end-1"].index == 1
1011+
assert (
1012+
sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-0"].index == 0
1013+
)
1014+
assert (
1015+
sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-1"].index == 1
1016+
)
1017+
assert sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-0"].index == 0
1018+
assert sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-1"].index == 1
1019+
10061020
assert (
10071021
sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-0"].sql
1022+
== "CREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);"
1023+
)
1024+
assert (
1025+
sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-1"].sql
10081026
== "CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);"
10091027
)
10101028
assert (
1011-
sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-0"].sql
1029+
sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-1"].sql
10121030
== "{{ packaged_tables(schemas) }}"
10131031
)
1014-
assert sushi_test_project.packages["sushi"].on_run_end
1032+
10151033
assert (
10161034
sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-0"].sql
10171035
== "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);"

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,7 @@ vars:
6161

6262
on-run-start:
6363
- 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);'
64+
- 'CREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);'
6465
on-run-end:
65-
- '{{ create_tables(schemas) }}'
66+
- '{{ create_tables(schemas) }}'
67+
- 'DROP TABLE to_be_executed_last;'

tests/fixtures/dbt/sushi_test/packages/customers/dbt_project.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ vars:
3333

3434

3535
on-run-start:
36+
- 'CREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);'
3637
- 'CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);'
3738
on-run-end:
39+
- 'DROP TABLE to_be_executed_first'
3840
- '{{ packaged_tables(schemas) }}'

0 commit comments

Comments
 (0)