Skip to content

Commit 42511e1

Browse files
Fix: Account for array types when showing sample in table diff (#4077)
1 parent 0fc89dd commit 42511e1

3 files changed

Lines changed: 180 additions & 5 deletions

File tree

sqlmesh/core/console.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import logging
99
import textwrap
1010
from pathlib import Path
11-
11+
import pandas as pd
12+
import numpy as np
1213
from hyperscript import h
1314
from rich.console import Console as RichConsole
1415
from rich.live import Live
@@ -1904,10 +1905,15 @@ def show_row_diff(
19041905
# Create a table with the joined keys and comparison columns
19051906
column_table = row_diff.joined_sample[keys + [source_column, target_column]]
19061907

1907-
# Filter out identical-valued rows
1908+
# Filter to retain non identical-valued rows
19081909
column_table = column_table[
1909-
column_table[source_column] != column_table[target_column]
1910+
column_table.apply(
1911+
lambda row: not _cells_match(row[source_column], row[target_column]),
1912+
axis=1,
1913+
)
19101914
]
1915+
1916+
# Rename the column headers for readability
19111917
column_table = column_table.rename(
19121918
columns={
19131919
source_column: source_name,
@@ -1921,7 +1927,16 @@ def show_row_diff(
19211927
table.add_column(column_name, style=style, header_style=style)
19221928

19231929
for _, row in column_table.iterrows():
1924-
table.add_row(*[str(cell) for cell in row])
1930+
table.add_row(
1931+
*[
1932+
str(
1933+
round(cell, row_diff.decimals)
1934+
if isinstance(cell, float)
1935+
else cell
1936+
)
1937+
for cell in row
1938+
]
1939+
)
19251940

19261941
self.console.print(
19271942
f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]",
@@ -2027,6 +2042,16 @@ def show_linter_violations(
20272042
self.log_warning(msg)
20282043

20292044

2045+
def _cells_match(x: t.Any, y: t.Any) -> bool:
2046+
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
2047+
2048+
# Convert array-like objects to list for consistent comparison
2049+
def _normalize(val: t.Any) -> t.Any:
2050+
return list(val) if isinstance(val, (pd.Series, np.ndarray)) else val
2051+
2052+
return _normalize(x) == _normalize(y)
2053+
2054+
20302055
def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget:
20312056
"""Helper function to add a widget to a layout widget.
20322057

sqlmesh/core/table_diff.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class RowDiff(PydanticModel, frozen=True):
7070
source_alias: t.Optional[str] = None
7171
target_alias: t.Optional[str] = None
7272
model_name: t.Optional[str] = None
73+
decimals: int = 3
7374

7475
@property
7576
def source_count(self) -> int:
@@ -576,5 +577,6 @@ def name(e: exp.Expression) -> str:
576577
source_alias=self.source_alias,
577578
target_alias=self.target_alias,
578579
model_name=self.model_name,
580+
decimals=self.decimals,
579581
)
580582
return self._row_diff

tests/core/test_table_diff.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,49 @@
33
import pandas as pd
44
from sqlglot import exp
55
from sqlmesh.core import dialect as d
6+
import re
7+
import typing as t
8+
from io import StringIO
9+
from rich.console import Console
10+
from sqlmesh.core.console import TerminalConsole
611
from sqlmesh.core.context import Context
712
from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig
813
from sqlmesh.core.model import SqlModel, load_sql_based_model
914
from sqlmesh.core.table_diff import TableDiff
15+
import numpy as np
16+
17+
18+
def create_test_console() -> t.Tuple[StringIO, TerminalConsole]:
19+
"""Creates a console and buffer for validating console output."""
20+
console_output = StringIO()
21+
console = Console(file=console_output, force_terminal=True)
22+
terminal_console = TerminalConsole(console=console)
23+
return console_output, terminal_console
24+
25+
26+
def capture_console_output(method_name: str, **kwargs) -> str:
27+
"""Factory function to invoke and capture output a TerminalConsole method.
28+
29+
Args:
30+
method_name: Name of the TerminalConsole method to call
31+
**kwargs: Arguments to pass to the method
32+
33+
Returns:
34+
The captured output as a string
35+
"""
36+
console_output, terminal_console = create_test_console()
37+
try:
38+
method = getattr(terminal_console, method_name)
39+
method(**kwargs)
40+
return console_output.getvalue()
41+
finally:
42+
console_output.close()
43+
44+
45+
def strip_ansi_codes(text: str) -> str:
46+
"""Strip ANSI color codes and styling from text."""
47+
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
48+
return ansi_escape.sub("", text).strip()
1049

1150

1251
@pytest.mark.slow
@@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date):
121160
pd.DataFrame(
122161
{
123162
"key": [1, 2, 3],
124-
"value": [1.0, 2.0, 3.1234],
163+
"value": [1.0, 2.0, 3.1234321],
125164
}
126165
),
127166
)
@@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date):
162201
assert "DEV__value" in aliased_joined_sample
163202
assert "PROD__value" in aliased_joined_sample
164203

204+
output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff())
205+
206+
# Expected output with box-drawings
207+
expected_output = r"""
208+
Row Counts:
209+
├── FULL MATCH: 2 rows (66.67%)
210+
└── PARTIAL MATCH: 1 rows (33.33%)
211+
212+
COMMON ROWS column comparison stats:
213+
pct_match
214+
value 66.666667
215+
216+
217+
COMMON ROWS sample data differences:
218+
Column: value
219+
┏━━━━━┳━━━━━━━━┳━━━━━━━━┓
220+
┃ key ┃ DEV ┃ PROD ┃
221+
┡━━━━━╇━━━━━━━━╇━━━━━━━━┩
222+
│ 3.0 │ 3.1233 │ 3.1234 │
223+
└─────┴────────┴────────┘
224+
"""
225+
226+
stripped_output = strip_ansi_codes(output)
227+
stripped_expected = expected_output.strip()
228+
assert stripped_output == stripped_expected
229+
165230

166231
@pytest.mark.slow
167232
def test_grain_check(sushi_context_fixed_date):
@@ -363,3 +428,86 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context)
363428

364429
_, _, col_names = table_diff.key_columns
365430
assert col_names == ["waiter_id", "event_date"]
431+
432+
433+
@pytest.mark.slow
434+
def test_data_diff_array_dict(sushi_context_fixed_date):
435+
engine_adapter = sushi_context_fixed_date.engine_adapter
436+
437+
engine_adapter.ctas(
438+
"table_diff_source",
439+
pd.DataFrame(
440+
{
441+
"key": [1, 2, 3],
442+
"value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])],
443+
"dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}],
444+
}
445+
),
446+
)
447+
448+
engine_adapter.ctas(
449+
"table_diff_target",
450+
pd.DataFrame(
451+
{
452+
"key": [1, 2, 3],
453+
"value": [
454+
np.array([51.2, 4.5679]),
455+
np.array([2.31, 12.2, 3.6, 1.9]),
456+
np.array([5.0]),
457+
],
458+
"dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}],
459+
}
460+
),
461+
)
462+
463+
table_diff = TableDiff(
464+
adapter=engine_adapter,
465+
source="table_diff_source",
466+
target="table_diff_target",
467+
source_alias="dev",
468+
target_alias="prod",
469+
on=["key"],
470+
decimals=4,
471+
)
472+
473+
diff = table_diff.row_diff()
474+
aliased_joined_sample = diff.joined_sample.columns
475+
476+
assert "DEV__value" in aliased_joined_sample
477+
assert "PROD__value" in aliased_joined_sample
478+
assert diff.full_match_count == 1
479+
assert diff.partial_match_count == 2
480+
481+
output = capture_console_output("show_row_diff", row_diff=diff)
482+
483+
# Expected output with boxes
484+
expected_output = r"""
485+
Row Counts:
486+
├── FULL MATCH: 1 rows (33.33%)
487+
└── PARTIAL MATCH: 2 rows (66.67%)
488+
489+
COMMON ROWS column comparison stats:
490+
pct_match
491+
value 33.333333
492+
dict 66.666667
493+
494+
495+
COMMON ROWS sample data differences:
496+
Column: value
497+
┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
498+
┃ key ┃ DEV ┃ PROD ┃
499+
┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
500+
│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │
501+
│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │
502+
└─────┴────────────────┴────────────────────────┘
503+
Column: dict
504+
┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
505+
┃ key ┃ DEV ┃ PROD ┃
506+
┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
507+
│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │
508+
└─────┴─────────────────────────────┴────────────────────┘
509+
"""
510+
511+
stripped_output = strip_ansi_codes(output)
512+
stripped_expected = expected_output.strip()
513+
assert stripped_output == stripped_expected

0 commit comments

Comments
 (0)