Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 02a9acb

Browse files
committed
feat: reimplement concurrency system with contextvar and weakmap
1 parent 1d4896f commit 02a9acb

File tree

1 file changed

+79
-47
lines changed

1 file changed

+79
-47
lines changed

databases/core.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import contextlib
3+
from contextvars import ContextVar
34
import functools
45
import logging
56
import typing
67
from types import TracebackType
78
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
8-
9+
import weakref
910
from sqlalchemy import text
1011
from sqlalchemy.sql import ClauseElement
1112

@@ -33,6 +34,14 @@
3334

3435
logger = logging.getLogger("databases")
3536

37+
_ACTIVE_CONNECTIONS: ContextVar[
38+
typing.Optional[weakref.WeakKeyDictionary["Database", "Connection"]]
39+
] = ContextVar("databases:open_connections", default=None)
40+
41+
_ACTIVE_TRANSACTIONS: ContextVar[
42+
typing.Optional[weakref.WeakKeyDictionary["Transaction", "TransactionBackend"]]
43+
] = ContextVar("databases:open_transactions", default=None)
44+
3645

3746
class Database:
3847
SUPPORTED_BACKENDS = {
@@ -62,14 +71,31 @@ def __init__(
6271
assert issubclass(backend_cls, DatabaseBackend)
6372
self._backend = backend_cls(self.url, **self.options)
6473

65-
# Connections are stored per asyncio task
66-
self._connections: typing.Dict[asyncio.Task, Connection] = {}
67-
6874
# When `force_rollback=True` is used, we use a single global
6975
# connection, within a transaction that always rolls back.
7076
self._global_connection: typing.Optional[Connection] = None
7177
self._global_transaction: typing.Optional[Transaction] = None
7278

79+
@property
80+
def _connection(self) -> typing.Optional["Connection"]:
81+
connections = _ACTIVE_CONNECTIONS.get()
82+
if connections is None:
83+
return None
84+
85+
return connections.get(self, None)
86+
87+
@_connection.setter
88+
def _connection(
89+
self, connection: typing.Optional["Connection"]
90+
) -> typing.Optional["Connection"]:
91+
connections = _ACTIVE_CONNECTIONS.get()
92+
if connections is None:
93+
connections = weakref.WeakKeyDictionary()
94+
_ACTIVE_CONNECTIONS.set(connections)
95+
96+
connections[self] = connection
97+
return connections[self]
98+
7399
async def connect(self) -> None:
74100
"""
75101
Establish the connection pool.
@@ -112,10 +138,7 @@ async def disconnect(self) -> None:
112138
self._global_transaction = None
113139
self._global_connection = None
114140
else:
115-
current_task = asyncio.current_task()
116-
assert current_task is not None, "No currently running task"
117-
if current_task in self._connections:
118-
del self._connections[current_task]
141+
self._connection = None
119142

120143
await self._backend.disconnect()
121144
logger.info(
@@ -189,12 +212,10 @@ def connection(self) -> "Connection":
189212
if self._global_connection is not None:
190213
return self._global_connection
191214

192-
current_task = asyncio.current_task()
193-
assert current_task is not None, "No currently running task"
194-
if current_task not in self._connections:
195-
self._connections[current_task] = Connection(self._backend)
215+
if not self._connection:
216+
self._connection = Connection(self._backend)
196217

197-
return self._connections[current_task]
218+
return self._connection
198219

199220
def transaction(
200221
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -347,10 +368,30 @@ def __init__(
347368
self._force_rollback = force_rollback
348369
self._extra_options = kwargs
349370

350-
# Transactions are stored per asyncio task
351-
self._transactions: typing.Dict[
352-
typing.Optional[asyncio.Task], TransactionBackend
353-
] = {}
371+
@property
372+
def _connection(self) -> "Connection":
373+
# Returns the same connection if called multiple times
374+
return self._connection_callable()
375+
376+
@property
377+
def _transaction(self) -> typing.Optional["TransactionBackend"]:
378+
transactions = _ACTIVE_TRANSACTIONS.get()
379+
if transactions is None:
380+
return None
381+
382+
return transactions.get(self, None)
383+
384+
@_transaction.setter
385+
def _transaction(
386+
self, transaction: typing.Optional["TransactionBackend"]
387+
) -> typing.Optional["TransactionBackend"]:
388+
transactions = _ACTIVE_TRANSACTIONS.get()
389+
if transactions is None:
390+
transactions = weakref.WeakKeyDictionary()
391+
_ACTIVE_TRANSACTIONS.set(transactions)
392+
393+
transactions[self] = transaction
394+
return transactions[self]
354395

355396
async def __aenter__(self) -> "Transaction":
356397
"""
@@ -392,41 +433,32 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
392433
return wrapper # type: ignore
393434

394435
async def start(self) -> "Transaction":
395-
connection = self._connection_callable()
396-
current_task = asyncio.current_task()
397-
assert current_task is not None, "No currently running task"
398-
transaction = connection._connection.transaction()
399-
self._transactions[current_task] = transaction
400-
async with connection._transaction_lock:
401-
is_root = not connection._transaction_stack
402-
await connection.__aenter__()
403-
await transaction.start(is_root=is_root, extra_options=self._extra_options)
404-
connection._transaction_stack.append(self)
436+
self._transaction = self._connection._connection.transaction()
437+
438+
async with self._connection._transaction_lock:
439+
is_root = not self._connection._transaction_stack
440+
await self._connection.__aenter__()
441+
await self._transaction.start(
442+
is_root=is_root, extra_options=self._extra_options
443+
)
444+
self._connection._transaction_stack.append(self)
405445
return self
406446

407447
async def commit(self) -> None:
408-
connection = self._connection_callable()
409-
current_task = asyncio.current_task()
410-
transaction = self._transactions.get(current_task, None)
411-
assert transaction is not None, "Transaction not found in current task"
412-
async with connection._transaction_lock:
413-
assert connection._transaction_stack[-1] is self
414-
connection._transaction_stack.pop()
415-
await transaction.commit()
416-
await connection.__aexit__()
417-
del self._transactions[current_task]
448+
async with self._connection._transaction_lock:
449+
assert self._connection._transaction_stack[-1] is self
450+
self._connection._transaction_stack.pop()
451+
await self._transaction.commit()
452+
await self._connection.__aexit__()
453+
self._transaction = None
418454

419455
async def rollback(self) -> None:
420-
connection = self._connection_callable()
421-
current_task = asyncio.current_task()
422-
transaction = self._transactions.get(current_task, None)
423-
assert transaction is not None, "Transaction not found in current task"
424-
async with connection._transaction_lock:
425-
assert connection._transaction_stack[-1] is self
426-
connection._transaction_stack.pop()
427-
await transaction.rollback()
428-
await connection.__aexit__()
429-
del self._transactions[current_task]
456+
async with self._connection._transaction_lock:
457+
assert self._connection._transaction_stack[-1] is self
458+
self._connection._transaction_stack.pop()
459+
await self._transaction.rollback()
460+
await self._connection.__aexit__()
461+
self._transaction = None
430462

431463

432464
class _EmptyNetloc(str):

0 commit comments

Comments
 (0)