Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 76a7167

Browse files
Refactoring toards using SQLAlchemy's standard RowProxy for results
1 parent d1b0b8e commit 76a7167

3 files changed

Lines changed: 48 additions & 42 deletions

File tree

databases/backends/mysql.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import aiomysql
77
from sqlalchemy.dialects.mysql import pymysql
8-
from sqlalchemy.engine.interfaces import Dialect
8+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
9+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
910
from sqlalchemy.sql import ClauseElement
1011
from sqlalchemy.types import TypeEngine
1112

@@ -45,28 +46,9 @@ def connection(self) -> "MySQLConnection":
4546
return MySQLConnection(self._pool, self._dialect)
4647

4748

48-
class Record:
49-
def __init__(self, row: tuple, result_columns: tuple, dialect: Dialect) -> None:
50-
self._row = row
51-
self._result_columns = result_columns
52-
self._dialect = dialect
53-
self._column_map = {
54-
column_name: (idx, datatype)
55-
for idx, (column_name, _, _, datatype) in enumerate(self._result_columns)
56-
}
57-
58-
def __getitem__(self, key: str) -> typing.Any:
59-
idx, datatype = self._column_map[key]
60-
raw = self._row[idx]
61-
try:
62-
processor = _result_processors[datatype]
63-
except KeyError:
64-
processor = datatype.result_processor(self._dialect, None)
65-
_result_processors[datatype] = processor
66-
67-
if processor is not None:
68-
return processor(raw)
69-
return raw
49+
class CompilationContext:
50+
def __init__(self, context: ExecutionContext):
51+
self.context = context
7052

7153

7254
class MySQLConnection(ConnectionBackend):
@@ -84,33 +66,38 @@ async def release(self) -> None:
8466
await self._pool.release(self._connection)
8567
self._connection = None
8668

87-
async def fetch_all(self, query: ClauseElement) -> typing.Any:
69+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
8870
assert self._connection is not None, "Connection is not acquired"
89-
query, args, result_columns = self._compile(query)
71+
query, args, context = self._compile(query)
9072
cursor = await self._connection.cursor()
9173
try:
9274
await cursor.execute(query, args)
9375
rows = await cursor.fetchall()
94-
return [Record(row, result_columns, self._dialect) for row in rows]
76+
metadata = ResultMetaData(context, cursor.description)
77+
return [
78+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
79+
for row in rows
80+
]
9581
finally:
9682
await cursor.close()
9783

98-
async def fetch_one(self, query: ClauseElement) -> typing.Any:
84+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
9985
assert self._connection is not None, "Connection is not acquired"
100-
query, args, result_columns = self._compile(query)
86+
query, args, context = self._compile(query)
10187
cursor = await self._connection.cursor()
10288
try:
10389
await cursor.execute(query, args)
10490
row = await cursor.fetchone()
105-
return Record(row, result_columns, self._dialect)
91+
metadata = ResultMetaData(context, cursor.description)
92+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
10693
finally:
10794
await cursor.close()
10895

10996
async def execute(self, query: ClauseElement, values: dict = None) -> None:
11097
assert self._connection is not None, "Connection is not acquired"
11198
if values is not None:
11299
query = query.values(values)
113-
query, args, result_columns = self._compile(query)
100+
query, args, context = self._compile(query)
114101
cursor = await self._connection.cursor()
115102
try:
116103
await cursor.execute(query, args)
@@ -123,7 +110,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
123110
try:
124111
for item in values:
125112
single_query = query.values(item)
126-
single_query, args, result_columns = self._compile(single_query)
113+
single_query, args, context = self._compile(single_query)
127114
await cursor.execute(single_query, args)
128115
finally:
129116
await cursor.close()
@@ -132,26 +119,38 @@ async def iterate(
132119
self, query: ClauseElement
133120
) -> typing.AsyncGenerator[typing.Any, None]:
134121
assert self._connection is not None, "Connection is not acquired"
135-
query, args, result_columns = self._compile(query)
122+
query, args, context = self._compile(query)
136123
cursor = await self._connection.cursor()
137124
try:
138125
await cursor.execute(query, args)
126+
metadata = ResultMetaData(context, cursor.description)
139127
async for row in cursor:
140-
yield Record(row, result_columns, self._dialect)
128+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
141129
finally:
142130
await cursor.close()
143131

