Skip to content

Commit d5a5588

Browse files
Fix: Account for array types when showing sample in table diff
1 parent 0fc89dd commit d5a5588

3 files changed

Lines changed: 183 additions & 5 deletions

File tree

sqlmesh/core/console.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import logging
99
import textwrap
1010
from pathlib import Path
11-
11+
import pandas as pd
12+
import numpy as np
1213
from hyperscript import h
1314
from rich.console import Console as RichConsole
1415
from rich.live import Live
@@ -1904,10 +1905,37 @@ def show_row_diff(
19041905
# Create a table with the joined keys and comparison columns
19051906
column_table = row_diff.joined_sample[keys + [source_column, target_column]]
19061907

1907-
# Filter out identical-valued rows
1908+
def compare_cells(x: t.Any, y: t.Any) -> bool:
1909+
"""Compare two cells and returns true if they're not equal, handling array objects."""
1910+
if x is None or y is None:
1911+
return x != y
1912+
1913+
# Convert any array-like object to list for consistent comparison
1914+
def to_list(val: t.Any) -> t.Any:
1915+
return (
1916+
list(val)
1917+
if isinstance(val, (pd.Series, np.ndarray, list, tuple, set))
1918+
else val
1919+
)
1920+
1921+
x = to_list(x)
1922+
y = to_list(y)
1923+
if isinstance(x, list) and isinstance(y, list):
1924+
if len(x) != len(y):
1925+
return True
1926+
return any(a != b for a, b in zip(x, y))
1927+
1928+
return x != y
1929+
1930+
# Filter to retain non identical-valued rows
19081931
column_table = column_table[
1909-
column_table[source_column] != column_table[target_column]
1932+
column_table.apply(
1933+
lambda row: compare_cells(row[source_column], row[target_column]),
1934+
axis=1,
1935+
)
19101936
]
1937+
1938+
# Rename the column headers for readability
19111939
column_table = column_table.rename(
19121940
columns={
19131941
source_column: source_name,
@@ -1921,7 +1949,16 @@ def show_row_diff(
19211949
table.add_column(column_name, style=style, header_style=style)
19221950

19231951
for _, row in column_table.iterrows():
1924-
table.add_row(*[str(cell) for cell in row])
1952+
table.add_row(
1953+
*[
1954+
str(
1955+
round(cell, row_diff.decimals)
1956+
if isinstance(cell, float)
1957+
else cell
1958+
)
1959+
for cell in row
1960+
]
1961+
)
19251962

19261963
self.console.print(
19271964
f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]",

sqlmesh/core/table_diff.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class RowDiff(PydanticModel, frozen=True):
7070
source_alias: t.Optional[str] = None
7171
target_alias: t.Optional[str] = None
7272
model_name: t.Optional[str] = None
73+
decimals: int = 3
7374

7475
@property
7576
def source_count(self) -> int:
@@ -576,5 +577,6 @@ def name(e: exp.Expression) -> str:
576577
source_alias=self.source_alias,
577578
target_alias=self.target_alias,
578579
model_name=self.model_name,
580+
decimals=self.decimals,
579581
)
580582
return self._row_diff

tests/core/test_table_diff.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,49 @@
33
import pandas as pd
44
from sqlglot import exp
55
from sqlmesh.core import dialect as d
6+
import re
7+
import typing as t
8+
from io import StringIO
9+
from rich.console import Console
10+
from sqlmesh.core.console import TerminalConsole
611
from sqlmesh.core.context import Context
712
from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig
813
from sqlmesh.core.model import SqlModel, load_sql_based_model
914
from sqlmesh.core.table_diff import TableDiff
15+
import numpy as np
16+
17+
18+
def create_test_console() -> t.Tuple[StringIO, TerminalConsole]:
19+
"""Creates a console and buffer for validating console output."""
20+
console_output = StringIO()
21+
console = Console(file=console_output, force_terminal=True)
22+
terminal_console = TerminalConsole(console=console)
23+
return console_output, terminal_console
24+
25+
26+
def capture_console_output(method_name: str, **kwargs) -> str:
27+
"""Factory function to invoke and capture output a TerminalConsole method.
28+
29+
Args:
30+
method_name: Name of the TerminalConsole method to call
31+
**kwargs: Arguments to pass to the method
32+
33+
Returns:
34+
The captured output as a string
35+
"""
36+
console_output, terminal_console = create_test_console()
37+
try:
38+
method = getattr(terminal_console, method_name)
39+
method(**kwargs)
40+
return console_output.getvalue()
41+
finally:
42+
console_output.close()
43+
44+
45+
def strip_ansi_codes(text: str) -> str:
46+
"""Strip ANSI color codes and styling from text."""
47+
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
48+
return ansi_escape.sub("", text).strip()
1049

1150

1251
@pytest.mark.slow
@@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date):
121160
pd.DataFrame(
122161
{
123162
"key": [1, 2, 3],
124-
"value": [1.0, 2.0, 3.1234],
163+
"value": [1.0, 2.0, 3.1234321],
125164
}
126165
),
127166
)
@@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date):
162201
assert "DEV__value" in aliased_joined_sample
163202
assert "PROD__value" in aliased_joined_sample
164203

