Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit aa6686c

Browse files
Merge pull request #36 from encode/sqlite-support
SQLite support
2 parents d1b0b8e + 96e26db commit aa6686c

8 files changed

Lines changed: 278 additions & 60 deletions

File tree

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ python:
99
- "3.7"
1010

1111
env:
12-
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database"
12+
- TEST_DATABASE_URLS="postgresql://localhost/test_database, mysql://localhost/test_database, sqlite:///test.db"
1313

1414
services:
1515
- postgresql

databases/backends/mysql.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import aiomysql
77
from sqlalchemy.dialects.mysql import pymysql
8-
from sqlalchemy.engine.interfaces import Dialect
8+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
9+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
910
from sqlalchemy.sql import ClauseElement
1011
from sqlalchemy.types import TypeEngine
1112

@@ -14,12 +15,10 @@
1415

1516
logger = logging.getLogger("databases")
1617

17-
_result_processors = {} # type: dict
18-
1918

2019
class MySQLBackend(DatabaseBackend):
21-
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
22-
self._database_url = DatabaseURL(database_url)
20+
def __init__(self, database_url: DatabaseURL) -> None:
21+
self._database_url = database_url
2322
self._dialect = pymysql.dialect(paramstyle="pyformat")
2423
self._pool = None
2524

@@ -45,28 +44,9 @@ def connection(self) -> "MySQLConnection":
4544
return MySQLConnection(self._pool, self._dialect)
4645

4746

48-
class Record:
49-
def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
50-
self._row = row
51-
self._result_columns = result_columns
52-
self._dialect = dialect
53-
self._column_map = {
54-
column_name: (idx, datatype)
55-
for idx, (column_name, _, _, datatype) in enumerate(self._result_columns)
56-
}
57-
58-
def __getitem__(self, key: str) -> typing.Any:
59-
idx, datatype = self._column_map[key]
60-
raw = self._row[idx]
61-
try:
62-
processor = _result_processors[datatype]
63-
except KeyError:
64-
processor = datatype.result_processor(self._dialect, None)
65-
_result_processors[datatype] = processor
66-
67-
if processor is not None:
68-
return processor(raw)
69-
return raw
47+
class CompilationContext:
48+
def __init__(self, context: ExecutionContext):
49+
self.context = context
7050

7151

7252
class MySQLConnection(ConnectionBackend):
@@ -84,33 +64,38 @@ async def release(self) -> None:
8464
await self._pool.release(self._connection)
8565
self._connection = None
8666

87-
async def fetch_all(self, query: ClauseElement) -> typing.Any:
67+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
8868
assert self._connection is not None, "Connection is not acquired"
89-
query, args, result_columns = self._compile(query)
69+
query, args, context = self._compile(query)
9070
cursor = await self._connection.cursor()
9171
try:
9272
await cursor.execute(query, args)
9373
rows = await cursor.fetchall()
94-
return [Record(row, result_columns, self._dialect) for row in rows]
74+
metadata = ResultMetaData(context, cursor.description)
75+
return [
76+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
77+
for row in rows
78+
]
9579
finally:
9680
await cursor.close()
9781

98-
async def fetch_one(self, query: ClauseElement) -> typing.Any:
82+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
9983
assert self._connection is not None, "Connection is not acquired"
100-
query, args, result_columns = self._compile(query)
84+
query, args, context = self._compile(query)
10185
cursor = await self._connection.cursor()
10286
try:
10387
await cursor.execute(query, args)
10488
row = await cursor.fetchone()
105-
return Record(row, result_columns, self._dialect)
89+
metadata = ResultMetaData(context, cursor.description)
90+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
10691
finally:
10792
await cursor.close()
10893

10994
async def execute(self, query: ClauseElement, values: dict = None) -> None:
11095
assert self._connection is not None, "Connection is not acquired"
11196
if values is not None:
11297
query = query.values(values)
113-
query, args, result_columns = self._compile(query)
98+
query, args, context = self._compile(query)
11499
cursor = await self._connection.cursor()
115100
try:
116101
await cursor.execute(query, args)
@@ -123,7 +108,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
123108
try:
124109
for item in values:
125110
single_query = query.values(item)
126-
single_query, args, result_columns = self._compile(single_query)
111+
single_query, args, context = self._compile(single_query)
127112
await cursor.execute(single_query, args)
128113
finally:
129114
await cursor.close()
@@ -132,26 +117,38 @@ async def iterate(
132117
self, query: ClauseElement
133118
) -> typing.AsyncGenerator[typing.Any, None]:
134119
assert self._connection is not None, "Connection is not acquired"
135-
query, args, result_columns = self._compile(query)
120+
query, args, context = self._compile(query)
136121
cursor = await self._connection.cursor()
137122
try:
138123
await cursor.execute(query, args)
124+
metadata = ResultMetaData(context, cursor.description)
139125
async for row in cursor:
140-
yield Record(row, result_columns, self._dialect)
126+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
141127
finally:
142128
await cursor.close()
143129

