55
66import aiomysql
77from sqlalchemy .dialects .mysql import pymysql
8- from sqlalchemy .engine .interfaces import Dialect
8+ from sqlalchemy .engine .interfaces import Dialect , ExecutionContext
9+ from sqlalchemy .engine .result import ResultMetaData , RowProxy
910from sqlalchemy .sql import ClauseElement
1011from sqlalchemy .types import TypeEngine
1112
@@ -45,28 +46,9 @@ def connection(self) -> "MySQLConnection":
4546 return MySQLConnection (self ._pool , self ._dialect )
4647
4748
48- class Record :
49- def __init__ (self , row : tuple , result_columns : tuple , dialect : Dialect ) -> None :
50- self ._row = row
51- self ._result_columns = result_columns
52- self ._dialect = dialect
53- self ._column_map = {
54- column_name : (idx , datatype )
55- for idx , (column_name , _ , _ , datatype ) in enumerate (self ._result_columns )
56- }
57-
58- def __getitem__ (self , key : str ) -> typing .Any :
59- idx , datatype = self ._column_map [key ]
60- raw = self ._row [idx ]
61- try :
62- processor = _result_processors [datatype ]
63- except KeyError :
64- processor = datatype .result_processor (self ._dialect , None )
65- _result_processors [datatype ] = processor
66-
67- if processor is not None :
68- return processor (raw )
69- return raw
49+ class CompilationContext :
50+ def __init__ (self , context : ExecutionContext ):
51+ self .context = context
7052
7153
7254class MySQLConnection (ConnectionBackend ):
@@ -84,33 +66,38 @@ async def release(self) -> None:
8466 await self ._pool .release (self ._connection )
8567 self ._connection = None
8668
87- async def fetch_all (self , query : ClauseElement ) -> typing .Any :
69+ async def fetch_all (self , query : ClauseElement ) -> typing .List [ RowProxy ] :
8870 assert self ._connection is not None , "Connection is not acquired"
89- query , args , result_columns = self ._compile (query )
71+ query , args , context = self ._compile (query )
9072 cursor = await self ._connection .cursor ()
9173 try :
9274 await cursor .execute (query , args )
9375 rows = await cursor .fetchall ()
94- return [Record (row , result_columns , self ._dialect ) for row in rows ]
76+ metadata = ResultMetaData (context , cursor .description )
77+ return [
78+ RowProxy (metadata , row , metadata ._processors , metadata ._keymap )
79+ for row in rows
80+ ]
9581 finally :
9682 await cursor .close ()
9783
98- async def fetch_one (self , query : ClauseElement ) -> typing . Any :
84+ async def fetch_one (self , query : ClauseElement ) -> RowProxy :
9985 assert self ._connection is not None , "Connection is not acquired"
100- query , args , result_columns = self ._compile (query )
86+ query , args , context = self ._compile (query )
10187 cursor = await self ._connection .cursor ()
10288 try :
10389 await cursor .execute (query , args )
10490 row = await cursor .fetchone ()
105- return Record (row , result_columns , self ._dialect )
91+ metadata = ResultMetaData (context , cursor .description )
92+ return RowProxy (metadata , row , metadata ._processors , metadata ._keymap )
10693 finally :
10794 await cursor .close ()
10895
10996 async def execute (self , query : ClauseElement , values : dict = None ) -> None :
11097 assert self ._connection is not None , "Connection is not acquired"
11198 if values is not None :
11299 query = query .values (values )
113- query , args , result_columns = self ._compile (query )
100+ query , args , context = self ._compile (query )
114101 cursor = await self ._connection .cursor ()
115102 try :
116103 await cursor .execute (query , args )
@@ -123,7 +110,7 @@ async def execute_many(self, query: ClauseElement, values: list) -> None:
123110 try :
124111 for item in values :
125112 single_query = query .values (item )
126- single_query , args , result_columns = self ._compile (single_query )
113+ single_query , args , context = self ._compile (single_query )
127114 await cursor .execute (single_query , args )
128115 finally :
129116 await cursor .close ()
@@ -132,26 +119,38 @@ async def iterate(
132119 self , query : ClauseElement
133120 ) -> typing .AsyncGenerator [typing .Any , None ]:
134121 assert self ._connection is not None , "Connection is not acquired"
135- query , args , result_columns = self ._compile (query )
122+ query , args , context = self ._compile (query )
136123 cursor = await self ._connection .cursor ()
137124 try :
138125 await cursor .execute (query , args )
126+ metadata = ResultMetaData (context , cursor .description )
139127 async for row in cursor :
140- yield Record ( row , result_columns , self . _dialect )
128+ yield RowProxy ( metadata , row , metadata . _processors , metadata . _keymap )
141129 finally :
142130 await cursor .close ()
143131
144132 def transaction (self ) -> TransactionBackend :
145133 return MySQLTransaction (self )
146134
147- def _compile (self , query : ClauseElement ) -> typing .Tuple [str , list , tuple ]:
135+ def _compile (
136+ self , query : ClauseElement
137+ ) -> typing .Tuple [str , list , CompilationContext ]:
148138 compiled = query .compile (dialect = self ._dialect )
149139 args = compiled .construct_params ()
150140 logger .debug (compiled .string , args )
151141 for key , val in args .items ():
152142 if key in compiled ._bind_processors :
153143 args [key ] = compiled ._bind_processors [key ](val )
154- return compiled .string , args , compiled ._result_columns
144+
145+ execution_context = self ._dialect .execution_ctx_cls ()
146+ execution_context .dialect = self ._dialect
147+ execution_context .result_column_struct = (
148+ compiled ._result_columns ,
149+ compiled ._ordered_columns ,
150+ compiled ._textual_ordered_columns ,
151+ )
152+
153+ return compiled .string , args , CompilationContext (execution_context )
155154
156155
157156class MySQLTransaction (TransactionBackend ):
0 commit comments