Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 16bc5d7

Browse files
SQLite support
1 parent 76a7167 commit 16bc5d7

7 files changed

Lines changed: 231 additions & 19 deletions

File tree

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python:
99
- "3.7"
1010

1111
env:
12-
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database"
12+
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database", sqlite:///test.db
1313

1414
services:
1515
- postgresql

databases/backends/mysql.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515

1616
logger = logging.getLogger("databases")
1717

18-
_result_processors = {} # type: dict
19-
2018

2119
class MySQLBackend(DatabaseBackend):
22-
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
23-
self._database_url = DatabaseURL(database_url)
20+
def __init__(self, database_url: DatabaseURL) -> None:
21+
self._database_url = database_url
2422
self._dialect = pymysql.dialect(paramstyle="pyformat")
2523
self._pool = None
2624

@@ -134,10 +132,9 @@ def transaction(self) -> TransactionBackend:
134132

135133
def _compile(
136134
self, query: ClauseElement
137-
) -> typing.Tuple[str, list, CompilationContext]:
135+
) -> typing.Tuple[str, dict, CompilationContext]:
138136
compiled = query.compile(dialect=self._dialect)
139137
args = compiled.construct_params()
140-
logger.debug(compiled.string, args)
141138
for key, val in args.items():
142139
if key in compiled._bind_processors:
143140
args[key] = compiled._bind_processors[key](val)
@@ -150,6 +147,7 @@ def _compile(
150147
compiled._textual_ordered_columns,
151148
)
152149

150+
logger.debug(compiled.string, args)
153151
return compiled.string, args, CompilationContext(execution_context)
154152

155153

@@ -175,10 +173,7 @@ async def start(self, is_root: bool) -> None:
175173

176174
async def commit(self) -> None:
177175
assert self._connection._connection is not None, "Connection is not acquired"
178-
if self._is_root: # pragma: no cover
179-
# In test cases the root transaction is never committed,
180-
# since we *always* wrap the test case up in a transaction
181-
# and rollback to a clean state at the end.
176+
if self._is_root:
182177
await self._connection._connection.commit()
183178
else:
184179
cursor = await self._connection._connection.cursor()

databases/backends/postgres.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
class PostgresBackend(DatabaseBackend):
19-
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
20-
self._database_url = DatabaseURL(database_url)
19+
def __init__(self, database_url: DatabaseURL) -> None:
20+
self._database_url = database_url
2121
self._dialect = self._get_dialect()
2222
self._pool = None
2323