144130
def transaction(self) -> TransactionBackend:
145131
return MySQLTransaction(self)
146132

147-
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
133+
def _compile(
134+
self, query: ClauseElement
135+
) -> typing.Tuple[str, dict, CompilationContext]:
148136
compiled = query.compile(dialect=self._dialect)
149137
args = compiled.construct_params()
150-
logger.debug(compiled.string, args)
151138
for key, val in args.items():
152139
if key in compiled._bind_processors:
153140
args[key] = compiled._bind_processors[key](val)
154-
return compiled.string, args, compiled._result_columns
141+
142+
execution_context = self._dialect.execution_ctx_cls()
143+
execution_context.dialect = self._dialect
144+
execution_context.result_column_struct = (
145+
compiled._result_columns,
146+
compiled._ordered_columns,
147+
compiled._textual_ordered_columns,
148+
)
149+
150+
logger.debug(compiled.string, args)
151+
return compiled.string, args, CompilationContext(execution_context)
155152

156153

157154
class MySQLTransaction(TransactionBackend):
@@ -176,10 +173,7 @@ async def start(self, is_root: bool) -> None:
176173

177174
async def commit(self) -> None:
178175
assert self._connection._connection is not None, "Connection is not acquired"
179-
if self._is_root: # pragma: no cover
180-
# In test cases the root transaction is never committed,
181-
# since we *always* wrap the test case up in a transaction
182-
# and rollback to a clean state at the end.
176+
if self._is_root:
183177
await self._connection._connection.commit()
184178
else:
185179
cursor = await self._connection._connection.cursor()

databases/backends/postgres.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
class PostgresBackend(DatabaseBackend):
19-
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
20-
self._database_url = DatabaseURL(database_url)
19+
def __init__(self, database_url: DatabaseURL) -> None:
20+
self._database_url = database_url
2121
self._dialect = self._get_dialect()
2222
self._pool = None
2323

