|
1 | 1 | import asyncio |
2 | 2 | import contextlib |
| 3 | +from contextvars import ContextVar |
3 | 4 | import functools |
4 | 5 | import logging |
5 | 6 | import typing |
6 | 7 | from types import TracebackType |
7 | 8 | from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit |
8 | | - |
| 9 | +import weakref |
9 | 10 | from sqlalchemy import text |
10 | 11 | from sqlalchemy.sql import ClauseElement |
11 | 12 |
|
|
33 | 34 |
|
34 | 35 | logger = logging.getLogger("databases") |
35 | 36 |
|
| 37 | +_ACTIVE_CONNECTIONS: ContextVar[ |
| 38 | + typing.Optional[weakref.WeakKeyDictionary["Database", "Connection"]] |
| 39 | +] = ContextVar("databases:open_connections", default=None) |
| 40 | + |
| 41 | +_ACTIVE_TRANSACTIONS: ContextVar[ |
| 42 | + typing.Optional[weakref.WeakKeyDictionary["Transaction", "TransactionBackend"]] |
| 43 | +] = ContextVar("databases:open_transactions", default=None) |
| 44 | + |
36 | 45 |
|
37 | 46 | class Database: |
38 | 47 | SUPPORTED_BACKENDS = { |
@@ -62,14 +71,31 @@ def __init__( |
62 | 71 | assert issubclass(backend_cls, DatabaseBackend) |
63 | 72 | self._backend = backend_cls(self.url, **self.options) |
64 | 73 |
|
65 | | - # Connections are stored per asyncio task |
66 | | - self._connections: typing.Dict[asyncio.Task, Connection] = {} |
67 | | - |
68 | 74 | # When `force_rollback=True` is used, we use a single global |
69 | 75 | # connection, within a transaction that always rolls back. |
70 | 76 | self._global_connection: typing.Optional[Connection] = None |
71 | 77 | self._global_transaction: typing.Optional[Transaction] = None |
72 | 78 |
|
| 79 | + @property |
| 80 | + def _connection(self) -> typing.Optional["Connection"]: |
| 81 | + connections = _ACTIVE_CONNECTIONS.get() |
| 82 | + if connections is None: |
| 83 | + return None |
| 84 | + |
| 85 | + return connections.get(self, None) |
| 86 | + |
| 87 | + @_connection.setter |
| 88 | + def _connection( |
| 89 | + self, connection: typing.Optional["Connection"] |
| 90 | + ) -> typing.Optional["Connection"]: |
| 91 | + connections = _ACTIVE_CONNECTIONS.get() |
| 92 | + if connections is None: |
| 93 | + connections = weakref.WeakKeyDictionary() |
| 94 | + _ACTIVE_CONNECTIONS.set(connections) |
| 95 | + |
| 96 | + connections[self] = connection |
| 97 | + return connections[self] |
| 98 | + |
73 | 99 | async def connect(self) -> None: |
74 | 100 | """ |
75 | 101 | Establish the connection pool. |
@@ -112,10 +138,7 @@ async def disconnect(self) -> None: |
112 | 138 | self._global_transaction = None |
113 | 139 | self._global_connection = None |
114 | 140 | else: |
115 | | - current_task = asyncio.current_task() |
116 | | - assert current_task is not None, "No currently running task" |
117 | | - if current_task in self._connections: |
118 | | - del self._connections[current_task] |
| 141 | + self._connection = None |
119 | 142 |
|
120 | 143 | await self._backend.disconnect() |
121 | 144 | logger.info( |
@@ -189,12 +212,10 @@ def connection(self) -> "Connection": |
189 | 212 | if self._global_connection is not None: |
190 | 213 | return self._global_connection |
191 | 214 |
|
192 | | - current_task = asyncio.current_task() |
193 | | - assert current_task is not None, "No currently running task" |
194 | | - if current_task not in self._connections: |
195 | | - self._connections[current_task] = Connection(self._backend) |
| 215 | + if not self._connection: |
| 216 | + self._connection = Connection(self._backend) |
196 | 217 |
|
197 | | - return self._connections[current_task] |
| 218 | + return self._connection |
198 | 219 |
|
199 | 220 | def transaction( |
200 | 221 | self, *, force_rollback: bool = False, **kwargs: typing.Any |
@@ -347,10 +368,30 @@ def __init__( |
347 | 368 | self._force_rollback = force_rollback |
348 | 369 | self._extra_options = kwargs |
349 | 370 |
|
350 | | - # Transactions are stored per asyncio task |
351 | | - self._transactions: typing.Dict[ |
352 | | - typing.Optional[asyncio.Task], TransactionBackend |
353 | | - ] = {} |
| 371 | + @property |
| 372 | + def _connection(self) -> "Connection": |
| 373 | + # Returns the same connection if called multiple times |
| 374 | + return self._connection_callable() |
| 375 | + |
| 376 | + @property |
| 377 | + def _transaction(self) -> typing.Optional["TransactionBackend"]: |
| 378 | + transactions = _ACTIVE_TRANSACTIONS.get() |
| 379 | + if transactions is None: |
| 380 | + return None |
| 381 | + |
| 382 | + return transactions.get(self, None) |
| 383 | + |
| 384 | + @_transaction.setter |
| 385 | + def _transaction( |
| 386 | + self, transaction: typing.Optional["TransactionBackend"] |
| 387 | + ) -> typing.Optional["TransactionBackend"]: |
| 388 | + transactions = _ACTIVE_TRANSACTIONS.get() |
| 389 | + if transactions is None: |
| 390 | + transactions = weakref.WeakKeyDictionary() |
| 391 | + _ACTIVE_TRANSACTIONS.set(transactions) |
| 392 | + |
| 393 | + transactions[self] = transaction |
| 394 | + return transactions[self] |
354 | 395 |
|
355 | 396 | async def __aenter__(self) -> "Transaction": |
356 | 397 | """ |
@@ -392,41 +433,32 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: |
392 | 433 | return wrapper # type: ignore |
393 | 434 |
|
394 | 435 | async def start(self) -> "Transaction": |
395 | | - connection = self._connection_callable() |
396 | | - current_task = asyncio.current_task() |
397 | | - assert current_task is not None, "No currently running task" |
398 | | - transaction = connection._connection.transaction() |
399 | | - self._transactions[current_task] = transaction |
400 | | - async with connection._transaction_lock: |
401 | | - is_root = not connection._transaction_stack |
402 | | - await connection.__aenter__() |
403 | | - await transaction.start(is_root=is_root, extra_options=self._extra_options) |
404 | | - connection._transaction_stack.append(self) |
| 436 | + self._transaction = self._connection._connection.transaction() |
| 437 | + |
| 438 | + async with self._connection._transaction_lock: |
| 439 | + is_root = not self._connection._transaction_stack |
| 440 | + await self._connection.__aenter__() |
| 441 | + await self._transaction.start( |
| 442 | + is_root=is_root, extra_options=self._extra_options |
| 443 | + ) |
| 444 | + self._connection._transaction_stack.append(self) |
405 | 445 | return self |
406 | 446 |
|
407 | 447 | async def commit(self) -> None: |
408 | | - connection = self._connection_callable() |
409 | | - current_task = asyncio.current_task() |
410 | | - transaction = self._transactions.get(current_task, None) |
411 | | - assert transaction is not None, "Transaction not found in current task" |
412 | | - async with connection._transaction_lock: |
413 | | - assert connection._transaction_stack[-1] is self |
414 | | - connection._transaction_stack.pop() |
415 | | - await transaction.commit() |
416 | | - await connection.__aexit__() |
417 | | - del self._transactions[current_task] |
| 448 | + async with self._connection._transaction_lock: |
| 449 | + assert self._connection._transaction_stack[-1] is self |
| 450 | + self._connection._transaction_stack.pop() |
| 451 | + await self._transaction.commit() |
| 452 | + await self._connection.__aexit__() |
| 453 | + self._transaction = None |
418 | 454 |
|
419 | 455 | async def rollback(self) -> None: |
420 | | - connection = self._connection_callable() |
421 | | - current_task = asyncio.current_task() |
422 | | - transaction = self._transactions.get(current_task, None) |
423 | | - assert transaction is not None, "Transaction not found in current task" |
424 | | - async with connection._transaction_lock: |
425 | | - assert connection._transaction_stack[-1] is self |
426 | | - connection._transaction_stack.pop() |
427 | | - await transaction.rollback() |
428 | | - await connection.__aexit__() |
429 | | - del self._transactions[current_task] |
| 456 | + async with self._connection._transaction_lock: |
| 457 | + assert self._connection._transaction_stack[-1] is self |
| 458 | + self._connection._transaction_stack.pop() |
| 459 | + await self._transaction.rollback() |
| 460 | + await self._connection.__aexit__() |
| 461 | + self._transaction = None |
430 | 462 |
|
431 | 463 |
|
432 | 464 | class _EmptyNetloc(str): |
|
0 commit comments