Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 2d4554d

Browse files
committed
Update Connection and Transaction to be robust to concurrent use
1 parent b6eba5f commit 2d4554d

1 file changed

Lines changed: 33 additions & 26 deletions

File tree

databases/core.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sqlalchemy.sql import ClauseElement
1212

1313
from databases.importer import import_from_string
14-
from databases.interfaces import DatabaseBackend, Record
14+
from databases.interfaces import DatabaseBackend, Record, TransactionBackend
1515

1616
try: # pragma: no cover
1717
import click
@@ -63,8 +63,8 @@ def __init__(
6363
assert issubclass(backend_cls, DatabaseBackend)
6464
self._backend = backend_cls(self.url, **self.options)
6565

66-
# Connections are stored as task-local state.
67-
self._connection_context: ContextVar = ContextVar("connection_context")
66+
# Connections are stored per asyncio task
67+
self._connections: typing.Dict[typing.Optional[asyncio.Task], Connection]= {}
6868

6969
# When `force_rollback=True` is used, we use a single global
7070
# connection, within a transaction that always rolls back.
@@ -113,7 +113,7 @@ async def disconnect(self) -> None:
113113
self._global_transaction = None
114114
self._global_connection = None
115115
else:
116-
self._connection_context = ContextVar("connection_context")
116+
self._connections.pop(asyncio.current_task(), None)
117117

118118
await self._backend.disconnect()
119119
logger.info(
@@ -187,12 +187,12 @@ def connection(self) -> "Connection":
187187
if self._global_connection is not None:
188188
return self._global_connection
189189

190-
try:
191-
return self._connection_context.get()
192-
except LookupError:
190+
current_task = asyncio.current_task()
191+
if current_task not in self._connections:
193192
connection = Connection(self._backend)
194-
self._connection_context.set(connection)
195-
return connection
193+
self._connections[current_task] = connection
194+
195+
return self._connections[current_task]
196196

197197
def transaction(
198198
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -344,6 +344,9 @@ def __init__(
344344
self._connection_callable = connection_callable
345345
self._force_rollback = force_rollback
346346
self._extra_options = kwargs
347+
348+
# Transactions are stored per asyncio task
349+
self._transactions: typing.Dict[typing.Optional[asyncio.Task], "TransactionBackend"]= {}
347350

348351
async def __aenter__(self) -> "Transaction":
349352
"""
@@ -385,31 +388,35 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
385388
return wrapper # type: ignore
386389

387390
async def start(self) -> "Transaction":
388-
self._connection = self._connection_callable()
389-
self._transaction = self._connection._connection.transaction()
391+
connection = self._connection_callable()
392+
transaction = self._transactions[asyncio.current_task()] = connection._connection.transaction()
390393

391-
async with self._connection._transaction_lock:
392-
is_root = not self._connection._transaction_stack
393-
await self._connection.__aenter__()
394-
await self._transaction.start(
394+
async with connection._transaction_lock:
395+
is_root = not connection._transaction_stack
396+
await connection.__aenter__()
397+
await transaction.start(
395398
is_root=is_root, extra_options=self._extra_options
396399
)
397-
self._connection._transaction_stack.append(self)
400+
connection._transaction_stack.append(self)
398401
return self
399402

400403
async def commit(self) -> None:
401-
async with self._connection._transaction_lock:
402-
assert self._connection._transaction_stack[-1] is self
403-
self._connection._transaction_stack.pop()
404-
await self._transaction.commit()
405-
await self._connection.__aexit__()
404+
connection = self._connection_callable()
405+
transaction = self._transactions[asyncio.current_task()]
406+
async with connection._transaction_lock:
407+
assert connection._transaction_stack[-1] is self
408+
connection._transaction_stack.pop()
409+
await transaction.commit()
410+
await connection.__aexit__()
406411

407412
async def rollback(self) -> None:
408-
async with self._connection._transaction_lock:
409-
assert self._connection._transaction_stack[-1] is self
410-
self._connection._transaction_stack.pop()
411-
await self._transaction.rollback()
412-
await self._connection.__aexit__()
413+
connection = self._connection_callable()
414+
transaction = self._transactions[asyncio.current_task()]
415+
async with connection._transaction_lock:
416+
assert connection._transaction_stack[-1] is self
417+
connection._transaction_stack.pop()
418+
await transaction.rollback()
419+
await connection.__aexit__()
413420

414421

415422
class _EmptyNetloc(str):

0 commit comments

Comments
 (0)