Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 80 additions & 7 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from rich.tree import Tree
from sqlglot import exp

from sqlmesh.core.test.result import ModelTextTestResult
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary
from sqlmesh.core.linter.rule import RuleViolation
from sqlmesh.core.model import Model
Expand Down Expand Up @@ -462,7 +463,7 @@ def plan(

@abc.abstractmethod
def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
"""Display the test result and output.

Expand Down Expand Up @@ -496,6 +497,12 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
def loading_stop(self, id: uuid.UUID) -> None:
"""Stop loading for the given id."""

@abc.abstractmethod
def log_unit_test_results(
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
) -> None:
"""Print the unit test results."""


class NoopConsole(Console):
def start_plan_evaluation(self, plan: EvaluatablePlan) -> None:
Expand Down Expand Up @@ -669,7 +676,7 @@ def plan(
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
pass

Expand Down Expand Up @@ -777,6 +784,11 @@ def start_destroy(self) -> bool:
def stop_destroy(self, success: bool = True) -> None:
pass

def log_unit_test_results(
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
) -> None:
pass


def make_progress_bar(
message: str,
Expand Down Expand Up @@ -1953,9 +1965,13 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
plan_builder.apply()

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
divider_length = 70

self.log_unit_test_results(result)

self._print("\n")
if result.wasSuccessful():
self._print("=" * divider_length)
self._print(
Expand All @@ -1972,7 +1988,7 @@ def log_test_results(
)
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"Failure Test: {test.model.name} {test.test_name}")
self._print(f"Failure Test: {test.path}::{test.test_name}")
self._print("=" * divider_length)
self._print(output)

Expand Down Expand Up @@ -2492,6 +2508,58 @@ def show_linter_violations(
else:
self.log_warning(msg)

def log_unit_test_results(
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
) -> None:
tests_run = result.testsRun
errors = result.errors
failures = result.original_failures
skipped = result.skipped

is_success = not (errors or failures)

infos = []
if failures:
infos.append(f"failures={len(failures)}")
if errors:
infos.append(f"errors={len(errors)}")
if skipped:
infos.append(f"skipped={skipped}")

self._print("\n", end="")

for test_case, failure in failures:
self._print(unittest.TextTestResult.separator1)
self._print(f"FAIL: {test_case}")

if test_description := test_case.shortDescription():
self._print(test_description)
self._print(f"{unittest.TextTestResult.separator2}")

if exception := failure[1]:
for i, arg in enumerate(exception.args):
arg = f"Exception: {arg}" if isinstance(arg, str) else arg
self._print(arg)

if i < len(exception.args) - 1:
self._print("\n")

for test_case, error in errors:
self._print(unittest.TextTestResult.separator1)
self._print(f"ERROR: {test_case}")
self._print(f"{unittest.TextTestResult.separator2}")
self._print(error)

# Output final report
self._print(unittest.TextTestResult.separator2)
test_duration_msg = f" in {test_duration:.3f}s" if test_duration else ""
self._print(
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
)
self._print(
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
)


def _cells_match(x: t.Any, y: t.Any) -> bool:
"""Helper function to compare two cells and returns true if they're equal, handling array objects."""
Expand Down Expand Up @@ -2764,7 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
self.display(radio)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
import ipywidgets as widgets

Expand Down Expand Up @@ -3138,8 +3206,12 @@ def log_success(self, message: str) -> None:
self._print(message)

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
# self._print("```")
self.log_unit_test_results(result)
# self._print("```\n\n")

if result.wasSuccessful():
self._print(
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
Expand All @@ -3151,6 +3223,7 @@ def log_test_results(
for test, _ in result.failures + result.errors:
if isinstance(test, ModelTest):
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")

self._print(f"```{output}```\n\n")

def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
Expand Down Expand Up @@ -3531,7 +3604,7 @@ def show_model_difference_summary(
self._write(f" Modified: {modified}")

def log_test_results(
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
) -> None:
self._write("Test Results:", result)

Expand Down
6 changes: 5 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,7 +2053,7 @@ def test(

test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)

return run_tests(
result = run_tests(
model_test_metadata=test_meta,
models=self._models,
config=self.config,
Expand All @@ -2066,6 +2066,10 @@ def test(
default_catalog_dialect=self.config.dialect or "",
)

self.console.log_test_results(result, output="", target_dialect=self.default_dialect)

return result

@python_api_analytics
def audit(
self,
Expand Down
84 changes: 78 additions & 6 deletions sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from pathlib import Path
from unittest.mock import patch

from rich.table import Table
from rich.align import Align

from io import StringIO
from sqlglot import Dialect, exp
from sqlglot.optimizer.annotate_types import annotate_types
Expand All @@ -24,6 +27,7 @@
from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime
from sqlmesh.utils.errors import ConfigError, TestError
from sqlmesh.utils.yaml import load as yaml_load
from sqlmesh.utils import Verbosity

if t.TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -60,6 +64,7 @@ def __init__(
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> None:
"""ModelTest encapsulates a unit test for a model.

Expand All @@ -83,6 +88,7 @@ def __init__(
self.default_catalog = default_catalog
self.dialect = dialect
self.concurrency = concurrency
self.verbosity = verbosity

self._fixture_table_cache: t.Dict[str, exp.Table] = {}
self._normalized_column_name_cache: t.Dict[str, str] = {}
Expand Down Expand Up @@ -134,6 +140,12 @@ def __init__(

super().__init__()

def defaultTestResult(self) -> unittest.TestResult:
from sqlmesh.core.test.result import ModelTextTestResult
import sys

return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)

def shortDescription(self) -> t.Optional[str]:
return self.body.get("description")

Expand Down Expand Up @@ -281,23 +293,41 @@ def _to_hashable(x: t.Any) -> t.Any:
check_like=True, # Ignore column order
)
except AssertionError as e:
args: t.List[t.Any] = []
if expected.shape != actual.shape:
_raise_if_unexpected_columns(expected.columns, actual.columns)

error_msg = "Data mismatch (rows are different)"
args.append("Data mismatch (rows are different)")

missing_rows = _row_difference(expected, actual)
if not missing_rows.empty:
error_msg += f"\n\nMissing rows:\n\n{missing_rows}"
args.append(df_to_table("Missing rows", missing_rows))

unexpected_rows = _row_difference(actual, expected)

if not unexpected_rows.empty:
error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
args.append(df_to_table("Unexpected rows", unexpected_rows))

e.args = (error_msg,)
else:
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",)
diff = expected.compare(actual).rename(
columns={"self": "Expected", "other": "Actual"}
)

if self.verbosity == Verbosity.DEFAULT:
args.append(df_to_table("Data mismatch", diff))
else:
from pandas import MultiIndex

levels = t.cast(MultiIndex, diff.columns).levels[0]
for col in levels:
col_diff = diff[col]
if not col_diff.empty:
table = df_to_table(
f"[bold red]Column '{col}' mismatch[/bold red]", col_diff
)
args.append(table)

Comment thread
georgesittas marked this conversation as resolved.
e.args = (*args,)

raise e

Expand All @@ -319,6 +349,7 @@ def create_test(
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> t.Optional[ModelTest]:
"""Create a SqlModelTest or a PythonModelTest.

Expand Down Expand Up @@ -364,6 +395,7 @@ def create_test(
preserve_fixtures,
default_catalog,
concurrency,
verbosity,
)
except Exception as e:
raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}")
Expand Down Expand Up @@ -683,6 +715,7 @@ def __init__(
preserve_fixtures: bool = False,
default_catalog: str | None = None,
concurrency: bool = False,
verbosity: Verbosity = Verbosity.DEFAULT,
) -> None:
"""PythonModelTest encapsulates a unit test for a Python model.

Expand All @@ -709,6 +742,7 @@ def __init__(
preserve_fixtures,
default_catalog,
concurrency,
verbosity,
)

self.context = TestExecutionContext(
Expand Down Expand Up @@ -942,3 +976,41 @@ def _normalize_df_value(value: t.Any) -> t.Any:
return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])}
return {k: _normalize_df_value(v) for k, v in value.items()}
return value


def df_to_table(
header: str,
df: pd.DataFrame,
show_index: bool = True,
index_name: str = "Row",
) -> Table:
"""Convert a pandas.DataFrame obj into a rich.Table obj.
Args:
df (DataFrame): A Pandas DataFrame to be converted to a rich Table.
rich_table (Table): A rich Table that should be populated by the DataFrame values.
show_index (bool): Add a column with a row count to the table. Defaults to True.
index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value.
Returns:
Table: The rich Table instance passed, populated with the DataFrame values."""

rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60)
if show_index:
index_name = str(index_name) if index_name else ""
rich_table.add_column(Align.center(index_name))

for column in df.columns:
column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column)
if "expected" in column_name.lower():
column_name = f"[green]{column_name}[/green]"
else:
column_name = f"[red]{column_name}[/red]"

rich_table.add_column(Align.center(column_name))

for index, value_list in enumerate(df.values.tolist()):
row = [str(index)] if show_index else []
row += [str(x) for x in value_list]
center = [Align.center(x) for x in row]
rich_table.add_row(*center)

return rich_table
Loading