Skip to content

Commit 974cb04

Browse files
committed
Feat: Add verbose result comparison in tests
1 parent 7dd52c5 commit 974cb04

5 files changed

Lines changed: 165 additions & 73 deletions

File tree

sqlmesh/core/console.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from rich.tree import Tree
2727
from sqlglot import exp
2828

29+
from sqlmesh.core.test.result import ModelTextTestResult
2930
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
3031
from sqlmesh.core.linter.rule import RuleViolation
3132
from sqlmesh.core.model import Model
@@ -496,6 +497,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
496497
def loading_stop(self, id: uuid.UUID) -> None:
497498
"""Stop loading for the given id."""
498499

500+
@abc.abstractmethod
501+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
502+
"""Print the unit test results."""
503+
499504

500505
class NoopConsole(Console):
501506
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
@@ -777,6 +782,9 @@ def start_destroy(self) -> bool:
777782
def stop_destroy(self, success: bool = True) -> None:
778783
pass
779784

785+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
786+
pass
787+
780788

781789
def make_progress_bar(
782790
message: str,
@@ -2492,6 +2500,51 @@ def show_linter_violations(
24922500
else:
24932501
self.log_warning(msg)
24942502

2503+
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
2504+
tests_run = result.testsRun
2505+
errors = result.errors
2506+
failures = result.original_failures
2507+
skipped = result.skipped
2508+
2509+
is_success = not (errors or failures)
2510+
2511+
infos = []
2512+
if failures:
2513+
infos.append(f"failures={len(failures)}")
2514+
if errors:
2515+
infos.append(f"errors={len(errors)}")
2516+
if skipped:
2517+
infos.append(f"skipped={skipped}")
2518+
2519+
self._print("\n", end="")
2520+
2521+
for test_case, failure in failures:
2522+
self._print(unittest.TextTestResult.separator1)
2523+
self._print(f"FAIL: {test_case}")
2524+
2525+
if test_description := test_case.shortDescription():
2526+
self._print(test_description)
2527+
self._print(f"{unittest.TextTestResult.separator2}\n")
2528+
2529+
if exception := failure[1]:
2530+
for arg in exception.args:
2531+
self._print(arg)
2532+
self._print("\n")
2533+
2534+
for test_case, error in errors:
2535+
self._print(unittest.TextTestResult.separator1)
2536+
self._print(f"ERROR: {test_case}")
2537+
self._print(error)
2538+
2539+
# Output final report
2540+
self._print(unittest.TextTestResult.separator2)
2541+
self._print(
2542+
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
2543+
)
2544+
self._print(
2545+
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
2546+
)
2547+
24952548

24962549
def _cells_match(x: t.Any, y: t.Any) -> bool:
24972550
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""

sqlmesh/core/test/definition.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from pathlib import Path
1111
from unittest.mock import patch
1212

13+
from rich.table import Table
14+
from rich.align import Align
15+
16+
import numpy as np
17+
import pandas as pd
1318
from io import StringIO
1419
from sqlglot import Dialect, exp
1520
from sqlglot.optimizer.annotate_types import annotate_types
@@ -24,6 +29,7 @@
2429
from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime
2530
from sqlmesh.utils.errors import ConfigError, TestError
2631
from sqlmesh.utils.yaml import load as yaml_load
32+
from sqlmesh.utils import Verbosity
2733

2834
if t.TYPE_CHECKING:
2935
import pandas as pd
@@ -60,6 +66,8 @@ def __init__(
6066
preserve_fixtures: bool = False,
6167
default_catalog: str | None = None,
6268
concurrency: bool = False,
69+
verbosity: Verbosity = Verbosity.DEFAULT,
70+
rich_output: bool = True,
6371
) -> None:
6472
"""ModelTest encapsulates a unit test for a model.
6573
@@ -83,6 +91,8 @@ def __init__(
8391
self.default_catalog = default_catalog
8492
self.dialect = dialect
8593
self.concurrency = concurrency
94+
self.verbosity = verbosity
95+
self.rich_output = rich_output
8696

8797
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
8898
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -281,6 +291,7 @@ def _to_hashable(x: t.Any) -> t.Any:
281291
check_like=True, # Ignore column order
282292
)
283293
except AssertionError as e:
294+
args: t.List[t.Any] = []
284295
if expected.shape != actual.shape:
285296
_raise_if_unexpected_columns(expected.columns, actual.columns)
286297

@@ -294,10 +305,29 @@ def _to_hashable(x: t.Any) -> t.Any:
294305
if not unexpected_rows.empty:
295306
error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
296307

297-
e.args = (error_msg,)
308+
args.append(error_msg)
298309
else:
299-
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
300-
e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",)
310+
diff = expected.compare(actual).rename(
311+
columns={"self": "Expected", "other": "Actual"}
312+
)
313+
314+
if not self.rich_output:
315+
args.append(f"Data mismatch\n\n{diff}")
316+
elif self.verbosity == Verbosity.DEFAULT:
317+
args.append(df_to_table("Data mismatch", diff))
318+
else:
319+
from pandas import MultiIndex
320+
321+
levels = t.cast(MultiIndex, diff.columns).levels[0]
322+
for col in levels:
323+
col_diff = diff[col]
324+
if not col_diff.empty:
325+
table = df_to_table(
326+
f"[bold red]Column '{col}' mismatch[/bold red]", col_diff
327+
)
328+
args.append(table)
329+
330+
e.args = (*args,)
301331

302332
raise e
303333

@@ -319,6 +349,7 @@ def create_test(
319349
preserve_fixtures: bool = False,
320350
default_catalog: str | None = None,
321351
concurrency: bool = False,
352+
verbosity: Verbosity = Verbosity.DEFAULT,
322353
) -> t.Optional[ModelTest]:
323354
"""Create a SqlModelTest or a PythonModelTest.
324355
@@ -364,6 +395,7 @@ def create_test(
364395
preserve_fixtures,
365396
default_catalog,
366397
concurrency,
398+
verbosity,
367399
)
368400
except Exception as e:
369401
raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
@@ -683,6 +715,8 @@ def __init__(
683715
preserve_fixtures: bool = False,
684716
default_catalog: str | None = None,
685717
concurrency: bool = False,
718+
verbosity: Verbosity = Verbosity.DEFAULT,
719+
rich_output: bool = True,
686720
) -> None:
687721
"""PythonModelTest encapsulates a unit test for a Python model.
688722
@@ -709,6 +743,8 @@ def __init__(
709743
preserve_fixtures,
710744
default_catalog,
711745
concurrency,
746+
verbosity,
747+
rich_output,
712748
)
713749

714750
self.context = TestExecutionContext(
@@ -942,3 +978,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
942978
return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])}
943979
return {k: _normalize_df_value(v) for k, v in value.items()}
944980
return value
981+
982+
983+
def df_to_table(
984+
header: str,
985+
df: pd.DataFrame,
986+
show_index: bool = True,
987+
index_name: str = "Row",
988+
) -> Table:
989+
"""Convert a pandas.DataFrame obj into a rich.Table obj.
990+
Args:
991+
df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
992+
rich_table (Table): A rich Table that should be populated by the DataFrame values.
993+
show_index (bool): Add a column with a row count to the table. Defaults to True.
994+
index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
995+
Returns:
996+
Table: The rich Table instance passed, populated with the DataFrame values."""
997+
998+
rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60)
999+
if show_index:
1000+
index_name = str(index_name) if index_name else ""
1001+
rich_table.add_column(index_name)
1002+
1003+
for column in df.columns:
1004+
column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column)
1005+
if "expected" in column_name.lower():
1006+
column_name = f"[green]{column_name}[/green]"
1007+
else:
1008+
column_name = f"[red]{column_name}[/red]"
1009+
1010+
rich_table.add_column(Align.center(column_name))
1011+
1012+
for index, value_list in enumerate(df.values.tolist()):
1013+
row = [str(index)] if show_index else []
1014+
row += [str(x) for x in value_list]
1015+
center = [Align.center(x) for x in row]
1016+
rich_table.add_row(*center)
1017+
1018+
return rich_table

sqlmesh/core/test/result.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -83,43 +83,6 @@ def log_test_report(self, test_duration: float) -> None:
8383
Args:
8484
test_duration: The duration of the tests.
8585
"""
86-
tests_run = self.testsRun
87-
errors = self.errors
88-
failures = self.failures
89-
skipped = self.skipped
90-
91-
is_success = not (errors or failures)
92-
93-
infos = []
94-
if failures:
95-
infos.append(f"failures={len(failures)}")
96-
if errors:
97-
infos.append(f"errors={len(errors)}")
98-
if skipped:
99-
infos.append(f"skipped={skipped}")
100-
101-
stream = self.stream
102-
103-
stream.write("\n")
104-
105-
for test_case, failure in failures:
106-
stream.writeln(unittest.TextTestResult.separator1)
107-
stream.writeln(f"FAIL: {test_case}")
108-
if test_description := test_case.shortDescription():
109-
stream.writeln(test_description)
110-
stream.writeln(unittest.TextTestResult.separator2)
111-
stream.writeln(failure)
112-
113-
for test_case, error in errors:
114-
stream.writeln(unittest.TextTestResult.separator1)
115-
stream.writeln(f"ERROR: {test_case}")
116-
stream.writeln(error)
117-
118-
# Output final report
119-
stream.writeln(unittest.TextTestResult.separator2)
120-
stream.writeln(
121-
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
122-
)
123-
stream.writeln(
124-
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
125-
)
86+
from sqlmesh.core.console import get_console
87+
88+
get_console().log_unit_test_results(self, test_duration)

sqlmesh/core/test/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run_tests(
107107
lock = threading.Lock()
108108

109109
combined_results = ModelTextTestResult(
110-
stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore
110+
stream=unittest.runner._WritelnDecorator(stream or sys.stdout), # type: ignore
111111
verbosity=2 if verbosity >= Verbosity.VERBOSE else 1,
112112
descriptions=True,
113113
)
@@ -136,6 +136,7 @@ def _run_single_test(
136136
default_catalog=default_catalog,
137137
preserve_fixtures=preserve_fixtures,
138138
concurrency=num_workers > 1,
139+
verbosity=verbosity,
139140
)
140141

141142
if not test:

0 commit comments

Comments
 (0)