Skip to content

Commit b303b64

Browse files
Refactors; account for identical db names between warehouses when creating schemas
1 parent 72dc0b3 commit b303b64

9 files changed

Lines changed: 62 additions & 49 deletions

File tree

sqlmesh/core/context.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
self._environment_statements: t.List[EnvironmentStatements] = []
365365
self._excluded_requirements: t.Set[str] = set()
366366
self._default_catalog: t.Optional[str] = None
367-
self._catalogs: t.Dict[str, str] = {}
367+
self._default_catalog_per_gateway: t.Dict[str, str] = {}
368368
self._linters: t.Dict[str, Linter] = {}
369369
self._loaded: bool = False
370370

@@ -2225,15 +2225,15 @@ def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
22252225
return self._engine_adapters
22262226

22272227
@cached_property
2228-
def catalogs(self) -> t.Dict[str, str]:
2228+
def default_catalog_per_gateway(self) -> t.Dict[str, str]:
22292229
"""Returns the catalogs for each engine adapter in a multi virtual layer setup when the catalog isn't shared."""
22302230
if self.gateway_managed_virtual_layer:
2231-
self._catalogs = {
2231+
self._default_catalog_per_gateway = {
22322232
name: adapter.default_catalog
22332233
for name, adapter in self.engine_adapters.items()
22342234
if adapter.default_catalog
22352235
}
2236-
return self._catalogs
2236+
return self._default_catalog_per_gateway
22372237

22382238
def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
22392239
if gateway:
@@ -2317,10 +2317,10 @@ def _cleanup_environments(self) -> None:
23172317
expired_environments = self.state_sync.delete_expired_environments()
23182318

23192319
cleanup_expired_views(
2320-
adapter=self.engine_adapter,
2320+
default_adapter=self.engine_adapter,
2321+
engine_adapters=self.engine_adapters,
23212322
environments=expired_environments,
23222323
console=self.console,
2323-
engine_adapters=self.engine_adapters,
23242324
)
23252325

23262326
def _try_connection(self, connection_name: str, validator: t.Callable[[], None]) -> None:

sqlmesh/core/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _load() -> t.List[Model]:
468468
default_catalog=self.context.default_catalog,
469469
infer_names=self.config.model_naming.infer_names,
470470
signal_definitions=signals,
471-
catalogs=self.context.catalogs,
471+
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
472472
)
473473
except Exception as ex:
474474
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
@@ -526,7 +526,7 @@ def _load_python_models(
526526
default_catalog=self.context.default_catalog,
527527
infer_names=self.config.model_naming.infer_names,
528528
audit_definitions=audits,
529-
catalogs=self.context.catalogs,
529+
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
530530
):
531531
if model.enabled:
532532
models[model.fqn] = model