204+
output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff())
205+
206+
# Expected output with box-drawings
207+
expected_output = r"""
208+
Row Counts:
209+
├── FULL MATCH: 2 rows (66.67%)
210+
└── PARTIAL MATCH: 1 rows (33.33%)
211+
212+
COMMON ROWS column comparison stats:
213+
pct_match
214+
value 66.666667
215+
216+
217+
COMMON ROWS sample data differences:
218+
Column: value
219+
┏━━━━━┳━━━━━━━━┳━━━━━━━━┓
220+
┃ key ┃ DEV ┃ PROD ┃
221+
┡━━━━━╇━━━━━━━━╇━━━━━━━━┩
222+
│ 3.0 │ 3.1233 │ 3.1234 │
223+
└─────┴────────┴────────┘
224+
"""
225+
226+
stripped_output = strip_ansi_codes(output)
227+
stripped_expected = expected_output.strip()
228+
assert stripped_output == stripped_expected
229+
165230

166231
@pytest.mark.slow
167232
def test_grain_check(sushi_context_fixed_date):
@@ -363,3 +428,77 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context)
363428

364429
_, _, col_names = table_diff.key_columns
365430
assert col_names == ["waiter_id", "event_date"]
431+
432+
433+
@pytest.mark.slow
434+
def test_data_diff_array(sushi_context_fixed_date):
435+
engine_adapter = sushi_context_fixed_date.engine_adapter
436+
437+
engine_adapter.ctas(
438+
"table_diff_source",
439+
pd.DataFrame(
440+
{
441+
"key": [1, 2, 3],
442+
"value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])],
443+
}
444+
),
445+
)
446+
447+
engine_adapter.ctas(
448+
"table_diff_target",
449+
pd.DataFrame(
450+
{
451+
"key": [1, 2, 3],
452+
"value": [
453+
np.array([51.2, 4.5679]),
454+
np.array([2.31, 12.2, 3.6, 1.9]),
455+
np.array([5.0]),
456+
],
457+
}
458+
),
459+
)
460+
461+
table_diff = TableDiff(
462+
adapter=engine_adapter,
463+
source="table_diff_source",
464+
target="table_diff_target",
465+
source_alias="dev",
466+
target_alias="prod",
467+
on=["key"],
468+
decimals=4,
469+
)
470+
471+
diff = table_diff.row_diff()
472+
aliased_joined_sample = diff.joined_sample.columns
473+
474+
assert "DEV__value" in aliased_joined_sample
475+
assert "PROD__value" in aliased_joined_sample
476+
assert diff.full_match_count == 1
477+
assert diff.partial_match_count == 2
478+
479+
output = capture_console_output("show_row_diff", row_diff=diff)
480+
481+
# Expected output with boxes
482+
expected_output = r"""
483+
Row Counts:
484+
├── FULL MATCH: 1 rows (33.33%)
485+
└── PARTIAL MATCH: 2 rows (66.67%)
486+
487+
COMMON ROWS column comparison stats:
488+
pct_match
489+
value 33.333333
490+
491+
492+
COMMON ROWS sample data differences:
493+
Column: value
494+
┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
495+
┃ key ┃ DEV ┃ PROD ┃
496+
┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
497+
│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │
498+
│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │
499+
└─────┴────────────────┴────────────────────────┘
500+
"""
501+
502+
stripped_output = strip_ansi_codes(output)
503+
stripped_expected = expected_output.strip()
504+
assert stripped_output == stripped_expected

0 commit comments

Comments
 (0)