55
66import aiomysql
77from sqlalchemy .dialects .mysql import pymysql
8- from sqlalchemy .engine .interfaces import Dialect
8+ from sqlalchemy .engine .interfaces import Dialect , ExecutionContext
9+ from sqlalchemy .engine .result import ResultMetaData , RowProxy
910from sqlalchemy .sql import ClauseElement
1011from sqlalchemy .types import TypeEngine
1112
1415
1516logger = logging .getLogger ("databases" )
1617
17- _result_processors = {} # type: dict
18-
1918
2019class MySQLBackend (DatabaseBackend ):
21- def __init__ (self , database_url : typing . Union [ str , DatabaseURL ] ) -> None :
22- self ._database_url = DatabaseURL ( database_url )
20+ def __init__ (self , database_url : DatabaseURL ) -> None :
21+ self ._database_url = database_url
2322 self ._dialect = pymysql .dialect (paramstyle = "pyformat" )
2423 self ._pool = None
2524
@@ -45,28 +44,9 @@ def connection(self) -> "MySQLConnection":
4544 return MySQLConnection (self ._pool , self ._dialect )
4645
4746
48- class Record :
49- def __init__ (self , row : tuple , result_columns : tuple , dialect : Dialect ) -> None :
50- self ._row = row
51- self ._result_columns = result_columns
52- self ._dialect = dialect
53- self ._column_map = {
54- column_name : (idx , datatype )
55- for idx , (column_name , _ , _ , datatype ) in enumerate (self ._result_columns )
56- }
57-
58- def __getitem__ (self , key : str ) -> typing .Any :
59- idx , datatype = self ._column_map [key ]
60- raw = self ._row [idx ]
61- try :
62- processor = _result_processors [datatype ]
63- except KeyError :
64- processor = datatype .result_processor (self ._dialect , None )
65- _result_processors [datatype ] = processor
66-
67- if processor is not None :
68- return processor (raw )
69- return raw
47+ class CompilationContext :
48+ def __init__ (self , context : ExecutionContext ):
49+ self .context = context
7050
7151
7252class MySQLConnection (ConnectionBackend ):
@@ -84,33 +64,38 @@ async def release(self) -> None:
8464 await self ._pool .release (self ._connection )
8565 self ._connection = None
8666
87- async def fetch_all (self , query : ClauseElement ) -> typing .Any :
67+ async def fetch_all (self , query : ClauseElement ) -> typing .List [ RowProxy ] :
8868 assert self ._connection is not None , "Connection is not acquired"
89- query , args , result_columns = self ._compile (query )
69+ query , args , context = self ._compile (query )
9070 cursor = await self ._connection .cursor ()
9171 try :
9272 await cursor .execute (query , args )
9373 rows = await cursor .fetchall ()
94- return [Record (row , result_columns , self ._dialect ) for row in rows ]
74+ metadata = ResultMetaData (context , cursor .description )
75+ return [
76+ RowProxy (metadata , row , metadata ._processors , metadata ._keymap )
77+ for row in rows
78+ ]
9579 finally :
9680 await cursor .close ()
9781
98- async def fetch_one (self , query : ClauseElement ) -> typing . Any :
82+ async def fetch_one (self , query : ClauseElement ) -> RowProxy :
9983 assert self ._connection is not None , "Connection is not acquired"
100- query , args , result_columns = self ._compile (query )
84+ query , args , context = self ._compile (query )
10185 cursor = await self ._connection .cursor ()
10286 try :
10387 await cursor .execute (query , args )
10488 row = await cursor .fetchone ()
105- return Record (row , result_columns , self ._dialect )
89+ metadata = ResultMetaData (context , cursor .description )
90+ return RowProxy (metadata , row , metadata ._processors , metadata ._keymap )
10691 finally :
10792 await cursor .close ()
10893
10994 async def execute (self , query : ClauseElement , values : dict = None ) -> None :
11095 assert self ._connection is not None , "Connection is not acquired"
11196 if values is not None :
11297 query = query .values (values )
113- query , args , result_columns = self ._compile (query )
98+ query , args , context = self ._compile (query )
11499 cursor = await self ._connection .cursor ()
115100 try :
116101 await cursor .execute (query , args )
@@ -123,7 +108,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
123108 try :
124109 for item in values :
125110 single_query = query .values (item )
126- single_query , args , result_columns = self ._compile (single_query )
111+ single_query , args , context = self ._compile (single_query )
127112 await cursor .execute (single_query , args )
128113 finally :
129114 await cursor .close ()
@@ -132,26 +117,38 @@ async def iterate(
132117 self , query : ClauseElement
133118 ) -> typing .AsyncGenerator [typing .Any , None ]:
134119 assert self ._connection is not None , "Connection is not acquired"
135- query , args , result_columns = self ._compile (query )
120+ query , args , context = self ._compile (query )
136121 cursor = await self ._connection .cursor ()
137122 try :
138123 await cursor .execute (query , args )
124+ metadata = ResultMetaData (context , cursor .description )
139125 async for row in cursor :
140- yield Record ( row , result_columns , self . _dialect )
126+ yield RowProxy ( metadata , row , metadata . _processors , metadata . _keymap )
141127 finally :
142128 await cursor .close ()
143129
144130 def transaction (self ) -> TransactionBackend :
145131 return MySQLTransaction (self )
146132
147- def _compile (self , query : ClauseElement ) -> typing .Tuple [str , list , tuple ]:
133+ def _compile (
134+ self , query : ClauseElement
135+ ) -> typing .Tuple [str , dict , CompilationContext ]:
148136 compiled = query .compile (dialect = self ._dialect )
149137 args = compiled .construct_params ()
150- logger .debug (compiled .string , args )
151138 for key , val in args .items ():
152139 if key in compiled ._bind_processors :
153140 args [key ] = compiled ._bind_processors [key ](val )
154- return compiled .string , args , compiled ._result_columns
141+
142+ execution_context = self ._dialect .execution_ctx_cls ()
143+ execution_context .dialect = self ._dialect
144+ execution_context .result_column_struct = (
145+ compiled ._result_columns ,
146+ compiled ._ordered_columns ,
147+ compiled ._textual_ordered_columns ,
148+ )
149+
150+ logger .debug (compiled .string , args )
151+ return compiled .string , args , CompilationContext (execution_context )
155152
156153
157154class MySQLTransaction (TransactionBackend ):
@@ -176,10 +173,7 @@ async def start(self, is_root: bool) -> None:
176173
177174 async def commit (self ) -> None :
178175 assert self ._connection ._connection is not None , "Connection is not acquired"
179- if self ._is_root : # pragma: no cover
180- # In test cases the root transaction is never committed,
181- # since we *always* wrap the test case up in a transaction
182- # and rollback to a clean state at the end.
176+ if self ._is_root :
183177 await self ._connection ._connection .commit ()
184178 else :
185179 cursor = await self ._connection ._connection .cursor ()
0 commit comments