|
3 | 3 | import pandas as pd |
4 | 4 | from sqlglot import exp |
5 | 5 | from sqlmesh.core import dialect as d |
| 6 | +import re |
| 7 | +import typing as t |
| 8 | +from io import StringIO |
| 9 | +from rich.console import Console |
| 10 | +from sqlmesh.core.console import TerminalConsole |
6 | 11 | from sqlmesh.core.context import Context |
7 | 12 | from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig |
8 | 13 | from sqlmesh.core.model import SqlModel, load_sql_based_model |
9 | 14 | from sqlmesh.core.table_diff import TableDiff |
| 15 | +import numpy as np |
| 16 | + |
| 17 | + |
| 18 | +def create_test_console() -> t.Tuple[StringIO, TerminalConsole]: |
| 19 | + """Creates a console and buffer for validating console output.""" |
| 20 | + console_output = StringIO() |
| 21 | + console = Console(file=console_output, force_terminal=True) |
| 22 | + terminal_console = TerminalConsole(console=console) |
| 23 | + return console_output, terminal_console |
| 24 | + |
| 25 | + |
| 26 | +def capture_console_output(method_name: str, **kwargs) -> str: |
| 27 | + """Factory function to invoke and capture output a TerminalConsole method. |
| 28 | +
|
| 29 | + Args: |
| 30 | + method_name: Name of the TerminalConsole method to call |
| 31 | + **kwargs: Arguments to pass to the method |
| 32 | +
|
| 33 | + Returns: |
| 34 | + The captured output as a string |
| 35 | + """ |
| 36 | + console_output, terminal_console = create_test_console() |
| 37 | + try: |
| 38 | + method = getattr(terminal_console, method_name) |
| 39 | + method(**kwargs) |
| 40 | + return console_output.getvalue() |
| 41 | + finally: |
| 42 | + console_output.close() |
| 43 | + |
| 44 | + |
| 45 | +def strip_ansi_codes(text: str) -> str: |
| 46 | + """Strip ANSI color codes and styling from text.""" |
| 47 | + ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") |
| 48 | + return ansi_escape.sub("", text).strip() |
10 | 49 |
|
11 | 50 |
|
12 | 51 | @pytest.mark.slow |
@@ -121,7 +160,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): |
121 | 160 | pd.DataFrame( |
122 | 161 | { |
123 | 162 | "key": [1, 2, 3], |
124 | | - "value": [1.0, 2.0, 3.1234], |
| 163 | + "value": [1.0, 2.0, 3.1234321], |
125 | 164 | } |
126 | 165 | ), |
127 | 166 | ) |
@@ -162,6 +201,32 @@ def test_data_diff_decimals(sushi_context_fixed_date): |
162 | 201 | assert "DEV__value" in aliased_joined_sample |
163 | 202 | assert "PROD__value" in aliased_joined_sample |
164 | 203 |
|
| 204 | + output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff()) |
| 205 | + |
| 206 | + # Expected output with box-drawings |
| 207 | + expected_output = r""" |
| 208 | +Row Counts: |
| 209 | +├── FULL MATCH: 2 rows (66.67%) |
| 210 | +└── PARTIAL MATCH: 1 rows (33.33%) |
| 211 | +
|
| 212 | +COMMON ROWS column comparison stats: |
| 213 | + pct_match |
| 214 | +value 66.666667 |
| 215 | +
|
| 216 | +
|
| 217 | +COMMON ROWS sample data differences: |
| 218 | +Column: value |
| 219 | +┏━━━━━┳━━━━━━━━┳━━━━━━━━┓ |
| 220 | +┃ key ┃ DEV ┃ PROD ┃ |
| 221 | +┡━━━━━╇━━━━━━━━╇━━━━━━━━┩ |
| 222 | +│ 3.0 │ 3.1233 │ 3.1234 │ |
| 223 | +└─────┴────────┴────────┘ |
| 224 | +""" |
| 225 | + |
| 226 | + stripped_output = strip_ansi_codes(output) |
| 227 | + stripped_expected = expected_output.strip() |
| 228 | + assert stripped_output == stripped_expected |
| 229 | + |
165 | 230 |
|
166 | 231 | @pytest.mark.slow |
167 | 232 | def test_grain_check(sushi_context_fixed_date): |
@@ -363,3 +428,86 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context) |
363 | 428 |
|
364 | 429 | _, _, col_names = table_diff.key_columns |
365 | 430 | assert col_names == ["waiter_id", "event_date"] |
| 431 | + |
| 432 | + |
| 433 | +@pytest.mark.slow |
| 434 | +def test_data_diff_array_dict(sushi_context_fixed_date): |
| 435 | + engine_adapter = sushi_context_fixed_date.engine_adapter |
| 436 | + |
| 437 | + engine_adapter.ctas( |
| 438 | + "table_diff_source", |
| 439 | + pd.DataFrame( |
| 440 | + { |
| 441 | + "key": [1, 2, 3], |
| 442 | + "value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])], |
| 443 | + "dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}], |
| 444 | + } |
| 445 | + ), |
| 446 | + ) |
| 447 | + |
| 448 | + engine_adapter.ctas( |
| 449 | + "table_diff_target", |
| 450 | + pd.DataFrame( |
| 451 | + { |
| 452 | + "key": [1, 2, 3], |
| 453 | + "value": [ |
| 454 | + np.array([51.2, 4.5679]), |
| 455 | + np.array([2.31, 12.2, 3.6, 1.9]), |
| 456 | + np.array([5.0]), |
| 457 | + ], |
| 458 | + "dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}], |
| 459 | + } |
| 460 | + ), |
| 461 | + ) |
| 462 | + |
| 463 | + table_diff = TableDiff( |
| 464 | + adapter=engine_adapter, |
| 465 | + source="table_diff_source", |
| 466 | + target="table_diff_target", |
| 467 | + source_alias="dev", |
| 468 | + target_alias="prod", |
| 469 | + on=["key"], |
| 470 | + decimals=4, |
| 471 | + ) |
| 472 | + |
| 473 | + diff = table_diff.row_diff() |
| 474 | + aliased_joined_sample = diff.joined_sample.columns |
| 475 | + |
| 476 | + assert "DEV__value" in aliased_joined_sample |
| 477 | + assert "PROD__value" in aliased_joined_sample |
| 478 | + assert diff.full_match_count == 1 |
| 479 | + assert diff.partial_match_count == 2 |
| 480 | + |
| 481 | + output = capture_console_output("show_row_diff", row_diff=diff) |
| 482 | + |
| 483 | + # Expected output with boxes |
| 484 | + expected_output = r""" |
| 485 | +Row Counts: |
| 486 | +├── FULL MATCH: 1 rows (33.33%) |
| 487 | +└── PARTIAL MATCH: 2 rows (66.67%) |
| 488 | +
|
| 489 | +COMMON ROWS column comparison stats: |
| 490 | + pct_match |
| 491 | +value 33.333333 |
| 492 | +dict 66.666667 |
| 493 | +
|
| 494 | +
|
| 495 | +COMMON ROWS sample data differences: |
| 496 | +Column: value |
| 497 | +┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| 498 | +┃ key ┃ DEV ┃ PROD ┃ |
| 499 | +┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| 500 | +│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │ |
| 501 | +│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │ |
| 502 | +└─────┴────────────────┴────────────────────────┘ |
| 503 | +Column: dict |
| 504 | +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓ |
| 505 | +┃ key ┃ DEV ┃ PROD ┃ |
| 506 | +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩ |
| 507 | +│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │ |
| 508 | +└─────┴─────────────────────────────┴────────────────────┘ |
| 509 | +""" |
| 510 | + |
| 511 | + stripped_output = strip_ansi_codes(output) |
| 512 | + stripped_expected = expected_output.strip() |
| 513 | + assert stripped_output == stripped_expected |
0 commit comments