Skip to content

Commit 38c5321

Browse files
committed
Feat: Add verbose result comparison in tests
1 parent 2f8e3b7 commit 38c5321

5 files changed

Lines changed: 170 additions & 73 deletions

File tree

sqlmesh/core/console.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from rich.tree import Tree
2929
from sqlglot import exp
3030

31+
from sqlmesh.core.test.result import ModelTextTestResult
3132
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
3233
from sqlmesh.core.linter.rule import RuleViolation
3334
from sqlmesh.core.model import Model
@@ -498,6 +499,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
498499
def loading_stop(self, id: uuid.UUID) -> None:
499500
"""Stop loading for the given id."""
500501

502+
@abc.abstractmethod
503+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
504+
"""Print the unit test results."""
505+
501506

502507
class NoopConsole(Console):
503508
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
@@ -779,6 +784,9 @@ def start_destroy(self) -> bool:
779784
def stop_destroy(self, success: bool = True) -> None:
780785
pass
781786

787+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
788+
pass
789+
782790

783791
def make_progress_bar(
784792
message: str,
@@ -2494,6 +2502,51 @@ def show_linter_violations(
24942502
else:
24952503
self.log_warning(msg)
24962504

2505+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
2506+
tests_run = result.testsRun
2507+
errors = result.errors
2508+
failures = result.original_failures
2509+
skipped = result.skipped
2510+
2511+
is_success = not (errors or failures)
2512+
2513+
infos = []
2514+
if failures:
2515+
infos.append(f"failures={len(failures)}")
2516+
if errors:
2517+
infos.append(f"errors={len(errors)}")
2518+
if skipped:
2519+
infos.append(f"skipped={skipped}")
2520+
2521+
self._print("\n", end="")
2522+
2523+
for test_case, failure in failures:
2524+
self._print(unittest.TextTestResult.separator1)
2525+
self._print(f"FAIL: {test_case}")
2526+
2527+
if test_description := test_case.shortDescription():
2528+
self._print(test_description)
2529+
self._print(f"{unittest.TextTestResult.separator2}\n")
2530+
2531+
if exception := failure[1]:
2532+
for arg in exception.args:
2533+
self._print(arg)
2534+
self._print("\n")
2535+
2536+
for test_case, error in errors:
2537+
self._print(unittest.TextTestResult.separator1)
2538+
self._print(f"ERROR: {test_case}")
2539+
self._print(error)
2540+
2541+
# Output final report
2542+
self._print(unittest.TextTestResult.separator2)
2543+
self._print(
2544+
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
2545+
)
2546+
self._print(
2547+
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
2548+
)
2549+
24972550

24982551
def _cells_match(x: t.Any, y: t.Any) -> bool:
24992552
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""

sqlmesh/core/test/definition.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from pathlib import Path
1111
from unittest.mock import patch
1212

13+
from rich.table import Table
14+
from rich.tree import Tree
15+
from rich.align import Align
16+
1317
import numpy as np
1418
import pandas as pd
1519
from io import StringIO
@@ -27,6 +31,7 @@
2731
from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime
2832
from sqlmesh.utils.errors import ConfigError, TestError
2933
from sqlmesh.utils.yaml import load as yaml_load
34+
from sqlmesh.utils import Verbosity
3035

3136
if t.TYPE_CHECKING:
3237
from sqlglot.dialects.dialect import DialectType
@@ -61,6 +66,8 @@ def __init__(
6166
preserve_fixtures: bool = False,
6267
default_catalog: str | None = None,
6368
concurrency: bool = False,
69+
verbosity: Verbosity = Verbosity.DEFAULT,
70+
rich_output: bool = True,
6471
) -> None:
6572
"""ModelTest encapsulates a unit test for a model.
6673
@@ -84,6 +91,8 @@ def __init__(
8491
self.default_catalog = default_catalog
8592
self.dialect = dialect
8693
self.concurrency = concurrency
94+
self.verbosity = verbosity
95+
self.rich_output = rich_output
8796

8897
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
8998
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -278,6 +287,7 @@ def _to_hashable(x: t.Any) -> t.Any:
278287
check_like=True, # Ignore column order
279288
)
280289
except AssertionError as e:
290+
args: t.List[t.Any] = []
281291
if expected.shape != actual.shape:
282292
_raise_if_unexpected_columns(expected.columns, actual.columns)
283293

@@ -291,10 +301,35 @@ def _to_hashable(x: t.Any) -> t.Any:
291301
if not unexpected_rows.empty:
292302
error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
293303

294-
e.args = (error_msg,)
304+
args.append(error_msg)
295305
else:
296-
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
297-
e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",)
306+
diff = expected.compare(actual).rename(
307+
columns={"self": "Expected", "other": "Actual"}
308+
)
309+
310+
if not self.rich_output:
311+
args.append(f"Data mismatch\n\n{diff}")
312+
elif self.verbosity == Verbosity.DEFAULT:
313+
args.append(df_to_table("Data mismatch", diff))
314+
else:
315+
from pandas import MultiIndex
316+
317+
levels = t.cast(MultiIndex, diff.columns).levels[0]
318+
for col in levels:
319+
col_diff = diff[col]
320+
if not col_diff.empty:
321+
table = df_to_table(
322+
f"[bold red]Column '{col}' mismatch[/bold red]", col_diff
323+
)
324+
args.append(table)
325+
326+
# Show summary statistics
327+
summary_tree = Tree("[bold][summary]Summary:[/summary]")
328+
summary_tree.add(f"Total differences: {len(diff)}\n")
329+
summary_tree.add(f"Different columns: {len(levels)}\n")
330+
args.append(summary_tree)
331+
332+
e.args = (*args,)
298333

299334
raise e
300335

@@ -316,6 +351,7 @@ def create_test(
316351
preserve_fixtures: bool = False,
317352
default_catalog: str | None = None,
318353
concurrency: bool = False,
354+
verbosity: Verbosity = Verbosity.DEFAULT,
319355
) -> t.Optional[ModelTest]:
320356
"""Create a SqlModelTest or a PythonModelTest.
321357
@@ -361,6 +397,7 @@ def create_test(
361397
preserve_fixtures,
362398
default_catalog,
363399
concurrency,
400+
verbosity,
364401
)
365402
except Exception as e:
366403
raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
@@ -676,6 +713,8 @@ def __init__(
676713
preserve_fixtures: bool = False,
677714
default_catalog: str | None = None,
678715
concurrency: bool = False,
716+
verbosity: Verbosity = Verbosity.DEFAULT,
717+
rich_output: bool = True,
679718
) -> None:
680719
"""PythonModelTest encapsulates a unit test for a Python model.
681720
@@ -702,6 +741,8 @@ def __init__(
702741
preserve_fixtures,
703742
default_catalog,
704743
concurrency,
744+
verbosity,
745+
rich_output,
705746
)
706747

707748
self.context = TestExecutionContext(
@@ -926,3 +967,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
926967
return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])}
927968
return {k: _normalize_df_value(v) for k, v in value.items()}
928969
return value
970+
971+
972+
def df_to_table(
973+
header: str,
974+
df: pd.DataFrame,
975+
show_index: bool = True,
976+
index_name: str = "Row",
977+
) -> Table:
978+
"""Convert a pandas.DataFrame obj into a rich.Table obj.
979+
Args:
980+
df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
981+
rich_table (Table): A rich Table that should be populated by the DataFrame values.
982+
show_index (bool): Add a column with a row count to the table. Defaults to True.
983+
index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
984+
Returns:
985+
Table: The rich Table instance passed, populated with the DataFrame values."""
986+
987+
rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60)
988+
if show_index:
989+
index_name = str(index_name) if index_name else ""
990+
rich_table.add_column(index_name)
991+
992+
for column in df.columns:
993+
column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column)
994+
if "expected" in column_name.lower():
995+
column_name = f"[green]{column_name}[/green]"
996+
else:
997+
column_name = f"[red]{column_name}[/red]"
998+
999+
rich_table.add_column(Align.center(column_name))
1000+
1001+
for index, value_list in enumerate(df.values.tolist()):
1002+
row = [str(index)] if show_index else []
1003+
row += [str(x) for x in value_list]
1004+
center = [Align.center(x) for x in row]
1005+
rich_table.add_row(*center)
1006+
1007+
return rich_table

sqlmesh/core/test/result.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -83,43 +83,6 @@ def log_test_report(self, test_duration: float) -> None:
8383
Args:
8484
test_duration: The duration of the tests.
8585
"""
86-
tests_run = self.testsRun
87-
errors = self.errors
88-
failures = self.failures
89-
skipped = self.skipped
90-
91-
is_success = not (errors or failures)
92-
93-
infos = []
94-
if failures:
95-
infos.append(f"failures={len(failures)}")
96-
if errors:
97-
infos.append(f"errors={len(errors)}")
98-
if skipped:
99-
infos.append(f"skipped={skipped}")
100-
101-
stream = self.stream
102-
103-
stream.write("\n")
104-
105-
for test_case, failure in failures:
106-
stream.writeln(unittest.TextTestResult.separator1)
107-
stream.writeln(f"FAIL: {test_case}")
108-
if test_description := test_case.shortDescription():
109-
stream.writeln(test_description)
110-
stream.writeln(unittest.TextTestResult.separator2)
111-
stream.writeln(failure)
112-
113-
for test_case, error in errors:
114-
stream.writeln(unittest.TextTestResult.separator1)
115-
stream.writeln(f"ERROR: {test_case}")
116-
stream.writeln(error)
117-
118-
# Output final report
119-
stream.writeln(unittest.TextTestResult.separator2)
120-
stream.writeln(
121-
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
122-
)
123-
stream.writeln(
124-
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
125-
)
86+
from sqlmesh.core.console import get_console
87+
88+
get_console().log_unit_test_results(self, test_duration)

sqlmesh/core/test/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run_tests(
107107
lock = threading.Lock()
108108

109109
combined_results = ModelTextTestResult(
110-
stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore
110+
stream=unittest.runner._WritelnDecorator(stream or sys.stdout), # type: ignore
111111
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
112112
descriptions=True,
113113
)
@@ -136,6 +136,7 @@ def _run_single_test(
136136
default_catalog=default_catalog,
137137
preserve_fixtures=preserve_fixtures,
138138
concurrency=num_workers > 1,
139+
verbosity=verbosity,
139140
)
140141

141142
if not test:

0 commit comments

Comments
 (0)