144132
def transaction(self) -> TransactionBackend:
145133
return MySQLTransaction(self)
146134

147-
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
135+
def _compile(
136+
self, query: ClauseElement
137+
) -> typing.Tuple[str, list, CompilationContext]:
148138
compiled = query.compile(dialect=self._dialect)
149139
args = compiled.construct_params()
150140
logger.debug(compiled.string, args)
151141
for key, val in args.items():
152142
if key in compiled._bind_processors:
153143
args[key] = compiled._bind_processors[key](val)
154-
return compiled.string, args, compiled._result_columns
144+
145+
execution_context = self._dialect.execution_ctx_cls()
146+
execution_context.dialect = self._dialect
147+
execution_context.result_column_struct = (
148+
compiled._result_columns,
149+
compiled._ordered_columns,
150+
compiled._textual_ordered_columns,
151+
)
152+
153+
return compiled.string, args, CompilationContext(execution_context)
155154

156155

157156
class MySQLTransaction(TransactionBackend):

databases/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from types import TracebackType
66
from urllib.parse import SplitResult, urlsplit
77

8+
from sqlalchemy.engine import RowProxy
89
from sqlalchemy.sql import ClauseElement
910

1011
from databases.importer import import_from_string
@@ -27,7 +28,6 @@ def __init__(
2728
):
2829
self._url = DatabaseURL(url)
2930
self._force_rollback = force_rollback
30-
3131
self.is_connected = False
3232

3333
backend_str = self.SUPPORTED_BACKENDS[self._url.dialect]
@@ -44,6 +44,9 @@ def __init__(
4444
self._global_transaction = None # type: typing.Optional[Transaction]
4545

4646
async def connect(self) -> None:
47+
"""
48+
Establish the connection pool.
49+
"""
4750
assert not self.is_connected, "Already connected."
4851

4952
await self._backend.connect()
@@ -57,6 +60,9 @@ async def connect(self) -> None:
5760
await self._global_transaction.__aenter__()
5861

5962
async def disconnect(self) -> None:
63+
"""
64+
Close all connections in the connection pool.
65+
"""
6066
assert self.is_connected, "Already disconnected."
6167

6268
if self._force_rollback:
@@ -80,11 +86,11 @@ async def __aexit__(
8086
) -> None:
8187
await self.disconnect()
8288

83-
async def fetch_all(self, query: ClauseElement) -> typing.Any:
89+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
8490
async with self.connection() as connection:
8591
return await connection.fetch_all(query=query)
8692

87-
async def fetch_one(self, query: ClauseElement) -> typing.Any:
93+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
8894
async with self.connection() as connection:
8995
return await connection.fetch_one(query=query)
9096

@@ -98,7 +104,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
98104

99105
async def iterate(
100106
self, query: ClauseElement
101-
) -> typing.AsyncGenerator[typing.Any, None]:
107+
) -> typing.AsyncGenerator[RowProxy, None]:
102108
async with self.connection() as connection:
103109
async for record in connection.iterate(query):
104110
yield record

databases/interfaces.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing
22

3+
from sqlalchemy.engine import RowProxy
34
from sqlalchemy.sql import ClauseElement
45

56

@@ -21,10 +22,10 @@ async def acquire(self) -> None:
2122
async def release(self) -> None:
2223
raise NotImplementedError() # pragma: no cover
2324

24-
async def fetch_all(self, query: ClauseElement) -> typing.Any:
25+
async def fetch_all(self, query: ClauseElement) -> typing.List[RowProxy]:
2526
raise NotImplementedError() # pragma: no cover
2627

27-
async def fetch_one(self, query: ClauseElement) -> typing.Any:
28+
async def fetch_one(self, query: ClauseElement) -> RowProxy:
2829
raise NotImplementedError() # pragma: no cover
2930

3031
async def execute(self, query: ClauseElement, values: dict = None) -> None:
@@ -35,7 +36,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
3536

3637
async def iterate(
3738
self, query: ClauseElement
38-
) -> typing.AsyncGenerator[typing.Any, None]:
39+
) -> typing.AsyncGenerator[RowProxy, None]:
3940
raise NotImplementedError() # pragma: no cover
4041
# mypy needs async iterators to contain a `yield`
4142
# https://github.com/python/mypy/issues/5385#issuecomment-407281656

0 commit comments

Comments
 (0)