Skip to content

Commit fc7e68f

Browse files
Refactor; extend test for dicts
1 parent d5a5588 commit fc7e68f

2 files changed

Lines changed: 30 additions & 24 deletions

File tree

sqlmesh/core/console.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,32 +1905,10 @@ def show_row_diff(
19051905
# Create a table with the joined keys and comparison columns
19061906
column_table = row_diff.joined_sample[keys + [source_column, target_column]]
19071907

1908-
def compare_cells(x: t.Any, y: t.Any) -> bool:
1909-
"""Compare two cells and returns true if they're not equal, handling array objects."""
1910-
if x is None or y is None:
1911-
return x != y
1912-
1913-
# Convert any array-like object to list for consistent comparison
1914-
def to_list(val: t.Any) -> t.Any:
1915-
return (
1916-
list(val)
1917-
if isinstance(val, (pd.Series, np.ndarray, list, tuple, set))
1918-
else val
1919-
)
1920-
1921-
x = to_list(x)
1922-
y = to_list(y)
1923-
if isinstance(x, list) and isinstance(y, list):
1924-
if len(x) != len(y):
1925-
return True
1926-
return any(a != b for a, b in zip(x, y))
1927-
1928-
return x != y
1929-
19301908
# Filter to retain non identical-valued rows
19311909
column_table = column_table[
19321910
column_table.apply(
1933-
lambda row: compare_cells(row[source_column], row[target_column]),
1911+
lambda row: _compare_df_cells(row[source_column], row[target_column]),
19341912
axis=1,
19351913
)
19361914
]
@@ -2064,6 +2042,25 @@ def show_linter_violations(
20642042
self.log_warning(msg)
20652043

20662044

2045+
def _compare_df_cells(x: t.Any, y: t.Any) -> bool:
2046+
"""Helper function to compare two cells and returns true if they're not equal, handling array objects."""
2047+
if x is None or y is None:
2048+
return x != y
2049+
2050+
# Convert any array-like object to list for consistent comparison
2051+
def to_list(val: t.Any) -> t.Any:
2052+
return list(val) if isinstance(val, (pd.Series, np.ndarray, list, tuple, set)) else val
2053+
2054+
x = to_list(x)
2055+
y = to_list(y)
2056+
if isinstance(x, list) and isinstance(y, list):
2057+
if len(x) != len(y):
2058+
return True
2059+
return any(a != b for a, b in zip(x, y))
2060+
2061+
return x != y
2062+
2063+
20672064
def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget:
20682065
"""Helper function to add a widget to a layout widget.
20692066

tests/core/test_table_diff.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context)
431431

432432

433433
@pytest.mark.slow
434-
def test_data_diff_array(sushi_context_fixed_date):
434+
def test_data_diff_array_dict(sushi_context_fixed_date):
435435
engine_adapter = sushi_context_fixed_date.engine_adapter
436436

437437
engine_adapter.ctas(
@@ -440,6 +440,7 @@ def test_data_diff_array(sushi_context_fixed_date):
440440
{
441441
"key": [1, 2, 3],
442442
"value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])],
443+
"dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}],
443444
}
444445
),
445446
)
@@ -454,6 +455,7 @@ def test_data_diff_array(sushi_context_fixed_date):
454455
np.array([2.31, 12.2, 3.6, 1.9]),
455456
np.array([5.0]),
456457
],
458+
"dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}],
457459
}
458460
),
459461
)
@@ -487,6 +489,7 @@ def test_data_diff_array(sushi_context_fixed_date):
487489
COMMON ROWS column comparison stats:
488490
pct_match
489491
value 33.333333
492+
dict 66.666667
490493
491494
492495
COMMON ROWS sample data differences:
@@ -497,6 +500,12 @@ def test_data_diff_array(sushi_context_fixed_date):
497500
│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │
498501
│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │
499502
└─────┴────────────────┴────────────────────────┘
503+
Column: dict
504+
┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
505+
┃ key ┃ DEV ┃ PROD ┃
506+
┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
507+
│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │
508+
└─────┴─────────────────────────────┴────────────────────┘
500509
"""
501510

502511
stripped_output = strip_ansi_codes(output)

0 commit comments

Comments
 (0)