diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index 5cbfc6a71c..fc44d8f356 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -130,7 +130,6 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig) def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: return BuiltInPlanEvaluator( state_sync=context.state_sync, - snapshot_evaluator=context.snapshot_evaluator, create_scheduler=context.create_scheduler, default_catalog=context.default_catalog, console=context.console, diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 0317aad894..402ed22fee 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -116,7 +116,7 @@ run_tests, ) from sqlmesh.core.user import User -from sqlmesh.utils import UniqueKeyDict, Verbosity +from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( @@ -418,7 +418,7 @@ def __init__( self.config.get_state_connection(self.gateway) or self.connection_config ) - self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None + self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {} self.console = get_console() setattr(self.console, "dialect", self.config.dialect) @@ -446,18 +446,22 @@ def engine_adapter(self) -> EngineAdapter: self._engine_adapter = self.connection_config.create_engine_adapter() return self._engine_adapter - @property - def snapshot_evaluator(self) -> SnapshotEvaluator: - if not self._snapshot_evaluator: - self._snapshot_evaluator = SnapshotEvaluator( + def snapshot_evaluator( + self, correlation_id: t.Optional[CorrelationId] = None + ) -> SnapshotEvaluator: + # Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations + if correlation_id not in self._snapshot_evaluators: + self._snapshot_evaluators[correlation_id] = SnapshotEvaluator( { - gateway: adapter.with_log_level(logging.INFO) + gateway: adapter.with_settings( + log_level=logging.INFO, correlation_id=correlation_id + ) for gateway, adapter in self.engine_adapters.items() }, ddl_concurrent_tasks=self.concurrent_tasks, selected_gateway=self.selected_gateway, ) - return self._snapshot_evaluator + return self._snapshot_evaluators[correlation_id] def execution_context( self, @@ -538,7 +542,9 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: return self.create_scheduler(snapshots) - def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler: + def create_scheduler( + self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None + ) -> Scheduler: """Creates the built-in scheduler. Args: @@ -549,7 +555,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler: """ return Scheduler( snapshots, - self.snapshot_evaluator, + self.snapshot_evaluator(correlation_id), self.state_sync, default_catalog=self.default_catalog, max_workers=self.concurrent_tasks, @@ -714,7 +720,7 @@ def run( NotificationEvent.RUN_START, environment=environment ) analytics_run_id = analytics.collector.on_run_start( - engine_type=self.snapshot_evaluator.adapter.dialect, + engine_type=self.snapshot_evaluator().adapter.dialect, state_sync_type=self.state_sync.state_type(), ) self._load_materializations() @@ -1076,7 +1082,7 @@ def evaluate( and not parent_snapshot.categorized ] - df = self.snapshot_evaluator.evaluate_and_fetch( + df = self.snapshot_evaluator().evaluate_and_fetch( snapshot, start=start, end=end, @@ -1588,7 +1594,12 @@ def apply( default_catalog=self.default_catalog, console=self.console, ) - explainer.evaluate(plan.to_evaluatable()) + explainer.evaluate( + plan.to_evaluatable(), + snapshot_evaluator=self.snapshot_evaluator( + correlation_id=CorrelationId.from_plan_id(plan.plan_id) + ), + ) return self.notification_target_manager.notify( @@ -1902,7 +1913,7 @@ def _table_diff( ) return TableDiff( - adapter=adapter.with_log_level(logger.getEffectiveLevel()), + adapter=adapter.with_settings(logger.getEffectiveLevel()), source=source, target=target, on=on, @@ -2111,7 +2122,7 @@ def audit( errors = [] skipped_count = 0 for snapshot in snapshots: - for audit_result in self.snapshot_evaluator.audit( + for audit_result in self.snapshot_evaluator().audit( snapshot=snapshot, start=start, end=end, @@ -2143,7 +2154,7 @@ def audit( self.console.log_status_update(f"Got {error.count} results, expected 0.") if error.query: self.console.show_sql( - f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}" + f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}" ) self.console.log_status_update("Done.") @@ -2335,11 +2346,14 @@ def print_environment_names(self) -> None: def close(self) -> None: """Releases all resources allocated by this context.""" - if self._snapshot_evaluator: - self._snapshot_evaluator.close() + for evaluator in self._snapshot_evaluators.values(): + evaluator.close() + if self._state_sync: self._state_sync.close() + self._snapshot_evaluators.clear() + def _run( self, environment: str, @@ -2390,7 +2404,11 @@ def _run( def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None: self._scheduler.create_plan_evaluator(self).evaluate( - plan.to_evaluatable(), circuit_breaker=circuit_breaker + plan.to_evaluatable(), + snapshot_evaluator=self.snapshot_evaluator( + correlation_id=CorrelationId.from_plan_id(plan.plan_id) + ), + circuit_breaker=circuit_breaker, ) @python_api_analytics @@ -2683,7 +2701,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None: ) # Remove the expired snapshots tables - self.snapshot_evaluator.cleanup( + self.snapshot_evaluator().cleanup( target_snapshots=cleanup_targets, on_complete=self.console.update_cleanup_progress, ) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 8e0f3d84f7..88ab9b2c5d 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -46,7 +46,7 @@ def __init__( self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any ): # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config - # which means that EngineAdapter.with_log_level() keeps this property when it makes a clone + # which means that EngineAdapter.with_settings() keeps this property when it makes a clone super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) self.s3_warehouse_location = s3_warehouse_location diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 924aca8c99..591d81c9ae 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -39,7 +39,7 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import columns_to_types_all_known, random_id +from sqlmesh.utils import columns_to_types_all_known, random_id, CorrelationId from sqlmesh.utils.connection_pool import create_connection_pool, ConnectionPool from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column from sqlmesh.utils.errors import ( @@ -123,6 +123,7 @@ def __init__( pre_ping: bool = False, pretty_sql: bool = False, shared_connection: bool = False, + correlation_id: t.Optional[CorrelationId] = None, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT @@ -144,19 +145,21 @@ def __init__( self._pre_ping = pre_ping self._pretty_sql = pretty_sql self._multithreaded = multithreaded + self.correlation_id = correlation_id - def with_log_level(self, level: int) -> EngineAdapter: + def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter: adapter = self.__class__( self._connection_pool, dialect=self.dialect, sql_gen_kwargs=self._sql_gen_kwargs, default_catalog=self._default_catalog, - execute_log_level=level, + execute_log_level=log_level, register_comments=self._register_comments, null_connection=True, multithreaded=self._multithreaded, pretty_sql=self._pretty_sql, **self._extra_config, + **kwargs, ) return adapter @@ -2211,6 +2214,9 @@ def execute( else: sql = t.cast(str, e) + if self.correlation_id: + sql = f"/* {self.correlation_id} */ {sql}" + self._log_sql( sql, expression=e if isinstance(e, exp.Expression) else None, diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index d959fd27a4..562f2ed60e 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -49,7 +49,10 @@ class PlanEvaluator(abc.ABC): @abc.abstractmethod def evaluate( - self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None + self, + plan: EvaluatablePlan, + snapshot_evaluator: SnapshotEvaluator, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: """Evaluates a plan by pushing snapshots and backfilling data. @@ -60,6 +63,8 @@ def evaluate( Args: plan: The plan to evaluate. + snapshot_evaluator: The snapshot evaluator to use. + circuit_breaker: The circuit breaker to use. """ @@ -67,13 +72,11 @@ class BuiltInPlanEvaluator(PlanEvaluator): def __init__( self, state_sync: StateSync, - snapshot_evaluator: SnapshotEvaluator, create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler], default_catalog: t.Optional[str], console: t.Optional[Console] = None, ): self.state_sync = state_sync - self.snapshot_evaluator = snapshot_evaluator self.create_scheduler = create_scheduler self.default_catalog = default_catalog self.console = console or get_console() @@ -82,9 +85,12 @@ def __init__( def evaluate( self, plan: EvaluatablePlan, + snapshot_evaluator: SnapshotEvaluator, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: self._circuit_breaker = circuit_breaker + self.snapshot_evaluator = snapshot_evaluator + self.console.start_plan_evaluation(plan) analytics.collector.on_plan_apply_start( plan=plan, diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index 4d1ee2256d..d3c6480f74 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -20,6 +20,7 @@ from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator logger = logging.getLogger(__name__) @@ -37,7 +38,10 @@ def __init__( self.console = console or get_console() def evaluate( - self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None + self, + plan: EvaluatablePlan, + snapshot_evaluator: SnapshotEvaluator, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog) explainer_console = _get_explainer_console( diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index f102f23292..80e4fa5934 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -13,6 +13,7 @@ import types import typing as t import uuid +from dataclasses import dataclass from collections import defaultdict from contextlib import contextmanager from copy import deepcopy @@ -382,3 +383,23 @@ def to_snake_case(name: str) -> str: return "".join( f"_{c.lower()}" if c.isupper() and idx != 0 else c.lower() for idx, c in enumerate(name) ) + + +class JobType(Enum): + PLAN = "SQLMESH_PLAN" + RUN = "SQLMESH_RUN" + + +@dataclass(frozen=True) +class CorrelationId: + """ID that is added to each query in order to identify the job that created it.""" + + job_type: JobType + job_id: str + + def __str__(self) -> str: + return f"{self.job_type.value}: {self.job_id}" + + @classmethod + def from_plan_id(cls, plan_id: str) -> CorrelationId: + return CorrelationId(JobType.PLAN, plan_id) diff --git a/tests/conftest.py b/tests/conftest.py index 574c802c0e..a874bd7590 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,7 +42,7 @@ SnapshotDataVersion, SnapshotFingerprint, ) -from sqlmesh.utils import random_id +from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.date import TimeLike, to_date from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path from sqlmesh.core.engine_adapter.shared import CatalogSupport @@ -266,10 +266,12 @@ def duck_conn() -> duckdb.DuckDBPyConnection: def push_plan(context: Context, plan: Plan) -> None: plan_evaluator = BuiltInPlanEvaluator( context.state_sync, - context.snapshot_evaluator, context.create_scheduler, context.default_catalog, ) + plan_evaluator.snapshot_evaluator = context.snapshot_evaluator( + CorrelationId.from_plan_id(plan.plan_id) + ) deployability_index = DeployabilityIndex.create(context.snapshots.values()) evaluatable_plan = plan.to_evaluatable() stages = plan_stages.build_plan_stages( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 766a788ac8..f68cb7ac47 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -67,6 +67,7 @@ SnapshotInfoLike, SnapshotTableInfo, ) +from sqlmesh.utils import CorrelationId from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError from sqlmesh.utils.pydantic import validate_string @@ -1137,7 +1138,7 @@ def test_non_breaking_change_after_forward_only_in_dev( init_and_plan_context: t.Callable, has_view_binding: bool ): context, plan = init_and_plan_context("examples/sushi") - context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding + context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding context.apply(plan) model = context.get_model("sushi.waiter_revenue_by_day") @@ -6793,3 +6794,29 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call # valid_from should be the epoch, valid_to should be NaT assert str(row["valid_from"]) == "1970-01-01 00:00:00" assert pd.isna(row["valid_to"]) + + +def test_plan_evaluator_correlation_id(tmp_path: Path): + def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger): + sqls = [call[0][0] for call in mock_logger.call_args_list] + return any(f"/* {correlation_id} */" in sql for sql in sqls) + + create_temp_file( + tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col" + ) + + # Case 1: Ensure that the correlation id (plan_id) is included in the SQL + with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: + ctx = Context(paths=[tmp_path], config=Config()) + plan = ctx.plan(auto_apply=True, no_prompts=True) + + correlation_id = CorrelationId.from_plan_id(plan.plan_id) + assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}" + + assert _correlation_id_in_sqls(correlation_id, mock_logger) + + # Case 2: Ensure that the previous correlation id is not included in the SQL for other operations + with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: + ctx.snapshot_evaluator().adapter.execute("SELECT 1") + + assert not _correlation_id_in_sqls(correlation_id, mock_logger) diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index a784644b6b..467c3e60bd 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -11,6 +11,7 @@ stages as plan_stages, ) from sqlmesh.core.snapshot import SnapshotChangeCategory +from sqlmesh.utils import CorrelationId @pytest.fixture @@ -59,11 +60,14 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): evaluator = BuiltInPlanEvaluator( sushi_context.state_sync, - sushi_context.snapshot_evaluator, sushi_context.create_scheduler, sushi_context.default_catalog, console=sushi_context.console, ) + evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator( + CorrelationId.from_plan_id(plan.plan_id) + ) + evaluatable_plan = plan.to_evaluatable() stages = plan_stages.build_plan_stages( evaluatable_plan, sushi_context.state_sync, sushi_context.default_catalog diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index ee4ab0ac73..1b5c39e2dd 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -335,11 +335,11 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"' drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' - # make with_log_level() return the current instance of engine_adapter so we can still spy on _execute + # make with_settings() return the current instance of engine_adapter so we can still spy on _execute mocker.patch.object( - engine_adapter, "with_log_level", new_callable=lambda: lambda _: engine_adapter + engine_adapter, "with_settings", new_callable=lambda: lambda _: engine_adapter ) - assert engine_adapter.with_log_level(1) == engine_adapter + assert engine_adapter.with_settings(1) == engine_adapter spy_execute = mocker.spy(engine_adapter, "_execute") mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")