|
11 | 11 | from sqlalchemy.sql import ClauseElement |
12 | 12 |
|
13 | 13 | from databases.importer import import_from_string |
14 | | -from databases.interfaces import DatabaseBackend, Record |
| 14 | +from databases.interfaces import DatabaseBackend, Record, TransactionBackend |
15 | 15 |
|
16 | 16 | try: # pragma: no cover |
17 | 17 | import click |
@@ -63,8 +63,8 @@ def __init__( |
63 | 63 | assert issubclass(backend_cls, DatabaseBackend) |
64 | 64 | self._backend = backend_cls(self.url, **self.options) |
65 | 65 |
|
66 | | - # Connections are stored as task-local state. |
67 | | - self._connection_context: ContextVar = ContextVar("connection_context") |
| 66 | + # Connections are stored per asyncio task |
| 67 | + self._connections: typing.Dict[typing.Optional[asyncio.Task], Connection]= {} |
68 | 68 |
|
69 | 69 | # When `force_rollback=True` is used, we use a single global |
70 | 70 | # connection, within a transaction that always rolls back. |
@@ -113,7 +113,7 @@ async def disconnect(self) -> None: |
113 | 113 | self._global_transaction = None |
114 | 114 | self._global_connection = None |
115 | 115 | else: |
116 | | - self._connection_context = ContextVar("connection_context") |
| 116 | + self._connections.pop(asyncio.current_task(), None) |
117 | 117 |
|
118 | 118 | await self._backend.disconnect() |
119 | 119 | logger.info( |
@@ -187,12 +187,12 @@ def connection(self) -> "Connection": |
187 | 187 | if self._global_connection is not None: |
188 | 188 | return self._global_connection |
189 | 189 |
|
190 | | - try: |
191 | | - return self._connection_context.get() |
192 | | - except LookupError: |
| 190 | + current_task = asyncio.current_task() |
| 191 | + if current_task not in self._connections: |
193 | 192 | connection = Connection(self._backend) |
194 | | - self._connection_context.set(connection) |
195 | | - return connection |
| 193 | + self._connections[current_task] = connection |
| 194 | + |
| 195 | + return self._connections[current_task] |
196 | 196 |
|
197 | 197 | def transaction( |
198 | 198 | self, *, force_rollback: bool = False, **kwargs: typing.Any |
@@ -344,6 +344,9 @@ def __init__( |
344 | 344 | self._connection_callable = connection_callable |
345 | 345 | self._force_rollback = force_rollback |
346 | 346 | self._extra_options = kwargs |
| 347 | + |
| 348 | + # Transactions are stored per asyncio task |
| 349 | + self._transactions: typing.Dict[typing.Optional[asyncio.Task], "TransactionBackend"]= {} |
347 | 350 |
|
348 | 351 | async def __aenter__(self) -> "Transaction": |
349 | 352 | """ |
@@ -385,31 +388,35 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: |
385 | 388 | return wrapper # type: ignore |
386 | 389 |
|
387 | 390 | async def start(self) -> "Transaction": |
388 | | - self._connection = self._connection_callable() |
389 | | - self._transaction = self._connection._connection.transaction() |
| 391 | + connection = self._connection_callable() |
| 392 | + transaction = self._transactions[asyncio.current_task()] = connection._connection.transaction() |
390 | 393 |
|
391 | | - async with self._connection._transaction_lock: |
392 | | - is_root = not self._connection._transaction_stack |
393 | | - await self._connection.__aenter__() |
394 | | - await self._transaction.start( |
| 394 | + async with connection._transaction_lock: |
| 395 | + is_root = not connection._transaction_stack |
| 396 | + await connection.__aenter__() |
| 397 | + await transaction.start( |
395 | 398 | is_root=is_root, extra_options=self._extra_options |
396 | 399 | ) |
397 | | - self._connection._transaction_stack.append(self) |
| 400 | + connection._transaction_stack.append(self) |
398 | 401 | return self |
399 | 402 |
|
400 | 403 | async def commit(self) -> None: |
401 | | - async with self._connection._transaction_lock: |
402 | | - assert self._connection._transaction_stack[-1] is self |
403 | | - self._connection._transaction_stack.pop() |
404 | | - await self._transaction.commit() |
405 | | - await self._connection.__aexit__() |
| 404 | + connection = self._connection_callable() |
| 405 | + transaction = self._transactions[asyncio.current_task()] |
| 406 | + async with connection._transaction_lock: |
| 407 | + assert connection._transaction_stack[-1] is self |
| 408 | + connection._transaction_stack.pop() |
| 409 | + await transaction.commit() |
| 410 | + await connection.__aexit__() |
406 | 411 |
|
407 | 412 | async def rollback(self) -> None: |
408 | | - async with self._connection._transaction_lock: |
409 | | - assert self._connection._transaction_stack[-1] is self |
410 | | - self._connection._transaction_stack.pop() |
411 | | - await self._transaction.rollback() |
412 | | - await self._connection.__aexit__() |
| 413 | + connection = self._connection_callable() |
| 414 | + transaction = self._transactions[asyncio.current_task()] |
| 415 | + async with connection._transaction_lock: |
| 416 | + assert connection._transaction_stack[-1] is self |
| 417 | + connection._transaction_stack.pop() |
| 418 | + await transaction.rollback() |
| 419 | + await connection.__aexit__() |
413 | 420 |
|
414 | 421 |
|
415 | 422 | class _EmptyNetloc(str): |
|
0 commit comments