1010from pathlib import Path
1111from unittest .mock import patch
1212
13+ from rich .table import Table
14+ from rich .align import Align
15+
16+ import numpy as np
17+ import pandas as pd
1318from io import StringIO
1419from sqlglot import Dialect , exp
1520from sqlglot .optimizer .annotate_types import annotate_types
2429from sqlmesh .utils .date import date_dict , pandas_timestamp_to_pydatetime , to_datetime
2530from sqlmesh .utils .errors import ConfigError , TestError
2631from sqlmesh .utils .yaml import load as yaml_load
32+ from sqlmesh .utils import Verbosity
2733
2834if t .TYPE_CHECKING :
2935 import pandas as pd
@@ -60,6 +66,8 @@ def __init__(
6066 preserve_fixtures : bool = False ,
6167 default_catalog : str | None = None ,
6268 concurrency : bool = False ,
69+ verbosity : Verbosity = Verbosity .DEFAULT ,
70+ rich_output : bool = True ,
6371 ) -> None :
6472 """ModelTest encapsulates a unit test for a model.
6573
@@ -83,6 +91,8 @@ def __init__(
8391 self .default_catalog = default_catalog
8492 self .dialect = dialect
8593 self .concurrency = concurrency
94+ self .verbosity = verbosity
95+ self .rich_output = rich_output
8696
8797 self ._fixture_table_cache : t .Dict [str , exp .Table ] = {}
8898 self ._normalized_column_name_cache : t .Dict [str , str ] = {}
@@ -281,6 +291,7 @@ def _to_hashable(x: t.Any) -> t.Any:
281291 check_like = True , # Ignore column order
282292 )
283293 except AssertionError as e :
294+ args : t .List [t .Any ] = []
284295 if expected .shape != actual .shape :
285296 _raise_if_unexpected_columns (expected .columns , actual .columns )
286297
@@ -294,10 +305,29 @@ def _to_hashable(x: t.Any) -> t.Any:
294305 if not unexpected_rows .empty :
295306 error_msg += f"\n \n Unexpected rows:\n \n { unexpected_rows } "
296307
297- e . args = (error_msg , )
308+ args . append (error_msg )
298309 else :
299- diff = expected .compare (actual ).rename (columns = {"self" : "exp" , "other" : "act" })
300- e .args = (f"Data mismatch (exp: expected, act: actual)\n \n { diff } " ,)
310+ diff = expected .compare (actual ).rename (
311+ columns = {"self" : "Expected" , "other" : "Actual" }
312+ )
313+
314+ if not self .rich_output :
315+ args .append (f"Data mismatch\n \n { diff } " )
316+ elif self .verbosity == Verbosity .DEFAULT :
317+ args .append (df_to_table ("Data mismatch" , diff ))
318+ else :
319+ from pandas import MultiIndex
320+
321+ levels = t .cast (MultiIndex , diff .columns ).levels [0 ]
322+ for col in levels :
323+ col_diff = diff [col ]
324+ if not col_diff .empty :
325+ table = df_to_table (
326+ f"[bold red]Column '{ col } ' mismatch[/bold red]" , col_diff
327+ )
328+ args .append (table )
329+
330+ e .args = (* args ,)
301331
302332 raise e
303333
@@ -319,6 +349,7 @@ def create_test(
319349 preserve_fixtures : bool = False ,
320350 default_catalog : str | None = None ,
321351 concurrency : bool = False ,
352+ verbosity : Verbosity = Verbosity .DEFAULT ,
322353 ) -> t .Optional [ModelTest ]:
323354 """Create a SqlModelTest or a PythonModelTest.
324355
@@ -364,6 +395,7 @@ def create_test(
364395 preserve_fixtures ,
365396 default_catalog ,
366397 concurrency ,
398+ verbosity ,
367399 )
368400 except Exception as e :
369401 raise TestError (f"Failed to create test { test_name } ({ path } )\n { str (e )} " )
@@ -683,6 +715,8 @@ def __init__(
683715 preserve_fixtures : bool = False ,
684716 default_catalog : str | None = None ,
685717 concurrency : bool = False ,
718+ verbosity : Verbosity = Verbosity .DEFAULT ,
719+ rich_output : bool = True ,
686720 ) -> None :
687721 """PythonModelTest encapsulates a unit test for a Python model.
688722
@@ -709,6 +743,8 @@ def __init__(
709743 preserve_fixtures ,
710744 default_catalog ,
711745 concurrency ,
746+ verbosity ,
747+ rich_output ,
712748 )
713749
714750 self .context = TestExecutionContext (
@@ -942,3 +978,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
942978 return {k : _normalize_df_value (v ) for k , v in zip (value ["key" ], value ["value" ])}
943979 return {k : _normalize_df_value (v ) for k , v in value .items ()}
944980 return value
981+
982+
983+ def df_to_table (
984+ header : str ,
985+ df : pd .DataFrame ,
986+ show_index : bool = True ,
987+ index_name : str = "Row" ,
988+ ) -> Table :
989+ """Convert a pandas.DataFrame obj into a rich.Table obj.
990+ Args:
991+ df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
992+ rich_table (Table): A rich Table that should be populated by the DataFrame values.
993+ show_index (bool): Add a column with a row count to the table. Defaults to True.
994+ index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
995+ Returns:
996+ Table: The rich Table instance passed, populated with the DataFrame values."""
997+
998+ rich_table = Table (title = f"[bold red]{ header } [/bold red]" , show_lines = True , min_width = 60 )
999+ if show_index :
1000+ index_name = str (index_name ) if index_name else ""
1001+ rich_table .add_column (index_name )
1002+
1003+ for column in df .columns :
1004+ column_name = column if isinstance (column , str ) else ": " .join (str (col ) for col in column )
1005+ if "expected" in column_name .lower ():
1006+ column_name = f"[green]{ column_name } [/green]"
1007+ else :
1008+ column_name = f"[red]{ column_name } [/red]"
1009+
1010+ rich_table .add_column (Align .center (column_name ))
1011+
1012+ for index , value_list in enumerate (df .values .tolist ()):
1013+ row = [str (index )] if show_index else []
1014+ row += [str (x ) for x in value_list ]
1015+ center = [Align .center (x ) for x in row ]
1016+ rich_table .add_row (* center )
1017+
1018+ return rich_table
0 commit comments