11import asyncio
2+ import functools
23import sys
34import typing
45from types import TracebackType
@@ -22,7 +23,7 @@ class Database:
2223 }
2324
2425 def __init__ (
25- self , url : typing .Union [str , "DatabaseURL" ], force_rollback : bool = False
26+ self , url : typing .Union [str , "DatabaseURL" ], * , force_rollback : bool = False
2627 ):
2728 self ._url = DatabaseURL (url )
2829 self ._force_rollback = force_rollback
@@ -113,8 +114,8 @@ def connection(self) -> "Connection":
113114 self ._connection_context .set (connection )
114115 return connection
115116
116- def transaction (self , force_rollback : bool = False ) -> "Transaction" :
117- return self .connection ().transaction (force_rollback )
117+ def transaction (self , * , force_rollback : bool = False ) -> "Transaction" :
118+ return self .connection ().transaction (force_rollback = force_rollback )
118119
119120
120121class Connection :
@@ -135,7 +136,12 @@ async def __aenter__(self) -> "Connection":
135136 await self ._connection .acquire ()
136137 return self
137138
138- async def __aexit__ (self , * args : typing .Any , ** kwargs : typing .Any ) -> None :
139+ async def __aexit__ (
140+ self ,
141+ exc_type : typing .Type [BaseException ] = None ,
142+ exc_value : BaseException = None ,
143+ traceback : TracebackType = None ,
144+ ) -> None :
139145 async with self ._connection_lock :
140146 assert self ._connection is not None
141147 self ._connection_counter -= 1
@@ -160,7 +166,7 @@ async def iterate(
160166 async for record in self ._connection .iterate (query ):
161167 yield record
162168
163- def transaction (self , force_rollback : bool = False ) -> "Transaction" :
169+ def transaction (self , * , force_rollback : bool = False ) -> "Transaction" :
164170 return Transaction (self , force_rollback )
165171
166172
@@ -171,6 +177,9 @@ def __init__(self, connection: Connection, force_rollback: bool) -> None:
171177 self ._transaction = connection ._connection .transaction ()
172178
173179 async def __aenter__ (self ) -> "Transaction" :
180+ """
181+ Called when entering `async with database.transaction()`
182+ """
174183 await self .start ()
175184 return self
176185
@@ -180,14 +189,32 @@ async def __aexit__(
180189 exc_value : BaseException = None ,
181190 traceback : TracebackType = None ,
182191 ) -> None :
192+ """
193+ Called when exiting `async with database.transaction()`
194+ """
183195 if exc_type is not None or self ._force_rollback :
184196 await self .rollback ()
185197 else :
186198 await self .commit ()
187199
188200 def __await__ (self ) -> typing .Generator :
201+ """
202+ Called if using the low-level `transaction = await database.transaction()`
203+ """
189204 return self .start ().__await__ ()
190205
206+ def __call__ (self , func : typing .Callable ) -> typing .Callable :
207+ """
208+ Called if using `@database.transaction()` as a decorator.
209+ """
210+
211+ @functools .wraps (func )
212+ async def wrapper (* args : typing .Any , ** kwargs : typing .Any ) -> typing .Any :
213+ async with self :
214+ return await func (* args , ** kwargs )
215+
216+ return wrapper
217+
191218 async def start (self ) -> "Transaction" :
192219 async with self ._connection ._transaction_lock :
193220 is_root = not self ._connection ._transaction_stack
0 commit comments