Skip to content

Commit 92fa73b

Browse files
committed
fix test_schema
1 parent 21ab49a commit 92fa73b

3 files changed

Lines changed: 50 additions & 124 deletions

File tree

mysql_tests/test_engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pymysql
66
import aiomysql
7-
import asyncpg
87
from gino import create_engine, UninitializedError
98
import pytest
109
from sqlalchemy.exc import ObjectNotExecutableError

mysql_tests/test_schema.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
import gino
6-
from gino.dialects.asyncpg import AsyncEnum
6+
from gino.dialects.aiomysql import AsyncEnum
77

88
pytestmark = pytest.mark.asyncio
99
db = gino.Gino()
@@ -18,7 +18,7 @@ class Blog(db.Model):
1818
__tablename__ = "s_blog"
1919

2020
id = db.Column(db.BigInteger(), primary_key=True)
21-
title = db.Column(db.Unicode(), index=True, comment="Title Comment")
21+
title = db.Column(db.Unicode(255), index=True, comment="Title Comment")
2222
visits = db.Column(db.BigInteger(), default=0)
2323
comment_id = db.Column(db.ForeignKey("s_comment.id"))
2424
number = db.Column(db.Enum(MyEnum), nullable=False, default=MyEnum.TWO)
@@ -36,44 +36,31 @@ class Comment(db.Model):
3636

3737

3838
async def test(engine, define=True):
39-
try:
40-
async with engine.acquire() as conn:
41-
assert not await engine.dialect.has_schema(conn, "schema_test")
42-
assert not await engine.dialect.has_table(conn, "non_exist")
43-
assert not await engine.dialect.has_sequence(conn, "non_exist")
44-
assert not await engine.dialect.has_type(conn, "non_exist")
45-
assert not await engine.dialect.has_type(
46-
conn, "non_exist", schema="schema_test"
47-
)
48-
await engine.status("create schema schema_test")
49-
Blog.__table__.schema = "schema_test"
50-
Blog.__table__.comment = "Blog Comment"
51-
Comment.__table__.schema = "schema_test"
52-
db.bind = engine
53-
await db.gino.create_all()
54-
await Blog.number.type.create_async(engine, checkfirst=True)
55-
await Blog.number2.type.create_async(engine, checkfirst=True)
56-
await db.gino.create_all(tables=[Blog.__table__], checkfirst=True)
57-
await blog_seq.gino.create(checkfirst=True)
58-
await Blog.__table__.gino.create(checkfirst=True)
59-
await db.gino.drop_all()
60-
await db.gino.drop_all(tables=[Blog.__table__], checkfirst=True)
61-
await Blog.__table__.gino.drop(checkfirst=True)
62-
await blog_seq.gino.drop(checkfirst=True)
63-
64-
if define:
65-
66-
class Comment2(db.Model):
67-
__tablename__ = "s_comment_2"
68-
69-
id = db.Column(db.BigInteger(), primary_key=True)
70-
blog_id = db.Column(db.ForeignKey("s_blog.id"))
71-
72-
Comment2.__table__.schema = "schema_test"
73-
await db.gino.create_all()
74-
await db.gino.drop_all()
75-
finally:
76-
await engine.status("drop schema schema_test cascade")
39+
async with engine.acquire() as conn:
40+
assert not await engine.dialect.has_table(conn, "non_exist")
41+
Blog.__table__.comment = "Blog Comment"
42+
db.bind = engine
43+
await db.gino.create_all()
44+
await Blog.number.type.create_async(engine, checkfirst=True)
45+
await Blog.number2.type.create_async(engine, checkfirst=True)
46+
await db.gino.create_all(tables=[Blog.__table__], checkfirst=True)
47+
await blog_seq.gino.create(checkfirst=True)
48+
await Blog.__table__.gino.create(checkfirst=True)
49+
await db.gino.drop_all()
50+
await db.gino.drop_all(tables=[Blog.__table__], checkfirst=True)
51+
await Blog.__table__.gino.drop(checkfirst=True)
52+
await blog_seq.gino.drop(checkfirst=True)
53+
54+
if define:
55+
56+
class Comment2(db.Model):
57+
__tablename__ = "s_comment_2"
58+
59+
id = db.Column(db.BigInteger(), primary_key=True)
60+
blog_id = db.Column(db.ForeignKey("s_blog.id"))
61+
62+
await db.gino.create_all()
63+
await db.gino.drop_all()
7764

7865

7966
async def test_no_alter(engine, mocker):

src/gino/dialects/aiomysql.py

Lines changed: 23 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aiomysql
99
from sqlalchemy import util, exc, sql
10-
from sqlalchemy.dialects.mysql import (JSON, json)
10+
from sqlalchemy.dialects.mysql import (JSON, ENUM)
1111
# from sqlalchemy.dialects.postgresql import ( # noqa: F401
1212
# ARRAY,
1313
# CreateEnumType,
@@ -353,6 +353,26 @@ async def rollback(self):
353353
await self._conn.rollback()
354354

