Skip to content

Commit 6e1e980

Browse files
committed
neutral state catalog
1 parent c50e9db commit 6e1e980

2 files changed

Lines changed: 76 additions & 13 deletions

File tree

sqlmesh/core/engine_adapter/fabric.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ def _target_catalog(self) -> t.Optional[str]:
5252
def _target_catalog(self, value: t.Optional[str]) -> None:
5353
self._connection_pool.set_attribute("target_catalog", value)
5454

55+
def _normalize_catalog(
56+
self, catalog_name: t.Optional[str]
57+
) -> t.Optional[str]:
58+
if not catalog_name:
59+
return None
60+
61+
default_catalog = (
62+
self._default_catalog or self._extra_config.get("database")
63+
)
64+
if default_catalog and catalog_name == default_catalog:
65+
return None
66+
67+
return catalog_name
68+
69+
def _catalog_state_label(self, catalog_name: t.Optional[str]) -> str:
70+
return catalog_name or "<default>"
71+
5572
@property
5673
def api_client(self) -> FabricHttpClient:
5774
# the requests Session is not guaranteed to be threadsafe
@@ -109,10 +126,10 @@ def _drop_catalog(self, catalog_name: exp.Identifier) -> None:
109126
self.close()
110127

111128
def get_current_catalog(self) -> t.Optional[str]:
112-
"""Return the adapter-managed catalog for Fabric's stateless sessions."""
113-
return self._target_catalog or self._extra_config.get("database")
129+
"""Return the explicit Fabric catalog target for the current thread."""
130+
return self._normalize_catalog(self._target_catalog)
114131

115-
def set_current_catalog(self, catalog_name: str) -> None:
132+
def set_current_catalog(self, catalog_name: t.Optional[str]) -> None:
116133
"""
117134
Set the current catalog for Microsoft Fabric connections.
118135
@@ -121,7 +138,8 @@ def set_current_catalog(self, catalog_name: str) -> None:
121138
recreate them with the new catalog in the connection configuration.
122139
123140
Args:
124-
catalog_name: The name of the catalog (warehouse) to switch to
141+
catalog_name: The name of the catalog (warehouse) to switch to.
142+
The configured default catalog is treated as the neutral state.
125143
126144
Note:
127145
Fabric doesn't support catalog switching via USE statements because each
@@ -132,13 +150,18 @@ def set_current_catalog(self, catalog_name: str) -> None:
132150
https://learn.microsoft.com/en-us/fabric/data-warehouse/sql-query-editor#limitations
133151
"""
134152
current_catalog = self.get_current_catalog()
153+
target_catalog = self._normalize_catalog(catalog_name)
135154

136155
# If already using the requested catalog, do nothing
137-
if current_catalog and current_catalog == catalog_name:
138-
logger.debug(f"Already using catalog '{catalog_name}', no action needed")
156+
if current_catalog == target_catalog:
157+
logger.debug("Already using the requested Fabric catalog state, no action needed")
139158
return
140159

141-
logger.info(f"Switching from catalog '{current_catalog}' to '{catalog_name}'")
160+
logger.info(
161+
"Switching from catalog '%s' to '%s'",
162+
self._catalog_state_label(current_catalog),
163+
self._catalog_state_label(target_catalog),
164+
)
142165

143166
# commit the transaction before closing the connection to help prevent errors like:
144167
# > Snapshot isolation transaction failed in database because the object accessed by the statement has been modified by a
@@ -149,14 +172,14 @@ def set_current_catalog(self, catalog_name: str) -> None:
149172
# note: we call close() on the connection pool instead of self.close() because self.close() calls close_all()
150173
# on the connection pool but we just want to close the connection for this thread
151174
self._connection_pool.close()
152-
self._target_catalog = catalog_name # new connections will use this catalog
175+
self._target_catalog = target_catalog
153176

154177
catalog_after_switch = self.get_current_catalog()
155178

156-
if catalog_after_switch != catalog_name:
179+
if catalog_after_switch != target_catalog:
157180
# We need to raise an error if the catalog switch failed to prevent the operation that needed the catalog switch from being run against the wrong catalog
158181
raise SQLMeshError(
159-
f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}"
182+
f"Unable to switch catalog to {target_catalog}, catalog ended up as {catalog_after_switch}"
160183
)
161184

162185
def alter_table(

tests/core/engine_adapter/test_fabric.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def adapter(make_mocked_engine_adapter: t.Callable) -> FabricEngineAdapter:
1919
return make_mocked_engine_adapter(FabricEngineAdapter)
2020

2121

22-
def test_get_current_catalog_uses_target_catalog_or_configured_database(
22+
def test_get_current_catalog_uses_only_explicit_target_catalog(
2323
make_mocked_engine_adapter: t.Callable,
2424
):
2525
adapter = make_mocked_engine_adapter(
2626
FabricEngineAdapter,
2727
database="default_catalog",
2828
)
2929

30-
assert adapter.get_current_catalog() == "default_catalog"
30+
assert adapter.get_current_catalog() is None
3131

3232
adapter._target_catalog = "switched_catalog"
3333

@@ -36,7 +36,7 @@ def test_get_current_catalog_uses_target_catalog_or_configured_database(
3636
adapter._connection_pool.close()
3737

3838
assert adapter._connection_pool.get_attribute("target_catalog") is None
39-
assert adapter.get_current_catalog() == "default_catalog"
39+
assert adapter.get_current_catalog() is None
4040
adapter.cursor.execute.assert_not_called()
4141

4242

@@ -63,6 +63,46 @@ def test_set_current_catalog_does_not_query_database(
6363
adapter.cursor.execute.assert_not_called()
6464

6565

66+
def test_set_current_catalog_to_default_clears_explicit_target(
67+
make_mocked_engine_adapter: t.Callable,
68+
):
69+
adapter = make_mocked_engine_adapter(
70+
FabricEngineAdapter,
71+
default_catalog="core",
72+
database="core",
73+
)
74+
75+
adapter.set_current_catalog("planning")
76+
adapter.set_current_catalog("core")
77+
78+
assert adapter.get_current_catalog() is None
79+
adapter.cursor.execute.assert_not_called()
80+
81+
82+
def test_catalog_scoped_call_restores_to_neutral_state(
83+
make_mocked_engine_adapter: t.Callable,
84+
mocker: MockerFixture,
85+
):
86+
adapter = make_mocked_engine_adapter(
87+
FabricEngineAdapter,
88+
default_catalog="core",
89+
database="core",
90+
)
91+
set_current_catalog_spy = mocker.patch.object(
92+
adapter,
93+
"set_current_catalog",
94+
wraps=adapter.set_current_catalog,
95+
)
96+
adapter.cursor.fetchone.return_value = (1,)
97+
98+
adapter.table_exists("planning.db.table")
99+
100+
assert [
101+
call.args[0] for call in set_current_catalog_spy.call_args_list
102+
] == ["planning", None]
103+
assert adapter.get_current_catalog() is None
104+
105+
66106
def test_columns(adapter: FabricEngineAdapter):
67107
adapter.cursor.fetchall.return_value = [
68108
("decimal_ps", "decimal", None, 5, 4),

0 commit comments

Comments
 (0)