sqlmesh/core/model/decorator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def models(
9393
path: Path,
9494
module_path: Path,
9595
dialect: t.Optional[str] = None,
96+
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
9697
**loader_kwargs: t.Any,
9798
) -> t.List[Model]:
9899
return create_models_from_blueprints(
@@ -103,6 +104,7 @@ def models(
103104
path=path,
104105
module_path=module_path,
105106
dialect=dialect,
107+
default_catalog_per_gateway=default_catalog_per_gateway,
106108
**loader_kwargs,
107109
)
108110

sqlmesh/core/model/definition.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,7 @@ def create_models_from_blueprints(
18861886
path: Path = Path(),
18871887
module_path: Path = Path(),
18881888
dialect: DialectType = None,
1889+
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
18891890
**loader_kwargs: t.Any,
18901891
) -> t.List[Model]:
18911892
model_blueprints: t.List[Model] = []
@@ -1907,11 +1908,10 @@ def create_models_from_blueprints(
19071908
else:
19081909
gateway_name = None
19091910

1910-
# We pop to avoid pydantic validation issues since catalogs is not a model property
19111911
if (
1912-
(catalogs := loader_kwargs.pop("catalogs", None))
1912+
default_catalog_per_gateway
19131913
and gateway_name
1914-
and (catalog := catalogs.get(gateway_name))
1914+
and (catalog := default_catalog_per_gateway.get(gateway_name))
19151915
):
19161916
loader_kwargs["default_catalog"] = catalog
19171917

@@ -1935,6 +1935,7 @@ def load_sql_based_models(
19351935
path: Path = Path(),
19361936
module_path: Path = Path(),
19371937
dialect: DialectType = None,
1938+
default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None,
19381939
**loader_kwargs: t.Any,
19391940
) -> t.List[Model]:
19401941
gateway: t.Optional[exp.Expression] = None
@@ -1972,6 +1973,7 @@ def load_sql_based_models(
19721973
path=path,
19731974
module_path=module_path,
19741975
dialect=dialect,
1976+
default_catalog_per_gateway=default_catalog_per_gateway,
19751977
**loader_kwargs,
19761978
)
19771979

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,7 @@ def _demote_snapshots(
424424
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None,
425425
) -> None:
426426
self.snapshot_evaluator.demote(
427-
target_snapshots,
428-
environment_naming_info,
429-
on_complete=on_complete,
427+
target_snapshots, environment_naming_info, on_complete=on_complete
430428
)
431429

432430
def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -227,21 +227,20 @@ def promote(
227227
on_complete: A callback to call on each successfully promoted snapshot.
228228
"""
229229

230-
gateway_by_schema: t.Dict[t.Any, str] = {}
231-
tables: t.List[t.Any] = []
230+
tables_by_gateway: t.Dict[t.Union[str, None], t.List[exp.Table]] = defaultdict(list)
232231
for snapshot in target_snapshots:
233232
if snapshot.is_model and not snapshot.is_symbolic:
233+
gateway = (
234+
snapshot.model_gateway if environment_naming_info.gateway_managed else None
235+
)
236+
adapter = self._get_adapter(gateway)
234237
table = snapshot.qualified_view_name.table_for_environment(
235-
environment_naming_info,
236-
dialect=self._get_adapter(snapshot.model_gateway).dialect
237-
if environment_naming_info.gateway_managed
238-
else self.adapter.dialect,
238+
environment_naming_info, dialect=adapter.dialect
239239
)
240-
tables.append(table)
241-
if environment_naming_info.gateway_managed:
242-
table_schema = d.schema_(table.db, catalog=table.catalog)
243-
gateway_by_schema[table_schema] = snapshot.model_gateway or ""
244-
self._create_schemas(tables=tables, gateways=gateway_by_schema)
240+
tables_by_gateway[gateway].append(table)
241+
242+
for gateway, tables in tables_by_gateway.items():
243+
self._create_schemas(tables=tables, gateway=gateway)
245244

246245
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
247246
with self.concurrent_context():
@@ -301,8 +300,9 @@ def create(
301300
allow_destructive_snapshots: Set of snapshots that are allowed to have destructive schema changes.
302301
"""
303302
snapshots_with_table_names = defaultdict(set)
304-
tables_by_schema = defaultdict(set)
305-
gateway_by_schema: t.Dict[exp.Table, str] = {}
303+
tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = (
304+
defaultdict(lambda: defaultdict(set))
305+
)
306306
table_deployability: t.Dict[str, bool] = {}
307307
allow_destructive_snapshots = allow_destructive_snapshots or set()
308308

@@ -324,24 +324,32 @@ def create(
324324
snapshots_with_table_names[snapshot].add(table.name)
325325
table_deployability[table.name] = is_deployable
326326
table_schema = d.schema_(table.db, catalog=table.catalog)
327-
tables_by_schema[table_schema].add(table.name)
328-
gateway_by_schema[table_schema] = snapshot.model.gateway or ""
327+
tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name)
329328

330-
def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.Set[str]:
329+
def _get_data_objects(
330+
schema: exp.Table,
331+
object_names: t.Optional[t.Set[str]] = None,
332+
gateway: t.Optional[str] = None,
333+
) -> t.Set[str]:
331334
logger.info("Listing data objects in schema %s", schema.sql())
332-
objs = self.get_adapter(gateway).get_data_objects(schema, tables_by_schema[schema])
335+
objs = self._get_adapter(gateway).get_data_objects(schema, object_names)
333336
return {obj.name for obj in objs}
334337

335338
with self.concurrent_context():
336-
existing_objects = {
337-
obj
338-
for objs in concurrent_apply_to_values(
339-
list(tables_by_schema),
340-
lambda s: _get_data_objects(s, gateway_by_schema[s]),
341-
self.ddl_concurrent_tasks,
342-
)
343-
for obj in objs
344-
}
339+
existing_objects: t.Set[str] = set()
340+
for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
341+
objs_for_gateway = {
342+
obj
343+
for objs in concurrent_apply_to_values(
344+
list(tables_by_schema),
345+
lambda s: _get_data_objects(
346+
schema=s, object_names=tables_by_schema.get(s), gateway=gateway
347+
),
348+
self.ddl_concurrent_tasks,
349+
)
350+
for obj in objs
351+
}
352+
existing_objects.update(objs_for_gateway)
345353

346354
snapshots_to_create = []
347355
target_deployability_flags: t.Dict[str, t.List[bool]] = defaultdict(list)
@@ -359,7 +367,10 @@ def _get_data_objects(schema: exp.Table, gateway: t.Optional[str] = None) -> t.S
359367
return
360368
if on_start:
361369
on_start(len(snapshots_to_create))
362-
self._create_schemas(tables_by_schema, gateway_by_schema)
370+
371+
for gateway, tables_by_schema in tables_by_gateway_and_schema.items():
372+
self._create_schemas(tables=tables_by_schema, gateway=gateway)
373+
363374
self._create_snapshots(
364375
snapshots_to_create=snapshots_to_create,
365376
snapshots=snapshots,
@@ -1072,7 +1083,7 @@ def _audit(
10721083
def _create_schemas(
10731084
self,
10741085
tables: t.Iterable[t.Union[exp.Table, str]],
1075-
gateways: t.Optional[t.Dict[exp.Table, str]] = None,
1086+
gateway: t.Optional[str] = None,
10761087
) -> None:
10771088
table_exprs = [exp.to_table(t) for t in tables]
10781089
unique_schemas = {(t.args["db"], t.args.get("catalog")) for t in table_exprs if t and t.db}
@@ -1081,7 +1092,7 @@ def _create_schemas(
10811092
for schema_name, catalog in unique_schemas:
10821093
schema = schema_(schema_name, catalog)
10831094
logger.info("Creating schema '%s'", schema)
1084-
adapter = self.get_adapter(gateways.get(schema)) if gateways else self.adapter
1095+
adapter = self._get_adapter(gateway)
10851096
adapter.create_schema(schema)
10861097

10871098
def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:

sqlmesh/core/state_sync/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222

2323
def cleanup_expired_views(
24-
adapter: EngineAdapter,
24+
default_adapter: EngineAdapter,
25+
engine_adapters: t.Dict[str, EngineAdapter],
2526
environments: t.List[Environment],
2627
console: t.Optional[Console] = None,
27-
engine_adapters: t.Optional[t.Dict[str, EngineAdapter]] = None,
2828
) -> None:
2929
expired_schema_environments = [
3030
environment for environment in environments if environment.suffix_target.is_schema
@@ -36,8 +36,8 @@ def cleanup_expired_views(
3636
# We have to use the corresponding adapter if the virtual layer is gateway managed
3737
def get_adapter(gateway_managed: bool, gateway: t.Optional[str] = None) -> EngineAdapter:
3838
if gateway_managed and gateway:
39-
return (engine_adapters or {}).get(gateway, adapter)
40-
return adapter
39+
return engine_adapters.get(gateway, default_adapter)
40+
return default_adapter
4141

4242
# Drop the schemas for the expired environments
4343
for engine_adapter, expired_catalog, expired_schema in {

sqlmesh/engines/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def cleanup(
140140
if isinstance(command_payload, str):
141141
command_payload = CleanupCommandPayload.parse_raw(command_payload)
142142

143-
cleanup_expired_views(evaluator.adapter, command_payload.environments)
143+
cleanup_expired_views(evaluator.adapter, evaluator.adapters, command_payload.environments)
144144
evaluator.cleanup(command_payload.tasks)
145145

146146

tests/core/state_sync/test_state_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2635,7 +2635,7 @@ def test_cleanup_expired_views(
26352635
previous_plan_id="test_plan_id",
26362636
catalog_name_override="catalog_override",
26372637
)
2638-
cleanup_expired_views(adapter, [schema_environment, table_environment])
2638+
cleanup_expired_views(adapter, {}, [schema_environment, table_environment])
26392639
assert adapter.drop_schema.called
26402640
assert adapter.drop_view.called
26412641
assert adapter.drop_schema.call_args_list == [

0 commit comments

Comments
 (0)