Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit f7f595b

Browse files
Merge pull request #28 from encode/support-global-force-rollback
Support global force rollback
2 parents a0568f8 + 0ceb8c7 commit f7f595b

6 files changed

Lines changed: 243 additions & 103 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
.coverage
33
.pytest_cache/
44
.mypy_cache/
5-
starlette.egg-info/
5+
*.egg-info/
6+
htmlcov/
67
venv/

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ rows = await database.fetch_all(query)
9090
query = notes.select()
9191
row = await database.fetch_one(query)
9292

93-
# Fetch multiple rows without loading everything into memory at once.
93+
# Fetch multiple rows without loading them all into memory at once
9494
query = notes.select()
9595
async for row in database.iterate(query):
9696
...

databases/core.py

Lines changed: 126 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -15,118 +15,57 @@
1515
from aiocontextvars import ContextVar
1616

1717

18-
class DatabaseURL:
19-
def __init__(self, url: typing.Union[str, "DatabaseURL"]):
20-
if isinstance(url, DatabaseURL):
21-
self._url = str(url)
22-
else:
23-
self._url = url
24-
25-
@property
26-
def components(self) -> SplitResult:
27-
if not hasattr(self, "_components"):
28-
self._components = urlsplit(self._url)
29-
return self._components
30-
31-
@property
32-
def dialect(self) -> str:
33-
return self.components.scheme.split("+")[0]
34-
35-
@property
36-
def driver(self) -> str:
37-
if "+" not in self.components.scheme:
38-
return ""
39-
return self.components.scheme.split("+", 1)[1]
40-
41-
@property
42-
def username(self) -> typing.Optional[str]:
43-
return self.components.username
44-
45-
@property
46-
def password(self) -> typing.Optional[str]:
47-
return self.components.password
48-
49-
@property
50-
def hostname(self) -> typing.Optional[str]:
51-
return self.components.hostname
52-
53-
@property
54-
def port(self) -> typing.Optional[int]:
55-
return self.components.port
56-
57-
@property
58-
def database(self) -> str:
59-
return self.components.path.lstrip("/")
60-
61-
def replace(self, **kwargs: typing.Any) -> "DatabaseURL":
62-
if (
63-
"username" in kwargs
64-
or "password" in kwargs
65-
or "hostname" in kwargs
66-
or "port" in kwargs
67-
):
68-
hostname = kwargs.pop("hostname", self.hostname)
69-
port = kwargs.pop("port", self.port)
70-
username = kwargs.pop("username", self.username)
71-
password = kwargs.pop("password", self.password)
72-
73-
netloc = hostname
74-
if port is not None:
75-
netloc += f":{port}"
76-
if username is not None:
77-
userpass = username
78-
if password is not None:
79-
userpass += f":{password}"
80-
netloc = f"{userpass}@{netloc}"
81-
82-
kwargs["netloc"] = netloc
83-
84-
if "database" in kwargs:
85-
kwargs["path"] = "/" + kwargs.pop("database")
86-
87-
if "dialect" in kwargs or "driver" in kwargs:
88-
dialect = kwargs.pop("dialect", self.dialect)
89-
driver = kwargs.pop("driver", self.driver)
90-
kwargs["scheme"] = f"{dialect}+{driver}" if driver else dialect
91-
92-
components = self.components._replace(**kwargs)
93-
return self.__class__(components.geturl())
94-
95-
def __str__(self) -> str:
96-
return self._url
97-
98-
def __repr__(self) -> str:
99-
url = str(self)
100-
if self.password:
101-
url = str(self.replace(password="********"))
102-
return f"{self.__class__.__name__}({repr(url)})"
103-
104-
10518
class Database:
10619
SUPPORTED_BACKENDS = {
10720
"postgresql": "databases.backends.postgres:PostgresBackend",
10821
"mysql": "databases.backends.mysql:MySQLBackend",
10922
}
11023

111-
def __init__(self, url: typing.Union[str, DatabaseURL]):
112-
self.url = DatabaseURL(url)
24+
def __init__(
25+
self, url: typing.Union[str, "DatabaseURL"], force_rollback: bool = False
26+
):
27+
self._url = DatabaseURL(url)
28+
self._force_rollback = force_rollback
29+
11330
self.is_connected = False
11431

