Skip to content

Commit a81cfda

Browse files
Feat: Improve the cli before all after all diff and include python env (#4116)
1 parent c9429ae commit a81cfda

6 files changed

Lines changed: 90 additions & 42 deletions

File tree

sqlmesh/core/audit/definition.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
default_catalog_validator,
1717
depends_on_validator,
1818
expression_validator,
19+
sort_python_env,
20+
sorted_python_env_payloads,
1921
)
2022
from sqlmesh.core.model.common import make_python_env, single_value_or_tuple
2123
from sqlmesh.core.node import _Node
@@ -236,7 +238,7 @@ def depends_on(self) -> t.Set[str]:
236238
@property
237239
def sorted_python_env(self) -> t.List[t.Tuple[str, Executable]]:
238240
"""Returns the python env sorted by executable kind and then var name."""
239-
return sorted(self.python_env.items(), key=lambda x: (x[1].kind, x[0]))
241+
return sort_python_env(self.python_env)
240242

241243
@property
242244
def data_hash(self) -> str:
@@ -337,12 +339,7 @@ def render_definition(
337339
jinja_expressions = []
338340
python_expressions = []
339341
if include_python:
340-
python_env = d.PythonCode(
341-
expressions=[
342-
v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}"
343-
for k, v in self.sorted_python_env
344-
]
345-
)
342+
python_env = d.PythonCode(expressions=sorted_python_env_payloads(self.python_env))
346343
if python_env.expressions:
347344
python_expressions.append(python_env)
348345

