33import functools
44import logging
55import typing
6- from contextvars import ContextVar
76from types import TracebackType
8- from typing import Dict , Tuple
97from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
108
119from sqlalchemy import text
3533
3634logger = logging .getLogger ("databases" )
3735
38- # Connections are stored as task-local state, but care must be taken to ensure
39- # that two database instances in the same task do not overwrite each other's connections.
40- # For this reason, the dict key comprises the database instance and the current task.
41- _connection_contextmap : ContextVar [
42- Dict [Tuple ["Database" , asyncio .Task ], "Connection" ]
43- ] = ContextVar ("databases:Connection" )
44-
45-
46- def _get_connection_contextmap () -> Dict [Tuple ["Database" , asyncio .Task ], "Connection" ]:
47- connections = _connection_contextmap .get (None )
48- if connections is None :
49- connections = {}
50- _connection_contextmap .set (connections )
51- return connections
52-
5336
5437class Database :
5538 SUPPORTED_BACKENDS = {
@@ -79,6 +62,9 @@ def __init__(
7962 assert issubclass (backend_cls , DatabaseBackend )
8063 self ._backend = backend_cls (self .url , ** self .options )
8164
65+ # Connections are stored per asyncio task
66+ self ._connections : typing .Dict [typing .Optional [asyncio .Task ], Connection ] = {}
67+
8268 # When `force_rollback=True` is used, we use a single global
8369 # connection, within a transaction that always rolls back.
8470 self ._global_connection : typing .Optional [Connection ] = None
@@ -126,11 +112,10 @@ async def disconnect(self) -> None:
126112 self ._global_transaction = None
127113 self ._global_connection = None
128114 else :
129- task = asyncio .current_task ()
130- assert task is not None , "Not running in an asyncio task"
131- connections = _get_connection_contextmap ()
132- if (self , task ) in connections :
133- del connections [self , task ]
115+ current_task = asyncio .current_task ()
116+ assert current_task is not None , "No currently running task"
117+ if current_task in self ._connections :
118+ del self ._connections [current_task ]
134119
135120 await self ._backend .disconnect ()
136121 logger .info (
@@ -204,13 +189,12 @@ def connection(self) -> "Connection":
204189 if self ._global_connection is not None :
205190 return self ._global_connection
206191
207- task = asyncio .current_task ()
208- assert task is not None , "Not running in an asyncio task"
209- connections = _get_connection_contextmap ()
210- if (self , task ) not in connections :
211- connections [self , task ] = Connection (self ._backend )
192+ current_task = asyncio .current_task ()
193+ assert current_task is not None , "No currently running task"
194+ if current_task not in self ._connections :
195+ self ._connections [current_task ] = Connection (self ._backend )
212196
213- return connections [ self , task ]
197+ return self . _connections [ current_task ]
214198
215199 def transaction (
216200 self , * , force_rollback : bool = False , ** kwargs : typing .Any
@@ -351,19 +335,6 @@ def _build_query(
351335
352336_CallableType = typing .TypeVar ("_CallableType" , bound = typing .Callable )
353337
354- _transaction_contextmap : ContextVar [
355- Dict ["Transaction" , TransactionBackend ]
356- ] = ContextVar ("databases:Transactions" )
357-
358-
359- def _get_transaction_contextmap () -> Dict ["Transaction" , TransactionBackend ]:
360- transactions = _transaction_contextmap .get (None )
361- if transactions is None :
362- transactions = {}
363- _transaction_contextmap .set (transactions )
364-
365- return transactions
366-
367338
368339class Transaction :
369340 def __init__ (
@@ -376,6 +347,11 @@ def __init__(
376347 self ._force_rollback = force_rollback
377348 self ._extra_options = kwargs
378349
350+ # Transactions are stored per asyncio task
351+ self ._transactions : typing .Dict [
352+ typing .Optional [asyncio .Task ], TransactionBackend
353+ ] = {}
354+
379355 async def __aenter__ (self ) -> "Transaction" :
380356 """
381357 Called when entering `async with database.transaction()`
@@ -417,9 +393,10 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
417393
418394 async def start (self ) -> "Transaction" :
419395 connection = self ._connection_callable ()
396+ current_task = asyncio .current_task ()
397+ assert current_task is not None , "No currently running task"
420398 transaction = connection ._connection .transaction ()
421- transactions = _get_transaction_contextmap ()
422- transactions [self ] = transaction
399+ self ._transactions [current_task ] = transaction
423400 async with connection ._transaction_lock :
424401 is_root = not connection ._transaction_stack
425402 await connection .__aenter__ ()
@@ -429,27 +406,27 @@ async def start(self) -> "Transaction":
429406
430407 async def commit (self ) -> None :
431408 connection = self ._connection_callable ()
432- transactions = _get_transaction_contextmap ()
433- transaction = transactions . get (self , None )
409+ current_task = asyncio . current_task ()
410+ transaction = self . _transactions . get (current_task , None )
434411 assert transaction is not None , "Transaction not found in current task"
435412 async with connection ._transaction_lock :
436413 assert connection ._transaction_stack [- 1 ] is self
437414 connection ._transaction_stack .pop ()
438415 await transaction .commit ()
439416 await connection .__aexit__ ()
440- del transactions [ self ]
417+ del self . _transactions [ current_task ]
441418
442419 async def rollback (self ) -> None :
443420 connection = self ._connection_callable ()
444- transactions = _get_transaction_contextmap ()
445- transaction = transactions . get (self , None )
421+ current_task = asyncio . current_task ()
422+ transaction = self . _transactions . get (current_task , None )
446423 assert transaction is not None , "Transaction not found in current task"
447424 async with connection ._transaction_lock :
448425 assert connection ._transaction_stack [- 1 ] is self
449426 connection ._transaction_stack .pop ()
450427 await transaction .rollback ()
451428 await connection .__aexit__ ()
452- del transactions [ self ]
429+ del self . _transactions [ current_task ]
453430
454431
455432class _EmptyNetloc (str ):
0 commit comments