77
88import aiomysql
99from sqlalchemy import util , exc
10- from sqlalchemy .dialects .mysql import ( JSON , ENUM )
10+ from sqlalchemy .dialects .mysql import JSON , ENUM
1111from sqlalchemy .dialects .mysql .base import (
1212 MySQLCompiler ,
1313 MySQLDialect ,
2727#: executemany only suports simple bulk insert.
2828#: You can use it to load large dataset.
2929_RE_INSERT_VALUES = re .compile (
30- r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
31- r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
32- r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z" ,
33- re .IGNORECASE | re .DOTALL )
30+ r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)"
31+ + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
32+ + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z" ,
33+ re .IGNORECASE | re .DOTALL ,
34+ )
3435
3536#: Max statement size which :meth:`executemany` generates.
3637#:
@@ -45,8 +46,7 @@ class AiomysqlDBAPI(base.BaseDBAPI):
4546
4647
4748# noinspection PyAbstractClass
48- class AiomysqlExecutionContext (base .ExecutionContextOverride ,
49- MySQLExecutionContext ):
49+ class AiomysqlExecutionContext (base .ExecutionContextOverride , MySQLExecutionContext ):
5050 async def _execute_scalar (self , stmt , type_ ):
5151 conn = self .root_connection
5252 if (
@@ -104,8 +104,7 @@ async def _init(self):
104104
105105 async def __anext__ (self ):
106106 await self ._init ()
107- row = await asyncio .wait_for (
108- self ._cursor .fetchone (), self ._context .timeout )
107+ row = await asyncio .wait_for (self ._cursor .fetchone (), self ._context .timeout )
109108 if row is None :
110109 raise StopAsyncIteration
111110 return self ._context .process_rows ([row ])[0 ]
@@ -129,7 +128,7 @@ async def forward(self, n, *, timeout=base.DEFAULT):
129128 await self ._init ()
130129 if timeout is base .DEFAULT :
131130 timeout = self ._context .timeout
132- await asyncio .wait_for (self ._cursor .scroll (n , mode = ' relative' ), timeout )
131+ await asyncio .wait_for (self ._cursor .scroll (n , mode = " relative" ), timeout )
133132
134133
135134class DBAPICursor (base .DBAPICursor ):
@@ -156,7 +155,8 @@ async def async_execute(self, query, timeout, args, limit=0, many=False):
156155 return await self ._async_execute (conn , query , timeout , args )
157156
158157 return await asyncio .wait_for (
159- self ._async_executemany (conn , query , args ), timeout = timeout )
158+ self ._async_executemany (conn , query , args ), timeout = timeout
159+ )
160160
161161 async def _async_execute (self , conn , query , timeout , args ):
162162 if args is not None :
@@ -175,9 +175,10 @@ async def _async_executemany(self, conn, query, args):
175175 if m :
176176 q_prefix = m .group (1 )
177177 q_values = m .group (2 ).rstrip ()
178- q_postfix = m .group (3 ) or ''
179- return (await self ._do_execute_many (
180- conn , q_prefix , q_values , q_postfix , args ))
178+ q_postfix = m .group (3 ) or ""
179+ return await self ._do_execute_many (
180+ conn , q_prefix , q_values , q_postfix , args
181+ )
181182 else :
182183 rows = 0
183184 for arg in args :
@@ -196,19 +197,19 @@ async def _do_execute_many(self, conn, prefix, values, postfix, args):
196197 args = iter (args )
197198 v = values % escape (next (args ), conn )
198199 if isinstance (v , str ):
199- v = v .encode (conn .encoding , ' surrogateescape' )
200+ v = v .encode (conn .encoding , " surrogateescape" )
200201 stmt += v
201202 rows = 0
202203 for arg in args :
203204 v = values % escape (arg , conn )
204205 if isinstance (v , str ):
205- v = v .encode (conn .encoding , ' surrogateescape' )
206+ v = v .encode (conn .encoding , " surrogateescape" )
206207 if len (stmt ) + len (v ) + len (postfix ) + 1 > _MAX_STMT_LENGTH :
207208 await self ._async_execute (conn , stmt + postfix , None , None )
208209 rows += self .affected_rows
209210 stmt = bytearray (prefix )
210211 else :
211- stmt += b','
212+ stmt += b","
212213 stmt += v
213214 await self ._async_execute (conn , stmt + postfix , None , None )
214215 rows += self .affected_rows
@@ -224,8 +225,7 @@ def get_statusmsg(self):
224225
225226 def iterate (self , context ):
226227 # use SSCursor to get server side cursor
227- return AiomysqlIterator (
228- context , aiomysql .SSCursor (self ._conn .raw_connection ))
228+ return AiomysqlIterator (context , aiomysql .SSCursor (self ._conn .raw_connection ))
229229
230230
231231class Pool (base .Pool ):
@@ -288,7 +288,7 @@ def repr(self, color):
288288 + "."
289289 + self ._pool .__class__ .__name__ ,
290290 fg = "green" ,
291- ),
291+ ),
292292 max = click .style (repr (self ._pool .maxsize ), fg = "cyan" ),
293293 min = click .style (repr (self ._pool ._minsize ), fg = "cyan" ),
294294 cur = click .style (repr (self ._pool .size ), fg = "cyan" ),
@@ -298,8 +298,8 @@ def repr(self, color):
298298 # noinspection PyProtectedMember
299299 return "<{classname} max={max} min={min} cur={cur} use={use}>" .format (
300300 classname = self ._pool .__class__ .__module__
301- + "."
302- + self ._pool .__class__ .__name__ ,
301+ + "."
302+ + self ._pool .__class__ .__name__ ,
303303 max = self ._pool .maxsize ,
304304 min = self ._pool ._minsize ,
305305 cur = self ._pool .size ,
@@ -371,14 +371,12 @@ class AiomysqlDialect(MySQLDialect, base.AsyncDialectMixin):
371371 for f in [aiomysql .create_pool , aiomysql .connect ]
372372 ]
373373 )
374- ) - {'echo' } # use SQLAlchemy's echo instead
374+ ) - {
375+ "echo"
376+ } # use SQLAlchemy's echo instead
375377 colspecs = util .update_copy (
376378 MySQLDialect .colspecs ,
377- {
378- ENUM : AsyncEnum ,
379- sqltypes .Enum : AsyncEnum ,
380- sqltypes .NullType : GinoNullType ,
381- },
379+ {ENUM : AsyncEnum , sqltypes .Enum : AsyncEnum , sqltypes .NullType : GinoNullType ,},
382380 )
383381 postfetch_lastrowid = False
384382 support_returning = False
@@ -395,15 +393,16 @@ def __init__(self, *args, **kwargs):
395393 async def init_pool (self , url , loop , pool_class = None ):
396394 if pool_class is None :
397395 pool_class = Pool
398- return await pool_class (
399- url , loop , init = self .on_connect (), ** self ._pool_kwargs )
396+ return await pool_class (url , loop , init = self .on_connect (), ** self ._pool_kwargs )
400397
401398 # noinspection PyMethodMayBeStatic
402399 def transaction (self , raw_conn , args , kwargs ):
403400 _set_isolation = None
404- if 'isolation' in kwargs :
401+ if "isolation" in kwargs :
402+
405403 async def _set_isolation (conn ):
406- await self .set_isolation_level (conn , kwargs ['isolation' ])
404+ await self .set_isolation_level (conn , kwargs ["isolation" ])
405+
407406 return Transaction (raw_conn , _set_isolation )
408407
409408 def on_connect (self ):
@@ -428,15 +427,13 @@ async def _set_isolation_level(self, connection, level):
428427 % (level , self .name , ", " .join (self ._isolation_lookup ))
429428 )
430429 cursor = await connection .cursor ()
431- await cursor .execute (
432- "SET SESSION TRANSACTION ISOLATION LEVEL %s" % level )
430+ await cursor .execute ("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level )
433431 await cursor .execute ("COMMIT" )
434432 await cursor .close ()
435433
436434 async def get_isolation_level (self , connection ):
437435 if self .server_version_info is None :
438- self .server_version_info = await self ._get_server_version_info (
439- connection )
436+ self .server_version_info = await self ._get_server_version_info (connection )
440437 cursor = await connection .cursor ()
441438 if self ._is_mysql and self .server_version_info >= (5 , 7 , 20 ):
442439 await cursor .execute ("SELECT @@transaction_isolation" )
@@ -484,9 +481,7 @@ def _parse_server_version(self, val):
484481
485482 async def has_table (self , connection , table_name , schema = None ):
486483 full_name = "." .join (
487- self .identifier_preparer ._quote_free_identifiers (
488- schema , table_name
489- )
484+ self .identifier_preparer ._quote_free_identifiers (schema , table_name )
490485 )
491486
492487 st = "DESCRIBE %s" % full_name
@@ -502,6 +497,7 @@ def _extract_error_code(self, exception):
502497 exception = exception .args [0 ]
503498 return exception .args [0 ]
504499
500+
505501def _escape_args (args , conn ):
506502 if isinstance (args , (tuple , list )):
507503 return tuple (conn .escape (arg ) for arg in args )
0 commit comments