77import uuid
88import logging
99import textwrap
10+ from itertools import zip_longest
1011from pathlib import Path
1112from hyperscript import h
1213from rich .console import Console as RichConsole
2627from rich .tree import Tree
2728from sqlglot import exp
2829
30+ from sqlmesh .core .test .result import ModelTextTestResult
2931from sqlmesh .core .environment import EnvironmentNamingInfo , EnvironmentSummary
3032from sqlmesh .core .linter .rule import RuleViolation
3133from sqlmesh .core .model import Model
4648 NodeAuditsErrors ,
4749 format_destructive_change_msg ,
4850)
51+ from sqlmesh .utils .rich import strip_ansi_codes
4952
5053if t .TYPE_CHECKING :
5154 import ipywidgets as widgets
@@ -316,6 +319,12 @@ def log_destructive_change(
316319 """Display a destructive change error or warning to the user."""
317320
318321
322+ class UnitTestConsole (abc .ABC ):
323+ @abc .abstractmethod
324+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
325+ """Display the test result and output."""
326+
327+
319328class Console (
320329 PlanBuilderConsole ,
321330 LinterConsole ,
@@ -327,6 +336,7 @@ class Console(
327336 DifferenceConsole ,
328337 TableDiffConsole ,
329338 BaseConsole ,
339+ UnitTestConsole ,
330340 abc .ABC ,
331341):
332342 """Abstract base class for defining classes used for displaying information to the user and also interact
@@ -461,9 +471,7 @@ def plan(
461471 """
462472
463473 @abc .abstractmethod
464- def log_test_results (
465- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
466- ) -> None :
474+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
467475 """Display the test result and output.
468476
469477 Args:
@@ -496,6 +504,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
496504 def loading_stop (self , id : uuid .UUID ) -> None :
497505 """Stop loading for the given id."""
498506
507+ @abc .abstractmethod
508+ def log_unit_test_results (self , result : ModelTextTestResult ) -> None :
509+ """Print the unit test results."""
510+
499511
500512class NoopConsole (Console ):
501513 def start_plan_evaluation (self , plan : EvaluatablePlan ) -> None :
@@ -668,9 +680,7 @@ def plan(
668680 if auto_apply :
669681 plan_builder .apply ()
670682
671- def log_test_results (
672- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
673- ) -> None :
683+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
674684 pass
675685
676686 def show_sql (self , sql : str ) -> None :
@@ -777,6 +787,9 @@ def start_destroy(self) -> bool:
777787 def stop_destroy (self , success : bool = True ) -> None :
778788 pass
779789
790+ def log_unit_test_results (self , result : ModelTextTestResult ) -> None :
791+ pass
792+
780793
781794def make_progress_bar (
782795 message : str ,
@@ -1952,10 +1965,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
19521965 ):
19531966 plan_builder .apply ()
19541967
1955- def log_test_results (
1956- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
1957- ) -> None :
1968+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
19581969 divider_length = 70
1970+
1971+ self .log_unit_test_results (result )
1972+ self ._print ("\n " )
1973+
19591974 if result .wasSuccessful ():
19601975 self ._print ("=" * divider_length )
19611976 self ._print (
@@ -1972,9 +1987,13 @@ def log_test_results(
19721987 )
19731988 for test , _ in result .failures + result .errors :
19741989 if isinstance (test , ModelTest ):
1975- self ._print (f"Failure Test: { test .model . name } { test .test_name } " )
1990+ self ._print (f"Failure Test: { test .path } :: { test .test_name } " )
19761991 self ._print ("=" * divider_length )
1977- self ._print (output )
1992+
1993+ def _captured_unit_test_results (self , result : ModelTextTestResult ) -> str :
1994+ with self .console .capture () as capture :
1995+ self .log_unit_test_results (result )
1996+ return strip_ansi_codes (capture .get ())
19781997
19791998 def show_sql (self , sql : str ) -> None :
19801999 self ._print (Syntax (sql , "sql" , word_wrap = True ), crop = False )
@@ -2492,6 +2511,56 @@ def show_linter_violations(
24922511 else :
24932512 self .log_warning (msg )
24942513
2514+ def log_unit_test_results (self , result : ModelTextTestResult ) -> None :
2515+ tests_run = result .testsRun
2516+ errors = result .errors
2517+ failures = result .failures
2518+ skipped = result .skipped
2519+ is_success = not (errors or failures )
2520+
2521+ infos = []
2522+ if failures :
2523+ infos .append (f"failures={ len (failures )} " )
2524+ if errors :
2525+ infos .append (f"errors={ len (errors )} " )
2526+ if skipped :
2527+ infos .append (f"skipped={ skipped } " )
2528+
2529+ self ._print ("\n " , end = "" )
2530+
2531+ for (test_case , failure ), test_failure_tables in zip_longest ( # type: ignore
2532+ failures , result .failure_tables
2533+ ):
2534+ self ._print (unittest .TextTestResult .separator1 )
2535+ self ._print (f"FAIL: { test_case } " )
2536+
2537+ if test_description := test_case .shortDescription ():
2538+ self ._print (test_description )
2539+ self ._print (f"{ unittest .TextTestResult .separator2 } " )
2540+
2541+ if not test_failure_tables :
2542+ self ._print (failure )
2543+ else :
2544+ for failure_table in test_failure_tables :
2545+ self ._print (failure_table )
2546+ self ._print ("\n " , end = "" )
2547+
2548+ for test_case , error in errors :
2549+ self ._print (unittest .TextTestResult .separator1 )
2550+ self ._print (f"ERROR: { test_case } " )
2551+ self ._print (f"{ unittest .TextTestResult .separator2 } " )
2552+ self ._print (error )
2553+
2554+ # Output final report
2555+ self ._print (unittest .TextTestResult .separator2 )
2556+ test_duration_msg = f" in { result .duration :.3f} s" if result .duration else ""
2557+ self ._print (
2558+ f"\n Ran { tests_run } { 'tests' if tests_run > 1 else 'test' } { test_duration_msg } \n "
2559+ )
2560+ self ._print (
2561+ f"{ 'OK' if is_success else 'FAILED' } { ' (' + ', ' .join (infos ) + ')' if infos else '' } "
2562+ )
2563+
24952564
24962565def _cells_match (x : t .Any , y : t .Any ) -> bool :
24972566 """Helper function to compare two cells and returns true if they're equal, handling array objects."""
@@ -2763,9 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
27632832 )
27642833 self .display (radio )
27652834
2766- def log_test_results (
2767- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
2768- ) -> None :
2835+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
27692836 import ipywidgets as widgets
27702837
27712838 divider_length = 70
@@ -2781,12 +2848,14 @@ def log_test_results(
27812848 h (
27822849 "span" ,
27832850 {"style" : {** shared_style , ** success_color }},
2784- f"Successfully Ran { str (result .testsRun )} Tests Against { target_dialect } " ,
2851+ f"Successfully Ran { str (result .testsRun )} tests against { target_dialect } " ,
27852852 )
27862853 )
27872854 footer = str (h ("span" , {"style" : shared_style }, "=" * divider_length ))
27882855 self .display (widgets .HTML ("<br>" .join ([header , message , footer ])))
27892856 else :
2857+ output = self ._captured_unit_test_results (result )
2858+
27902859 fail_color = {"color" : "#db3737" }
27912860 fail_shared_style = {** shared_style , ** fail_color }
27922861 header = str (h ("span" , {"style" : fail_shared_style }, "-" * divider_length ))
@@ -3137,21 +3206,22 @@ def stop_promotion_progress(self, success: bool = True) -> None:
31373206 def log_success (self , message : str ) -> None :
31383207 self ._print (message )
31393208
3140- def log_test_results (
3141- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
3142- ) -> None :
3209+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
31433210 if result .wasSuccessful ():
31443211 self ._print (
31453212 f"**Successfully Ran `{ str (result .testsRun )} ` Tests Against `{ target_dialect } `**\n \n "
31463213 )
31473214 else :
3215+ self ._print ("```" )
3216+ self .log_unit_test_results (result )
3217+ self ._print ("```\n \n " )
3218+
31483219 self ._print (
31493220 f"**Num Successful Tests: { result .testsRun - len (result .failures ) - len (result .errors )} **\n \n "
31503221 )
31513222 for test , _ in result .failures + result .errors :
31523223 if isinstance (test , ModelTest ):
31533224 self ._print (f"* Failure Test: `{ test .model .name } ` - `{ test .test_name } `\n \n " )
3154- self ._print (f"```{ output } ```\n \n " )
31553225
31563226 def log_skipped_models (self , snapshot_names : t .Set [str ]) -> None :
31573227 if snapshot_names :
@@ -3530,9 +3600,7 @@ def show_model_difference_summary(
35303600 for modified in context_diff .modified_snapshots :
35313601 self ._write (f" Modified: { modified } " )
35323602
3533- def log_test_results (
3534- self , result : unittest .result .TestResult , output : t .Optional [str ], target_dialect : str
3535- ) -> None :
3603+ def log_test_results (self , result : ModelTextTestResult , target_dialect : str ) -> None :
35363604 self ._write ("Test Results:" , result )
35373605
35383606 def show_sql (self , sql : str ) -> None :
0 commit comments