Skip to content

Commit 529e52f

Browse files
authored
Fix(snowflake): Allow models that utilize Snowpark to execute concurrently (#4431)
1 parent d2cb20f commit 529e52f

3 files changed

Lines changed: 71 additions & 7 deletions

File tree

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
6868
},
6969
)
7070
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
71+
SNOWPARK = "snowpark"
7172

7273
@contextlib.contextmanager
7374
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -104,9 +105,16 @@ def _current_warehouse(self) -> exp.Identifier:
104105
@property
105106
def snowpark(self) -> t.Optional[SnowparkSession]:
106107
if snowpark:
107-
return snowpark.Session.builder.configs(
108-
{"connection": self._connection_pool.get()}
109-
).getOrCreate()
108+
if not self._connection_pool.get_attribute(self.SNOWPARK):
109+
# Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
110+
# The sessions are cleaned up when close() is called
111+
new_session = snowpark.Session.builder.configs(
112+
{"connection": self._connection_pool.get()}
113+
).create()
114+
self._connection_pool.set_attribute(self.SNOWPARK, new_session)
115+
116+
return self._connection_pool.get_attribute(self.SNOWPARK)
117+
110118
return None
111119

112120
@property
@@ -584,3 +592,10 @@ def _columns_to_types(
584592
return columns_to_types_from_dtypes(query_or_df.sample(n=1).to_pandas().dtypes.items())
585593

586594
return super()._columns_to_types(query_or_df, columns_to_types)
595+
596+
def close(self) -> t.Any:
597+
if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK):
598+
snowpark_session.close() # type: ignore
599+
self._connection_pool.set_attribute(self.SNOWPARK, None)
600+
601+
return super().close()

tests/core/engine_adapter/integration/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,7 @@ def cleanup(self, ctx: t.Optional[Context] = None):
707707
schema_name=schema_name, ignore_if_not_exists=True, cascade=True
708708
)
709709

710-
if snowpark := self.engine_adapter.snowpark:
711-
# ensure that the next test gets a fresh Snowpark session
712-
snowpark.close()
710+
self.engine_adapter.close()
713711

714712
def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
715713
if not self._context:

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from pytest import FixtureRequest
44
from sqlglot import exp
5+
from pathlib import Path
56
from sqlglot.optimizer.qualify_columns import quote_identifiers
67
from sqlglot.helper import seq_get
78
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
@@ -10,6 +11,9 @@
1011
from sqlmesh.core.model import SqlModel, load_sql_based_model
1112
from sqlmesh.core.plan import Plan
1213
from tests.core.engine_adapter.integration import TestContext
14+
from sqlmesh import model, ExecutionContext
15+
from sqlmesh.core.model import ModelKindName
16+
from datetime import datetime
1317

1418
from tests.core.engine_adapter.integration import (
1519
TestContext,
@@ -19,7 +23,9 @@
1923
)
2024

2125

22-
@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"])))
26+
@pytest.fixture(
27+
params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"], show_variant_in_test_id=False))
28+
)
2329
def ctx(
2430
request: FixtureRequest,
2531
create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]],
@@ -220,3 +226,48 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA
220226
result = sqlmesh.plan(auto_apply=True)
221227

222228
assert len(result.new_snapshots) == 2
229+
230+
231+
def test_snowpark_concurrency(ctx: TestContext) -> None:
232+
from snowflake.snowpark import DataFrame
233+
234+
table = ctx.table("my_model")
235+
236+
# this model will insert 10 records in batches of 1, with 4 batches at a time running concurrently
237+
@model(
238+
name=table.sql(),
239+
kind=dict(
240+
name=ModelKindName.INCREMENTAL_BY_TIME_RANGE,
241+
time_column="ds",
242+
batch_size=1,
243+
batch_concurrency=4,
244+
),
245+
columns={"id": "int", "ds": "date"},
246+
start="2020-01-01",
247+
end="2020-01-10",
248+
)
249+
def execute(context: ExecutionContext, start: datetime, **kwargs) -> DataFrame:
250+
if snowpark := context.snowpark:
251+
return snowpark.create_dataframe([(start.day, start.date())], schema=["id", "ds"])
252+
253+
raise ValueError("Snowpark not present!")
254+
255+
m = model.get_registry()[table.sql().lower()].model(
256+
module_path=Path("."), path=Path("."), dialect="snowflake"
257+
)
258+
259+
sqlmesh = ctx.create_context()
260+
261+
# verify that we are actually running in multithreaded mode
262+
assert sqlmesh.concurrent_tasks > 1
263+
assert ctx.engine_adapter._multithreaded
264+
265+
sqlmesh.upsert_model(m)
266+
267+
plan = sqlmesh.plan(auto_apply=True)
268+
269+
assert len(plan.new_snapshots) == 1
270+
271+
query = exp.select("*").from_(table)
272+
df = ctx.engine_adapter.fetchdf(query, quote_identifiers=True)
273+
assert len(df) == 10

0 commit comments

Comments
 (0)