Skip to content

Commit 7fa517e

Browse files
Refactors; revise to keep model gateway in SnapshotTableInfo
1 parent 2dcb38f commit 7fa517e

16 files changed

Lines changed: 42 additions & 94 deletions

sqlmesh/core/context.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,23 +2311,11 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
23112311
def _cleanup_environments(self) -> None:
23122312
expired_environments = self.state_sync.delete_expired_environments()
23132313

2314-
environment_snapshot_adapters: t.Dict[str, t.Dict[str, EngineAdapter]] = {}
2315-
for environment in expired_environments:
2316-
snapshot_adapters: t.Dict[str, EngineAdapter] = {}
2317-
if environment.gateway_managed_virtual_layer:
2318-
snapshots = self.state_sync.get_snapshots(environment.snapshots)
2319-
for snapshot_id, snapshot in snapshots.items():
2320-
if snapshot.is_model and not snapshot.is_symbolic:
2321-
snapshot_adapters[snapshot_id.name] = self._get_engine_adapter(
2322-
snapshot.model_gateway
2323-
)
2324-
environment_snapshot_adapters[environment.name] = snapshot_adapters
2325-
23262314
cleanup_expired_views(
23272315
adapter=self.engine_adapter,
23282316
environments=expired_environments,
23292317
console=self.console,
2330-
environment_snapshot_adapters=environment_snapshot_adapters,
2318+
engine_adapters=self.engine_adapters,
23312319
)
23322320

23332321
def _try_connection(self, connection_name: str, validator: t.Callable[[], None]) -> None:

sqlmesh/core/context_diff.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,7 @@ def create(
121121
env = state_reader.get_environment(environment)
122122

123123
create_from_env_exists = False
124-
if (
125-
env is None
126-
or env.expired
127-
or env.gateway_managed_virtual_layer != gateway_managed_virtual_layer
128-
):
124+
if env is None or env.expired or env.gateway_managed != gateway_managed_virtual_layer:
129125
env = state_reader.get_environment(create_from.lower())
130126

131127
if not env and create_from != c.PROD:

sqlmesh/core/environment.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sqlmesh.utils.date import TimeLike, now_timestamp
1717
from sqlmesh.utils.jinja import JinjaMacroRegistry
1818
from sqlmesh.utils.metaprogramming import Executable
19-
from sqlmesh.utils.pydantic import PydanticModel, field_validator
19+
from sqlmesh.utils.pydantic import PydanticModel, field_validator, ValidationInfo
2020

2121
T = t.TypeVar("T", bound="EnvironmentNamingInfo")
2222
PydanticType = t.TypeVar("PydanticType", bound="PydanticModel")
@@ -32,30 +32,27 @@ class EnvironmentNamingInfo(PydanticModel):
3232
catalog_name_override: The name of the catalog to use for this environment if an override was provided
3333
normalize_name: Indicates whether the environment's name will be normalized. For example, if it's
3434
`dev`, then it will become `DEV` when targeting Snowflake.
35-
gateway_managed_virtual_layer: Determines whether the virtual layer's views are created by the model-specific
35+
gateway_managed: Determines whether the virtual layer's views are created by the model-specific
3636
gateways, otherwise the default gateway is used. Default: False.
3737
"""
3838

3939
name: str = c.PROD
4040
suffix_target: EnvironmentSuffixTarget = Field(default=EnvironmentSuffixTarget.SCHEMA)
4141
catalog_name_override: t.Optional[str] = None
4242
normalize_name: bool = True
43-
gateway_managed_virtual_layer: bool = False
43+
gateway_managed: bool = False
4444

4545
@field_validator("name", mode="before")
4646
@classmethod
4747
def _sanitize_name(cls, v: str) -> str:
4848
return word_characters_only(v).lower()
4949

50-
@field_validator("normalize_name", mode="before")
50+
@field_validator("normalize_name", "gateway_managed", mode="before")
5151
@classmethod
52-
def _validate_normalize_name(cls, v: t.Any) -> bool:
53-
return True if v is None else bool(v)
54-
55-
@field_validator("gateway_managed_virtual_layer", mode="before")
56-
@classmethod
57-
def _validate_gateway_managed_virtual_layer(cls, v: t.Any) -> bool:
58-
return False if v is None else bool(v)
52+
def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool:
53+
if v is None:
54+
return info.field_name == "normalize_name"
55+
return bool(v)
5956

6057
@t.overload
6158
@classmethod
@@ -202,7 +199,7 @@ def naming_info(self) -> EnvironmentNamingInfo:
202199
suffix_target=self.suffix_target,
203200
catalog_name_override=self.catalog_name_override,
204201
normalize_name=self.normalize_name,
205-
gateway_managed_virtual_layer=self.gateway_managed_virtual_layer,
202+
gateway_managed=self.gateway_managed,
206203
)
207204

208205
@property

sqlmesh/core/plan/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
name=self._context_diff.environment,
152152
suffix_target=environment_suffix_target,
153153
normalize_name=self._context_diff.normalize_environment_name,
154-
gateway_managed_virtual_layer=self._context_diff.gateway_managed_virtual_layer,
154+
gateway_managed=self._context_diff.gateway_managed_virtual_layer,
155155
)
156156

157157
self._latest_plan: t.Optional[Plan] = None

sqlmesh/core/plan/evaluator.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,8 @@ def _demote_snapshots(
423423
environment_naming_info: EnvironmentNamingInfo,
424424
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None,
425425
) -> None:
426-
# In a multi virtual layer setup we need the gateway info from the snapshots for demotion
427-
snapshots_to_demote: t.List[Snapshot] = []
428-
if environment_naming_info.gateway_managed_virtual_layer:
429-
removed_snapshots = self.state_sync.get_snapshots(target_snapshots)
430-
snapshots_to_demote = [removed_snapshots[s.snapshot_id] for s in target_snapshots]
431-
432426
self.snapshot_evaluator.demote(
433-
snapshots_to_demote or target_snapshots,
427+
target_snapshots,
434428
environment_naming_info,
435429
on_complete=on_complete,
436430
)

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True):
477477
base_table_name_override: t.Optional[str] = None
478478
custom_materialization: t.Optional[str] = None
479479
dev_table_suffix: str
480+
model_gateway: t.Optional[str] = None
480481

481482
def __lt__(self, other: SnapshotTableInfo) -> bool:
482483
return self.name < other.name
@@ -1177,6 +1178,7 @@ def table_info(self) -> SnapshotTableInfo:
11771178
node_type=self.node_type,
11781179
custom_materialization=custom_materialization,
11791180
dev_table_suffix=self.dev_table_suffix,
1181+
model_gateway=self.model_gateway,
11801182
)
11811183

11821184
@property
@@ -1336,7 +1338,6 @@ def __getstate__(self) -> t.Dict[t.Any, t.Any]:
13361338
class SnapshotTableCleanupTask(PydanticModel):
13371339
snapshot: SnapshotTableInfo
13381340
dev_table_only: bool
1339-
gateway: t.Optional[str] = None
13401341

13411342

13421343
SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, Snapshot]

sqlmesh/core/snapshot/evaluator.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,11 @@ def promote(
234234
table = snapshot.qualified_view_name.table_for_environment(
235235
environment_naming_info,
236236
dialect=self._get_adapter(snapshot.model_gateway).dialect
237-
if environment_naming_info.gateway_managed_virtual_layer
237+
if environment_naming_info.gateway_managed
238238
else self.adapter.dialect,
239239
)
240240
tables.append(table)
241-
if environment_naming_info.gateway_managed_virtual_layer:
241+
if environment_naming_info.gateway_managed:
242242
table_schema = d.schema_(table.db, catalog=table.catalog)
243243
gateway_by_schema[table_schema] = snapshot.model_gateway or ""
244244
self._create_schemas(tables=tables, gateways=gateway_by_schema)
@@ -437,15 +437,14 @@ def cleanup(
437437
snapshots_to_dev_table_only = {
438438
t.snapshot.snapshot_id: t.dev_table_only for t in target_snapshots
439439
}
440-
snapshot_gateways = {t.snapshot.snapshot_id: t.gateway for t in target_snapshots}
441440

442441
with self.concurrent_context():
443442
concurrent_apply_to_snapshots(
444443
[t.snapshot for t in target_snapshots],
445444
lambda s: self._cleanup_snapshot(
446445
s,
447446
snapshots_to_dev_table_only[s.snapshot_id],
448-
self._get_adapter(snapshot_gateways[s.snapshot_id]),
447+
self._get_adapter(s.model_gateway),
449448
on_complete,
450449
),
451450
self.ddl_concurrent_tasks,
@@ -928,7 +927,7 @@ def _promote_snapshot(
928927
if snapshot.is_model:
929928
adapter = (
930929
self._get_adapter(snapshot.model_gateway)
931-
if environment_naming_info.gateway_managed_virtual_layer
930+
if environment_naming_info.gateway_managed
932931
else self.adapter
933932
)
934933
table_name = snapshot.table_name(deployability_index.is_representative(snapshot))
@@ -965,8 +964,7 @@ def _demote_snapshot(
965964
) -> None:
966965
adapter = (
967966
self._get_adapter(snapshot.model_gateway)
968-
if environment_naming_info.gateway_managed_virtual_layer
969-
and isinstance(snapshot, Snapshot)
967+
if environment_naming_info.gateway_managed
970968
else self.adapter
971969
)
972970
view_name = snapshot.qualified_view_name.for_environment(

sqlmesh/core/state_sync/common.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
def cleanup_expired_views(
2424
adapter: EngineAdapter,
2525
environments: t.List[Environment],
26-
environment_snapshot_adapters: t.Optional[t.Dict[str, t.Dict[str, EngineAdapter]]] = None,
2726
console: t.Optional[Console] = None,
27+
engine_adapters: t.Optional[t.Dict[str, EngineAdapter]] = None,
2828
) -> None:
2929
expired_schema_environments = [
3030
environment for environment in environments if environment.suffix_target.is_schema
@@ -33,19 +33,16 @@ def cleanup_expired_views(
3333
environment for environment in environments if environment.suffix_target.is_table
3434
]
3535

36+
# We have to use the corresponding adapter if the virtual layer is gateway managed
37+
def get_adapter(gateway_managed: bool, gateway: t.Optional[str] = None) -> EngineAdapter:
38+
if gateway_managed and gateway:
39+
return (engine_adapters or {}).get(gateway, adapter)
40+
return adapter
41+
3642
# Drop the schemas for the expired environments
37-
# Note: We have to use the corresponding adapter if it is a gateway managed virtual layer
3843
for engine_adapter, expired_catalog, expired_schema in {
3944
(
40-
(
41-
engine_adapter := (
42-
(environment_dict.get(snapshot.name) or adapter)
43-
if environment.gateway_managed_virtual_layer
44-
and environment_snapshot_adapters
45-
and (environment_dict := environment_snapshot_adapters.get(environment.name))
46-
else adapter
47-
)
48-
),
45+
(engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)),
4946
snapshot.qualified_view_name.catalog_for_environment(
5047
environment.naming_info, dialect=engine_adapter.dialect
5148
),
@@ -74,15 +71,7 @@ def cleanup_expired_views(
7471
# Drop the views for the expired environments
7572
for engine_adapter, expired_view in {
7673
(
77-
(
78-
engine_adapter := (
79-
(environment_dict.get(snapshot.name) or adapter)
80-
if environment.gateway_managed_virtual_layer
81-
and environment_snapshot_adapters
82-
and (environment_dict := environment_snapshot_adapters.get(environment.name))
83-
else adapter
84-
)
85-
),
74+
(engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)),
8675
snapshot.qualified_view_name.for_environment(
8776
environment.naming_info, dialect=engine_adapter.dialect
8877
),

sqlmesh/core/state_sync/db/environment.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
"catalog_name_override": exp.DataType.build("text"),
4949
"previous_finalized_snapshots": exp.DataType.build(blob_type),
5050
"normalize_name": exp.DataType.build("boolean"),
51-
"gateway_managed_virtual_layer": exp.DataType.build("boolean"),
51+
"gateway_managed": exp.DataType.build("boolean"),
5252
"requirements": exp.DataType.build(blob_type),
5353
}
5454

@@ -168,7 +168,6 @@ def delete_expired_environments(self) -> t.List[Environment]:
168168
Returns:
169169
A list of deleted environments.
170170
"""
171-
172171
now_ts = now_timestamp()
173172
filter_expr = exp.LTE(
174173
this=exp.column("expiration_ts"),
@@ -182,10 +181,8 @@ def delete_expired_environments(self) -> t.List[Environment]:
182181
lock_for_update=True,
183182
),
184183
)
185-
186184
environments = [self._environment_from_row(r) for r in rows]
187185

188-
# Delete the expired environments
189186
self.engine_adapter.delete_from(
190187
self.environments_table,
191188
where=filter_expr,
@@ -331,7 +328,7 @@ def _environment_to_df(environment: Environment) -> pd.DataFrame:
331328
else None
332329
),
333330
"normalize_name": environment.normalize_name,
334-
"gateway_managed_virtual_layer": environment.gateway_managed_virtual_layer,
331+
"gateway_managed": environment.gateway_managed,
335332
"requirements": json.dumps(environment.requirements),
336333
}
337334
]

sqlmesh/core/state_sync/db/facade.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ def promote(
201201
}
202202
if (
203203
not existing_environment.expired
204-
and existing_environment.gateway_managed_virtual_layer
205-
== environment.gateway_managed_virtual_layer
204+
and existing_environment.gateway_managed == environment.gateway_managed
206205
):
207206
if environment.previous_plan_id != existing_environment.plan_id:
208207
raise ConflictingPlanError(

0 commit comments

Comments
 (0)