Skip to content

Commit 25a1121

Browse files
committed
Fix: Properly support engines that can share a single connection instance across threads
1 parent 561e4fd commit 25a1121

3 files changed

Lines changed: 29 additions & 40 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ class ConnectionConfig(abc.ABC, BaseConfig):
5151
pre_ping: bool
5252
pretty_sql: bool = False
5353

54+
# Whether to share a single connection across threads or create a new connection per thread.
55+
shared_connection: t.ClassVar[bool] = False
56+
5457
@property
5558
@abc.abstractmethod
5659
def _connection_kwargs_keys(self) -> t.Set[str]:
@@ -94,13 +97,21 @@ def is_forbidden_for_state_sync(self) -> bool:
9497
@property
9598
def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]:
9699
"""A function that is called to return a connection object for the given Engine Adapter"""
97-
return partial(
100+
factory = partial(
98101
self._connection_factory,
99102
**{
100103
**self._static_connection_kwargs,
101104
**{k: v for k, v in self.dict().items() if k in self._connection_kwargs_keys},
102105
},
103106
)
107+
if self.shared_connection:
108+
# Make sure that a single connection is created and returned
109+
@lru_cache
110+
def _cached_connection() -> t.Any:
111+
return factory()
112+
113+
return _cached_connection
114+
return factory
104115

105116
def connection_validator(self) -> t.Callable[[], None]:
106117
"""A function that validates the connection configuration"""
@@ -116,6 +127,7 @@ def create_engine_adapter(self, register_comments_override: bool = False) -> Eng
116127
register_comments=register_comments_override or self.register_comments,
117128
pre_ping=self.pre_ping,
118129
pretty_sql=self.pretty_sql,
130+
shared_connection=self.shared_connection,
119131
**self._extra_engine_config,
120132
)
121133

@@ -182,6 +194,8 @@ class BaseDuckDBConnectionConfig(ConnectionConfig):
182194

183195
token: t.Optional[str] = None
184196

197+
shared_connection: t.ClassVar[bool] = True
198+
185199
_data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {}
186200

187201
@model_validator(mode="before")
@@ -212,43 +226,6 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
212226
def _connection_factory(self) -> t.Callable:
213227
import duckdb
214228

215-
if self.concurrent_tasks > 1:
216-
# ensures a single connection instance is used across threads rather than a new connection being established per thread
217-
# this is in line with https://duckdb.org/docs/guides/python/multiple_threads.html
218-
# the important thing is that the *cursor*'s are per thread, but the connection should be shared
219-
@lru_cache
220-
def _factory(*args: t.Any, **kwargs: t.Any) -> t.Any:
221-
class ConnWrapper:
222-
def __init__(self, conn: duckdb.DuckDBPyConnection):
223-
self.conn = conn
224-
225-
def __getattr__(self, attr: str) -> t.Any:
226-
return getattr(self.conn, attr)
227-
228-
def close(self) -> None:
229-
# This overrides conn.close() to be a no-op to work with ThreadLocalConnectionPool which assumes that a new connection should
230-
# be created per thread. However, DuckDB expects the same connection instance to be shared across threads. There is a pattern
231-
# in the SQLMesh codebase that `EngineAdapter.recycle()` is called after doing things like merging intervals. This in turn causes
232-
# `ThreadLocalConnectionPool.close_all(exclude_calling_thread=True)` to be called.
233-
#
234-
# The problem with sharing a connection across threads and then allowing it to be closed for every thread except the current one
235-
# is that it gets closed for the current one too because its shared. This causes any ":memory:" databases to be discarded.
236-
# ":memory:" databases are convienient and are used heavily in our test suite amongst other things.
237-
#
238-
# Ok, so why not have a connection per thread as is the default for ThreadLocalConnectionPool? Two reasons:
239-
# - It makes any ":memory:" databases unique to that thread. So if one thread creates tables, another thread cant see them
240-
# - If you use local files instead (eg point each connection to the same db file) then all the connection instances
241-
# fight over locks to the same file and performance tanks heavily
242-
#
243-
# From what I can tell, DuckDB expects the single process reading / writing the database from multiple
244-
# threads to /share the same connection/ and just use thread-local cursors. In order to support ":memory:" databases
245-
# and remove lock contention, the connection needs to live for the life of the application and not be closed
246-
pass
247-
248-
return ConnWrapper(duckdb.connect(*args, **kwargs))
249-
250-
return _factory
251-
252229
return duckdb.connect
253230

254231
@property

sqlmesh/core/engine_adapter/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,18 @@ def __init__(
119119
register_comments: bool = True,
120120
pre_ping: bool = False,
121121
pretty_sql: bool = False,
122+
shared_connection: bool = False,
122123
**kwargs: t.Any,
123124
):
124125
self.dialect = dialect.lower() or self.DIALECT
125126
self._connection_pool = (
126127
connection_factory_or_pool
127128
if isinstance(connection_factory_or_pool, ConnectionPool)
128129
else create_connection_pool(
129-
connection_factory_or_pool, multithreaded, cursor_init=cursor_init
130+
connection_factory_or_pool,
131+
multithreaded,
132+
shared_connection=shared_connection,
133+
cursor_init=cursor_init,
130134
)
131135
)
132136
self._sql_gen_kwargs = sql_gen_kwargs or {}

sqlmesh/utils/connection_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class ThreadLocalConnectionPool(_TransactionManagementMixin):
115115
def __init__(
116116
self,
117117
connection_factory: t.Callable[[], t.Any],
118+
shared_connection: bool = False,
118119
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
119120
):
120121
self._connection_factory = connection_factory
@@ -125,6 +126,7 @@ def __init__(
125126
self._thread_connections_lock = Lock()
126127
self._thread_cursors_lock = Lock()
127128
self._thread_transactions_lock = Lock()
129+
self._shared_connection = shared_connection
128130
self._cursor_init = cursor_init
129131

130132
def get_cursor(self) -> t.Any:
@@ -187,6 +189,9 @@ def close(self) -> None:
187189
self._thread_attributes.pop(thread_id, None)
188190

189191
def close_all(self, exclude_calling_thread: bool = False) -> None:
192+
if exclude_calling_thread and self._shared_connection:
193+
return
194+
190195
calling_thread_id = get_ident()
191196
with self._thread_cursors_lock, self._thread_connections_lock:
192197
for thread_id, connection in self._thread_connections.copy().items():
@@ -269,10 +274,13 @@ def close_all(self, exclude_calling_thread: bool = False) -> None:
269274
def create_connection_pool(
270275
connection_factory: t.Callable[[], t.Any],
271276
multithreaded: bool,
277+
shared_connection: bool = False,
272278
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
273279
) -> ConnectionPool:
274280
return (
275-
ThreadLocalConnectionPool(connection_factory, cursor_init=cursor_init)
281+
ThreadLocalConnectionPool(
282+
connection_factory, shared_connection=shared_connection, cursor_init=cursor_init
283+
)
276284
if multithreaded
277285
else SingletonConnectionPool(connection_factory, cursor_init=cursor_init)
278286
)

0 commit comments

Comments
 (0)