Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/core/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
return BuiltInPlanEvaluator(
state_sync=context.state_sync,
snapshot_evaluator=context.snapshot_evaluator,
create_scheduler=context.create_scheduler,
default_catalog=context.default_catalog,
console=context.console,
Expand Down
55 changes: 20 additions & 35 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
run_tests,
)
from sqlmesh.core.user import User
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
from sqlmesh.utils import UniqueKeyDict, Verbosity
from sqlmesh.utils.concurrency import concurrent_apply_to_values
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import (
Expand Down Expand Up @@ -417,7 +417,7 @@ def __init__(
self.config.get_state_connection(self.gateway) or self.connection_config
)

self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None

self.console = get_console()
setattr(self.console, "dialect", self.config.dialect)
Expand Down Expand Up @@ -445,22 +445,18 @@ def engine_adapter(self) -> EngineAdapter:
self._engine_adapter = self.connection_config.create_engine_adapter()
return self._engine_adapter

def snapshot_evaluator(
self, correlation_id: t.Optional[CorrelationId] = None
) -> SnapshotEvaluator:
# Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations
if correlation_id not in self._snapshot_evaluators:
self._snapshot_evaluators[correlation_id] = SnapshotEvaluator(
@property
def snapshot_evaluator(self) -> SnapshotEvaluator:
if not self._snapshot_evaluator:
self._snapshot_evaluator = SnapshotEvaluator(
{
gateway: adapter.with_settings(
log_level=logging.INFO, correlation_id=correlation_id
)
gateway: adapter.with_settings(log_level=logging.INFO)
for gateway, adapter in self.engine_adapters.items()
},
ddl_concurrent_tasks=self.concurrent_tasks,
selected_gateway=self.selected_gateway,
)
return self._snapshot_evaluators[correlation_id]
return self._snapshot_evaluator

def execution_context(
self,
Expand Down Expand Up @@ -539,10 +535,10 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
if not snapshots:
raise ConfigError("No models were found")

return self.create_scheduler(snapshots)
return self.create_scheduler(snapshots, self.snapshot_evaluator)

def create_scheduler(
self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None
self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator
) -> Scheduler:
"""Creates the built-in scheduler.

Expand All @@ -554,7 +550,7 @@ def create_scheduler(
"""
return Scheduler(
snapshots,
self.snapshot_evaluator(correlation_id),
snapshot_evaluator,
self.state_sync,
default_catalog=self.default_catalog,
max_workers=self.concurrent_tasks,
Expand Down Expand Up @@ -719,7 +715,7 @@ def run(
NotificationEvent.RUN_START, environment=environment
)
analytics_run_id = analytics.collector.on_run_start(
engine_type=self.snapshot_evaluator().adapter.dialect,
engine_type=self.snapshot_evaluator.adapter.dialect,
state_sync_type=self.state_sync.state_type(),
)
self._load_materializations()
Expand Down Expand Up @@ -1081,7 +1077,7 @@ def evaluate(
and not parent_snapshot.categorized
]

df = self.snapshot_evaluator().evaluate_and_fetch(
df = self.snapshot_evaluator.evaluate_and_fetch(
snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -1593,12 +1589,7 @@ def apply(
default_catalog=self.default_catalog,
console=self.console,
)
explainer.evaluate(
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
)
explainer.evaluate(plan.to_evaluatable())
return

self.notification_target_manager.notify(
Expand Down Expand Up @@ -2121,7 +2112,7 @@ def audit(
errors = []
skipped_count = 0
for snapshot in snapshots:
for audit_result in self.snapshot_evaluator().audit(
for audit_result in self.snapshot_evaluator.audit(
snapshot=snapshot,
start=start,
end=end,
Expand Down Expand Up @@ -2153,7 +2144,7 @@ def audit(
self.console.log_status_update(f"Got {error.count} results, expected 0.")
if error.query:
self.console.show_sql(
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
)

self.console.log_status_update("Done.")
Expand Down Expand Up @@ -2345,14 +2336,12 @@ def print_environment_names(self) -> None:

def close(self) -> None:
"""Releases all resources allocated by this context."""
for evaluator in self._snapshot_evaluators.values():
evaluator.close()
if self._snapshot_evaluator:
self._snapshot_evaluator.close()

if self._state_sync:
self._state_sync.close()

self._snapshot_evaluators.clear()

def _run(
self,
environment: str,
Expand Down Expand Up @@ -2403,11 +2392,7 @@ def _run(

def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
self._scheduler.create_plan_evaluator(self).evaluate(
plan.to_evaluatable(),
snapshot_evaluator=self.snapshot_evaluator(
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
),
circuit_breaker=circuit_breaker,
plan.to_evaluatable(), circuit_breaker=circuit_breaker
)

@python_api_analytics
Expand Down Expand Up @@ -2700,7 +2685,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
)

# Remove the expired snapshots tables
self.snapshot_evaluator().cleanup(
self.snapshot_evaluator.cleanup(
target_snapshots=cleanup_targets,
on_complete=self.console.update_cleanup_progress,
)
Expand Down
12 changes: 5 additions & 7 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class PlanEvaluator(abc.ABC):
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
"""Evaluates a plan by pushing snapshots and backfilling data.
Expand All @@ -63,7 +62,6 @@ def evaluate(

Args:
plan: The plan to evaluate.
snapshot_evaluator: The snapshot evaluator to use.
circuit_breaker: The circuit breaker to use.
"""

Expand All @@ -72,11 +70,13 @@ class BuiltInPlanEvaluator(PlanEvaluator):
def __init__(
self,
state_sync: StateSync,
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
snapshot_evaluator: SnapshotEvaluator,
create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler],
default_catalog: t.Optional[str],
console: t.Optional[Console] = None,
):
self.state_sync = state_sync
self.snapshot_evaluator = snapshot_evaluator
self.create_scheduler = create_scheduler
self.default_catalog = default_catalog
self.console = console or get_console()
Expand All @@ -85,11 +85,9 @@ def __init__(
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
self._circuit_breaker = circuit_breaker
self.snapshot_evaluator = snapshot_evaluator

self.console.start_plan_evaluation(plan)
analytics.collector.on_plan_apply_start(
Expand Down Expand Up @@ -230,7 +228,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
self.console.log_success("SKIP: No model batches to execute")
return

scheduler = self.create_scheduler(stage.all_snapshots.values())
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
errors, _ = scheduler.run_merged_intervals(
merged_intervals=stage.snapshot_to_intervals,
deployability_index=stage.deployability_index,
Expand All @@ -251,7 +249,7 @@ def visit_audit_only_run_stage(
return

# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
scheduler = self.create_scheduler(audit_snapshots)
scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator)
completion_status = scheduler.audit(
plan.environment,
plan.start,
Expand Down
2 changes: 0 additions & 2 deletions sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
from sqlmesh.utils.date import to_ts
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator


logger = logging.getLogger(__name__)
Expand All @@ -40,7 +39,6 @@ def __init__(
def evaluate(
self,
plan: EvaluatablePlan,
snapshot_evaluator: SnapshotEvaluator,
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
) -> None:
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
Expand Down
6 changes: 2 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SnapshotDataVersion,
SnapshotFingerprint,
)
from sqlmesh.utils import random_id, CorrelationId
from sqlmesh.utils import random_id
from sqlmesh.utils.date import TimeLike, to_date
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
from sqlmesh.core.engine_adapter.shared import CatalogSupport
Expand Down Expand Up @@ -266,12 +266,10 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
def push_plan(context: Context, plan: Plan) -> None:
plan_evaluator = BuiltInPlanEvaluator(
context.state_sync,
context.snapshot_evaluator,
context.create_scheduler,
context.default_catalog,
)
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(
CorrelationId.from_plan_id(plan.plan_id)
)
deployability_index = DeployabilityIndex.create(context.snapshots.values())
evaluatable_plan = plan.to_evaluatable()
stages = plan_stages.build_plan_stages(
Expand Down
29 changes: 1 addition & 28 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
SnapshotInfoLike,
SnapshotTableInfo,
)
from sqlmesh.utils import CorrelationId
from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
from sqlmesh.utils.pydantic import validate_string
Expand Down Expand Up @@ -1138,7 +1137,7 @@ def test_non_breaking_change_after_forward_only_in_dev(
init_and_plan_context: t.Callable, has_view_binding: bool
):
context, plan = init_and_plan_context("examples/sushi")
context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding
context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding
context.apply(plan)

model = context.get_model("sushi.waiter_revenue_by_day")
Expand Down Expand Up @@ -6794,29 +6793,3 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
# valid_from should be the epoch, valid_to should be NaT
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
assert pd.isna(row["valid_to"])


def test_plan_evaluator_correlation_id(tmp_path: Path):
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
sqls = [call[0][0] for call in mock_logger.call_args_list]
return any(f"/* {correlation_id} */" in sql for sql in sqls)

create_temp_file(
tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
)

# Case 1: Ensure that the correlation id (plan_id) is included in the SQL
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
ctx = Context(paths=[tmp_path], config=Config())
plan = ctx.plan(auto_apply=True, no_prompts=True)

correlation_id = CorrelationId.from_plan_id(plan.plan_id)
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"

assert _correlation_id_in_sqls(correlation_id, mock_logger)

# Case 2: Ensure that the previous correlation id is not included in the SQL for other operations
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
ctx.snapshot_evaluator().adapter.execute("SELECT 1")

assert not _correlation_id_in_sqls(correlation_id, mock_logger)
5 changes: 1 addition & 4 deletions tests/core/test_plan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
stages as plan_stages,
)
from sqlmesh.core.snapshot import SnapshotChangeCategory
from sqlmesh.utils import CorrelationId


@pytest.fixture
Expand Down Expand Up @@ -60,13 +59,11 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):

evaluator = BuiltInPlanEvaluator(
sushi_context.state_sync,
sushi_context.snapshot_evaluator,
sushi_context.create_scheduler,
sushi_context.default_catalog,
console=sushi_context.console,
)
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(
CorrelationId.from_plan_id(plan.plan_id)
)

evaluatable_plan = plan.to_evaluatable()
stages = plan_stages.build_plan_stages(
Expand Down