Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 44dbf7b

Browse files
Merge pull request #30 from encode/transaction-decorator
Add transaction decorator
2 parents 9675bf8 + 19e9428 commit 44dbf7b

2 files changed

Lines changed: 61 additions & 5 deletions

File tree

databases/core.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import functools
23
import sys
34
import typing
45
from types import TracebackType
@@ -22,7 +23,7 @@ class Database:
2223
}
2324

2425
def __init__(
25-
self, url: typing.Union[str, "DatabaseURL"], force_rollback: bool = False
26+
self, url: typing.Union[str, "DatabaseURL"], *, force_rollback: bool = False
2627
):
2728
self._url = DatabaseURL(url)
2829
self._force_rollback = force_rollback
@@ -113,8 +114,8 @@ def connection(self) -> "Connection":
113114
self._connection_context.set(connection)
114115
return connection
115116

116-
def transaction(self, force_rollback: bool = False) -> "Transaction":
117-
return self.connection().transaction(force_rollback)
117+
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
118+
return self.connection().transaction(force_rollback=force_rollback)
118119

119120

120121
class Connection:
@@ -135,7 +136,12 @@ async def __aenter__(self) -> "Connection":
135136
await self._connection.acquire()
136137
return self
137138

138-
async def __aexit__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
139+
async def __aexit__(
140+
self,
141+
exc_type: typing.Type[BaseException] = None,
142+
exc_value: BaseException = None,
143+
traceback: TracebackType = None,
144+
) -> None:
139145
async with self._connection_lock:
140146
assert self._connection is not None
141147
self._connection_counter -= 1
@@ -160,7 +166,7 @@ async def iterate(
160166
async for record in self._connection.iterate(query):
161167
yield record
162168

163-
def transaction(self, force_rollback: bool = False) -> "Transaction":
169+
def transaction(self, *, force_rollback: bool = False) -> "Transaction":
164170
return Transaction(self, force_rollback)
165171

166172

@@ -171,6 +177,9 @@ def __init__(self, connection: Connection, force_rollback: bool) -> None:
171177
self._transaction = connection._connection.transaction()
172178

173179
async def __aenter__(self) -> "Transaction":
180+
"""
181+
Called when entering `async with database.transaction()`
182+
"""
174183
await self.start()
175184
return self
176185

@@ -180,14 +189,32 @@ async def __aexit__(
180189
exc_value: BaseException = None,
181190
traceback: TracebackType = None,
182191
) -> None:
192+
"""
193+
Called when exiting `async with database.transaction()`
194+
"""
183195
if exc_type is not None or self._force_rollback:
184196
await self.rollback()
185197
else:
186198
await self.commit()
187199

188200
def __await__(self) -> typing.Generator:
201+
"""
202+
Called if using the low-level `transaction = await database.transaction()`
203+
"""
189204
return self.start().__await__()
190205

206+
def __call__(self, func: typing.Callable) -> typing.Callable:
207+
"""
208+
Called if using `@database.transaction()` as a decorator.
209+
"""
210+
211+
@functools.wraps(func)
212+
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
213+
async with self:
214+
return await func(*args, **kwargs)
215+
216+
return wrapper
217+
191218
async def start(self) -> "Transaction":
192219
async with self._connection._transaction_lock:
193220
is_root = not self._connection._transaction_stack

tests/test_databases.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,35 @@ async def test_transaction_rollback_low_level(database_url):
256256
assert len(results) == 0
257257

258258

259+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
260+
@async_adapter
261+
async def test_transaction_decorator(database_url):
262+
"""
263+
Ensure that @database.transaction() is supported.
264+
"""
265+
async with Database(database_url, force_rollback=True) as database:
266+
267+
@database.transaction()
268+
async def insert_data(raise_exception):
269+
query = notes.insert().values(text="example", completed=True)
270+
await database.execute(query)
271+
if raise_exception:
272+
raise RuntimeError()
273+
274+
with pytest.raises(RuntimeError):
275+
await insert_data(raise_exception=True)
276+
277+
query = notes.select()
278+
results = await database.fetch_all(query=query)
279+
assert len(results) == 0
280+
281+
await insert_data(raise_exception=False)
282+
283+
query = notes.select()
284+
results = await database.fetch_all(query=query)
285+
assert len(results) == 1
286+
287+
259288
@pytest.mark.parametrize("database_url", DATABASE_URLS)
260289
@async_adapter
261290
async def test_datetime_field(database_url):

0 commit comments

Comments
 (0)