Skip to content

Commit cf67d85

Browse files
Refactor: Move engine adapter instantiation after loading on-demand
1 parent fa43912 commit cf67d85

6 files changed

Lines changed: 110 additions & 49 deletions

File tree

sqlmesh/core/config/scheduler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
5454
context: The SQLMesh Context.
5555
"""
5656

57+
@abc.abstractmethod
58+
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
59+
"""Returns the default catalog for each gateway.
60+
61+
Args:
62+
context: The SQLMesh Context.
63+
"""
64+
5765
@abc.abstractmethod
5866
def state_sync_fingerprint(self, context: GenericContext) -> str:
5967
"""Returns the fingerprint of the State Sync configuration.
@@ -139,6 +147,13 @@ def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
139147
def get_default_catalog(self, context: GenericContext) -> t.Optional[str]:
140148
return context.engine_adapter.default_catalog
141149

150+
def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]:
151+
return {
152+
name: adapter.default_catalog
153+
for name, adapter in context.engine_adapters.items()
154+
if adapter.default_catalog
155+
}
156+
142157

143158
SCHEDULER_CONFIG_TO_TYPE = {
144159
tpe.all_field_infos()["type_"].default: tpe

sqlmesh/core/context.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
Config,
6363
load_configs,
6464
)
65+
from sqlmesh.core.config.connection import ConnectionConfig
6566
from sqlmesh.core.config.loader import C
67+
from sqlmesh.core.config.root import RegexKeyDict
6668
from sqlmesh.core.console import get_console
6769
from sqlmesh.core.context_diff import ContextDiff
6870
from sqlmesh.core.dialect import (
@@ -91,6 +93,7 @@
9193
from sqlmesh.core.plan.definition import UserProvidedFlags
9294
from sqlmesh.core.reference import ReferenceGraph
9395
from sqlmesh.core.scheduler import Scheduler, CompletionStatus
96+
from sqlmesh.core.schema_diff import SchemaDiffer
9497
from sqlmesh.core.schema_loader import create_external_models_file
9598
from sqlmesh.core.selector import Selector
9699
from sqlmesh.core.snapshot import (
@@ -366,7 +369,7 @@ def __init__(
366369
self._environment_statements: t.List[EnvironmentStatements] = []
367370
self._excluded_requirements: t.Set[str] = set()
368371
self._default_catalog: t.Optional[str] = None
369-
self._default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None
372+
self._engine_adapter: t.Optional[EngineAdapter] = None
370373
self._linters: t.Dict[str, Linter] = {}
371374
self._loaded: bool = False
372375

@@ -407,24 +410,15 @@ def __init__(
407410
for path, config in self.configs.items()
408411
]
409412

410-
self._connection_config = self.config.get_connection(self.gateway)
413+
self._concurrent_tasks = concurrent_tasks
411414
self._state_connection_config = (
412415
self.config.get_state_connection(self.gateway) or self._connection_config
413416
)
414-
self.concurrent_tasks = concurrent_tasks or self._connection_config.concurrent_tasks
415-
416-
self._engine_adapters: t.Dict[str, EngineAdapter] = {
417-
self.selected_gateway: self._connection_config.create_engine_adapter()
418-
}
419417

420418
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
421419

422420
self.console = get_console()
423-
setattr(self.console, "dialect", self.engine_adapter.dialect)
424-
425-
self._test_connection_config = self.config.get_test_connection(
426-
self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT
427-
)
421+
setattr(self.console, "dialect", self.config.dialect)
428422

429423
self._provided_state_sync: t.Optional[StateSync] = state_sync
430424
self._state_sync: t.Optional[StateSync] = None
@@ -435,14 +429,6 @@ def __init__(
435429
self.users = list({user.username: user for user in self.users}.values())
436430
self._register_notification_targets()
437431

438-
if (
439-
self.config.environment_catalog_mapping
440-
and not self.engine_adapter.catalog_support.is_multi_catalog_supported
441-
):
442-
raise SQLMeshError(
443-
"Environment catalog mapping is only supported for engine adapters that support multiple catalogs"
444-
)
445-
446432
if load:
447433
self.load()
448434

@@ -453,7 +439,9 @@ def default_dialect(self) -> t.Optional[str]:
453439
@property
454440
def engine_adapter(self) -> EngineAdapter:
455441
"""Returns the default engine adapter."""
456-
return self._engine_adapters[self.selected_gateway]
442+
if self._engine_adapter is None:
443+
self._engine_adapter = self._connection_config.create_engine_adapter()
444+
return self._engine_adapter
457445

458446
@property
459447
def snapshot_evaluator(self) -> SnapshotEvaluator:
@@ -1535,7 +1523,7 @@ def plan_builder(
15351523
allow_destructive_models=expanded_destructive_models,
15361524
environment_ttl=environment_ttl,
15371525
environment_suffix_target=self.config.environment_suffix_target,
1538-
environment_catalog_mapping=self.config.environment_catalog_mapping,
1526+
environment_catalog_mapping=self.environment_catalog_mapping,
15391527
categorizer_config=categorizer_config or self.auto_categorize_changes,
15401528
auto_categorization_enabled=not no_auto_categorization,
15411529
effective_from=effective_from,
@@ -1547,7 +1535,7 @@ def plan_builder(
15471535
),
15481536
end_bounded=not run,
15491537
ensure_finalized_snapshots=self.config.plan.use_finalized_state,
1550-
engine_schema_differ=self.engine_adapter.SCHEMA_DIFFER,
1538+
engine_schema_differ=SchemaDiffer(), # TODO: fix to properly handle it
15511539
interval_end_per_model=max_interval_end_per_model,
15521540
console=self.console,
15531541
user_provided_flags=user_provided_flags,
@@ -1636,7 +1624,7 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> b
16361624
self.console.show_model_difference_summary(
16371625
context_diff,
16381626
EnvironmentNamingInfo.from_environment_catalog_mapping(
1639-
self.config.environment_catalog_mapping,
1627+
self.environment_catalog_mapping,
16401628
name=environment,
16411629
suffix_target=self.config.environment_suffix_target,
16421630
normalize_name=context_diff.normalize_environment_name,
@@ -2036,7 +2024,7 @@ def test(
20362024
preserve_fixtures=preserve_fixtures,
20372025
stream=stream,
20382026
default_catalog=self.default_catalog,
2039-
default_catalog_dialect=self.engine_adapter.DIALECT,
2027+
default_catalog_dialect=self.config.dialect or "",
20402028
)
20412029

20422030
@python_api_analytics
@@ -2496,7 +2484,7 @@ def _model_tables(self) -> t.Dict[str, str]:
24962484
if snapshot.version
24972485
else snapshot.qualified_view_name.for_environment(
24982486
EnvironmentNamingInfo.from_environment_catalog_mapping(
2499-
self.config.environment_catalog_mapping,
2487+
self.environment_catalog_mapping,
25002488
name=c.PROD,
25012489
suffix_target=self.config.environment_suffix_target,
25022490
)
@@ -2508,23 +2496,47 @@ def _model_tables(self) -> t.Dict[str, str]:
25082496
@cached_property
25092497
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
25102498
"""Returns all the engine adapters for the gateways defined in the configuration."""
2499+
adapters: t.Dict[str, EngineAdapter] = {self.selected_gateway: self.engine_adapter}
25112500
for gateway_name in self.config.gateways:
25122501
if gateway_name != self.selected_gateway:
25132502
connection = self.config.get_connection(gateway_name)
25142503
adapter = connection.create_engine_adapter(concurrent_tasks=self.concurrent_tasks)
2515-
self._engine_adapters[gateway_name] = adapter
2516-
return self._engine_adapters
2504+
adapters[gateway_name] = adapter
2505+
return adapters
25172506

25182507
@cached_property
25192508
def default_catalog_per_gateway(self) -> t.Dict[str, str]:
25202509
"""Returns the default catalogs for each engine adapter."""
2521-
if self._default_catalog_per_gateway is None:
2522-
self._default_catalog_per_gateway = {
2523-
name: adapter.default_catalog
2524-
for name, adapter in self.engine_adapters.items()
2525-
if adapter.default_catalog
2526-
}
2527-
return self._default_catalog_per_gateway
2510+
if self.gateway_managed_virtual_layer:
2511+
return self._scheduler.get_default_catalog_per_gateway(self)
2512+
return {}
2513+
2514+
@cached_property
2515+
def concurrent_tasks(self) -> int:
2516+
if self._concurrent_tasks is None:
2517+
self._concurrent_tasks = self._connection_config.concurrent_tasks
2518+
return self._concurrent_tasks
2519+
2520+
@cached_property
2521+
def _connection_config(self) -> ConnectionConfig:
2522+
return self.config.get_connection(self.gateway)
2523+
2524+
@cached_property
2525+
def _test_connection_config(self) -> ConnectionConfig:
2526+
return self.config.get_test_connection(
2527+
self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT
2528+
)
2529+
2530+
@cached_property
2531+
def environment_catalog_mapping(self) -> RegexKeyDict:
2532+
if (
2533+
self.config.environment_catalog_mapping
2534+
and not self.engine_adapter.catalog_support.is_multi_catalog_supported
2535+
):
2536+
raise SQLMeshError(
2537+
"Environment catalog mapping is only supported for engine adapters that support multiple catalogs"
2538+
)
2539+
return self.config.environment_catalog_mapping
25282540

25292541
def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
25302542
if gateway:

tests/core/test_context.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,18 @@ def test_evaluate_limit():
332332
def test_gateway_specific_adapters(copy_to_temp_path, mocker):
333333
path = copy_to_temp_path("examples/sushi")
334334
ctx = Context(paths=path, config="isolated_systems_config", gateway="prod")
335-
assert len(ctx._engine_adapters) == 3
336-
assert ctx.engine_adapter == ctx._engine_adapters["prod"]
337-
assert ctx._get_engine_adapter("dev") == ctx._engine_adapters["dev"]
335+
assert len(ctx.engine_adapters) == 3
336+
assert ctx.engine_adapter == ctx.engine_adapters["prod"]
337+
assert ctx._get_engine_adapter("dev") == ctx.engine_adapters["dev"]
338338

339339
ctx = Context(paths=path, config="isolated_systems_config")
340-
assert len(ctx._engine_adapters) == 3
341-
assert ctx.engine_adapter == ctx._engine_adapters["dev"]
340+
assert len(ctx.engine_adapters) == 3
341+
assert ctx.engine_adapter == ctx.engine_adapters["dev"]
342342

343343
ctx = Context(paths=path, config="isolated_systems_config")
344344
assert len(ctx.engine_adapters) == 3
345345
assert ctx.engine_adapter == ctx._get_engine_adapter()
346-
assert ctx._get_engine_adapter("test") == ctx._engine_adapters["test"]
346+
assert ctx._get_engine_adapter("test") == ctx.engine_adapters["test"]
347347

348348

349349
def test_multiple_gateways(tmp_path: Path):
@@ -800,7 +800,8 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None:
800800
),
801801
]
802802

803-
sushi_context._engine_adapters = {sushi_context.config.default_gateway: adapter_mock}
803+
sushi_context._engine_adapter = adapter_mock
804+
sushi_context.engine_adapters = {sushi_context.config.default_gateway: adapter_mock}
804805
sushi_context._state_sync = state_sync_mock
805806
state_sync_mock.get_expired_snapshots.return_value = []
806807

tests/core/test_integration.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4878,7 +4878,7 @@ def test_multi(mocker):
48784878
state_sync=context.state_sync,
48794879
gateway="memory",
48804880
)
4881-
context._engine_adapters["memory"] = adapter
4881+
context._engine_adapter = adapter
48824882

48834883
model = context.get_model("bronze.a")
48844884
assert model.project == "repo_1"
@@ -5064,6 +5064,39 @@ def test_multi_virtual_layer(copy_to_temp_path):
50645064
context.apply(plan)
50655065

50665066

5067+
def test_multi_virtual_layer_catalogs(copy_to_temp_path):
5068+
paths = copy_to_temp_path("tests/fixtures/multi_virtual_layer")
5069+
path = Path(paths[0])
5070+
first_db_path = str(path / "db_1.db")
5071+
second_db_path = str(path / "db_2.db")
5072+
5073+
config = Config(
5074+
gateways={
5075+
"first": GatewayConfig(
5076+
connection=DuckDBConnectionConfig(database=first_db_path),
5077+
variables={"overriden_var": "gateway_1"},
5078+
),
5079+
"second": GatewayConfig(
5080+
connection=DuckDBConnectionConfig(database=second_db_path),
5081+
variables={"overriden_var": "gateway_2"},
5082+
),
5083+
},
5084+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
5085+
model_naming=NameInferenceConfig(infer_names=True),
5086+
default_gateway="first",
5087+
variables={"overriden_var": "global", "global_one": 88},
5088+
)
5089+
5090+
# With gateway_managed_virtual_layer to False the catalogs won't be retrieved
5091+
context = Context(paths=paths, config=config)
5092+
assert context.default_catalog_per_gateway == {}
5093+
5094+
config.gateway_managed_virtual_layer = True
5095+
context = Context(paths=paths, config=config)
5096+
assert context.default_catalog_per_gateway == {"first": "db_1", "second": "db_2"}
5097+
assert len(context.engine_adapters) == 2
5098+
5099+
50675100
def test_multi_dbt(mocker):
50685101
context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"])
50695102
context._new_state_sync().reset(default_catalog=context.default_catalog)

tests/core/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4518,7 +4518,7 @@ def test_model_session_properties(sushi_context):
45184518
name test_schema.test_model,
45194519
session_properties (
45204520
'query_label' = (
4521-
'some value',
4521+
'some value',
45224522
'another value',
45234523
'yet another value',
45244524
)
@@ -8350,7 +8350,7 @@ def test_gateway_specific_render(assert_exp_eq) -> None:
83508350
default_gateway="main",
83518351
)
83528352
context = Context(config=config)
8353-
assert context.engine_adapter == context._engine_adapters["main"]
8353+
assert context.engine_adapter == context.engine_adapters["main"]
83548354

83558355
@model(
83568356
name="dummy_model",
@@ -8376,7 +8376,7 @@ def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select:
83768376
""",
83778377
)
83788378
assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter)
8379-
assert len(context._engine_adapters) == 2
8379+
assert len(context.engine_adapters) == 2
83808380

83818381

83828382
def test_model_on_virtual_update(make_snapshot: t.Callable):

tests/core/test_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,16 +2128,16 @@ def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture)
21282128
return_value=pd.DataFrame({"c": [5]}),
21292129
)
21302130

2131-
assert context.engine_adapter == context._engine_adapters["main"]
2131+
assert context.engine_adapter == context.engine_adapters["main"]
21322132
with pytest.raises(
21332133
SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters."
21342134
):
21352135
context._get_engine_adapter("wrong")
21362136

21372137
# Create test should use the gateway specific engine adapter
21382138
context.create_test("sqlmesh_example.gw_model", input_queries=input_queries, overwrite=True)
2139-
assert context._get_engine_adapter("second") == context._engine_adapters["second"]
2140-
assert len(context._engine_adapters) == 2
2139+
assert context._get_engine_adapter("second") == context.engine_adapters["second"]
2140+
assert len(context.engine_adapters) == 2
21412141

21422142
test = load_yaml(context.path / c.TESTS / "test_gw_model.yaml")
21432143

0 commit comments

Comments
 (0)