diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index e272442e67..47c8abaea9 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -26,6 +26,7 @@ from rich.tree import Tree from sqlglot import exp +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary from sqlmesh.core.linter.rule import RuleViolation from sqlmesh.core.model import Model @@ -462,7 +463,7 @@ def plan( @abc.abstractmethod def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: """Display the test result and output. @@ -496,6 +497,12 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: def loading_stop(self, id: uuid.UUID) -> None: """Stop loading for the given id.""" + @abc.abstractmethod + def log_unit_test_results( + self, result: ModelTextTestResult, test_duration: t.Optional[float] = None + ) -> None: + """Print the unit test results.""" + class NoopConsole(Console): def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: @@ -669,7 +676,7 @@ def plan( plan_builder.apply() def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: pass @@ -777,6 +784,11 @@ def start_destroy(self) -> bool: def stop_destroy(self, success: bool = True) -> None: pass + def log_unit_test_results( + self, result: ModelTextTestResult, test_duration: t.Optional[float] = None + ) -> None: + pass + def make_progress_bar( message: str, @@ -1953,9 +1965,13 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None: plan_builder.apply() def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: divider_length = 70 + + self.log_unit_test_results(result) + + self._print("\n") if result.wasSuccessful(): self._print("=" * divider_length) self._print( @@ -1972,7 +1988,7 @@ def log_test_results( ) for test, _ in result.failures + result.errors: if isinstance(test, ModelTest): - self._print(f"Failure Test: {test.model.name} {test.test_name}") + self._print(f"Failure Test: {test.path}::{test.test_name}") self._print("=" * divider_length) self._print(output) @@ -2492,6 +2508,58 @@ def show_linter_violations( else: self.log_warning(msg) + def log_unit_test_results( + self, result: ModelTextTestResult, test_duration: t.Optional[float] = None + ) -> None: + tests_run = result.testsRun + errors = result.errors + failures = result.original_failures + skipped = result.skipped + + is_success = not (errors or failures) + + infos = [] + if failures: + infos.append(f"failures={len(failures)}") + if errors: + infos.append(f"errors={len(errors)}") + if skipped: + infos.append(f"skipped={skipped}") + + self._print("\n", end="") + + for test_case, failure in failures: + self._print(unittest.TextTestResult.separator1) + self._print(f"FAIL: {test_case}") + + if test_description := test_case.shortDescription(): + self._print(test_description) + self._print(f"{unittest.TextTestResult.separator2}") + + if exception := failure[1]: + for i, arg in enumerate(exception.args): + arg = f"Exception: {arg}" if isinstance(arg, str) else arg + self._print(arg) + + if i < len(exception.args) - 1: + self._print("\n") + + for test_case, error in errors: + self._print(unittest.TextTestResult.separator1) + self._print(f"ERROR: {test_case}") + self._print(f"{unittest.TextTestResult.separator2}") + self._print(error) + + # Output final report + self._print(unittest.TextTestResult.separator2) + test_duration_msg = f" in {test_duration:.3f}s" if test_duration else "" + self._print( + f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n" + ) + self._print( + f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}" + ) + def _cells_match(x: t.Any, y: t.Any) -> bool: """Helper function to compare two cells and returns true if they're equal, handling array objects.""" @@ -2764,7 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None: self.display(radio) def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: import ipywidgets as widgets @@ -3138,8 +3206,12 @@ def log_success(self, message: str) -> None: self._print(message) def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: + # self._print("```") + self.log_unit_test_results(result) + # self._print("```\n\n") + if result.wasSuccessful(): self._print( f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n" @@ -3151,6 +3223,7 @@ def log_test_results( for test, _ in result.failures + result.errors: if isinstance(test, ModelTest): self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n") + self._print(f"```{output}```\n\n") def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: @@ -3531,7 +3604,7 @@ def show_model_difference_summary( self._write(f" Modified: {modified}") def log_test_results( - self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str + self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str ) -> None: self._write("Test Results:", result) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 0450827d6e..a6bb17be57 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2053,7 +2053,7 @@ def test( test_meta = self.load_model_tests(tests=tests, patterns=match_patterns) - return run_tests( + result = run_tests( model_test_metadata=test_meta, models=self._models, config=self.config, @@ -2066,6 +2066,10 @@ def test( default_catalog_dialect=self.config.dialect or "", ) + self.console.log_test_results(result, output="", target_dialect=self.default_dialect) + + return result + @python_api_analytics def audit( self, diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index a766706801..bcd8d60fdd 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -10,6 +10,9 @@ from pathlib import Path from unittest.mock import patch +from rich.table import Table +from rich.align import Align + from io import StringIO from sqlglot import Dialect, exp from sqlglot.optimizer.annotate_types import annotate_types @@ -24,6 +27,7 @@ from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime from sqlmesh.utils.errors import ConfigError, TestError from sqlmesh.utils.yaml import load as yaml_load +from sqlmesh.utils import Verbosity if t.TYPE_CHECKING: import pandas as pd @@ -60,6 +64,7 @@ def __init__( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """ModelTest encapsulates a unit test for a model. @@ -83,6 +88,7 @@ def __init__( self.default_catalog = default_catalog self.dialect = dialect self.concurrency = concurrency + self.verbosity = verbosity self._fixture_table_cache: t.Dict[str, exp.Table] = {} self._normalized_column_name_cache: t.Dict[str, str] = {} @@ -134,6 +140,12 @@ def __init__( super().__init__() + def defaultTestResult(self) -> unittest.TestResult: + from sqlmesh.core.test.result import ModelTextTestResult + import sys + + return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity) + def shortDescription(self) -> t.Optional[str]: return self.body.get("description") @@ -281,23 +293,41 @@ def _to_hashable(x: t.Any) -> t.Any: check_like=True, # Ignore column order ) except AssertionError as e: + args: t.List[t.Any] = [] if expected.shape != actual.shape: _raise_if_unexpected_columns(expected.columns, actual.columns) - error_msg = "Data mismatch (rows are different)" + args.append("Data mismatch (rows are different)") missing_rows = _row_difference(expected, actual) if not missing_rows.empty: - error_msg += f"\n\nMissing rows:\n\n{missing_rows}" + args.append(df_to_table("Missing rows", missing_rows)) unexpected_rows = _row_difference(actual, expected) + if not unexpected_rows.empty: - error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}" + args.append(df_to_table("Unexpected rows", unexpected_rows)) - e.args = (error_msg,) else: - diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) - e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",) + diff = expected.compare(actual).rename( + columns={"self": "Expected", "other": "Actual"} + ) + + if self.verbosity == Verbosity.DEFAULT: + args.append(df_to_table("Data mismatch", diff)) + else: + from pandas import MultiIndex + + levels = t.cast(MultiIndex, diff.columns).levels[0] + for col in levels: + col_diff = diff[col] + if not col_diff.empty: + table = df_to_table( + f"[bold red]Column '{col}' mismatch[/bold red]", col_diff + ) + args.append(table) + + e.args = (*args,) raise e @@ -319,6 +349,7 @@ def create_test( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> t.Optional[ModelTest]: """Create a SqlModelTest or a PythonModelTest. @@ -364,6 +395,7 @@ def create_test( preserve_fixtures, default_catalog, concurrency, + verbosity, ) except Exception as e: raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}") @@ -683,6 +715,7 @@ def __init__( preserve_fixtures: bool = False, default_catalog: str | None = None, concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """PythonModelTest encapsulates a unit test for a Python model. @@ -709,6 +742,7 @@ def __init__( preserve_fixtures, default_catalog, concurrency, + verbosity, ) self.context = TestExecutionContext( @@ -942,3 +976,41 @@ def _normalize_df_value(value: t.Any) -> t.Any: return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])} return {k: _normalize_df_value(v) for k, v in value.items()} return value + + +def df_to_table( + header: str, + df: pd.DataFrame, + show_index: bool = True, + index_name: str = "Row", +) -> Table: + """Convert a pandas.DataFrame obj into a rich.Table obj. + Args: + df (DataFrame): A Pandas DataFrame to be converted to a rich Table. + rich_table (Table): A rich Table that should be populated by the DataFrame values. + show_index (bool): Add a column with a row count to the table. Defaults to True. + index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value. + Returns: + Table: The rich Table instance passed, populated with the DataFrame values.""" + + rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60) + if show_index: + index_name = str(index_name) if index_name else "" + rich_table.add_column(Align.center(index_name)) + + for column in df.columns: + column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column) + if "expected" in column_name.lower(): + column_name = f"[green]{column_name}[/green]" + else: + column_name = f"[red]{column_name}[/red]" + + rich_table.add_column(Align.center(column_name)) + + for index, value_list in enumerate(df.values.tolist()): + row = [str(index)] if show_index else [] + row += [str(x) for x in value_list] + center = [Align.center(x) for x in row] + rich_table.add_row(*center) + + return rich_table diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index cdba66b612..62b5754753 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -19,6 +19,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any): self.successes = [] self.original_failures: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] self.original_errors: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] + self.duration: t.Optional[float] = None def addSubTest( self, @@ -76,50 +77,13 @@ def addSuccess(self, test: unittest.TestCase) -> None: super().addSuccess(test) self.successes.append(test) - def log_test_report(self, test_duration: float) -> None: + def log_test_report(self, test_duration: t.Optional[float] = None) -> None: """ Log the test report following unittest's conventions. Args: test_duration: The duration of the tests. """ - tests_run = self.testsRun - errors = self.errors - failures = self.failures - skipped = self.skipped - - is_success = not (errors or failures) - - infos = [] - if failures: - infos.append(f"failures={len(failures)}") - if errors: - infos.append(f"errors={len(errors)}") - if skipped: - infos.append(f"skipped={skipped}") - - stream = self.stream - - stream.write("\n") - - for test_case, failure in failures: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"FAIL: {test_case}") - if test_description := test_case.shortDescription(): - stream.writeln(test_description) - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln(failure) - - for test_case, error in errors: - stream.writeln(unittest.TextTestResult.separator1) - stream.writeln(f"ERROR: {test_case}") - stream.writeln(error) - - # Output final report - stream.writeln(unittest.TextTestResult.separator2) - stream.writeln( - f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n" - ) - stream.writeln( - f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}" - ) + from sqlmesh.core.console import get_console + + get_console().log_unit_test_results(self, test_duration) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py index d2a54d68e8..79235492fd 100644 --- a/sqlmesh/core/test/runner.py +++ b/sqlmesh/core/test/runner.py @@ -107,7 +107,7 @@ def run_tests( lock = threading.Lock() combined_results = ModelTextTestResult( - stream=unittest.runner._WritelnDecorator(stream or sys.stderr), # type: ignore + stream=unittest.runner._WritelnDecorator(stream or sys.stdout), # type: ignore verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, descriptions=True, ) @@ -136,6 +136,7 @@ def _run_single_test( default_catalog=default_catalog, preserve_fixtures=preserve_fixtures, concurrency=num_workers > 1, + verbosity=verbosity, ) if not test: @@ -183,6 +184,6 @@ def _run_single_test( end_time = time.perf_counter() - combined_results.log_test_report(test_duration=end_time - start_time) + combined_results.duration = end_time - start_time return combined_results diff --git a/tests/core/test_test.py b/tests/core/test_test.py index a05e66e48f..067a014a57 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -38,7 +38,7 @@ from tests.utils.test_helpers import use_terminal_console if t.TYPE_CHECKING: - from unittest import TestResult + pass pytestmark = pytest.mark.slow @@ -76,16 +76,26 @@ def _create_model( ) +@use_terminal_console def _check_successful_or_raise( - result: t.Optional[TestResult], expected_msg: t.Optional[str] = None + test_or_result: t.Union[ModelTest, ModelTextTestResult], expected_msg: t.Optional[str] = None ) -> None: - assert result is not None + if isinstance(test_or_result, ModelTextTestResult): + result = test_or_result + test_output = "" + else: + with capture_output() as output: + result = t.cast(ModelTextTestResult, test_or_result.run()) + assert result is not None + result.log_test_report() + + test_output = output.stdout + if not result.wasSuccessful(): - error_or_failure_traceback = (result.errors or result.failures)[0][1] if expected_msg: - assert expected_msg in error_or_failure_traceback + assert expected_msg in test_output else: - raise AssertionError(error_or_failure_traceback) + raise AssertionError(test_output) @pytest.fixture @@ -149,7 +159,7 @@ def test_ctes(sushi_context: Context, full_model_with_two_ctes: SqlModel) -> Non test_name="test_foo", model=sushi_context.upsert_model(full_model_with_two_ctes), context=sushi_context, - ).run() + ) ) @@ -177,7 +187,7 @@ def test_ctes_only(sushi_context: Context, full_model_with_two_ctes: SqlModel) - test_name="test_foo", model=sushi_context.upsert_model(full_model_with_two_ctes), context=sushi_context, - ).run() + ) ) @@ -202,7 +212,7 @@ def test_query_only(sushi_context: Context, full_model_with_two_ctes: SqlModel) test_name="test_foo", model=sushi_context.upsert_model(full_model_with_two_ctes), context=sushi_context, - ).run() + ) ) @@ -233,7 +243,7 @@ def test_with_rows(sushi_context: Context, full_model_with_single_cte: SqlModel) test_name="test_foo", model=sushi_context.upsert_model(full_model_with_single_cte), context=sushi_context, - ).run() + ) ) @@ -261,7 +271,7 @@ def test_without_rows(sushi_context: Context, full_model_with_single_cte: SqlMod test_name="test_foo", model=sushi_context.upsert_model(full_model_with_single_cte), context=sushi_context, - ).run() + ) ) @@ -290,7 +300,7 @@ def test_column_order(sushi_context: Context, full_model_without_ctes: SqlModel) test_name="test_foo", model=sushi_context.upsert_model(full_model_without_ctes), context=sushi_context, - ).run() + ) ) @@ -329,7 +339,7 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> test_name="test_foo", model=sushi_context.upsert_model(full_model_without_ctes), context=sushi_context, - ).run() + ) ) full_model_without_ctes_dict = full_model_without_ctes.dict() @@ -343,13 +353,17 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> test_name="test_foo", model=sushi_context.upsert_model(full_model_without_ctes_orderby), context=sushi_context, - ).run(), + ), expected_msg=( - "AssertionError: Data mismatch (exp: expected, act: actual)\n\n" - " id value ds \n" - " exp act exp act exp act\n" - "0 2 1 3 2 4 3\n" - "1 1 2 2 3 3 4\n" + """Data mismatch +┏━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┓ +┃ ┃ id: ┃ id: ┃ value: ┃ value: ┃ ds: ┃ ds: ┃ +┃ Row ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ +┡━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━┩ +│ 0 │ 2 │ 1 │ 3 │ 2 │ 4 │ 3 │ +├─────┼───────────┼───────────┼───────────┼───────────┼────────────┼───────────┤ +│ 1 │ 1 │ 2 │ 2 │ 3 │ 3 │ 4 │ +└─────┴───────────┴───────────┴───────────┴───────────┴────────────┴───────────┘""" ), ) @@ -386,12 +400,14 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> test_name="test_array_order", model=_create_model(model_sql), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( - """AssertionError: Data mismatch (exp: expected, act: actual)\n\n""" - " aggregated_duplicates \n" - " exp act\n" - "0 (c, b) (b, c)\n" + """Data mismatch +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ aggregated_duplicates: Expected ┃ aggregated_duplicates: Actual ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ ('c', 'b') │ ('b', 'c') │ +└─────┴─────────────────────────────────┴───────────────────────────────┘""" ), ) @@ -451,7 +467,7 @@ def test_partial_data(sushi_context: Context, waiter_names_input: str) -> None: ) ), context=sushi_context, - ).run() + ) ) @@ -507,7 +523,7 @@ def test_format_inline(sushi_context: Context, waiter_names_input: str) -> None: ) ), context=sushi_context, - ).run() + ) ) @@ -577,7 +593,7 @@ def test_format_path( ) ), context=sushi_context, - ).run() + ) ) @@ -686,7 +702,7 @@ def test_partial_output_columns() -> None: test_name="test_foo", model=_create_model("WITH t AS (SELECT a, b, c, d FROM raw) SELECT a, b, c, d FROM t"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) _check_successful_or_raise( @@ -725,7 +741,7 @@ def test_partial_output_columns() -> None: test_name="test_foo", model=_create_model("WITH t AS (SELECT a, b, c, d FROM raw) SELECT a, b, c, d FROM t"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -760,7 +776,7 @@ def test_partial_data_column_order(sushi_context: Context) -> None: ) ), context=sushi_context, - ).run() + ) ) @@ -786,7 +802,7 @@ def test_partial_data_missing_schemas(sushi_context: Context) -> None: test_name="test_foo", model=sushi_context.upsert_model(_create_model("SELECT * FROM unknown")), context=sushi_context, - ).run() + ) ) _check_successful_or_raise( _create_test( @@ -819,7 +835,7 @@ def test_partial_data_missing_schemas(sushi_context: Context) -> None: ) ), context=sushi_context, - ).run() + ) ) @@ -860,7 +876,7 @@ def test_partially_inferred_schemas(sushi_context: Context, mocker: MockerFixtur test = _create_test(body, "test_child", child, sushi_context) spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_any_call( 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."memory__sushi__parent" ("s", "a", "b") AS ' @@ -890,12 +906,13 @@ def test_uninferrable_schema() -> None: test_name="test_foo", model=_create_model("SELECT value FROM raw"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( - """Failed to infer the data type of column 'value' for '"raw"'. This issue can """ - "be mitigated by casting the column in the model definition, setting its type in " - "external_models.yaml if it's an external model, setting the model's 'columns' property, " - "or setting its 'columns' mapping in the test itself\n" + "sqlmesh.utils.errors.TestError: Failed to run test:\n" + """Failed to infer the data type of column 'value' for '"raw"'. This issue can be \n""" + """mitigated by casting the column in the model definition, setting its type in \n""" + """external_models.yaml if it's an external model, setting the model's 'columns' \n""" + """property, or setting its 'columns' mapping in the test itself\n\n""" ), ) @@ -922,12 +939,14 @@ def test_missing_column_failure(sushi_context: Context, full_model_without_ctes: test_name="test_foo", model=sushi_context.upsert_model(full_model_without_ctes), context=sushi_context, - ).run(), + ), expected_msg=( - "AssertionError: Data mismatch (exp: expected, act: actual)\n\n" - " value ds \n" - " exp act exp act\n" - "0 None 2 None 3\n" + """Data mismatch +┏━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ Row ┃ value: Expected ┃ value: Actual ┃ ds: Expected ┃ ds: Actual ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ 0 │ None │ 2 │ None │ 3 │ +└─────┴─────────────────┴───────────────┴──────────────┴────────────┘""" ), ) @@ -951,12 +970,17 @@ def test_row_difference_failure() -> None: test_name="test_foo", model=_create_model("SELECT value FROM raw"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( - "AssertionError: Data mismatch (rows are different)\n\n" - "Missing rows:\n\n" - " value\n" - "0 2\n" + """Data mismatch (rows are different) + + + Missing rows +┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ value ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2 │ +└────────────────────────┴─────────────────────────────────┘""" ), ) _check_successful_or_raise( @@ -978,12 +1002,17 @@ def test_row_difference_failure() -> None: "SELECT value FROM raw UNION ALL SELECT value + 1 AS value FROM raw" ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( - "AssertionError: Data mismatch (rows are different)\n\n" - "Unexpected rows:\n\n" - " value\n" - "0 2\n" + """Data mismatch (rows are different) + + + Unexpected rows +┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ value ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2 │ +└────────────────────────┴─────────────────────────────────┘""" ), ) _check_successful_or_raise( @@ -1007,16 +1036,27 @@ def test_row_difference_failure() -> None: "SELECT value FROM raw UNION ALL SELECT value + 1 AS value FROM raw" ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( - "AssertionError: Data mismatch (rows are different)\n\n" - "Missing rows:\n\n" - " value\n" - "0 3\n" - "1 4\n\n" - "Unexpected rows:\n\n" - " value\n" - "0 2\n" + """Data mismatch (rows are different) + + + Missing rows +┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ value ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 3 │ +├────────────────────────┼─────────────────────────────────┤ +│ 1 │ 4 │ +└────────────────────────┴─────────────────────────────────┘ + + + Unexpected rows +┏━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ value ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2 │ +└────────────────────────┴─────────────────────────────────┘""" ), ) @@ -1040,7 +1080,7 @@ def test_unknown_column_error() -> None: test_name="test_foo", model=_create_model("SELECT id, value FROM raw"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run(), + ), expected_msg=( "sqlmesh.utils.errors.TestError: Failed to run test:\n" "Detected unknown column(s)\n\n" @@ -1070,7 +1110,7 @@ def test_empty_rows(sushi_context: Context) -> None: ) ), context=sushi_context, - ).run() + ) ) _check_successful_or_raise( @@ -1093,7 +1133,7 @@ def test_empty_rows(sushi_context: Context) -> None: _create_model("SELECT x FROM b", default_catalog="memory") ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -1161,7 +1201,7 @@ def test_source_func() -> None: """ ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -1200,7 +1240,7 @@ def test_nested_data_types(sushi_context: Context) -> None: "SELECT array1, array2, struct FROM sushi.raw", default_catalog="memory" ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -1228,7 +1268,7 @@ def test_freeze_time(mocker: MockerFixture) -> None: ) spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_has_calls( [ @@ -1261,7 +1301,7 @@ def test_freeze_time(mocker: MockerFixture) -> None: ) spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_has_calls( [call('''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''')] @@ -1293,7 +1333,7 @@ def execute(context, start, end, execution_time, **kwargs): test_name="test_py_model", model=model.get_registry()["py_model"].model(module_path=Path("."), path=Path(".")), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -1324,7 +1364,7 @@ def test_create_external_model_fixture(sushi_context: Context, mocker: MockerFix model=_create_model("SELECT x FROM c.db.external"), context=sushi_context, ) - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) assert len(test._fixture_table_cache) == 1 for table in test._fixture_table_cache.values(): @@ -1351,7 +1391,7 @@ def test_macro(evaluator: MacroEvaluator) -> t.List[bool]: test_name="test_foo", model=_create_model("SELECT [@test_macro()] AS c"), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -1445,7 +1485,7 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> test_name="test_example_full_model_alt", model=context.get_model("sqlmesh_example.full_model"), context=context, - ).run() + ) ) _check_successful_or_raise( @@ -1471,7 +1511,7 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> test_name="test_example_full_model_partial", model=context.get_model("sqlmesh_example.full_model"), context=context, - ).run() + ) ) _check_successful_or_raise( @@ -1498,7 +1538,7 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> test_name="test_example_full_model_partial", model=context.get_model("sqlmesh_example.full_model"), context=context, - ).run() + ) ) mocker.patch("sqlmesh.core.test.definition.random_id", return_value="jzngz56a") @@ -1520,7 +1560,7 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), ) spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_any_call( 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' @@ -1586,7 +1626,7 @@ def execute(context, start, end, execution_time, **kwargs): module_path=Path("."), path=Path(".") ), context=context, - ).run() + ) ) @@ -1716,7 +1756,7 @@ def test_custom_testing_schema(mocker: MockerFixture) -> None: ) spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_has_calls( [ @@ -1745,7 +1785,7 @@ def test_pretty_query(mocker: MockerFixture) -> None: ) test.engine_adapter._pretty_sql = True spy_execute = mocker.spy(test.engine_adapter, "_execute") - _check_successful_or_raise(test.run()) + _check_successful_or_raise(test) spy_execute.assert_has_calls( [ call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"'), @@ -1845,7 +1885,7 @@ def test_complicated_recursive_cte() -> None: test_name="test_recursive_ctes", model=_create_model(model_sql), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), - ).run() + ) ) @@ -2218,6 +2258,7 @@ def test_test_with_resolve_template_macro(tmp_path: Path): _check_successful_or_raise(context.test()) +@use_terminal_console def test_test_output(tmp_path: Path) -> None: init_example_project(tmp_path, dialect="duckdb") @@ -2259,25 +2300,23 @@ def test_test_output(tmp_path: Path) -> None: with capture_output() as output: context.test() - # Order may change due to concurrent execution - assert "F." in output.stderr or ".F" in output.stderr - assert ( - f"""====================================================================== -FAIL: test_example_full_model ({new_test_file}) -This is a test ----------------------------------------------------------------------- -AssertionError: Data mismatch (exp: expected, act: actual) + stdout = output.stdout - num_orders - exp act -1 2.0 1.0 + # Order may change due to concurrent execution + assert "F." in stdout or ".F" in stdout -----------------------------------------------------------------------""" - in output.stderr + assert ( + """Data mismatch +┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ num_orders: Expected ┃ num_orders: Actual ┃ +┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2.0 │ 1.0 │ +└──────┴───────────────────────────┴───────────────────────┘""" + in stdout ) - assert "Ran 2 tests" in output.stderr - assert "FAILED (failures=1)" in output.stderr + assert "Ran 2 tests" in stdout + assert "FAILED (failures=1)" in stdout # Case 2: Assert that concurrent execution is working properly for i in range(50): @@ -2287,8 +2326,9 @@ def test_test_output(tmp_path: Path) -> None: with capture_output() as output: context.test() - assert "Ran 102 tests" in output.stderr - assert "FAILED (failures=51)" in output.stderr + stdout = output.stdout + assert "Ran 102 tests" in stdout + assert "FAILED (failures=51)" in stdout @use_terminal_console @@ -2335,10 +2375,10 @@ def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: in mock_logger.call_args[0][0] ) assert ( - ".\n----------------------------------------------------------------------\nRan 1 test in" - in output.stderr + ".\n----------------------------------------------------------------------\n\nRan 1 test in" + in output.stdout ) - assert "OK" in output.stderr + assert "OK" in output.stdout def test_number_of_tests_found(tmp_path: Path) -> None: @@ -2549,10 +2589,11 @@ def upstream_table_python(context, **kwargs): module_path=Path("."), path=Path(".") ), context=sushi_context, - ).run() + ) ) +@use_terminal_console @pytest.mark.parametrize("is_error", [True, False]) def test_model_test_text_result_reporting_no_traceback( sushi_context: Context, full_model_with_two_ctes: SqlModel, is_error: bool @@ -2596,10 +2637,10 @@ def test_model_test_text_result_reporting_no_traceback( else: result.addFailure(test, (e.__class__, e, e.__traceback__)) - result.log_test_report(0) + with capture_output() as captured_output: + result.log_test_report(0) - stream.seek(0) - output = stream.read() + output = captured_output.stdout # Make sure that the traceback is not printed assert "Traceback" not in output