Skip to content

Commit ebf88ed

Browse files
committed
Fix(snowflake): Allow models that utilize Snowpark to execute concurrently
1 parent 5d7f21d commit ebf88ed

3 files changed

Lines changed: 74 additions & 7 deletions

File tree

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import logging
55
import typing as t
6+
import threading
67

78
import pandas as pd
89
from pandas.api.types import is_datetime64_any_dtype # type: ignore
@@ -69,6 +70,10 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
6970
)
7071
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
7172

73+
def __init__(self, *args: t.Any, **kwargs: t.Any):
74+
super().__init__(*args, **kwargs)
75+
self._snowpark_threadlocal = threading.local()
76+
7277
@contextlib.contextmanager
7378
def session(self, properties: SessionProperties) -> t.Iterator[None]:
7479
warehouse = properties.get("warehouse")
@@ -104,9 +109,15 @@ def _current_warehouse(self) -> exp.Identifier:
104109
@property
105110
def snowpark(self) -> t.Optional[SnowparkSession]:
106111
if snowpark:
107-
return snowpark.Session.builder.configs(
108-
{"connection": self._connection_pool.get()}
109-
).getOrCreate()
112+
# Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other
113+
# The sessions are cleaned up when close() is called
114+
if not hasattr(self._snowpark_threadlocal, "session"):
115+
new_session = snowpark.Session.builder.configs(
116+
{"connection": self._connection_pool.get()}
117+
).create()
118+
self._snowpark_threadlocal.session = new_session
119+
120+
return self._snowpark_threadlocal.session
110121
return None
111122

112123
@property
@@ -584,3 +595,15 @@ def _columns_to_types(
584595
return columns_to_types_from_dtypes(query_or_df.sample(n=1).to_pandas().dtypes.items())
585596

586597
return super()._columns_to_types(query_or_df, columns_to_types)
598+
599+
def _cleanup_snowpark(self) -> None:
600+
if hasattr(self._snowpark_threadlocal, "session") and (
601+
session := self._snowpark_threadlocal.session
602+
):
603+
session.close()
604+
delattr(self._snowpark_threadlocal, "session")
605+
606+
def close(self) -> t.Any:
607+
self._cleanup_snowpark()
608+
609+
return super().close()

tests/core/engine_adapter/integration/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,7 @@ def cleanup(self, ctx: t.Optional[Context] = None):
707707
schema_name=schema_name, ignore_if_not_exists=True, cascade=True
708708
)
709709

710-
if snowpark := self.engine_adapter.snowpark:
711-
# ensure that the next test gets a fresh Snowpark session
712-
snowpark.close()
710+
self.engine_adapter.close()
713711

714712
def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]:
715713
if not self._context:

tests/core/engine_adapter/integration/test_integration_snowflake.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from pytest import FixtureRequest
44
from sqlglot import exp
5+
from pathlib import Path
56
from sqlglot.optimizer.qualify_columns import quote_identifiers
67
from sqlglot.helper import seq_get
78
from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter
@@ -10,6 +11,9 @@
1011
from sqlmesh.core.model import SqlModel, load_sql_based_model
1112
from sqlmesh.core.plan import Plan
1213
from tests.core.engine_adapter.integration import TestContext
14+
from sqlmesh import model, ExecutionContext
15+
from sqlmesh.core.model import ModelKindName
16+
from datetime import datetime
1317

1418
from tests.core.engine_adapter.integration import (
1519
TestContext,
@@ -19,7 +23,9 @@
1923
)
2024

2125

22-
@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"])))
26+
@pytest.fixture(
27+
params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"], show_variant_in_test_id=False))
28+
)
2329
def ctx(
2430
request: FixtureRequest,
2531
create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]],
@@ -220,3 +226,43 @@ def test_create_iceberg_table(ctx: TestContext, engine_adapter: SnowflakeEngineA
220226
result = sqlmesh.plan(auto_apply=True)
221227

222228
assert len(result.new_snapshots) == 2
229+
230+
231+
def test_snowpark_concurrency(ctx: TestContext) -> None:
232+
from snowflake.snowpark import DataFrame
233+
234+
@model(
235+
name="my_model",
236+
kind=dict(
237+
name=ModelKindName.INCREMENTAL_BY_TIME_RANGE,
238+
time_column="ds",
239+
batch_size=1,
240+
batch_concurrency=4,
241+
),
242+
columns={"id": "int", "ds": "date"},
243+
start="2020-01-01",
244+
end="2020-01-10",
245+
)
246+
def execute(context: ExecutionContext, start: datetime, **kwargs) -> DataFrame:
247+
if snowpark := context.snowpark:
248+
return snowpark.create_dataframe([(start.day, start.date())], schema=["id", "ds"])
249+
250+
raise ValueError("Snowpark not present!")
251+
252+
m = model.get_registry()["my_model"].model(
253+
module_path=Path("."), path=Path("."), dialect="snowflake"
254+
)
255+
256+
sqlmesh = ctx.create_context()
257+
258+
# verify that we are actually running in multithreaded mode
259+
assert sqlmesh.concurrent_tasks > 1
260+
assert ctx.engine_adapter._multithreaded
261+
262+
sqlmesh.upsert_model(m)
263+
264+
plan = sqlmesh.plan(auto_apply=True)
265+
266+
assert len(plan.new_snapshots) == 1
267+
268+
# todo: read table result

0 commit comments

Comments
 (0)