sqlmesh/core/console.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,9 @@ def show_model_difference_summary(
13111311
self._print(f"[bold]Requirements:\n{context_diff.requirements_diff()}")
13121312

13131313
if context_diff.has_environment_statements_changes:
1314-
self._print(
1315-
f"[bold]Environment statements:\n{context_diff.environment_statements_diff()}"
1316-
)
1314+
self._print("[bold]Environment statements:\n")
1315+
for type, diff in context_diff.environment_statements_diff():
1316+
self._print(Syntax(diff, type, line_numbers=False))
13171317

13181318
self._show_summary_tree_for(
13191319
context_diff,
@@ -2463,7 +2463,9 @@ def show_model_difference_summary(
24632463
self._print(f"Requirements:\n{context_diff.requirements_diff()}")
24642464

24652465
if context_diff.has_environment_statements_changes:
2466-
self._print(f"Environment statements:\n{context_diff.environment_statements_diff()}")
2466+
self._print("[bold]Environment statements:\n")
2467+
for _, diff in context_diff.environment_statements_diff():
2468+
self._print(diff)
24672469

24682470
added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added}
24692471
added_snapshot_models = {s for s in added_snapshots if s.is_model}
@@ -2976,7 +2978,9 @@ def show_model_difference_summary(
29762978
self._write(f"Requirements:\n{context_diff.requirements_diff()}")
29772979

29782980
if context_diff.has_environment_statements_changes:
2979-
self._write(f"Environment statements:\n{context_diff.environment_statements_diff()}")
2981+
self._write("Environment statements:\n")
2982+
for _, diff in context_diff.environment_statements_diff():
2983+
self._write(diff)
29802984

29812985
for added in context_diff.new_snapshots:
29822986
self._write(f" Added: {added}")

sqlmesh/core/context_diff.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sqlmesh.core import constants as c
2020
from sqlmesh.core.console import get_console
2121
from sqlmesh.core.macros import RuntimeStage
22+
from sqlmesh.core.model.common import sorted_python_env_payloads
2223
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo
2324
from sqlmesh.utils.errors import SQLMeshError
2425
from sqlmesh.utils.pydantic import PydanticModel
@@ -340,22 +341,43 @@ def requirements_diff(self) -> str:
340341
)
341342
)
342343

343-
def environment_statements_diff(self) -> str:
344+
def environment_statements_diff(self) -> t.List[t.Tuple[str, str]]:
344345
def extract_statements(statements: t.List[EnvironmentStatements], attr: str) -> t.List[str]:
345-
return [str(stmt) for statement in statements for stmt in getattr(statement, attr)]
346-
347-
def format_diff(runtime_stage: str) -> str:
348-
previous = extract_statements(self.previous_environment_statements, runtime_stage)
349-
current = extract_statements(self.environment_statements, runtime_stage)
350-
return (
351-
f" {runtime_stage}:\n " + "\n ".join(ndiff(previous, current)) + "\n"
352-
if previous or current
353-
else ""
354-
)
346+
return [
347+
expr
348+
for statement in statements
349+
for expr in (
350+
sorted_python_env_payloads(statement.python_env)
351+
if attr == "python_env"
352+
else getattr(statement, attr)
353+
)
354+
]
355355

356-
return format_diff(RuntimeStage.BEFORE_ALL.value) + format_diff(
357-
RuntimeStage.AFTER_ALL.value
358-
)
356+
def compute_diff(attribute: str) -> t.Optional[t.Tuple[str, str]]:
357+
previous = extract_statements(self.previous_environment_statements, attribute)
358+
current = extract_statements(self.environment_statements, attribute)
359+
360+
if previous == current:
361+
return None
362+
363+
diff_lines = list(ndiff(previous, current))
364+
diff_text = attribute if not attribute == "python_env" else "dependencies"
365+
diff_text += ":\n"
366+
367+
if any(line.startswith(("-", "+")) for line in diff_lines):
368+
diff_text += " " + "\n ".join(diff_lines) + "\n"
369+
370+
return "python" if attribute == "python_env" else "sql", diff_text
371+
372+
return [
373+
diff
374+
for attribute in [
375+
RuntimeStage.BEFORE_ALL.value,
376+
RuntimeStage.AFTER_ALL.value,
377+
"python_env",
378+
]
379+
if (diff := compute_diff(attribute)) is not None
380+
]
359381

360382
@property
361383
def environment_snapshots(self) -> t.List[SnapshotTableInfo]:

sqlmesh/core/model/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,19 @@ def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[
352352
return v
353353

354354

355+
def sort_python_env(python_env: t.Dict[str, Executable]) -> t.List[t.Tuple[str, Executable]]:
356+
"""Returns the python env sorted."""
357+
return sorted(python_env.items(), key=lambda x: (x[1].kind, x[0]))
358+
359+
360+
def sorted_python_env_payloads(python_env: t.Dict[str, Executable]) -> t.List[str]:
361+
"""Returns the payloads of the sorted python env."""
362+
return [
363+
v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}"
364+
for k, v in sort_python_env(python_env)
365+
]
366+
367+
355368
expression_validator: t.Callable = field_validator(
356369
"query",
357370
"expressions_",

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
make_python_env,
3131
parse_dependencies,
3232
single_value_or_tuple,
33+
sorted_python_env_payloads,
3334
)
3435
from sqlmesh.core.model.meta import ModelMeta, FunctionCall
3536
from sqlmesh.core.model.kind import (
@@ -257,12 +258,7 @@ def render_definition(
257258
jinja_expressions = []
258259
python_expressions = []
259260
if include_python:
260-
python_env = d.PythonCode(
261-
expressions=[
262-
v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}"
263-
for k, v in self.sorted_python_env
264-
]
265-
)
261+
python_env = d.PythonCode(expressions=sorted_python_env_payloads(self.python_env))
266262
if python_env.expressions:
267263
python_expressions.append(python_env)
268264

tests/core/test_plan.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from unittest.mock import patch
55

66
import pytest
7+
from sqlmesh.utils.metaprogramming import Executable
8+
from tests.core.test_table_diff import create_test_console, strip_ansi_codes
79
import time_machine
810
from pytest_mock.plugin import MockerFixture
911
from sqlglot import parse_one
@@ -3000,9 +3002,13 @@ def test_plan_environment_statements_diff(make_snapshot):
30003002
previous_finalized_snapshots=None,
30013003
environment_statements=[
30023004
EnvironmentStatements(
3003-
before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 1"],
3005+
before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 1", "@test_macro()"],
30043006
after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 2"],
3005-
python_env={},
3007+
python_env={
3008+
"test_macro": Executable(
3009+
payload="def test_macro(evaluator):\n return 'one'"
3010+
),
3011+
},
30063012
)
30073013
],
30083014
previous_gateway_managed_virtual_layer=False,
@@ -3011,11 +3017,21 @@ def test_plan_environment_statements_diff(make_snapshot):
30113017

30123018
assert context_diff.has_changes
30133019
assert context_diff.has_environment_statements_changes
3014-
assert (
3015-
context_diff.environment_statements_diff()
3016-
== """ before_all:
3017-
+ CREATE OR REPLACE TABLE table_1 AS SELECT 1
3018-
after_all:
3019-
+ CREATE OR REPLACE TABLE table_2 AS SELECT 2
3020-
"""
3021-
)
3020+
3021+
console_output, terminal_console = create_test_console()
3022+
for _, diff in context_diff.environment_statements_diff():
3023+
terminal_console._print(diff)
3024+
output = console_output.getvalue()
3025+
stripped = strip_ansi_codes(output)
3026+
expected_output = (
3027+
"before_all:\n"
3028+
" + CREATE OR REPLACE TABLE table_1 AS SELECT 1\n"
3029+
" + @test_macro()\n\n"
3030+
"after_all:\n"
3031+
" + CREATE OR REPLACE TABLE table_2 AS SELECT 2\n\n"
3032+
"dependencies:\n"
3033+
" + def test_macro(evaluator):\n"
3034+
" return 'one'"
3035+
)
3036+
assert stripped == expected_output
3037+
console_output.close()

0 commit comments

Comments
 (0)