11import asyncio
22import contextlib
3- from contextvars import ContextVar
43import functools
54import logging
65import typing
6+ import weakref
7+ from contextvars import ContextVar
78from types import TracebackType
89from urllib .parse import SplitResult , parse_qsl , unquote , urlsplit
9- import weakref
10+
1011from sqlalchemy import text
1112from sqlalchemy .sql import ClauseElement
1213
@@ -93,8 +94,12 @@ def _connection(
9394 connections = weakref .WeakKeyDictionary ()
9495 _ACTIVE_CONNECTIONS .set (connections )
9596
96- connections [self ] = connection
97- return connections [self ]
97+ if connection is None :
98+ connections .pop (self , None )
99+ else :
100+ connections [self ] = connection
101+
102+ return connections .get (self , None )
98103
99104 async def connect (self ) -> None :
100105 """
@@ -390,8 +395,12 @@ def _transaction(
390395 transactions = weakref .WeakKeyDictionary ()
391396 _ACTIVE_TRANSACTIONS .set (transactions )
392397
393- transactions [self ] = transaction
394- return transactions [self ]
398+ if transaction is None :
399+ transactions .pop (self , None )
400+ else :
401+ transactions [self ] = transaction
402+
403+ return transactions .get (self , None )
395404
396405 async def __aenter__ (self ) -> "Transaction" :
397406 """
@@ -448,6 +457,7 @@ async def commit(self) -> None:
448457 async with self ._connection ._transaction_lock :
449458 assert self ._connection ._transaction_stack [- 1 ] is self
450459 self ._connection ._transaction_stack .pop ()
460+ assert self ._transaction is not None
451461 await self ._transaction .commit ()
452462 await self ._connection .__aexit__ ()
453463 self ._transaction = None
@@ -456,6 +466,7 @@ async def rollback(self) -> None:
456466 async with self ._connection ._transaction_lock :
457467 assert self ._connection ._transaction_stack [- 1 ] is self
458468 self ._connection ._transaction_stack .pop ()
469+ assert self ._transaction is not None
459470 await self ._transaction .rollback ()
460471 await self ._connection .__aexit__ ()
461472 self ._transaction = None
0 commit comments