Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 16403c3

Browse files
committed
Merge remote-tracking branch 'madkinsz/example/instance-safe' into fix-transaction-contextvar
2 parents 460f72e + 2d4554d commit 16403c3

1 file changed

Lines changed: 26 additions & 49 deletions

File tree

databases/core.py

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import functools
44
import logging
55
import typing
6-
from contextvars import ContextVar
76
from types import TracebackType
8-
from typing import Dict, Tuple
97
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
108

119
from sqlalchemy import text
@@ -35,21 +33,6 @@
3533

3634
logger = logging.getLogger("databases")
3735

38-
# Connections are stored as task-local state, but care must be taken to ensure
39-
# that two database instances in the same task do not overwrite each other's connections.
40-
# For this reason, the dict key comprises the database instance and the current task.
41-
_connection_contextmap: ContextVar[
42-
Dict[Tuple["Database", asyncio.Task], "Connection"]
43-
] = ContextVar("databases:Connection")
44-
45-
46-
def _get_connection_contextmap() -> Dict[Tuple["Database", asyncio.Task], "Connection"]:
47-
connections = _connection_contextmap.get(None)
48-
if connections is None:
49-
connections = {}
50-
_connection_contextmap.set(connections)
51-
return connections
52-
5336

5437
class Database:
5538
SUPPORTED_BACKENDS = {
@@ -79,6 +62,9 @@ def __init__(
7962
assert issubclass(backend_cls, DatabaseBackend)
8063
self._backend = backend_cls(self.url, **self.options)
8164

65+
# Connections are stored per asyncio task
66+
self._connections: typing.Dict[typing.Optional[asyncio.Task], Connection] = {}
67+
8268
# When `force_rollback=True` is used, we use a single global
8369
# connection, within a transaction that always rolls back.
8470
self._global_connection: typing.Optional[Connection] = None
@@ -126,11 +112,10 @@ async def disconnect(self) -> None:
126112
self._global_transaction = None
127113
self._global_connection = None
128114
else:
129-
task = asyncio.current_task()
130-
assert task is not None, "Not running in an asyncio task"
131-
connections = _get_connection_contextmap()
132-
if (self, task) in connections:
133-
del connections[self, task]
115+
current_task = asyncio.current_task()
116+
assert current_task is not None, "No currently running task"
117+
if current_task in self._connections:
118+
del self._connections[current_task]
134119

135120
await self._backend.disconnect()
136121
logger.info(
@@ -204,13 +189,12 @@ def connection(self) -> "Connection":
204189
if self._global_connection is not None:
205190
return self._global_connection
206191

207-
task = asyncio.current_task()
208-
assert task is not None, "Not running in an asyncio task"
209-
connections = _get_connection_contextmap()
210-
if (self, task) not in connections:
211-
connections[self, task] = Connection(self._backend)
192+
current_task = asyncio.current_task()
193+
assert current_task is not None, "No currently running task"
194+
if current_task not in self._connections:
195+
self._connections[current_task] = Connection(self._backend)
212196

213-
return connections[self, task]
197+
return self._connections[current_task]
214198

215199
def transaction(
216200
self, *, force_rollback: bool = False, **kwargs: typing.Any
@@ -351,19 +335,6 @@ def _build_query(
351335

352336
_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
353337

354-
_transaction_contextmap: ContextVar[
355-
Dict["Transaction", TransactionBackend]
356-
] = ContextVar("databases:Transactions")
357-
358-
359-
def _get_transaction_contextmap() -> Dict["Transaction", TransactionBackend]:
360-
transactions = _transaction_contextmap.get(None)
361-
if transactions is None:
362-
transactions = {}
363-
_transaction_contextmap.set(transactions)
364-
365-
return transactions
366-
367338

368339
class Transaction:
369340
def __init__(
@@ -376,6 +347,11 @@ def __init__(
376347
self._force_rollback = force_rollback
377348
self._extra_options = kwargs
378349

350+
# Transactions are stored per asyncio task
351+
self._transactions: typing.Dict[
352+
typing.Optional[asyncio.Task], TransactionBackend
353+
] = {}
354+
379355
async def __aenter__(self) -> "Transaction":
380356
"""
381357
Called when entering `async with database.transaction()`
@@ -417,9 +393,10 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
417393

418394
async def start(self) -> "Transaction":
419395
connection = self._connection_callable()
396+
current_task = asyncio.current_task()
397+
assert current_task is not None, "No currently running task"
420398
transaction = connection._connection.transaction()
421-
transactions = _get_transaction_contextmap()
422-
transactions[self] = transaction
399+
self._transactions[current_task] = transaction
423400
async with connection._transaction_lock:
424401
is_root = not connection._transaction_stack
425402
await connection.__aenter__()
@@ -429,27 +406,27 @@ async def start(self) -> "Transaction":
429406

430407
async def commit(self) -> None:
431408
connection = self._connection_callable()
432-
transactions = _get_transaction_contextmap()
433-
transaction = transactions.get(self, None)
409+
current_task = asyncio.current_task()
410+
transaction = self._transactions.get(current_task, None)
434411
assert transaction is not None, "Transaction not found in current task"
435412
async with connection._transaction_lock:
436413
assert connection._transaction_stack[-1] is self
437414
connection._transaction_stack.pop()
438415
await transaction.commit()
439416
await connection.__aexit__()
440-
del transactions[self]
417+
del self._transactions[current_task]
441418

442419
async def rollback(self) -> None:
443420
connection = self._connection_callable()
444-
transactions = _get_transaction_contextmap()
445-
transaction = transactions.get(self, None)
421+
current_task = asyncio.current_task()
422+
transaction = self._transactions.get(current_task, None)
446423
assert transaction is not None, "Transaction not found in current task"
447424
async with connection._transaction_lock:
448425
assert connection._transaction_stack[-1] is self
449426
connection._transaction_stack.pop()
450427
await transaction.rollback()
451428
await connection.__aexit__()
452-
del transactions[self]
429+
del self._transactions[current_task]
453430

454431

455432
class _EmptyNetloc(str):

0 commit comments

Comments
 (0)