Skip to content

Commit 6000d3b

Browse files
Feat: Improve the before all after all diff and include python env
1 parent fc7fe1b commit 6000d3b

3 files changed

Lines changed: 86 additions & 28 deletions

File tree

sqlmesh/core/console.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,9 @@ def show_model_difference_summary(
13111311
self._print(f"[bold]Requirements:\n{context_diff.requirements_diff()}")
13121312

13131313
if context_diff.has_environment_statements_changes:
1314-
self._print(
1315-
f"[bold]Environment statements:\n{context_diff.environment_statements_diff()}"
1316-
)
1314+
self._print("[bold]Environment statements:\n")
1315+
for changes in context_diff.environment_statements_diff():
1316+
self._print(changes)
13171317

13181318
self._show_summary_tree_for(
13191319
context_diff,
@@ -2463,7 +2463,9 @@ def show_model_difference_summary(
24632463
self._print(f"Requirements:\n{context_diff.requirements_diff()}")
24642464

24652465
if context_diff.has_environment_statements_changes:
2466-
self._print(f"Environment statements:\n{context_diff.environment_statements_diff()}")
2466+
self._print("[bold]Environment statements:\n")
2467+
for changes in context_diff.environment_statements_diff():
2468+
self._print(changes)
24672469

24682470
added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added}
24692471
added_snapshot_models = {s for s in added_snapshots if s.is_model}
@@ -2976,7 +2978,9 @@ def show_model_difference_summary(
29762978
self._write(f"Requirements:\n{context_diff.requirements_diff()}")
29772979

29782980
if context_diff.has_environment_statements_changes:
2979-
self._write(f"Environment statements:\n{context_diff.environment_statements_diff()}")
2981+
self._write("Environment statements:\n")
2982+
for changes in context_diff.environment_statements_diff():
2983+
self._write(changes)
29802984

29812985
for added in context_diff.new_snapshots:
29822986
self._write(f" Added: {added}")

sqlmesh/core/context_diff.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import typing as t
1717
from difflib import ndiff
1818
from functools import cached_property
19+
from rich.syntax import Syntax
1920
from sqlmesh.core import constants as c
2021
from sqlmesh.core.console import get_console
2122
from sqlmesh.core.macros import RuntimeStage
@@ -330,22 +331,55 @@ def requirements_diff(self) -> str:
330331
)
331332
)
332333

333-
def environment_statements_diff(self) -> str:
334+
def environment_statements_diff(self) -> t.List[Syntax]:
335+
PYTHON_ENV = "python_env"
336+
337+
def extract_python_env(python_env: t.Dict[str, Executable]) -> t.List[str]:
338+
return [
339+
v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}"
340+
for k, v in python_env.items()
341+
]
342+
334343
def extract_statements(statements: t.List[EnvironmentStatements], attr: str) -> t.List[str]:
335-
return [str(stmt) for statement in statements for stmt in getattr(statement, attr)]
336-
337-
def format_diff(runtime_stage: str) -> str:
338-
previous = extract_statements(self.previous_environment_statements, runtime_stage)
339-
current = extract_statements(self.environment_statements, runtime_stage)
340-
return (
341-
f" {runtime_stage}:\n " + "\n ".join(ndiff(previous, current)) + "\n"
342-
if previous or current
343-
else ""
344+
return [
345+
expr
346+
for statement in statements
347+
for expr in (
348+
extract_python_env(statement.python_env)
349+
if attr == PYTHON_ENV
350+
else getattr(statement, attr)
351+
)
352+
]
353+
354+
def format_diff(attribute: str) -> t.Optional[Syntax]:
355+
previous = extract_statements(self.previous_environment_statements, attribute)
356+
current = extract_statements(self.environment_statements, attribute)
357+
358+
if previous == current:
359+
return None
360+
361+
diff_text = f"=== {attribute if not attribute == PYTHON_ENV else "dependencies"} ===\n"
362+
363+
diff_lines = list(ndiff(previous, current))
364+
if any(line.startswith(("-", "+")) for line in diff_lines):
365+
diff_text += " " + "\n ".join(diff_lines) + "\n"
366+
367+
return Syntax(
368+
diff_text, "python" if attribute == PYTHON_ENV else "sql", line_numbers=False
344369
)
345370

346-
return format_diff(RuntimeStage.BEFORE_ALL.value) + format_diff(
347-
RuntimeStage.AFTER_ALL.value
348-
)
371+
return [
372+
diff
373+
for diff in (
374+
format_diff(attribute)
375+
for attribute in [
376+
RuntimeStage.BEFORE_ALL.value,
377+
RuntimeStage.AFTER_ALL.value,
378+
PYTHON_ENV,
379+
]
380+
)
381+
if diff is not None
382+
]
349383

350384
@property
351385
def environment_snapshots(self) -> t.List[SnapshotTableInfo]:

tests/core/test_plan.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from unittest.mock import patch
55

66
import pytest
7+
from sqlmesh.utils.metaprogramming import Executable
8+
from tests.core.test_table_diff import create_test_console, strip_ansi_codes
79
import time_machine
810
from pytest_mock.plugin import MockerFixture
911
from sqlglot import parse_one
@@ -2904,20 +2906,38 @@ def test_plan_environment_statements_diff(make_snapshot):
29042906
previous_finalized_snapshots=None,
29052907
environment_statements=[
29062908
EnvironmentStatements(
2907-
before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 1"],
2909+
before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 1", "@test_macro()"],
29082910
after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 2"],
2909-
python_env={},
2911+
python_env={
2912+
"test_macro": Executable(
2913+
payload="def test_macro(evaluator):\n return 'one'"
2914+
),
2915+
},
29102916
)
29112917
],
29122918
)
29132919

29142920
assert context_diff.has_changes
29152921
assert context_diff.has_environment_statements_changes
2916-
assert (
2917-
context_diff.environment_statements_diff()
2918-
== """ before_all:
2919-
+ CREATE OR REPLACE TABLE table_1 AS SELECT 1
2920-
after_all:
2921-
+ CREATE OR REPLACE TABLE table_2 AS SELECT 2
2922-
"""
2923-
)
2922+
2923+
console_output, terminal_console = create_test_console()
2924+
2925+
for stmt in context_diff.environment_statements_diff():
2926+
terminal_console._print(stmt)
2927+
output = console_output.getvalue()
2928+
stripped = strip_ansi_codes(output)
2929+
2930+
expected_output = (
2931+
"=== before_all === \n"
2932+
" + CREATE OR REPLACE TABLE table_1 AS SELECT 1 \n"
2933+
" + @test_macro() \n"
2934+
" \n"
2935+
"=== after_all === \n"
2936+
" + CREATE OR REPLACE TABLE table_2 AS SELECT 2 \n"
2937+
" \n"
2938+
"=== dependencies === \n"
2939+
" + def test_macro(evaluator): \n"
2940+
" return 'one'"
2941+
)
2942+
assert stripped == expected_output
2943+
console_output.close()

0 commit comments

Comments
 (0)