115-
backend_str = self.SUPPORTED_BACKENDS[self.url.dialect]
32+
backend_str = self.SUPPORTED_BACKENDS[self._url.dialect]
11633
backend_cls = import_from_string(backend_str)
11734
assert issubclass(backend_cls, DatabaseBackend)
118-
self._backend = backend_cls(self.url)
35+
self._backend = backend_cls(self._url)
36+
37+
# Connections are stored as task-local state.
11938
self._connection_context = ContextVar("connection_context") # type: ContextVar
12039

40+
# When `force_rollback=True` is used, we use a single global
41+
# connection, within a transaction that always rolls back.
42+
self._global_connection = None # type: typing.Optional[Connection]
43+
self._global_transaction = None # type: typing.Optional[Transaction]
44+
12145
async def connect(self) -> None:
122-
if not self.is_connected:
123-
await self._backend.connect()
124-
self.is_connected = True
46+
assert not self.is_connected, "Already connected."
47+
48+
await self._backend.connect()
49+
self.is_connected = True
50+
51+
if self._force_rollback:
52+
self._global_connection = Connection(self._backend)
53+
self._global_transaction = self._global_connection.transaction(
54+
force_rollback=True
55+
)
56+
await self._global_transaction.__aenter__()
12557

12658
async def disconnect(self) -> None:
127-
if self.is_connected:
128-
await self._backend.disconnect()
129-
self.is_connected = False
59+
assert self.is_connected, "Already disconnected."
60+
61+
if self._force_rollback:
62+
assert self._global_transaction is not None
63+
await self._global_transaction.__aexit__()
64+
self._global_transaction = None
65+
self._global_connection = None
66+
67+
await self._backend.disconnect()
68+
self.is_connected = False
13069

