Skip to content

Commit b5a81e9

Browse files
Feat!: Introduce multiple gateway virtual layer
1 parent 8db5700 commit b5a81e9

21 files changed

Lines changed: 137 additions & 48 deletions

docs/reference/configuration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Configuration options for SQLMesh environment creation and promotion.
3535
| `physical_schema_override` | (Deprecated) Use `physical_schema_mapping` instead. A mapping from model schema names to names of schemas in which physical tables for the corresponding models will be placed. | dict[string, string] | N |
3636
| `physical_schema_mapping` | A mapping from regular expressions to names of schemas in which physical tables for the corresponding models [will be placed](../guides/configuration.md#physical-table-schemas). (Default physical schema name: `sqlmesh__[model schema]`) | dict[string, string] | N |
3737
| `environment_suffix_target` | Whether SQLMesh views should append their environment name to the `schema` or `table` - [additional details](../guides/configuration.md#view-schema-override). (Default: `schema`) | string | N |
38+
| `gateway_managed_virtual_layer` | Whether SQLMesh views of the virtual layer will be created by the default gateway or model specified gateways - [additional details](../guides/configuration.md#view-schema-override). (Default: False) | boolean | N |
3839
| `environment_catalog_mapping` | A mapping from regular expressions to catalog names. The catalog name is used to determine the target catalog for a given environment. | dict[string, string] | N |
3940
| `log_limit` | The default number of logs to keep (Default: `20`) | int | N |
4041

sqlmesh/core/config/root.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Config(BaseConfig):
7272
model_defaults: Default values for model definitions.
7373
physical_schema_mapping: A mapping from regular expressions to names of schemas in which physical tables for corresponding models will be placed.
7474
environment_suffix_target: Indicates whether to append the environment name to the schema or table name.
75+
gateway_managed_virtual_layer: Whether the models' views in the virtual layer are created by the model-specific gateway rather than the default gateway.
7576
environment_catalog_mapping: A mapping from regular expressions to catalog names. The catalog name is used to determine the target catalog for a given environment.
7677
default_target_environment: The name of the environment that will be the default target for the `sqlmesh plan` and `sqlmesh run` commands.
7778
log_limit: The default number of logs to keep.
@@ -110,6 +111,7 @@ class Config(BaseConfig):
110111
environment_suffix_target: EnvironmentSuffixTarget = Field(
111112
default=EnvironmentSuffixTarget.default
112113
)
114+
gateway_managed_virtual_layer: bool = False
113115
environment_catalog_mapping: t.Dict[re.Pattern, str] = {}
114116
default_target_environment: str = c.PROD
115117
log_limit: int = c.DEFAULT_LOG_LIMIT

sqlmesh/core/context.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def __init__(
379379
self.pinned_environments = Environment.sanitize_names(self.config.pinned_environments)
380380
self.auto_categorize_changes = self.config.plan.auto_categorize_changes
381381
self.selected_gateway = gateway or self.config.default_gateway_name
382+
self.gateway_managed_virtual_layer = self.config.gateway_managed_virtual_layer
383+
self.catalogs: t.Dict[str, str] = {}
382384

383385
gw_model_defaults = self.config.gateways[self.selected_gateway].model_defaults
384386
if gw_model_defaults:
@@ -584,6 +586,15 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
584586
"""Load all files in the context's path."""
585587
load_start_ts = time.perf_counter()
586588

589+
# In a multi virtual layer setup we need the catalog of each engine
590+
# since there is a possibility of not a shared catalog between them
591+
if self.gateway_managed_virtual_layer:
592+
self.catalogs = {
593+
name: adapter.default_catalog
594+
for name, adapter in self.engine_adapters.items()
595+
if adapter.default_catalog
596+
}
597+
587598
loaded_projects = [loader.load() for loader in self._loaders]
588599

589600
self.dag = DAG()
@@ -2207,16 +2218,6 @@ def _model_tables(self) -> t.Dict[str, str]:
22072218
for fqn, snapshot in self.snapshots.items()
22082219
}
22092220

2210-
@property
2211-
def _snapshot_gateways(self) -> t.Dict[str, str]:
2212-
"""Mapping of snapshot name to the gateway if specified in the model."""
2213-
2214-
return {
2215-
fqn: snapshot.model.gateway
2216-
for fqn, snapshot in self.snapshots.items()
2217-
if snapshot.is_model and snapshot.model.gateway
2218-
}
2219-
22202221
@cached_property
22212222
def engine_adapters(self) -> t.Dict[str, EngineAdapter]:
22222223
"""Returns all the engine adapters for the gateways defined in the configuration."""
@@ -2287,14 +2288,18 @@ def _context_diff(
22872288
ensure_finalized_snapshots=ensure_finalized_snapshots,
22882289
diff_rendered=diff_rendered,
22892290
environment_statements=self._environment_statements,
2291+
gateway_managed_virtual_layer=self.gateway_managed_virtual_layer,
22902292
)
22912293

22922294
def _run_janitor(self, ignore_ttl: bool = False) -> None:
22932295
self._cleanup_environments()
2294-
expired_snapshots = self.state_sync.delete_expired_snapshots(ignore_ttl=ignore_ttl)
2296+
expired_snapshots, snapshot_gateways = self.state_sync.delete_expired_snapshots(
2297+
ignore_ttl=ignore_ttl
2298+
)
2299+
22952300
self.snapshot_evaluator.cleanup(
22962301
expired_snapshots,
2297-
self._snapshot_gateways,
2302+
snapshot_gateways,
22982303
on_complete=self.console.update_cleanup_progress,
22992304
)
23002305

sqlmesh/core/context_diff.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class ContextDiff(PydanticModel):
5353
"""Whether the currently stored environment record is in unfinalized state."""
5454
normalize_environment_name: bool
5555
"""Whether the environment name should be normalized."""
56+
gateway_managed_virtual_layer: bool = False
57+
"""Whether the virtual layer's views will be created by the model specified gateways."""
5658
create_from: str
5759
"""The name of the environment the target environment will be created from if new."""
5860
create_from_env_exists: bool
@@ -96,6 +98,7 @@ def create(
9698
excluded_requirements: t.Optional[t.Set[str]] = None,
9799
diff_rendered: bool = False,
98100
environment_statements: t.Optional[t.List[EnvironmentStatements]] = [],
101+
gateway_managed_virtual_layer: bool = False,
99102
) -> ContextDiff:
100103
"""Create a ContextDiff object.
101104
@@ -118,7 +121,11 @@ def create(
118121
env = state_reader.get_environment(environment)
119122

120123
create_from_env_exists = False
121-
if env is None or env.expired:
124+
if (
125+
env is None
126+
or env.expired
127+
or env.gateway_managed_virtual_layer != gateway_managed_virtual_layer
128+
):
122129
env = state_reader.get_environment(create_from.lower())
123130

124131
if not env and create_from != c.PROD:
@@ -226,6 +233,7 @@ def create(
226233
diff_rendered=diff_rendered,
227234
previous_environment_statements=previous_environment_statements,
228235
environment_statements=environment_statements,
236+
gateway_managed_virtual_layer=gateway_managed_virtual_layer,
229237
)
230238

231239
@classmethod

sqlmesh/core/environment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ class EnvironmentNamingInfo(PydanticModel):
3232
catalog_name_override: The name of the catalog to use for this environment if an override was provided
3333
normalize_name: Indicates whether the environment's name will be normalized. For example, if it's
3434
`dev`, then it will become `DEV` when targeting Snowflake.
35+
gateway_managed_virtual_layer: Determines whether the virtual layer's views are created by the model-specific
36+
gateways, otherwise the default gateway is used. Default: False.
3537
"""
3638

3739
name: str = c.PROD
3840
suffix_target: EnvironmentSuffixTarget = Field(default=EnvironmentSuffixTarget.SCHEMA)
3941
catalog_name_override: t.Optional[str] = None
4042
normalize_name: bool = True
43+
gateway_managed_virtual_layer: bool = False
4144

4245
@field_validator("name", mode="before")
4346
@classmethod
@@ -49,6 +52,11 @@ def _sanitize_name(cls, v: str) -> str:
4952
def _validate_normalize_name(cls, v: t.Any) -> bool:
5053
return True if v is None else bool(v)
5154

55+
@field_validator("gateway_managed_virtual_layer", mode="before")
56+
@classmethod
57+
def _validate_gateway_managed_virtual_layer(cls, v: t.Any) -> bool:
58+
return False if v is None else bool(v)
59+
5260
@t.overload
5361
@classmethod
5462
def sanitize_name(cls, v: str) -> str: ...
@@ -194,6 +202,7 @@ def naming_info(self) -> EnvironmentNamingInfo:
194202
suffix_target=self.suffix_target,
195203
catalog_name_override=self.catalog_name_override,
196204
normalize_name=self.normalize_name,
205+
gateway_managed_virtual_layer=self.gateway_managed_virtual_layer,
197206
)
198207

199208
@property

sqlmesh/core/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def _load() -> t.List[Model]:
468468
default_catalog=self.context.default_catalog,
469469
infer_names=self.config.model_naming.infer_names,
470470
signal_definitions=signals,
471+
catalogs=self.context.catalogs,
471472
)
472473
except Exception as ex:
473474
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
@@ -525,6 +526,7 @@ def _load_python_models(
525526
default_catalog=self.context.default_catalog,
526527
infer_names=self.config.model_naming.infer_names,
527528
audit_definitions=audits,
529+
catalogs=self.context.catalogs,
528530
):
529531
if model.enabled:
530532
models[model.fqn] = model

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,6 +1907,13 @@ def create_models_from_blueprints(
19071907
else:
19081908
gateway_name = None
19091909

1910+
if (
1911+
(catalogs := loader_kwargs.pop("catalogs", None))
1912+
and gateway_name
1913+
and (catalog := catalogs.get(gateway_name))
1914+
):
1915+
loader_kwargs["default_catalog"] = catalog
1916+
19101917
model_blueprints.append(
19111918
loader(
19121919
path=path,

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
name=self._context_diff.environment,
152152
suffix_target=environment_suffix_target,
153153
normalize_name=self._context_diff.normalize_environment_name,
154+
gateway_managed_virtual_layer=self._context_diff.gateway_managed_virtual_layer,
154155
)
155156

156157
self._latest_plan: t.Optional[Plan] = None

sqlmesh/core/plan/evaluator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,16 @@ def _demote_snapshots(
423423
environment_naming_info: EnvironmentNamingInfo,
424424
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None,
425425
) -> None:
426+
# In a multi virtual layer setup we need the gateway info from the snapshots for demotion
427+
snapshots_to_demote: t.List[Snapshot] = []
428+
if environment_naming_info.gateway_managed_virtual_layer:
429+
removed_snapshots = self.state_sync.get_snapshots(target_snapshots)
430+
snapshots_to_demote = [removed_snapshots[s.snapshot_id] for s in target_snapshots]
431+
426432
self.snapshot_evaluator.demote(
427-
target_snapshots, environment_naming_info, on_complete=on_complete
433+
snapshots_to_demote or target_snapshots,
434+
environment_naming_info,
435+
on_complete=on_complete,
428436
)
429437

430438
def _restate(self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]) -> None:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,15 +226,23 @@ def promote(
226226
deployability_index: Determines snapshots that are deployable in the context of this promotion.
227227
on_complete: A callback to call on each successfully promoted snapshot.
228228
"""
229-
self._create_schemas(
230-
[
231-
s.qualified_view_name.table_for_environment(
232-
environment_naming_info, dialect=self.adapter.dialect
229+
230+
gateway_by_schema: t.Dict[t.Any, str] = {}
231+
tables: t.List[t.Any] = []
232+
for snapshot in target_snapshots:
233+
if snapshot.is_model and not snapshot.is_symbolic:
234+
table = snapshot.qualified_view_name.table_for_environment(
235+
environment_naming_info,
236+
dialect=self._get_adapter(snapshot.model_gateway).dialect
237+
if environment_naming_info.gateway_managed_virtual_layer
238+
else self.adapter.dialect,
233239
)
234-
for s in target_snapshots
235-
if s.is_model and not s.is_symbolic
236-
]
237-
)
240+
tables.append(table)
241+
if environment_naming_info.gateway_managed_virtual_layer:
242+
table_schema = d.schema_(table.db, catalog=table.catalog)
243+
gateway_by_schema[table_schema] = snapshot.model_gateway or ""
244+
self._create_schemas(tables=tables, gateways=gateway_by_schema)
245+
238246
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
239247
with self.concurrent_context():
240248
concurrent_apply_to_snapshots(
@@ -920,7 +928,11 @@ def _promote_snapshot(
920928
table_mapping: t.Optional[t.Dict[str, str]] = None,
921929
) -> None:
922930
if snapshot.is_model:
923-
adapter = self.adapter
931+
adapter = (
932+
self._get_adapter(snapshot.model_gateway)
933+
if environment_naming_info.gateway_managed_virtual_layer
934+
else self.adapter
935+
)
924936
table_name = snapshot.table_name(deployability_index.is_representative(snapshot))
925937
view_name = snapshot.qualified_view_name.for_environment(
926938
environment_naming_info, dialect=adapter.dialect
@@ -953,7 +965,12 @@ def _demote_snapshot(
953965
environment_naming_info: EnvironmentNamingInfo,
954966
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]],
955967
) -> None:
956-
adapter = self.adapter
968+
adapter = (
969+
self._get_adapter(snapshot.model_gateway)
970+
if environment_naming_info.gateway_managed_virtual_layer
971+
and isinstance(snapshot, Snapshot)
972+
else self.adapter
973+
)
957974
view_name = snapshot.qualified_view_name.for_environment(
958975
environment_naming_info, dialect=adapter.dialect
959976
)

0 commit comments

Comments
 (0)