databases/backends/sqlite.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import logging
2+
import typing
3+
import uuid
4+
5+
import aiosqlite
6+
from sqlalchemy.dialects.sqlite import pysqlite
7+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
8+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
9+
from sqlalchemy.sql import ClauseElement
10+
from sqlalchemy.types import TypeEngine
11+
12+
from databases.core import DatabaseURL
13+
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
14+
15+
logger = logging.getLogger("databases")
16+
17+
18+
class SQLiteBackend(DatabaseBackend):
19+
def __init__(self, database_url: DatabaseURL) -> None:
20+
self._database_url = database_url
21+
self._dialect = pysqlite.dialect(paramstyle="qmark")
22+
self._pool = SQLitePool(database_url)
23+
24+
async def connect(self) -> None:
25+
pass
26+
# assert self._pool is None, "DatabaseBackend is already running"
27+
# self._pool = await aiomysql.create_pool(
28+
# host=self._database_url.hostname,
29+
# port=self._database_url.port or 3306,
30+
# user=self._database_url.username or getpass.getuser(),
31+
# password=self._database_url.password,
32+
# db=self._database_url.database,
33+
# autocommit=True,
34+
# )
35+
36+
async def disconnect(self) -> None:
37+
pass
38+
# assert self._pool is not None, "DatabaseBackend is not running"
39+
# self._pool.close()
40+
# await self._pool.wait_closed()
41+
# self._pool = None
42+
43+
def connection(self) -> "SQLiteConnection":
44+
return SQLiteConnection(self._pool, self._dialect)
45+
46+
47+
class SQLitePool:
48+
def __init__(self, url: DatabaseURL) -> None:
49+
self._url = url
50+
51+
async def acquire(self) -> aiosqlite.Connection:
52+
connection = aiosqlite.connect(
53+
database=self._url.database, isolation_level=None
54+
)
55+
await connection.__aenter__()
56+
return connection
57+
58+
async def release(self, connection: aiosqlite.Connection) -> None:
59+
await connection.__aexit__(None, None, None)
60+
61+
62+
class CompilationContext:
63+
def __init__(self, context: ExecutionContext):
64+
self.context = context
65+
66+
67+
class SQLiteConnection(ConnectionBackend):
68+
def __init__(self, pool: SQLitePool, dialect: Dialect):
69+
self._pool = pool
70+
self._dialect = dialect
71+
self._connection = None
72+
73+
async def acquire(self) -> None:
74+
assert self._connection is None, "Connection is already acquired"
75+
self._connection = await self._pool.acquire()
76+
77+
async def release(self) -> None:
78+
assert self._connection is not None, "Connection is not acquired"
79+
await self._pool.release(self._connection)
80+
self._connection = None
81+
82+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
83+
assert self._connection is not None, "Connection is not acquired"
84+
query, args, context = self._compile(query)
85+
86+
async with self._connection.execute(query, args) as cursor:
87+
rows = await cursor.fetchall()
88+
metadata = ResultMetaData(context, cursor.description)
89+
return [
90+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
91+
for row in rows
92+
]
93+
94+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
95+
assert self._connection is not None, "Connection is not acquired"
96+
query, args, context = self._compile(query)
97+
98+
async with self._connection.execute(query, args) as cursor:
99+
row = await cursor.fetchone()
100+
metadata = ResultMetaData(context, cursor.description)
101+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
102+
103+
async def execute(self, query: ClauseElement, values: dict = None) -> None:
104+
assert self._connection is not None, "Connection is not acquired"
105+
if values is not None:
106+
query = query.values(values)
107+
query, args, context = self._compile(query)
108+
cursor = await self._connection.execute(query, args)
109+
await cursor.close()
110+
111+
async def execute_many(self, query: ClauseElement, values: list) -> None:
112+
assert self._connection is not None, "Connection is not acquired"
113+
for value in values:
114+
await self.execute(query, value)
115+
116+
async def iterate(
117+
self, query: ClauseElement
118+
) -> typing.AsyncGenerator[typing.Any, None]:
119+
assert self._connection is not None, "Connection is not acquired"
120+
query, args, context = self._compile(query)
121+
cursor = await self._connection.cursor()
122+
async with self._connection.execute(query, args) as cursor:
123+
metadata = ResultMetaData(context, cursor.description)
124+
async for row in cursor:
125+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
126+
127+
def transaction(self) -> TransactionBackend:
128+
return SQLiteTransaction(self)
129+
130+
def _compile(
131+
self, query: ClauseElement
132+
) -> typing.Tuple[str, list, CompilationContext]:
133+
compiled = query.compile(dialect=self._dialect)
134+
args = []
135+
for key, raw_val in compiled.construct_params().items():
136+
if key in compiled._bind_processors:
137+
val = compiled._bind_processors[key](raw_val)
138+
else:
139+
val = raw_val
140+
args.append(val)
141+
142+
execution_context = self._dialect.execution_ctx_cls()
143+
execution_context.dialect = self._dialect
144+
execution_context.result_column_struct = (
145+
compiled._result_columns,
146+
compiled._ordered_columns,
147+
compiled._textual_ordered_columns,
148+
)
149+
150+
logger.debug(compiled.string, args)
151+
return compiled.string, args, CompilationContext(execution_context)
152+
153+
154+
class SQLiteTransaction(TransactionBackend):
155+
def __init__(self, connection: SQLiteConnection):
156+
self._connection = connection
157+
self._is_root = False
158+
self._savepoint_name = ""
159+
160+
async def start(self, is_root: bool) -> None:
161+
assert self._connection._connection is not None, "Connection is not acquired"
162+
self._is_root = is_root
163+
if self._is_root:
164+
cursor = await self._connection._connection.execute("BEGIN")
165+
await cursor.close()
166+
else:
167+
id = str(uuid.uuid4()).replace("-", "_")
168+
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
169+
cursor = await self._connection._connection.execute(
170+
f"SAVEPOINT {self._savepoint_name}"
171+
)
172+
await cursor.close()
173+
174+
async def commit(self) -> None:
175+
assert self._connection._connection is not None, "Connection is not acquired"
176+
if self._is_root:
177+
cursor = await self._connection._connection.execute("COMMIT")
178+
await cursor.close()
179+
else:
180+
cursor = await self._connection._connection.execute(
181+
f"RELEASE SAVEPOINT {self._savepoint_name}"
182+
)
183+
await cursor.close()
184+
185+
async def rollback(self) -> None:
186+
assert self._connection._connection is not None, "Connection is not acquired"
187+
if self._is_root:
188+
cursor = await self._connection._connection.execute("ROLLBACK")
189+
await cursor.close()
190+
else:
191+
cursor = await self._connection._connection.execute(
192+
f"ROLLBACK TO SAVEPOINT {self._savepoint_name}"
193+
)
194+
await cursor.close()

0 commit comments

Comments
 (0)