@@ -73,12 +73,6 @@ def __init__(
7373 self ._global_connection = None # type: typing.Optional[Connection]
7474 self ._global_transaction = None # type: typing.Optional[Transaction]
7575
76- if self ._force_rollback :
77- self ._global_connection = Connection (self ._backend )
78- self ._global_transaction = self ._global_connection .transaction (
79- force_rollback = True
80- )
81-
8276 async def connect (self ) -> None :
8377 """
8478 Establish the connection pool.
@@ -92,7 +86,14 @@ async def connect(self) -> None:
9286 self .is_connected = True
9387
9488 if self ._force_rollback :
95- assert self ._global_transaction is not None
89+ assert self ._global_connection is None
90+ assert self ._global_transaction is None
91+
92+ self ._global_connection = Connection (self ._backend )
93+ self ._global_transaction = self ._global_connection .transaction (
94+ force_rollback = True
95+ )
96+
9697 await self ._global_transaction .__aenter__ ()
9798
9899 async def disconnect (self ) -> None :
@@ -102,9 +103,14 @@ async def disconnect(self) -> None:
102103 assert self .is_connected , "Already disconnected."
103104
104105 if self ._force_rollback :
106+ assert self ._global_connection is not None
105107 assert self ._global_transaction is not None
108+
106109 await self ._global_transaction .__aexit__ ()
107110
111+ self ._global_transaction = None
112+ self ._global_connection = None
113+
108114 await self ._backend .disconnect ()
109115 logger .info (
110116 "Disconnected from database %s" ,
0 commit comments