databases/backends/sqlite.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import logging
2+
import typing
3+
import uuid
4+
5+
import aiosqlite
6+
from sqlalchemy.dialects.sqlite import pysqlite
7+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
8+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
9+
from sqlalchemy.sql import ClauseElement
10+
from sqlalchemy.types import TypeEngine
11+
12+
from databases.core import DatabaseURL
13+
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
14+
15+
logger = logging.getLogger("databases")
16+
17+
18+
class SQLiteBackend(DatabaseBackend):
19+
def __init__(self, database_url: DatabaseURL) -> None:
20+
self._database_url = database_url
21+
self._dialect = pysqlite.dialect(paramstyle="qmark")
22+
self._pool = SQLitePool(database_url)
23+
24+
async def connect(self) -> None:
25+
pass
26+
# assert self._pool is None, "DatabaseBackend is already running"
27+
# self._pool = await aiomysql.create_pool(
28+
# host=self._database_url.hostname,
29+
# port=self._database_url.port or 3306,
30+
# user=self._database_url.username or getpass.getuser(),
31+
# password=self._database_url.password,
32+
# db=self._database_url.database,
33+
# autocommit=True,
34+
# )
35+
36+
async def disconnect(self) -> None:
37+
pass
38+
# assert self._pool is not None, "DatabaseBackend is not running"
39+
# self._pool.close()
40+
# await self._pool.wait_closed()
41+
# self._pool = None
42+
43+
def connection(self) -> "SQLiteConnection":
44+
return SQLiteConnection(self._pool, self._dialect)
45+
46+
47+
class SQLitePool:
48+
def __init__(self, url: DatabaseURL) -> None:
49+
self._url = url
50+
51+
async def acquire(self) -> aiosqlite.Connection:
52+
connection = aiosqlite.connect(
53+
database=self._url.database, isolation_level=None
54+
)
55+
await connection.__aenter__()
56+
return connection
57+
58+
async def release(self, connection: aiosqlite.Connection) -> None:
59+
await connection.__aexit__(None, None, None)
60+
61+
62+
class CompilationContext:
63+
def __init__(self, context: ExecutionContext):
64+
self.context = context
65+
66+
67+
class SQLiteConnection(ConnectionBackend):
68+
def __init__(self, pool: SQLitePool, dialect: Dialect):
69+
self._pool = pool
70+
self._dialect = dialect
71+
self._connection = None
72+
73+
async def acquire(self) -> None:
74+
assert self._connection is None, "Connection is already acquired"
75+
self._connection = await self._pool.acquire()
76+
77+
async def release(self) -> None:
78+
assert self._connection is not None, "Connection is not acquired"
79+
await self._pool.release(self._connection)
80+
self._connection = None
81+
82+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
83+
assert self._connection is not None, "Connection is not acquired"
84+
query, args, context = self._compile(query)
85+
86+
async with self._connection.execute(query, args) as cursor:
87+
rows = await cursor.fetchall()
88+
metadata = ResultMetaData(context, cursor.description)
89+
return [
90+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
91+
for row in rows
92+
]
93+
94+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
95+
assert self._connection is not None, "Connection is not acquired"
96+
query, args, context = self._compile(query)
97+
98+
async with self._connection.execute(query, args) as cursor:
99+
row = await cursor.fetchone()
100+
metadata = ResultMetaData(context, cursor.description)
101+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
102+
103+
async def execute(self, query: ClauseElement, values: dict = None) -> None:
104+
assert self._connection is not None, "Connection is not acquired"
105+
if values is not None:
106+
query = query.values(values)
107+
query, args, context = self._compile(query)
108+
cursor = await self._connection.execute(query, args)
109+
await cursor.close()
110+
111+
async def execute_many(self, query: ClauseElement, values: list) -> None:
112+
assert self._connection is not None, "Connection is not acquired"
113+
for value in values:
114+
await self.execute(query, value)
115+
116+
async def iterate(
117+
self, query: ClauseElement
118+
) -> typing.AsyncGenerator[typing.Any, None]:
119+
assert self._connection is not None, "Connection is not acquired"
120+
query, args, context = self._compile(query)
121+
cursor = await self._connection.cursor()
122+
async with self._connection.execute(query, args) as cursor:
123+
metadata = ResultMetaData(context, cursor.description)
124+
async for row in cursor:
125+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
126+
127+
def transaction(self) -> TransactionBackend:
128+
return SQLiteTransaction(self)
129+
130+
def _compile(
131+
self, query: ClauseElement
132+
) -> typing.Tuple[str, list, CompilationContext]:
133+
compiled = query.compile(dialect=self._dialect)
134+
args = []
135+
for key, raw_val in compiled.construct_params().items():
136+
if key in compiled._bind_processors:
137+
val = compiled._bind_processors[key](raw_val)
138+
else:
139+
val = raw_val
140+
args.append(val)
141+
142+
execution_context = self._dialect.execution_ctx_cls()
143+
execution_context.dialect = self._dialect
144+
execution_context.result_column_struct = (
145+
compiled._result_columns,
146+
compiled._ordered_columns,
147+
compiled._textual_ordered_columns,
148+
)
149+
150+
logger.debug(compiled.string, args)
151+
return compiled.string, args, CompilationContext(execution_context)
152+
153+
154+
class SQLiteTransaction(TransactionBackend):
155+
def __init__(self, connection: SQLiteConnection):
156+
self._connection = connection
157+
self._is_root = False
158+
self._savepoint_name = ""
159+
160+
async def start(self, is_root: bool) -> None:
161+
assert self._connection._connection is not None, "Connection is not acquired"
162+
self._is_root = is_root
163+
if self._is_root:
164+
cursor = await self._connection._connection.execute("BEGIN")
165+
await cursor.close()
166+
else:
167+
id = str(uuid.uuid4()).replace("-", "_")
168+
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
169+
cursor = await self._connection._connection.execute(
170+
f"SAVEPOINT {self._savepoint_name}"
171+
)
172+
await cursor.close()
173+
174+
async def commit(self) -> None:
175+
assert self._connection._connection is not None, "Connection is not acquired"
176+
if self._is_root:
177+
cursor = await self._connection._connection.execute("COMMIT")
178+
await cursor.close()
179+
else:
180+
cursor = await self._connection._connection.execute(
181+
f"RELEASE SAVEPOINT {self._savepoint_name}"
182+
)
183+
await cursor.close()
184+
185+
async def rollback(self) -> None:
186+
assert self._connection._connection is not None, "Connection is not acquired"
187+
if self._is_root:
188+
cursor = await self._connection._connection.execute("ROLLBACK")
189+
await cursor.close()
190+
else:
191+
cursor = await self._connection._connection.execute(
192+
f"ROLLBACK TO SAVEPOINT {self._savepoint_name}"
193+
)
194+
await cursor.close()

databases/core.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Database:
2121
SUPPORTED_BACKENDS = {
2222
"postgresql": "databases.backends.postgres:PostgresBackend",
2323
"mysql": "databases.backends.mysql:MySQLBackend",
24+
"sqlite": "databases.backends.sqlite:SQLiteBackend",
2425
}
2526

2627
def __init__(
@@ -246,10 +247,7 @@ async def rollback(self) -> None:
246247

247248
class DatabaseURL:
248249
def __init__(self, url: typing.Union[str, "DatabaseURL"]):
249-
if isinstance(url, DatabaseURL):
250-
self._url = str(url)
251-
else:
252-
self._url = url
250+
self._url = str(url)
253251

254252
@property
255253
def components(self) -> SplitResult:

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
sqlalchemy
1+
sqlalchemy==1.3.0b3
22
aiocontextvars;python_version<"3.7"
33

44
# Async database drivers
5-
asyncpg
65
aiomysql
6+
aiosqlite
7+
asyncpg
78

89
# Sync database drivers for standard tooling around setup/teardown/migrations.
910
psycopg2-binary

tests/test_databases.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,30 @@ async def test_connections_isolation(database_url):
376376
await database.execute(query)
377377

378378

379+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
380+
@async_adapter
381+
async def test_commit_on_root_transaction(database_url):
382+
"""
383+
Because our tests are generally wrapped in rollback-islation, they
384+
don't have coverage for commiting the root transaction.
385+
386+
Deal with this here, and delete the records rather than rolling back.
387+
"""
388+
389+
async with Database(database_url) as database:
390+
try:
391+
async with database.transaction():
392+
query = notes.insert().values(text="example1", completed=True)
393+
await database.execute(query)
394+
395+
query = notes.select()
396+
results = await database.fetch_all(query=query)
397+
assert len(results) == 1
398+
finally:
399+
query = notes.delete()
400+
await database.execute(query)
401+
402+
379403
@pytest.mark.parametrize("database_url", DATABASE_URLS)
380404
@async_adapter
381405
async def test_connect_and_disconnect(database_url):

0 commit comments

Comments
 (0)