355355

356+
class AsyncEnum(ENUM):
357+
async def create_async(self, bind=None, checkfirst=True):
358+
pass
359+
360+
async def drop_async(self, bind=None, checkfirst=True):
361+
pass
362+
363+
async def _on_table_create_async(self, target, bind, checkfirst=False, **kw):
364+
pass
365+
366+
async def _on_table_drop_async(self, target, bind, checkfirst=False, **kw):
367+
pass
368+
369+
async def _on_metadata_create_async(self, target, bind, checkfirst=False, **kw):
370+
pass
371+
372+
async def _on_metadata_drop_async(self, target, bind, checkfirst=False, **kw):
373+
pass
374+
375+
356376
class GinoNullType(sqltypes.NullType):
357377
def result_processor(self, dialect, coltype):
358378
if coltype == JSON_COLTYPE:
@@ -379,8 +399,8 @@ class AiomysqlDialect(MySQLDialect, base.AsyncDialectMixin):
379399
colspecs = util.update_copy(
380400
MySQLDialect.colspecs,
381401
{
382-
# ENUM: AsyncEnum,
383-
# sqltypes.Enum: AsyncEnum,
402+
ENUM: AsyncEnum,
403+
sqltypes.Enum: AsyncEnum,
384404
sqltypes.NullType: GinoNullType,
385405
},
386406
)
@@ -485,20 +505,6 @@ def _parse_server_version(self, val):
485505
version.append(n)
486506
return tuple(version)
487507

488-
489-
# async def has_schema(self, connection, schema):
490-
# row = await connection.first(
491-
# sql.text(
492-
# "select nspname from pg_namespace " "where lower(nspname)=:schema"
493-
# ).bindparams(
494-
# sql.bindparam(
495-
# "schema", util.text_type(schema.lower()), type_=sqltypes.Unicode,
496-
# )
497-
# )
498-
# )
499-
#
500-
# return bool(row)
501-
502508
async def has_table(self, connection, table_name, schema=None):
503509
full_name = ".".join(
504510
self.identifier_preparer._quote_free_identifiers(
@@ -518,72 +524,6 @@ def _extract_error_code(self, exception):
518524
if isinstance(exception.args[0], Exception):
519525
exception = exception.args[0]
520526
return exception.args[0]
521-
#
522-
# async def has_sequence(self, connection, sequence_name, schema=None):
523-
# if schema is None:
524-
# row = await connection.first(
525-
# sql.text(
526-
# "SELECT relname FROM pg_class c join pg_namespace n on "
527-
# "n.oid=c.relnamespace where relkind='S' and "
528-
# "n.nspname=current_schema() "
529-
# "and relname=:name"
530-
# ).bindparams(
531-
# sql.bindparam(
532-
# "name", util.text_type(sequence_name), type_=sqltypes.Unicode,
533-
# )
534-
# )
535-
# )
536-
# else:
537-
# row = await connection.first(
538-
# sql.text(
539-
# "SELECT relname FROM pg_class c join pg_namespace n on "
540-
# "n.oid=c.relnamespace where relkind='S' and "
541-
# "n.nspname=:schema and relname=:name"
542-
# ).bindparams(
543-
# sql.bindparam(
544-
# "name", util.text_type(sequence_name), type_=sqltypes.Unicode,
545-
# ),
546-
# sql.bindparam(
547-
# "schema", util.text_type(schema), type_=sqltypes.Unicode,
548-
# ),
549-
# )
550-
# )
551-
#
552-
# return bool(row)
553-
#
554-
# async def has_type(self, connection, type_name, schema=None):
555-
# if schema is not None:
556-
# query = """
557-
# SELECT EXISTS (
558-
# SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
559-
# WHERE t.typnamespace = n.oid
560-
# AND t.typname = :typname
561-
# AND n.nspname = :nspname
562-
# )
563-
# """
564-
# query = sql.text(query)
565-
# else:
566-
# query = """
567-
# SELECT EXISTS (
568-
# SELECT * FROM pg_catalog.pg_type t
569-
# WHERE t.typname = :typname
570-
# AND pg_type_is_visible(t.oid)
571-
# )
572-
# """
573-
# query = sql.text(query)
574-
# query = query.bindparams(
575-
# sql.bindparam(
576-
# "typname", util.text_type(type_name), type_=sqltypes.Unicode,
577-
# ),
578-
# )
579-
# if schema is not None:
580-
# query = query.bindparams(
581-
# sql.bindparam(
582-
# "nspname", util.text_type(schema), type_=sqltypes.Unicode,
583-
# ),
584-
# )
585-
# return bool(await connection.scalar(query))
586-
587527

588528
def _escape_args(args, conn):
589529
if isinstance(args, (tuple, list)):

0 commit comments

Comments
 (0)