1010from pathlib import Path
1111from unittest .mock import patch
1212
13+ from rich .table import Table
14+ from rich .tree import Tree
15+ from rich .align import Align
16+
1317import numpy as np
1418import pandas as pd
1519from io import StringIO
2731from sqlmesh .utils .date import date_dict , pandas_timestamp_to_pydatetime , to_datetime
2832from sqlmesh .utils .errors import ConfigError , TestError
2933from sqlmesh .utils .yaml import load as yaml_load
34+ from sqlmesh .utils import Verbosity
3035
3136if t .TYPE_CHECKING :
3237 from sqlglot .dialects .dialect import DialectType
@@ -61,6 +66,8 @@ def __init__(
6166 preserve_fixtures : bool = False ,
6267 default_catalog : str | None = None ,
6368 concurrency : bool = False ,
69+ verbosity : Verbosity = Verbosity .DEFAULT ,
70+ rich_output : bool = True ,
6471 ) -> None :
6572 """ModelTest encapsulates a unit test for a model.
6673
@@ -84,6 +91,8 @@ def __init__(
8491 self .default_catalog = default_catalog
8592 self .dialect = dialect
8693 self .concurrency = concurrency
94+ self .verbosity = verbosity
95+ self .rich_output = rich_output
8796
8897 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8998 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -278,6 +287,7 @@ def _to_hashable(x: t.Any) -> t.Any:
278287 check_like = True , # Ignore column order
279288 )
280289 except AssertionError as e :
290+ args : t .List [t .Any ] = []
281291 if expected .shape != actual .shape :
282292 _raise_if_unexpected_columns (expected .columns , actual .columns )
283293
@@ -291,10 +301,35 @@ def _to_hashable(x: t.Any) -> t.Any:
291301 if not unexpected_rows .empty :
292302 error_msg += f"\n \n Unexpected rows:\n \n { unexpected_rows } "
293303
294- e . args = (error_msg , )
304+ args . append (error_msg )
295305 else :
296- diff = expected .compare (actual ).rename (columns = {"self" : "exp" , "other" : "act" })
297- e .args = (f"Data mismatch (exp: expected, act: actual)\n \n { diff } " ,)
306+ diff = expected .compare (actual ).rename (
307+ columns = {"self" : "Expected" , "other" : "Actual" }
308+ )
309+
310+ if not self .rich_output :
311+ args .append (f"Data mismatch\n \n { diff } " )
312+ elif self .verbosity == Verbosity .DEFAULT :
313+ args .append (df_to_table ("Data mismatch" , diff ))
314+ else :
315+ from pandas import MultiIndex
316+
317+ levels = t .cast (MultiIndex , diff .columns ).levels [0 ]
318+ for col in levels :
319+ col_diff = diff [col ]
320+ if not col_diff .empty :
321+ table = df_to_table (
322+ f"[bold red]Column '{ col } ' mismatch[/bold red]" , col_diff
323+ )
324+ args .append (table )
325+
326+ # Show summary statistics
327+ summary_tree = Tree ("[bold][summary]Summary:[/summary]" )
328+ summary_tree .add (f"Total differences: { len (diff )} \n " )
329+ summary_tree .add (f"Different columns: { len (levels )} \n " )
330+ args .append (summary_tree )
331+
332+ e .args = (* args ,)
298333
299334 raise e
300335
@@ -316,6 +351,7 @@ def create_test(
316351 preserve_fixtures : bool = False ,
317352 default_catalog : str | None = None ,
318353 concurrency : bool = False ,
354+ verbosity : Verbosity = Verbosity .DEFAULT ,
319355 ) -> t .Optional [ModelTest ]:
320356 """Create a SqlModelTest or a PythonModelTest.
321357
@@ -361,6 +397,7 @@ def create_test(
361397 preserve_fixtures ,
362398 default_catalog ,
363399 concurrency ,
400+ verbosity ,
364401 )
365402 except Exception as e :
366403 raise TestError (f"Failed to create test { test_name } ({ path } )\n { str (e )} " )
@@ -676,6 +713,8 @@ def __init__(
676713 preserve_fixtures : bool = False ,
677714 default_catalog : str | None = None ,
678715 concurrency : bool = False ,
716+ verbosity : Verbosity = Verbosity .DEFAULT ,
717+ rich_output : bool = True ,
679718 ) -> None :
680719 """PythonModelTest encapsulates a unit test for a Python model.
681720
@@ -702,6 +741,8 @@ def __init__(
702741 preserve_fixtures ,
703742 default_catalog ,
704743 concurrency ,
744+ verbosity ,
745+ rich_output ,
705746 )
706747
707748 self .context = TestExecutionContext (
@@ -926,3 +967,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
926967 return {k : _normalize_df_value (v ) for k , v in zip (value ["key" ], value ["value" ])}
927968 return {k : _normalize_df_value (v ) for k , v in value .items ()}
928969 return value
970+
971+
972+ def df_to_table (
973+ header : str ,
974+ df : pd .DataFrame ,
975+ show_index : bool = True ,
976+ index_name : str = "Row" ,
977+ ) -> Table :
978+ """Convert a pandas.DataFrame obj into a rich.Table obj.
979+ Args:
980+ df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
981+ rich_table (Table): A rich Table that should be populated by the DataFrame values.
982+ show_index (bool): Add a column with a row count to the table. Defaults to True.
983+ index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
984+ Returns:
985+ Table: The rich Table instance passed, populated with the DataFrame values."""
986+
987+ rich_table = Table (title = f"[bold red]{ header } [/bold red]" , show_lines = True , min_width = 60 )
988+ if show_index :
989+ index_name = str (index_name ) if index_name else ""
990+ rich_table .add_column (index_name )
991+
992+ for column in df .columns :
993+ column_name = column if isinstance (column , str ) else ": " .join (str (col ) for col in column )
994+ if "expected" in column_name .lower ():
995+ column_name = f"[green]{ column_name } [/green]"
996+ else :
997+ column_name = f"[red]{ column_name } [/red]"
998+
999+ rich_table .add_column (Align .center (column_name ))
1000+
1001+ for index , value_list in enumerate (df .values .tolist ()):
1002+ row = [str (index )] if show_index else []
1003+ row += [str (x ) for x in value_list ]
1004+ center = [Align .center (x ) for x in row ]
1005+ rich_table .add_row (* center )
1006+
1007+ return rich_table
0 commit comments