|
3 | 3 | import contextlib |
4 | 4 | import logging |
5 | 5 | import typing as t |
6 | | -import threading |
7 | 6 |
|
8 | 7 | import pandas as pd |
9 | 8 | from pandas.api.types import is_datetime64_any_dtype # type: ignore |
@@ -69,10 +68,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi |
69 | 68 | }, |
70 | 69 | ) |
71 | 70 | MANAGED_TABLE_KIND = "DYNAMIC TABLE" |
72 | | - |
73 | | - def __init__(self, *args: t.Any, **kwargs: t.Any): |
74 | | - super().__init__(*args, **kwargs) |
75 | | - self._snowpark_threadlocal = threading.local() |
| 71 | + SNOWPARK = "snowpark" |
76 | 72 |
|
77 | 73 | @contextlib.contextmanager |
78 | 74 | def session(self, properties: SessionProperties) -> t.Iterator[None]: |
@@ -109,15 +105,16 @@ def _current_warehouse(self) -> exp.Identifier: |
109 | 105 | @property |
110 | 106 | def snowpark(self) -> t.Optional[SnowparkSession]: |
111 | 107 | if snowpark: |
112 | | - # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other |
113 | | - # The sessions are cleaned up when close() is called |
114 | | - if not hasattr(self._snowpark_threadlocal, "session"): |
| 108 | + if not self._connection_pool.get_attribute(self.SNOWPARK): |
| 109 | + # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other |
| 110 | + # The sessions are cleaned up when close() is called |
115 | 111 | new_session = snowpark.Session.builder.configs( |
116 | 112 | {"connection": self._connection_pool.get()} |
117 | 113 | ).create() |
118 | | - self._snowpark_threadlocal.session = new_session |
| 114 | + self._connection_pool.set_attribute(self.SNOWPARK, new_session) |
| 115 | + |
| 116 | + return self._connection_pool.get_attribute(self.SNOWPARK) |
119 | 117 |
|
120 | | - return self._snowpark_threadlocal.session |
121 | 118 | return None |
122 | 119 |
|
123 | 120 | @property |
@@ -596,14 +593,9 @@ def _columns_to_types( |
596 | 593 |
|
597 | 594 | return super()._columns_to_types(query_or_df, columns_to_types) |
598 | 595 |
|
599 | | - def _cleanup_snowpark(self) -> None: |
600 | | - if hasattr(self._snowpark_threadlocal, "session") and ( |
601 | | - session := self._snowpark_threadlocal.session |
602 | | - ): |
603 | | - session.close() |
604 | | - delattr(self._snowpark_threadlocal, "session") |
605 | | - |
606 | 596 | def close(self) -> t.Any: |
607 | | - self._cleanup_snowpark() |
| 597 | + if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK): |
| 598 | + snowpark_session.close() # type: ignore |
| 599 | + self._connection_pool.set_attribute(self.SNOWPARK, None) |
608 | 600 |
|
609 | 601 | return super().close() |
0 commit comments