Skip to content

Commit 11d89ff

Browse files
committed
feat: add execution context to signals
1 parent 561e4fd commit 11d89ff

9 files changed

Lines changed: 81 additions & 36 deletions

File tree

docs/guides/signals.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,18 @@ MODEL (
116116

117117
SELECT @start_ds AS ds
118118
```
119+
120+
### Accessing execution context / engine adapter
121+
It is possible to access the execution context in a signal and access the engine adapter (warehouse connection).
122+
123+
```python
124+
import typing as t
125+
126+
from sqlmesh import signal, DatetimeRanges, ExecutionContext
127+
128+
129+
# add the context argument to your function
130+
@signal()
131+
def one_week_ago(batch: DatetimeRanges, context: ExecutionContext) -> t.Union[bool, DatetimeRanges]:
132+
return len(context.engine_adapter.fetchdf("SELECT 1")) > 1
133+
```

sqlmesh/core/macros.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def send(
214214
raise SQLMeshError(f"Macro '{name}' does not exist.")
215215

216216
try:
217-
return call_macro(func, self.dialect, self._path, self, *args, **kwargs) # type: ignore
217+
return call_macro(func, self.dialect, self._path, args=(self, *args), kwargs=kwargs) # type: ignore
218218
except Exception as e:
219219
print_exception(e, self.python_env)
220220
raise MacroEvalError("Error trying to eval macro.") from e
@@ -1286,11 +1286,17 @@ def call_macro(
12861286
func: t.Callable,
12871287
dialect: DialectType,
12881288
path: Path,
1289-
*args: t.Any,
1290-
**kwargs: t.Any,
1289+
args: t.Tuple[t.Any, ...],
1290+
kwargs: t.Dict[str, t.Any],
1291+
**optional_kwargs: t.Any,
12911292
) -> t.Any:
12921293
# Bind the macro's actual parameters to its formal parameters
12931294
sig = inspect.signature(func)
1295+
1296+
for k, v in optional_kwargs.items():
1297+
if k in sig.parameters:
1298+
kwargs[k] = v
1299+
12941300
bound = sig.bind(*args, **kwargs)
12951301
bound.apply_defaults()
12961302

sqlmesh/core/scheduler.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
):
9696
self.state_sync = state_sync
9797
self.snapshots = {s.snapshot_id: s for s in snapshots}
98+
self.snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
9899
self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values())
99100
self.default_catalog = default_catalog
100101
self.snapshot_evaluator = snapshot_evaluator
@@ -348,7 +349,11 @@ def run(
348349

349350
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
350351

351-
def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snapshot, Intervals]:
352+
def batch_intervals(
353+
self,
354+
merged_intervals: SnapshotToIntervals,
355+
deployability_index: t.Optional[DeployabilityIndex],
356+
) -> t.Dict[Snapshot, Intervals]:
352357
dag = snapshots_to_dag(merged_intervals)
353358

354359
snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = {
@@ -369,7 +374,20 @@ def batch_intervals(self, merged_intervals: SnapshotToIntervals) -> t.Dict[Snaps
369374
continue
370375
snapshot, intervals = snapshot_intervals[snapshot_id]
371376
unready = set(intervals)
372-
intervals = snapshot.check_ready_intervals(intervals)
377+
378+
from sqlmesh.core.context import ExecutionContext
379+
380+
adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway)
381+
382+
context = ExecutionContext(
383+
adapter,
384+
self.snapshots_by_name,
385+
deployability_index,
386+
default_dialect=adapter.dialect,
387+
default_catalog=self.default_catalog,
388+
)
389+
390+
intervals = snapshot.check_ready_intervals(intervals, context)
373391
unready -= set(intervals)
374392

375393
for parent in snapshot.parents:
@@ -424,7 +442,7 @@ def run_merged_intervals(
424442
"""
425443
execution_time = execution_time or now_timestamp()
426444

427-
batched_intervals = self.batch_intervals(merged_intervals)
445+
batched_intervals = self.batch_intervals(merged_intervals, deployability_index)
428446

429447
self.console.start_evaluation_progress(
430448
{snapshot: len(intervals) for snapshot, intervals in batched_intervals.items()},
@@ -434,8 +452,6 @@ def run_merged_intervals(
434452

435453
dag = self._dag(batched_intervals)
436454

437-
snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
438-
439455
if run_environment_statements:
440456
environment_statements = self.state_sync.get_environment_statements(
441457
environment_naming_info.name
@@ -446,7 +462,7 @@ def run_merged_intervals(
446462
runtime_stage=RuntimeStage.BEFORE_ALL,
447463
environment_naming_info=environment_naming_info,
448464
default_catalog=self.default_catalog,
449-
snapshots=snapshots_by_name,
465+
snapshots=self.snapshots_by_name,
450466
start=start,
451467
end=end,
452468
execution_time=execution_time,
@@ -459,7 +475,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
459475
snapshot_name, ((start, end), batch_idx) = node
460476
if batch_idx == -1:
461477
return
462-
snapshot = snapshots_by_name[snapshot_name]
478+
snapshot = self.snapshots_by_name[snapshot_name]
463479

464480
self.console.start_snapshot_evaluation_progress(snapshot)
465481

@@ -520,7 +536,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
520536
runtime_stage=RuntimeStage.AFTER_ALL,
521537
environment_naming_info=environment_naming_info,
522538
default_catalog=self.default_catalog,
523-
snapshots=snapshots_by_name,
539+
snapshots=self.snapshots_by_name,
524540
start=start,
525541
end=end,
526542
execution_time=execution_time,

sqlmesh/core/snapshot/definition.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import inspect
34
import sys
45
import typing as t
56
from collections import defaultdict
@@ -46,6 +47,7 @@
4647
if t.TYPE_CHECKING:
4748
from sqlglot.dialects.dialect import DialectType
4849
from sqlmesh.core.environment import EnvironmentNamingInfo
50+
from sqlmesh.core.context import ExecutionContext
4951

5052
Interval = t.Tuple[int, int]
5153
Intervals = t.List[Interval]
@@ -940,7 +942,7 @@ def missing_intervals(
940942
model_end_ts,
941943
)
942944

943-
def check_ready_intervals(self, intervals: Intervals) -> Intervals:
945+
def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext) -> Intervals:
944946
"""Returns a list of intervals that are considered ready by the provided signal.
945947
946948
Note that this will handle gaps in the provided intervals. The returned intervals
@@ -959,6 +961,7 @@ def check_ready_intervals(self, intervals: Intervals) -> Intervals:
959961
intervals = _check_ready_intervals(
960962
env[signal_name],
961963
intervals,
964+
context,
962965
dialect=self.model.dialect,
963966
path=self.model._path,
964967
kwargs=kwargs,
@@ -2148,17 +2151,22 @@ def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]:
21482151
def _check_ready_intervals(
21492152
check: t.Callable,
21502153
intervals: Intervals,
2154+
context: ExecutionContext,
21512155
dialect: DialectType = None,
21522156
path: Path = Path(),
21532157
kwargs: t.Optional[t.Dict] = None,
21542158
) -> Intervals:
21552159
checked_intervals: Intervals = []
21562160

2161+
inspect.signature(check)
2162+
21572163
for interval_batch in _contiguous_intervals(intervals):
21582164
batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch]
21592165

21602166
try:
2161-
ready_intervals = call_macro(check, dialect, path, batch, **(kwargs or {}))
2167+
ready_intervals = call_macro(
2168+
check, dialect, path, args=(batch,), kwargs=(kwargs or {}), context=context
2169+
)
21622170
except Exception:
21632171
raise SQLMeshError("Error evaluating signal")
21642172

sqlmesh/core/snapshot/evaluator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def create(
321321

322322
def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.Set[str]:
323323
logger.info("Listing data objects in schema %s", schema.sql())
324-
objs = self._get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
324+
objs = self.get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
325325
return {obj.name for obj in objs}
326326

327327
with self.concurrent_context():
@@ -409,7 +409,7 @@ def migrate(
409409
s,
410410
snapshots,
411411
allow_destructive_snapshots,
412-
self._get_adapter(s.model_gateway),
412+
self.get_adapter(s.model_gateway),
413413
deployability_index,
414414
),
415415
self.ddl_concurrent_tasks,
@@ -437,7 +437,7 @@ def cleanup(
437437
lambda s: self._cleanup_snapshot(
438438
s,
439439
snapshots_to_dev_table_only[s.snapshot_id],
440-
self._get_adapter(
440+
self.get_adapter(
441441
snapshot_gateways.get(s.snapshot_id.name) if snapshot_gateways else None
442442
),
443443
on_complete,
@@ -471,7 +471,7 @@ def audit(
471471
kwargs: Additional kwargs to pass to the renderer.
472472
"""
473473
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
474-
adapter = self._get_adapter(snapshot.model_gateway)
474+
adapter = self.get_adapter(snapshot.model_gateway)
475475

476476
if not snapshot.version:
477477
raise ConfigError(
@@ -605,7 +605,7 @@ def _evaluate_snapshot(
605605
else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
606606
)
607607

608-
adapter = self._get_adapter(model.gateway)
608+
adapter = self.get_adapter(model.gateway)
609609
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
610610

611611
# https://github.com/TobikoData/sqlmesh/issues/2609
@@ -764,7 +764,7 @@ def _create_snapshot(
764764

765765
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
766766

767-
adapter = self._get_adapter(snapshot.model.gateway)
767+
adapter = self.get_adapter(snapshot.model.gateway)
768768
create_render_kwargs: t.Dict[str, t.Any] = dict(
769769
engine_adapter=adapter,
770770
snapshots=parent_snapshots_by_name(snapshot, snapshots),
@@ -994,7 +994,7 @@ def _wap_publish_snapshot(
994994
) -> None:
995995
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
996996
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
997-
adapter = self._get_adapter(snapshot.model_gateway)
997+
adapter = self.get_adapter(snapshot.model_gateway)
998998
adapter.wap_publish(table_name, wap_id)
999999

10001000
def _audit(
@@ -1021,7 +1021,7 @@ def _audit(
10211021
blocking = audit_args.pop("blocking", None)
10221022
blocking = blocking == exp.true() if blocking else audit.blocking
10231023

1024-
adapter = self._get_adapter(snapshot.model_gateway)
1024+
adapter = self.get_adapter(snapshot.model_gateway)
10251025

10261026
kwargs = {
10271027
"start": start,
@@ -1068,10 +1068,10 @@ def _create_schemas(
10681068
for schema_name, catalog in unique_schemas:
10691069
schema = schema_(schema_name, catalog)
10701070
logger.info("Creating schema '%s'", schema)
1071-
adapter = self._get_adapter(gateways.get(schema)) if gateways else self.adapter
1071+
adapter = self.get_adapter(gateways.get(schema)) if gateways else self.adapter
10721072
adapter.create_schema(schema)
10731073

1074-
def _get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
1074+
def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
10751075
"""Returns the adapter for the specified gateway or the default adapter if none is provided."""
10761076
if gateway:
10771077
if adapter := self.adapters.get(gateway):
@@ -1089,7 +1089,7 @@ def _execute_create(
10891089
rendered_physical_properties: t.Dict[str, exp.Expression],
10901090
dry_run: bool,
10911091
) -> None:
1092-
adapter = self._get_adapter(snapshot.model.gateway)
1092+
adapter = self.get_adapter(snapshot.model.gateway)
10931093
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
10941094

10951095
# It can still be useful for some strategies to know if the snapshot was actually deployable

tests/core/test_scheduler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context
6666

6767

6868
@pytest.fixture
69-
def get_batched_missing_intervals() -> (
70-
t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]
71-
):
69+
def get_batched_missing_intervals(
70+
mocker: MockerFixture,
71+
) -> t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]:
7272
def _get_batched_missing_intervals(
7373
scheduler: Scheduler,
7474
start: TimeLike,
7575
end: TimeLike,
7676
execution_time: t.Optional[TimeLike] = None,
7777
) -> SnapshotToIntervals:
7878
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
79-
return scheduler.batch_intervals(merged_intervals)
79+
return scheduler.batch_intervals(merged_intervals, mocker.Mock())
8080

8181
return _get_batched_missing_intervals
8282

@@ -622,7 +622,7 @@ def test_interval_diff():
622622

623623
def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals):
624624
@signal()
625-
def signal_a(batch: DatetimeRanges):
625+
def signal_a(batch: DatetimeRanges, context):
626626
return [batch[0], batch[1]]
627627

628628
@signal()

tests/core/test_snapshot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,7 +2532,7 @@ def assert_check_intervals(
25322532
):
25332533
mock = mocker.Mock()
25342534
mock.side_effect = [to_intervals(r) for r in ready]
2535-
_check_ready_intervals(mock, intervals) == expected
2535+
_check_ready_intervals(mock, intervals, mocker.Mock()) == expected
25362536

25372537
assert_check_intervals([], [], [])
25382538
assert_check_intervals([(0, 1)], [[]], [])
@@ -2894,7 +2894,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
28942894
]
28952895

28962896

2897-
def test_render_signal(make_snapshot):
2897+
def test_render_signal(make_snapshot, mocker):
28982898
@signal()
28992899
def check_types(batch, env: str, default: int = 0):
29002900
if env != "in_memory" or not default == 0:
@@ -2917,4 +2917,4 @@ def check_types(batch, env: str, default: int = 0):
29172917
signal_definitions=signal.get_registry(),
29182918
)
29192919
snapshot_a = make_snapshot(sql_model)
2920-
assert snapshot_a.check_ready_intervals([(0, 1)]) == [(0, 1)]
2920+
assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)]

tests/core/test_snapshot_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3751,9 +3751,9 @@ def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot):
37513751

37523752
assert len(evaluator.adapters) == 3
37533753
assert evaluator.adapter == engine_adapters["default"]
3754-
assert evaluator._get_adapter() == engine_adapters["default"]
3755-
assert evaluator._get_adapter("third") == engine_adapters["third"]
3756-
assert evaluator._get_adapter("secondary") == engine_adapters["secondary"]
3754+
assert evaluator.get_adapter() == engine_adapters["default"]
3755+
assert evaluator.get_adapter("third") == engine_adapters["third"]
3756+
assert evaluator.get_adapter("secondary") == engine_adapters["secondary"]
37573757

37583758
model = load_sql_based_model(
37593759
parse( # type: ignore

web/server/api/endpoints/plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges:
132132
def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]:
133133
"""Get plan backfills"""
134134
merged_intervals = context.scheduler().merged_missing_intervals()
135-
batches = context.scheduler().batch_intervals(merged_intervals)
135+
batches = context.scheduler().batch_intervals(merged_intervals, None)
136136
tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()}
137137
snapshots = plan.context_diff.snapshots
138138
default_catalog = context.default_catalog

0 commit comments

Comments
 (0)