1010from pathlib import Path
1111from unittest .mock import patch
1212
13+ from rich .table import Table
14+ from rich .tree import Tree
15+ from rich .align import Align
16+
1317import numpy as np
1418import pandas as pd
1519from io import StringIO
2731from sqlmesh .utils .date import date_dict , pandas_timestamp_to_pydatetime , to_datetime
2832from sqlmesh .utils .errors import ConfigError , TestError
2933from sqlmesh .utils .yaml import load as yaml_load
34+ from sqlmesh .utils import Verbosity
3035
3136if t .TYPE_CHECKING :
3237 from sqlglot .dialects .dialect import DialectType
@@ -61,6 +66,8 @@ def __init__(
6166 preserve_fixtures : bool = False ,
6267 default_catalog : str | None = None ,
6368 concurrency : bool = False ,
69+ verbosity : Verbosity = Verbosity .DEFAULT ,
70+ rich_output : bool = True ,
6471 ) -> None :
6572 """ModelTest encapsulates a unit test for a model.
6673
@@ -84,6 +91,8 @@ def __init__(
8491 self .default_catalog = default_catalog
8592 self .dialect = dialect
8693 self .concurrency = concurrency
94+ self .verbosity = verbosity
95+ self .rich_output = rich_output
8796
8897 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8998 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -278,6 +287,7 @@ def _to_hashable(x: t.Any) -> t.Any:
278287 check_like = True , # Ignore column order
279288 )
280289 except AssertionError as e :
290+ args : t .List [t .Any ] = []
281291 if expected .shape != actual .shape :
282292 _raise_if_unexpected_columns (expected .columns , actual .columns )
283293
@@ -291,10 +301,29 @@ def _to_hashable(x: t.Any) -> t.Any:
291301 if not unexpected_rows .empty :
292302 error_msg += f"\n \n Unexpected rows:\n \n { unexpected_rows } "
293303
294- e . args = (error_msg , )
304+ args . append (error_msg )
295305 else :
296- diff = expected .compare (actual ).rename (columns = {"self" : "exp" , "other" : "act" })
297- e .args = (f"Data mismatch (exp: expected, act: actual)\n \n { diff } " ,)
306+ diff = expected .compare (actual ).rename (
307+ columns = {"self" : "Expected" , "other" : "Actual" }
308+ )
309+
310+ if not self .rich_output :
311+ args .append (f"Data mismatch\n \n { diff } " )
312+ elif self .verbosity == Verbosity .DEFAULT :
313+ args .append (df_to_table ("Data mismatch" , diff ))
314+ else :
315+ from pandas import MultiIndex
316+
317+ levels = t .cast (MultiIndex , diff .columns ).levels [0 ]
318+ for col in levels :
319+ col_diff = diff [col ]
320+ if not col_diff .empty :
321+ table = df_to_table (
322+ f"[bold red]Column '{ col } ' mismatch[/bold red]" , col_diff
323+ )
324+ args .append (table )
325+
326+ e .args = (* args ,)
298327
299328 raise e
300329
@@ -316,6 +345,7 @@ def create_test(
316345 preserve_fixtures : bool = False ,
317346 default_catalog : str | None = None ,
318347 concurrency : bool = False ,
348+ verbosity : Verbosity = Verbosity .DEFAULT ,
319349 ) -> t .Optional [ModelTest ]:
320350 """Create a SqlModelTest or a PythonModelTest.
321351
@@ -361,6 +391,7 @@ def create_test(
361391 preserve_fixtures ,
362392 default_catalog ,
363393 concurrency ,
394+ verbosity ,
364395 )
365396 except Exception as e :
366397 raise TestError (f"Failed to create test { test_name } ({ path } )\n { str (e )} " )
@@ -676,6 +707,8 @@ def __init__(
676707 preserve_fixtures : bool = False ,
677708 default_catalog : str | None = None ,
678709 concurrency : bool = False ,
710+ verbosity : Verbosity = Verbosity .DEFAULT ,
711+ rich_output : bool = True ,
679712 ) -> None :
680713 """PythonModelTest encapsulates a unit test for a Python model.
681714
@@ -702,6 +735,8 @@ def __init__(
702735 preserve_fixtures ,
703736 default_catalog ,
704737 concurrency ,
738+ verbosity ,
739+ rich_output ,
705740 )
706741
707742 self .context = TestExecutionContext (
@@ -926,3 +961,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
926961 return {k : _normalize_df_value (v ) for k , v in zip (value ["key" ], value ["value" ])}
927962 return {k : _normalize_df_value (v ) for k , v in value .items ()}
928963 return value
964+
965+
966+ def df_to_table (
967+ header : str ,
968+ df : pd .DataFrame ,
969+ show_index : bool = True ,
970+ index_name : str = "Row" ,
971+ ) -> Table :
972+ """Convert a pandas.DataFrame obj into a rich.Table obj.
973+ Args:
974+ df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
975+ rich_table (Table): A rich Table that should be populated by the DataFrame values.
976+ show_index (bool): Add a column with a row count to the table. Defaults to True.
977+ index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
978+ Returns:
979+ Table: The rich Table instance passed, populated with the DataFrame values."""
980+
981+ rich_table = Table (title = f"[bold red]{ header } [/bold red]" , show_lines = True , min_width = 60 )
982+ if show_index :
983+ index_name = str (index_name ) if index_name else ""
984+ rich_table .add_column (index_name )
985+
986+ for column in df .columns :
987+ column_name = column if isinstance (column , str ) else ": " .join (str (col ) for col in column )
988+ if "expected" in column_name .lower ():
989+ column_name = f"[green]{ column_name } [/green]"
990+ else :
991+ column_name = f"[red]{ column_name } [/red]"
992+
993+ rich_table .add_column (Align .center (column_name ))
994+
995+ for index , value_list in enumerate (df .values .tolist ()):
996+ row = [str (index )] if show_index else []
997+ row += [str (x ) for x in value_list ]
998+ center = [Align .center (x ) for x in row ]
999+ rich_table .add_row (* center )
1000+
1001+ return rich_table
0 commit comments