@@ -789,49 +789,69 @@ async def _query_and_update(bind, item, query, cols, execution_opts):
789789 if bind ._dialect .support_returning :
790790 # noinspection PyArgumentList
791791 query = query .returning (* cols )
792- row = await bind .first (query )
792+
793+ async def _execute_and_fetch (conn , query ):
794+ context , row = await conn ._first_with_context (query )
795+ if not bind ._dialect .support_returning :
796+ if context .isinsert :
797+ table = context .compiled .statement .table
798+ key_getter = context .compiled ._key_getters_for_crud_column [2 ]
799+ compiled_params = context .compiled_parameters [0 ]
800+ last_row_id = context .get_lastrowid ()
801+ if last_row_id is not None :
802+ lookup_conds = [
803+ c == last_row_id
804+ if c is table ._autoincrement_column
805+ else c == _cast_json (
806+ c , compiled_params .get (key_getter (c ), None ))
807+ for c in table .primary_key
808+ ]
809+ else :
810+ lookup_conds = [
811+ c == _cast_json (
812+ c , compiled_params .get (key_getter (c ), None ))
813+ for c in table .columns
814+ ]
815+ query = sa .select (table .columns ).where (
816+ sa .and_ (* lookup_conds )).execution_options (** execution_opts )
817+ row = await conn .first (query )
818+ elif context .isupdate :
819+ if context .get_affected_rows () == 0 :
820+ raise NoSuchRowError ()
821+ table = context .compiled .statement .table
822+ if len (table .primary_key ) > 0 :
823+ lookup_conds = [
824+ c == _cast_json (
825+ c , item .__values__ [
826+ item ._column_name_map .invert_get (c .name )])
827+ for c in table .primary_key
828+ ]
829+ else :
830+ lookup_conds = [
831+ c == _cast_json (
832+ c , item .__values__ [
833+ item ._column_name_map .invert_get (c .name )])
834+ for c in table .columns
835+ ]
836+ query = sa .select (table .columns ).where (
837+ sa .and_ (* lookup_conds )).execution_options (** execution_opts )
838+ row = await conn .first (query )
839+ return row
840+
841+ if isinstance (bind , GinoConnection ):
842+ row = await _execute_and_fetch (bind , query )
793843 else :
794- # CAVEAT: MySQL doesn't support RETURNING. The workaround here is
795- # to get lastrowid and load it after insertion.
796- # Note that this only works for tables with AUTO_INCREMENT column
797- # For update queries, update using its primary key
798-
799- # make insertion and select in one transaction to get the might-be
800- # "dirty" row
801- release_conn = False
802- if not isinstance (bind , GinoConnection ):
803- conn = await bind .acquire (reuse = True )
804- release_conn = True
805- else :
806- conn = bind
807- try :
808- lastrowid , affected_rows = await conn .all (
809- query .execution_options (return_affected_rows = True )
810- )
811- if not lastrowid and not affected_rows :
812- raise NoSuchRowError ()
813- # It's insertion and primary key is AUTO_INCREMENT
814- if lastrowid :
815- pkey = cls .__table__ .primary_key
816- query = (
817- sa .select (cols )
818- .where (pkey .columns .values ()[0 ] == lastrowid )
819- .execution_options (** execution_opts )
820- )
821- else :
822- try :
823- query = (
824- sa .select (cols )
825- .where (item .lookup ())
826- .execution_options (** execution_opts )
827- )
828- except LookupError : # no primary key
829- return None
830- row = await conn .first (query )
831- finally :
832- if release_conn :
833- await conn .release ()
844+ async with bind .acquire (reuse = True ) as conn :
845+ row = await _execute_and_fetch (conn , query )
834846 if not row :
835847 raise NoSuchRowError ()
836848 for k , v in row .items ():
837849 item .__values__ [item ._column_name_map .invert_get (k )] = v
850+
851+
852+ def _cast_json (column , value ):
853+ # FIXME: for MySQL, json string in WHERE clause needs to be cast to JSON type
854+ if (isinstance (column .type , sa .JSON ) or
855+ isinstance (getattr (column .type , 'impl' , None ), sa .JSON )):
856+ return sa .cast (value , sa .JSON )
857+ return value
0 commit comments