13170
async def __aenter__(self) -> "Database":
13271
await self.connect()
@@ -164,6 +103,9 @@ async def iterate(
164103
yield record
165104

166105
def connection(self) -> "Connection":
106+
if self._global_connection is not None:
107+
return self._global_connection
108+
167109
try:
168110
return self._connection_context.get()
169111
except LookupError:
@@ -267,3 +209,90 @@ async def rollback(self) -> None:
267209
self._connection._transaction_stack.pop()
268210
await self._transaction.rollback()
269211
await self._connection.__aexit__()
212+
213+
214+
class DatabaseURL:
215+
def __init__(self, url: typing.Union[str, "DatabaseURL"]):
216+
if isinstance(url, DatabaseURL):
217+
self._url = str(url)
218+
else:
219+
self._url = url
220+
221+
@property
222+
def components(self) -> SplitResult:
223+
if not hasattr(self, "_components"):
224+
self._components = urlsplit(self._url)
225+
return self._components
226+
227+
@property
228+
def dialect(self) -> str:
229+
return self.components.scheme.split("+")[0]
230+
231+
@property
232+
def driver(self) -> str:
233+
if "+" not in self.components.scheme:
234+
return ""
235+
return self.components.scheme.split("+", 1)[1]
236+
237+
@property
238+
def username(self) -> typing.Optional[str]:
239+
return self.components.username
240+
241+
@property
242+
def password(self) -> typing.Optional[str]:
243+
return self.components.password
244+
245+
@property
246+
def hostname(self) -> typing.Optional[str]:
247+
return self.components.hostname
248+
249+
@property
250+
def port(self) -> typing.Optional[int]:
251+
return self.components.port
252+
253+
@property
254+
def database(self) -> str:
255+
return self.components.path.lstrip("/")
256+
257+
def replace(self, **kwargs: typing.Any) -> "DatabaseURL":
258+
if (
259+
"username" in kwargs
260+
or "password" in kwargs
261+
or "hostname" in kwargs
262+
or "port" in kwargs
263+
):
264+
hostname = kwargs.pop("hostname", self.hostname)
265+
port = kwargs.pop("port", self.port)
266+
username = kwargs.pop("username", self.username)
267+
password = kwargs.pop("password", self.password)
268+
269+
netloc = hostname
270+
if port is not None:
271+
netloc += f":{port}"
272+
if username is not None:
273+
userpass = username
274+
if password is not None:
275+
userpass += f":{password}"
276+
netloc = f"{userpass}@{netloc}"
277+
278+
kwargs["netloc"] = netloc
279+
280+
if "database" in kwargs:
281+
kwargs["path"] = "/" + kwargs.pop("database")
282+
283+
if "dialect" in kwargs or "driver" in kwargs:
284+
dialect = kwargs.pop("dialect", self.dialect)
285+
driver = kwargs.pop("driver", self.driver)
286+
kwargs["scheme"] = f"{dialect}+{driver}" if driver else dialect
287+
288+
components = self.components._replace(**kwargs)
289+
return self.__class__(components.geturl())
290+
291+
def __str__(self) -> str:
292+
return self._url
293+
294+
def __repr__(self) -> str:
295+
url = str(self)
296+
if self.password:
297+
url = str(self.replace(password="********"))
298+
return f"{self.__class__.__name__}({repr(url)})"

requirements.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,5 @@ isort
1717
mypy
1818
pytest
1919
pytest-cov
20-
21-
# Documentation
22-
mkdocs
23-
mkdocs-material
20+
starlette
21+
requests

tests/test_databases.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,18 @@ async def test_connections_isolation(database_url):
345345
finally:
346346
query = notes.delete()
347347
await database.execute(query)
348+
349+
350+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
351+
@async_adapter
352+
async def test_connect_and_disconnect(database_url):
353+
"""
354+
Test explicit connect() and disconnect().
355+
"""
356+
database = Database(database_url)
357+
358+
assert not database.is_connected
359+
await database.connect()
360+
assert database.is_connected
361+
await database.disconnect()
362+
assert not database.is_connected

tests/test_integration.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
3+
import pytest
4+
import sqlalchemy
5+
from starlette.applications import Starlette
6+
from starlette.responses import JSONResponse
7+
from starlette.testclient import TestClient
8+
9+
from databases import Database, DatabaseURL
10+
11+
assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."
12+
13+
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
14+
15+
metadata = sqlalchemy.MetaData()
16+
17+
notes = sqlalchemy.Table(
18+
"notes",
19+
metadata,
20+
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
21+
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
22+
sqlalchemy.Column("completed", sqlalchemy.Boolean),
23+
)
24+
25+
26+
@pytest.fixture(autouse=True, scope="module")
27+
def create_test_database():
28+
# Create test databases
29+
for url in DATABASE_URLS:
30+
database_url = DatabaseURL(url)
31+
if database_url.dialect == "mysql":
32+
url = str(database_url.replace(driver="pymysql"))
33+
engine = sqlalchemy.create_engine(url)
34+
metadata.create_all(engine)
35+
36+
# Run the test suite
37+
yield
38+
39+
# Drop test databases
40+
for url in DATABASE_URLS:
41+
database_url = DatabaseURL(url)
42+
if database_url.dialect == "mysql":
43+
url = str(database_url.replace(driver="pymysql"))
44+
engine = sqlalchemy.create_engine(url)
45+
metadata.drop_all(engine)
46+
47+
48+
def get_app(database_url):
49+
database = Database(database_url, force_rollback=True)
50+
app = Starlette()
51+
52+
@app.on_event("startup")
53+
async def startup():
54+
await database.connect()
55+
56+
@app.on_event("shutdown")
57+
async def shutdown():
58+
await database.disconnect()
59+
60+
@app.route("/notes", methods=["GET"])
61+
async def list_notes(request):
62+
query = notes.select()
63+
results = await database.fetch_all(query)
64+
content = [
65+
{"text": result["text"], "completed": result["completed"]}
66+
for result in results
67+
]
68+
return JSONResponse(content)
69+
70+
@app.route("/notes", methods=["POST"])
71+
async def add_note(request):
72+
data = await request.json()
73+
query = notes.insert().values(text=data["text"], completed=data["completed"])
74+
await database.execute(query)
75+
return JSONResponse({"text": data["text"], "completed": data["completed"]})
76+
77+
return app
78+
79+
80+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
81+
def test_integration(database_url):
82+
app = get_app(database_url)
83+
84+
with TestClient(app) as client:
85+
response = client.post("/notes", json={"text": "example", "completed": True})
86+
assert response.status_code == 200
87+
assert response.json() == {"text": "example", "completed": True}
88+
89+
response = client.get("/notes")
90+
assert response.status_code == 200
91+
assert response.json() == [{"text": "example", "completed": True}]
92+
93+
with TestClient(app) as client:
94+
# Ensure sessions are isolated
95+
response = client.get("/notes")
96+
assert response.status_code == 200
97+
assert response.json() == []

0 commit comments

Comments
 (0)