Skip to content

Commit aa0c50f

Browse files
authored
Fix: ensure session-scoped Snowflake warehouse is rolled back on failure (#4640)
1 parent 5a1be92 commit aa0c50f

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
9393
return
9494

9595
self.execute(f"USE WAREHOUSE {warehouse_sql}")
96-
yield
97-
self.execute(f"USE WAREHOUSE {current_warehouse_sql}")
96+
try:
97+
yield
98+
finally:
99+
self.execute(f"USE WAREHOUSE {current_warehouse_sql}")
98100

99101
@property
100102
def _current_warehouse(self) -> exp.Identifier:

tests/core/engine_adapter/test_snowflake.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_session(
9898
adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter)
9999
adapter.cursor.fetchone.return_value = (current_warehouse,)
100100

101+
# Test normal execution
101102
with adapter.session({"warehouse": configured_warehouse}):
102103
pass
103104

@@ -114,6 +115,27 @@ def test_session(
114115

115116
assert to_sql_calls(adapter) == expected_calls
116117

118+
# Test exception handling - warehouse should still be reset
119+
if should_change:
120+
adapter.cursor.execute.reset_mock()
121+
adapter.cursor.fetchone.return_value = (current_warehouse,)
122+
123+
try:
124+
with adapter.session({"warehouse": configured_warehouse}):
125+
adapter.execute("SELECT 1")
126+
raise RuntimeError("Test exception")
127+
except RuntimeError:
128+
pass
129+
130+
expected_exception_calls = [
131+
"SELECT CURRENT_WAREHOUSE()",
132+
f"USE WAREHOUSE {configured_warehouse_exp}",
133+
"SELECT 1",
134+
f"USE WAREHOUSE {current_warehouse_exp}",
135+
]
136+
137+
assert to_sql_calls(adapter) == expected_exception_calls
138+
117139

118140
def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
119141
adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter)

0 commit comments

Comments
 (0)