Skip to content

Commit 6921468

Browse files
Refactor: Move engine adapter instantiation after loading on-demand (#4578)
1 parent dbff673 commit 6921468

9 files changed

Lines changed: 99 additions & 64 deletions

File tree

sqlmesh/core/config/scheduler.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def create_state_sync(self, context: GenericContext) -> StateSync:
4747
"""
4848

4949
@abc.abstractmethod
50-
def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
51-
"""Returns the default catalog for the Scheduler.
50+
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
51+
"""Returns the default catalog for each gateway.
5252
5353
Args:
5454
context: The SQLMesh Context.
@@ -66,7 +66,7 @@ def state_sync_fingerprint(self, context: GenericContext) -> str:
6666
class _EngineAdapterStateSyncSchedulerConfig(SchedulerConfig):
6767
def create_state_sync(self, context: GenericContext) -> StateSync:
6868
state_connection = (
69-
context.config.get_state_connection(context.gateway) or context._connection_config
69+
context.config.get_state_connection(context.gateway) or context.connection_config
7070
)
7171

7272
warehouse_connection = context.config.get_connection(context.gateway)
@@ -110,7 +110,7 @@ def create_state_sync(self, context: GenericContext) -> StateSync:
110110

111111
def state_sync_fingerprint(self, context: GenericContext) -> str:
112112
state_connection = (
113-
context.config.get_state_connection(context.gateway) or context._connection_config
113+
context.config.get_state_connection(context.gateway) or context.connection_config
114114
)
115115
return md5(
116116
[
@@ -132,12 +132,16 @@ def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
132132
state_sync=context.state_sync,
133133
snapshot_evaluator=context.snapshot_evaluator,
134134
create_scheduler=context.create_scheduler,
135-
default_catalog=self.get_default_catalog(context),
135+
default_catalog=context.default_catalog,
136136
console=context.console,
137137
)
138138

139-
def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
140-
return context.engine_adapter.default_catalog
139+
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
140+
default_catalogs_per_gateway: t.Dict[str, str] = {}
141+
for gateway, adapter in context.engine_adapters.items():
142+
if catalog := adapter.default_catalog:
143+
default_catalogs_per_gateway[gateway] = catalog
144+
return default_catalogs_per_gateway
141145

142146

143147
SCHEDULER_CONFIG_TO_TYPE = {

sqlmesh/core/context.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
Config,
6363
load_configs,
6464
)
65+
from sqlmesh.core.config.connection import ConnectionConfig
6566
from sqlmesh.core.config.loader import C
67+
from sqlmesh.core.config.root import RegexKeyDict
6668
from sqlmesh.core.console import get_console
6769
from sqlmesh.core.context_diff import ContextDiff
6870
from sqlmesh.core.dialect import (
@@ -91,6 +93,7 @@
9193
from sqlmesh.core.plan.definition import UserProvidedFlags
9294
from sqlmesh.core.reference import ReferenceGraph
9395
from sqlmesh.core.scheduler import Scheduler, CompletionStatus
96+
from sqlmesh.core.schema_diff import SchemaDiffer
9497
from sqlmesh.core.schema_loader import create_external_models_file
9598
from sqlmesh.core.selector import Selector
9699
from sqlmesh.core.snapshot import (
@@ -367,6 +370,9 @@ def __init__(
367370
self._excluded_requirements: t.Set[str] = set()
368371
self._default_catalog: t.Optional[str] = None
369372
self._default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None
373+
self._engine_adapter: t.Optional[EngineAdapter] = None
374+
self._connection_config: t.Optional[ConnectionConfig] = None
375+
self._test_connection_config: t.Optional[ConnectionConfig] = None
370376
self._linters: t.Dict[str, Linter] = {}
371377
self._loaded: bool = False
372378

@@ -407,24 +413,15 @@ def __init__(
407413
for path, config in self.configs.items()
408414
]
409415

410-
self._connection_config = self.config.get_connection(self.gateway)
416+
self._concurrent_tasks = concurrent_tasks
411417
self._state_connection_config = (
412-
self.config.get_state_connection(self.gateway) or self._connection_config
418+
self.config.get_state_connection(self.gateway) or self.connection_config
413419
)
414-
self.concurrent_tasks = concurrent_tasks or self._connection_config.concurrent_tasks
415-
416-
self._engine_adapters: t.Dict[str, EngineAdapter] = {
417-
self.selected_gateway: self._connection_config.create_engine_adapter()
418-
}
419420

420421
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
421422

422423
self.console = get_console()
423-
setattr(self.console, "dialect", self.engine_adapter.dialect)
424-
425-
self._test_connection_config = self.config.get_test_connection(
426-
self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT
427-
)
424+
setattr(self.console, "dialect", self.config.dialect)
428425

429426
self._provided_state_sync: t.Optional[StateSync] = state_sync
430427
self._state_sync: t.Optional[StateSync] = None
@@ -435,14 +432,6 @@ def __init__(
435432
self.users = list({user.username: user for user in self.users}.values())
436433
self._register_notification_targets()
437434

438-
if (
439-
self.config.environment_catalog_mapping
440-
and not self.engine_adapter.catalog_support.is_multi_catalog_supported
441-
):
442-
raise SQLMeshError(
443-
"Environment catalog mapping is only supported for engine adapters that support multiple catalogs"
444-
)
445-
446435
if load:
447436
self.load()
448437

@@ -453,7 +442,9 @@ def default_dialect(self) -> t.Optional[str]:
453442
@property
454443
def engine_adapter(self) -> EngineAdapter:
455444
"""Returns the default engine adapter."""
456-
return self._engine_adapters[self.selected_gateway]
445+
if self._engine_adapter is None:
446+
self._engine_adapter = self.connection_config.create_engine_adapter()
447+
return self._engine_adapter
457448

458449
@property
459450
def snapshot_evaluator(self) -> SnapshotEvaluator:
@@ -980,8 +971,8 @@ def requirements(self) -> t.Dict[str, str]:
980971

981972
@property
982973
def default_catalog(self) -> t.Optional[str]:
983-
if self._default_catalog is None:
984-
self._default_catalog = self._scheduler.get_default_catalog(self)
974+
if self._default_catalog is None and self.default_catalog_per_gateway:
975+
self._default_catalog = self.default_catalog_per_gateway[self.selected_gateway]
985976
return self._default_catalog
986977

987978
@python_api_analytics
@@ -1538,7 +1529,7 @@ def plan_builder(
15381529
allow_destructive_models=expanded_destructive_models,
15391530
environment_ttl=environment_ttl,
15401531
environment_suffix_target=self.config.environment_suffix_target,
1541-
environment_catalog_mapping=self.config.environment_catalog_mapping,
1532+
environment_catalog_mapping=self.environment_catalog_mapping,
15421533
categorizer_config=categorizer_config or self.auto_categorize_changes,
15431534
auto_categorization_enabled=not no_auto_categorization,
15441535
effective_from=effective_from,
@@ -1550,7 +1541,7 @@ def plan_builder(
15501541
),
15511542
end_bounded=not run,
15521543
ensure_finalized_snapshots=self.config.plan.use_finalized_state,
1553-
engine_schema_differ=self.engine_adapter.SCHEMA_DIFFER,
1544+
engine_schema_differ=SchemaDiffer(), # TODO: fix to properly handle it
15541545
interval_end_per_model=max_interval_end_per_model,
15551546
console=self.console,
15561547
user_provided_flags=user_provided_flags,
@@ -1639,7 +1630,7 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> b
16391630
self.console.show_model_difference_summary(
16401631
context_diff,
16411632
EnvironmentNamingInfo.from_environment_catalog_mapping(
1642-
self.config.environment_catalog_mapping,
1633+
self.environment_catalog_mapping,
16431634
name=environment,
16441635
suffix_target=self.config.environment_suffix_target,
16451636
normalize_name=context_diff.normalize_environment_name,
@@ -1993,7 +1984,7 @@ def create_test(
19931984

19941985
try:
19951986
model_to_test = self.get_model(model, raise_if_missing=True)
1996-
test_adapter = self._test_connection_config.create_engine_adapter(
1987+
test_adapter = self.test_connection_config.create_engine_adapter(
19971988
register_comments_override=False
19981989
)
19991990

@@ -2039,7 +2030,7 @@ def test(
20392030
preserve_fixtures=preserve_fixtures,
20402031
stream=stream,
20412032
default_catalog=self.default_catalog,
2042-
default_catalog_dialect=self.engine_adapter.DIALECT,
2033+
default_catalog_dialect=self.config.dialect or "",
20432034
)
20442035

20452036
@python_api_analytics
@@ -2478,7 +2469,7 @@ def _run_plan_tests(
24782469
self.console.log_test_results(
24792470
result,
24802471
test_output,
2481-
self._test_connection_config._engine_adapter.DIALECT,
2472+
self.test_connection_config._engine_adapter.DIALECT,
24822473
)
24832474
if not result.wasSuccessful():
24842475
raise PlanError(
@@ -2499,7 +2490,7 @@ def _model_tables(self) -> t.Dict[str, str]:
24992490
if snapshot.version
25002491
else snapshot.qualified_view_name.for_environment(
25012492
EnvironmentNamingInfo.from_environment_catalog_mapping(
2502-
self.config.environment_catalog_mapping,
2493+
self.environment_catalog_mapping,
25032494
name=c.PROD,
25042495
suffix_target=self.config.environment_suffix_target,
25052496
)
@@ -2511,24 +2502,63 @@ def _model_tables(self) -> t.Dict[str, str]:
25112502
@cached_property
25122503
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
25132504
"""Returns all the engine adapters for the gateways defined in the configuration."""
2505+
adapters: t.Dict[str, EngineAdapter] = {self.selected_gateway: self.engine_adapter}
25142506
for gateway_name in self.config.gateways:
25152507
if gateway_name != self.selected_gateway:
25162508
connection = self.config.get_connection(gateway_name)
25172509
adapter = connection.create_engine_adapter(concurrent_tasks=self.concurrent_tasks)
2518-
self._engine_adapters[gateway_name] = adapter
2519-
return self._engine_adapters
2510+
adapters[gateway_name] = adapter
2511+
return adapters
25202512

25212513
@cached_property
25222514
def default_catalog_per_gateway(self) -> t.Dict[str, str]:
25232515
"""Returns the default catalogs for each engine adapter."""
25242516
if self._default_catalog_per_gateway is None:
2525-
self._default_catalog_per_gateway = {
2526-
name: adapter.default_catalog
2527-
for name, adapter in self.engine_adapters.items()
2528-
if adapter.default_catalog
2529-
}
2517+
self._default_catalog_per_gateway = self._scheduler.get_default_catalog_per_gateway(
2518+
self
2519+
)
25302520
return self._default_catalog_per_gateway
25312521

2522+
@cached_property
2523+
def concurrent_tasks(self) -> int:
2524+
if self._concurrent_tasks is None:
2525+
self._concurrent_tasks = self.connection_config.concurrent_tasks
2526+
return self._concurrent_tasks
2527+
2528+
@cached_property
2529+
def connection_config(self) -> ConnectionConfig:
2530+
if self._connection_config is None:
2531+
self._connection_config = self.config.get_connection(self.selected_gateway)
2532+
return self._connection_config
2533+
2534+
@cached_property
2535+
def test_connection_config(self) -> ConnectionConfig:
2536+
if self._test_connection_config is None:
2537+
self._test_connection_config = self.config.get_test_connection(
2538+
self.gateway,
2539+
self.default_catalog,
2540+
default_catalog_dialect=self.engine_adapter.DIALECT,
2541+
)
2542+
return self._test_connection_config
2543+
2544+
@cached_property
2545+
def environment_catalog_mapping(self) -> RegexKeyDict:
2546+
engine_adapter = None
2547+
try:
2548+
engine_adapter = self.engine_adapter
2549+
except Exception:
2550+
pass
2551+
2552+
if (
2553+
self.config.environment_catalog_mapping
2554+
and engine_adapter
2555+
and not self.engine_adapter.catalog_support.is_multi_catalog_supported
2556+
):
2557+
raise SQLMeshError(
2558+
"Environment catalog mapping is only supported for engine adapters that support multiple catalogs"
2559+
)
2560+
return self.config.environment_catalog_mapping
2561+
25322562
def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
25332563
if gateway:
25342564
if adapter := self.engine_adapters.get(gateway):

sqlmesh/integrations/github/cicd/controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def conclusion_handler(
726726
self._console.log_test_results(
727727
result,
728728
output,
729-
self._context._test_connection_config._engine_adapter.DIALECT,
729+
self._context.test_connection_config._engine_adapter.DIALECT,
730730
)
731731
test_summary = self._console.consume_captured_output()
732732
test_title = "Tests Passed" if result.wasSuccessful() else "Tests Failed"

tests/core/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture):
742742

743743
ctx = Context(paths=tmp_path, config=config)
744744

745-
assert isinstance(ctx._connection_config, RedshiftConnectionConfig)
745+
assert isinstance(ctx.connection_config, RedshiftConnectionConfig)
746746
assert len(ctx.engine_adapters) == 3
747747
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)
748748
assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter)
@@ -782,7 +782,7 @@ def test_multi_gateway_single_threaded_config(tmp_path):
782782
)
783783

784784
ctx = Context(paths=tmp_path, config=config)
785-
assert isinstance(ctx._connection_config, DuckDBConnectionConfig)
785+
assert isinstance(ctx.connection_config, DuckDBConnectionConfig)
786786
assert len(ctx.engine_adapters) == 2
787787
assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb")
788788
assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter)

tests/core/test_context.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,18 @@ def test_evaluate_limit():
332332
def test_gateway_specific_adapters(copy_to_temp_path, mocker):
333333
path = copy_to_temp_path("examples/sushi")
334334
ctx = Context(paths=path, config="isolated_systems_config", gateway="prod")
335-
assert len(ctx._engine_adapters) == 3
336-
assert ctx.engine_adapter == ctx._engine_adapters["prod"]
337-
assert ctx._get_engine_adapter("dev") == ctx._engine_adapters["dev"]
335+
assert len(ctx.engine_adapters) == 3
336+
assert ctx.engine_adapter == ctx.engine_adapters["prod"]
337+
assert ctx._get_engine_adapter("dev") == ctx.engine_adapters["dev"]
338338

339339
ctx = Context(paths=path, config="isolated_systems_config")
340-
assert len(ctx._engine_adapters) == 3
341-
assert ctx.engine_adapter == ctx._engine_adapters["dev"]
340+
assert len(ctx.engine_adapters) == 3
341+
assert ctx.engine_adapter == ctx.engine_adapters["dev"]
342342

343343
ctx = Context(paths=path, config="isolated_systems_config")
344344
assert len(ctx.engine_adapters) == 3
345345
assert ctx.engine_adapter == ctx._get_engine_adapter()
346-
assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"]
346+
assert ctx._get_engine_adapter("test") == ctx.engine_adapters["test"]
347347

348348

349349
def test_multiple_gateways(tmp_path: Path):
@@ -800,7 +800,8 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None:
800800
),
801801
]
802802

803-
sushi_context._engine_adapters = {sushi_context.config.default_gateway: adapter_mock}
803+
sushi_context._engine_adapter = adapter_mock
804+
sushi_context.engine_adapters = {sushi_context.config.default_gateway: adapter_mock}
804805
sushi_context._state_sync = state_sync_mock
805806
state_sync_mock.get_expired_snapshots.return_value = []
806807

tests/core/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4872,13 +4872,11 @@ def test_multi(mocker):
48724872
assert context.fetchdf("select * from after_1").to_dict()["repo_1"][0] == "repo_1"
48734873
assert context.fetchdf("select * from after_2").to_dict()["repo_2"][0] == "repo_2"
48744874

4875-
adapter = context.engine_adapter
48764875
context = Context(
48774876
paths=["examples/multi/repo_1"],
48784877
state_sync=context.state_sync,
48794878
gateway="memory",
48804879
)
4881-
context._engine_adapters["memory"] = adapter
48824880

48834881
model = context.get_model("bronze.a")
48844882
assert model.project == "repo_1"
@@ -4935,6 +4933,8 @@ def test_multi_virtual_layer(copy_to_temp_path):
49354933
)
49364934

49374935
context = Context(paths=paths, config=config)
4936+
assert context.default_catalog_per_gateway == {"first": "db_1", "second": "db_2"}
4937+
assert len(context.engine_adapters) == 2
49384938

49394939
# For the model without gateway the default should be used and the gateway variable should overide the global
49404940
assert (

tests/core/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4518,7 +4518,7 @@ def test_model_session_properties(sushi_context):
45184518
name test_schema.test_model,
45194519
session_properties (
45204520
'query_label' = (
4521-
'some value',
4521+
'some value',
45224522
'another value',
45234523
'yet another value',
45244524
)
@@ -8350,7 +8350,7 @@ def test_gateway_specific_render(assert_exp_eq) -> None:
83508350
default_gateway="main",
83518351
)
83528352
context = Context(config=config)
8353-
assert context.engine_adapter == context._engine_adapters["main"]
8353+
assert context.engine_adapter == context.engine_adapters["main"]
83548354

83558355
@model(
83568356
name="dummy_model",
@@ -8376,7 +8376,7 @@ def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select:
83768376
""",
83778377
)
83788378
assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter)
8379-
assert len(context._engine_adapters) == 2
8379+
assert len(context.engine_adapters) == 2
83808380

83818381

83828382
def test_model_on_virtual_update(make_snapshot: t.Callable):

0 commit comments

Comments
 (0)