Skip to content

Commit a54f9fe

Browse files
committed
Feat: Refactor unit test output
1 parent 9e5bf54 commit a54f9fe

16 files changed

Lines changed: 365 additions & 193 deletions

File tree

sqlmesh/core/console.py

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import uuid
88
import logging
99
import textwrap
10+
from itertools import zip_longest
1011
from pathlib import Path
1112
from hyperscript import h
1213
from rich.console import Console as RichConsole
@@ -26,6 +27,7 @@
2627
from rich.tree import Tree
2728
from sqlglot import exp
2829

30+
from sqlmesh.core.test.result import ModelTextTestResult
2931
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
3032
from sqlmesh.core.linter.rule import RuleViolation
3133
from sqlmesh.core.model import Model
@@ -46,6 +48,7 @@
4648
NodeAuditsErrors,
4749
format_destructive_change_msg,
4850
)
51+
from sqlmesh.utils.rich import strip_ansi_codes
4952

5053
if t.TYPE_CHECKING:
5154
import ipywidgets as widgets
@@ -316,6 +319,12 @@ def log_destructive_change(
316319
"""Display a destructive change error or warning to the user."""
317320

318321

322+
class UnitTestConsole(abc.ABC):
323+
@abc.abstractmethod
324+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
325+
"""Display the test result and output."""
326+
327+
319328
class Console(
320329
PlanBuilderConsole,
321330
LinterConsole,
@@ -327,6 +336,7 @@ class Console(
327336
DifferenceConsole,
328337
TableDiffConsole,
329338
BaseConsole,
339+
UnitTestConsole,
330340
abc.ABC,
331341
):
332342
"""Abstract base class for defining classes used for displaying information to the user and also interact
@@ -461,9 +471,7 @@ def plan(
461471
"""
462472

463473
@abc.abstractmethod
464-
def log_test_results(
465-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
466-
) -> None:
474+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
467475
"""Display the test result and output.
468476
469477
Args:
@@ -496,6 +504,10 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
496504
def loading_stop(self, id: uuid.UUID) -> None:
497505
"""Stop loading for the given id."""
498506

507+
@abc.abstractmethod
508+
def log_unit_test_results(self, result: ModelTextTestResult) -> None:
509+
"""Print the unit test results."""
510+
499511

500512
class NoopConsole(Console):
501513
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
@@ -668,9 +680,7 @@ def plan(
668680
if auto_apply:
669681
plan_builder.apply()
670682

671-
def log_test_results(
672-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
673-
) -> None:
683+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
674684
pass
675685

676686
def show_sql(self, sql: str) -> None:
@@ -777,6 +787,9 @@ def start_destroy(self) -> bool:
777787
def stop_destroy(self, success: bool = True) -> None:
778788
pass
779789

790+
def log_unit_test_results(self, result: ModelTextTestResult) -> None:
791+
pass
792+
780793

781794
def make_progress_bar(
782795
message: str,
@@ -1952,10 +1965,12 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
19521965
):
19531966
plan_builder.apply()
19541967

1955-
def log_test_results(
1956-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
1957-
) -> None:
1968+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
19581969
divider_length = 70
1970+
1971+
self.log_unit_test_results(result)
1972+
self._print("\n")
1973+
19591974
if result.wasSuccessful():
19601975
self._print("=" * divider_length)
19611976
self._print(
@@ -1972,9 +1987,13 @@ def log_test_results(
19721987
)
19731988
for test, _ in result.failures + result.errors:
19741989
if isinstance(test, ModelTest):
1975-
self._print(f"Failure Test: {test.model.name} {test.test_name}")
1990+
self._print(f"Failure Test: {test.path}::{test.test_name}")
19761991
self._print("=" * divider_length)
1977-
self._print(output)
1992+
1993+
def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
1994+
with self.console.capture() as capture:
1995+
self.log_unit_test_results(result)
1996+
return strip_ansi_codes(capture.get())
19781997

19791998
def show_sql(self, sql: str) -> None:
19801999
self._print(Syntax(sql, "sql", word_wrap=True), crop=False)
@@ -2492,6 +2511,56 @@ def show_linter_violations(
24922511
else:
24932512
self.log_warning(msg)
24942513

2514+
def log_unit_test_results(self, result: ModelTextTestResult) -> None:
2515+
tests_run = result.testsRun
2516+
errors = result.errors
2517+
failures = result.failures
2518+
skipped = result.skipped
2519+
is_success = not (errors or failures)
2520+
2521+
infos = []
2522+
if failures:
2523+
infos.append(f"failures={len(failures)}")
2524+
if errors:
2525+
infos.append(f"errors={len(errors)}")
2526+
if skipped:
2527+
infos.append(f"skipped={skipped}")
2528+
2529+
self._print("\n", end="")
2530+
2531+
for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
2532+
failures, result.failure_tables
2533+
):
2534+
self._print(unittest.TextTestResult.separator1)
2535+
self._print(f"FAIL: {test_case}")
2536+
2537+
if test_description := test_case.shortDescription():
2538+
self._print(test_description)
2539+
self._print(f"{unittest.TextTestResult.separator2}")
2540+
2541+
if not test_failure_tables:
2542+
self._print(failure)
2543+
else:
2544+
for failure_table in test_failure_tables:
2545+
self._print(failure_table)
2546+
self._print("\n", end="")
2547+
2548+
for test_case, error in errors:
2549+
self._print(unittest.TextTestResult.separator1)
2550+
self._print(f"ERROR: {test_case}")
2551+
self._print(f"{unittest.TextTestResult.separator2}")
2552+
self._print(error)
2553+
2554+
# Output final report
2555+
self._print(unittest.TextTestResult.separator2)
2556+
test_duration_msg = f" in {result.duration:.3f}s" if result.duration else ""
2557+
self._print(
2558+
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
2559+
)
2560+
self._print(
2561+
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
2562+
)
2563+
24952564

24962565
def _cells_match(x: t.Any, y: t.Any) -> bool:
24972566
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
@@ -2763,9 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
27632832
)
27642833
self.display(radio)
27652834

2766-
def log_test_results(
2767-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
2768-
) -> None:
2835+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
27692836
import ipywidgets as widgets
27702837

27712838
divider_length = 70
@@ -2781,12 +2848,14 @@ def log_test_results(
27812848
h(
27822849
"span",
27832850
{"style": {**shared_style, **success_color}},
2784-
f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}",
2851+
f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}",
27852852
)
27862853
)
27872854
footer = str(h("span", {"style": shared_style}, "=" * divider_length))
27882855
self.display(widgets.HTML("<br>".join([header, message, footer])))
27892856
else:
2857+
output = self._captured_unit_test_results(result)
2858+
27902859
fail_color = {"color": "#db3737"}
27912860
fail_shared_style = {**shared_style, **fail_color}
27922861
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
@@ -3137,21 +3206,22 @@ def stop_promotion_progress(self, success: bool = True) -> None:
31373206
def log_success(self, message: str) -> None:
31383207
self._print(message)
31393208

3140-
def log_test_results(
3141-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3142-
) -> None:
3209+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
31433210
if result.wasSuccessful():
31443211
self._print(
31453212
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
31463213
)
31473214
else:
3215+
self._print("```")
3216+
self.log_unit_test_results(result)
3217+
self._print("```\n\n")
3218+
31483219
self._print(
31493220
f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n"
31503221
)
31513222
for test, _ in result.failures + result.errors:
31523223
if isinstance(test, ModelTest):
31533224
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")
3154-
self._print(f"```{output}```\n\n")
31553225

31563226
def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
31573227
if snapshot_names:
@@ -3530,9 +3600,7 @@ def show_model_difference_summary(
35303600
for modified in context_diff.modified_snapshots:
35313601
self._write(f" Modified: {modified}")
35323602

3533-
def log_test_results(
3534-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3535-
) -> None:
3603+
def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None:
35363604
self._write("Test Results:", result)
35373605

35383606
def show_sql(self, sql: str) -> None:

sqlmesh/core/context.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import time
4141
import traceback
4242
import typing as t
43-
import unittest.result
4443
from functools import cached_property
4544
from io import StringIO
4645
from itertools import chain
@@ -2044,6 +2043,7 @@ def test(
20442043
verbosity: Verbosity = Verbosity.DEFAULT,
20452044
preserve_fixtures: bool = False,
20462045
stream: t.Optional[t.TextIO] = None,
2046+
log_results: bool = True,
20472047
) -> ModelTextTestResult:
20482048
"""Discover and run model tests"""
20492049
if verbosity >= Verbosity.VERBOSE:
@@ -2053,7 +2053,7 @@ def test(
20532053

20542054
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
20552055

2056-
return run_tests(
2056+
result = run_tests(
20572057
model_test_metadata=test_meta,
20582058
models=self._models,
20592059
config=self.config,
@@ -2066,6 +2066,14 @@ def test(
20662066
default_catalog_dialect=self.config.dialect or "",
20672067
)
20682068

2069+
if log_results:
2070+
self.console.log_test_results(
2071+
result,
2072+
self.test_connection_config._engine_adapter.DIALECT,
2073+
)
2074+
2075+
return result
2076+
20692077
@python_api_analytics
20702078
def audit(
20712079
self,
@@ -2488,28 +2496,20 @@ def import_state(self, input_file: Path, clear: bool = False, confirm: bool = Tr
24882496

24892497
def _run_tests(
24902498
self, verbosity: Verbosity = Verbosity.DEFAULT
2491-
) -> t.Tuple[unittest.result.TestResult, str]:
2499+
) -> t.Tuple[ModelTextTestResult, str]:
24922500
test_output_io = StringIO()
2493-
result = self.test(stream=test_output_io, verbosity=verbosity)
2501+
result = self.test(stream=test_output_io, verbosity=verbosity, log_results=False)
24942502
return result, test_output_io.getvalue()
24952503

2496-
def _run_plan_tests(
2497-
self, skip_tests: bool = False
2498-
) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]:
2504+
def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]:
24992505
if not skip_tests:
2500-
result, test_output = self._run_tests()
2501-
if result.testsRun > 0:
2502-
self.console.log_test_results(
2503-
result,
2504-
test_output,
2505-
self.test_connection_config._engine_adapter.DIALECT,
2506-
)
2506+
result = self.test()
25072507
if not result.wasSuccessful():
25082508
raise PlanError(
25092509
"Cannot generate plan due to failing test(s). Fix test(s) and run again."
25102510
)
2511-
return result, test_output
2512-
return None, None
2511+
return result
2512+
return None
25132513

25142514
@property
25152515
def _model_tables(self) -> t.Dict[str, str]:

0 commit comments

Comments
 (0)