Skip to content

Commit 6fb12ed

Browse files
committed
add bakery and fix tests
1 parent 3774ba0 commit 6fb12ed

6 files changed

Lines changed: 192 additions & 19 deletions

File tree

mysql_tests/test_bakery.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import pytest
2+
import sqlalchemy
3+
4+
from gino import UninitializedError, create_engine, InitializedError
5+
from gino.bakery import Bakery, BakedQuery
6+
from .models import db, User, MYSQL_URL
7+
8+
pytestmark = pytest.mark.asyncio
9+
10+
11+
@pytest.mark.parametrize(
12+
"query",
13+
[
14+
User.query.where(User.id == db.bindparam("uid")),
15+
sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"),
16+
"SELECT * FROM gino_users WHERE id = :uid",
17+
lambda: User.query.where(User.id == db.bindparam("uid")),
18+
lambda: sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"),
19+
lambda: "SELECT * FROM gino_users WHERE id = :uid",
20+
],
21+
)
22+
@pytest.mark.parametrize("options", [dict(return_model=False), dict(loader=User)])
23+
@pytest.mark.parametrize("api", [True, False])
24+
@pytest.mark.parametrize("timeout", [None, 1])
25+
async def test(query, options, sa_engine, api, timeout):
26+
uid = sa_engine.execute(User.insert()).lastrowid
27+
if timeout:
28+
options["timeout"] = timeout
29+
30+
if api:
31+
b = db._bakery
32+
qs = [db.bake(query, **options)]
33+
if callable(query):
34+
qs.append(db.bake(**options)(query))
35+
else:
36+
b = Bakery()
37+
qs = [b.bake(query, **options)]
38+
if callable(query):
39+
qs.append(b.bake(**options)(query))
40+
41+
for q in qs:
42+
assert isinstance(q, BakedQuery)
43+
assert q in list(b)
44+
assert q.sql is None
45+
assert q.compiled_sql is None
46+
47+
with pytest.raises(UninitializedError):
48+
q.bind.first()
49+
with pytest.raises(UninitializedError):
50+
await q.first()
51+
52+
for k, v in options.items():
53+
assert q.query.get_execution_options()[k] == v
54+
55+
if api:
56+
e = await db.set_bind(MYSQL_URL, minsize=1)
57+
else:
58+
e = await create_engine(MYSQL_URL, bakery=b, minsize=1)
59+
60+
with pytest.raises(InitializedError):
61+
b.bake("SELECT now()")
62+
63+
with pytest.raises(InitializedError):
64+
await create_engine(MYSQL_URL, bakery=b, minsize=0)
65+
66+
try:
67+
for q in qs:
68+
assert q.sql is not None
69+
assert q.compiled_sql is not None
70+
71+
if api:
72+
assert q.bind is e
73+
else:
74+
with pytest.raises(UninitializedError):
75+
q.bind.first()
76+
with pytest.raises(UninitializedError):
77+
await q.first()
78+
79+
if api:
80+
rv = await q.first(uid=uid)
81+
else:
82+
rv = await e.first(q, uid=uid)
83+
84+
if options.get("return_model", True):
85+
assert isinstance(rv, User)
86+
assert rv.id == uid
87+
else:
88+
assert rv[0] == rv[User.id] == rv["id"] == uid
89+
90+
eq = q.execution_options(return_model=True, loader=User)
91+
assert eq is not q
92+
assert isinstance(eq, BakedQuery)
93+
assert type(eq) is not BakedQuery
94+
assert eq in list(b)
95+
assert eq.sql == q.sql
96+
assert eq.compiled_sql is not q.compiled_sql
97+
98+
if api:
99+
assert q.bind is e
100+
else:
101+
with pytest.raises(UninitializedError):
102+
eq.bind.first()
103+
with pytest.raises(UninitializedError):
104+
await eq.first()
105+
106+
assert eq.query.get_execution_options()["return_model"]
107+
assert eq.query.get_execution_options()["loader"] is User
108+
109+
if api:
110+
rv = await eq.first(uid=uid)
111+
non = await eq.first(uid=uid + 1)
112+
rvl = await eq.all(uid=uid)
113+
else:
114+
rv = await e.first(eq, uid=uid)
115+
non = await e.first(eq, uid=uid + 1)
116+
rvl = await e.all(eq, uid=uid)
117+
118+
assert isinstance(rv, User)
119+
assert rv.id == uid
120+
121+
assert non is None
122+
123+
assert len(rvl) == 1
124+
assert rvl[0].id == uid
125+
126+
# original query is not affected
127+
if api:
128+
rv = await q.first(uid=uid)
129+
else:
130+
rv = await e.first(q, uid=uid)
131+
132+
if options.get("return_model", True):
133+
assert isinstance(rv, User)
134+
assert rv.id == uid
135+
else:
136+
assert rv[0] == rv[User.id] == rv["id"] == uid
137+
138+
finally:
139+
if api:
140+
await db.pop_bind().close()
141+
else:
142+
await e.close()
143+
144+
145+
async def test_class_level_bake():
146+
class BakeOnClass(db.Model):
147+
__tablename__ = "bake_on_class_test"
148+
149+
name = db.Column(db.String(255), primary_key=True)
150+
151+
@db.bake
152+
def getter(cls):
153+
return cls.query.where(cls.name == db.bindparam("name"))
154+
155+
async with db.with_bind(MYSQL_URL, prebake=False, autocommit=True):
156+
await db.gino.create_all()
157+
try:
158+
await BakeOnClass.create(name="exist")
159+
assert (await BakeOnClass.getter.one(name="exist")).name == "exist"
160+
assert (await BakeOnClass.getter.one_or_none(name="nonexist")) is None
161+
finally:
162+
await db.gino.drop_all()

mysql_tests/test_json.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ async def test_no_profile():
199199
class Test(db.Model):
200200
__tablename__ = "tests_no_profile"
201201

202+
id = db.Column(db.BigInteger(), primary_key=True)
202203
age = db.IntegerProperty(default=18)
203204

204205

@@ -213,11 +214,12 @@ def process_result_value(self, *_):
213214

214215
class PropsTest(db.Model):
215216
__tablename__ = "props_test_291"
216-
profile1 = db.Column(JSON(), nullable=False, default="{}")
217-
profile2 = db.Column(CustomJSON(), nullable=False, default="{}")
217+
profile = db.Column(JSON(), nullable=False, default={})
218+
profile1 = db.Column(JSON(), nullable=False, default={})
219+
profile2 = db.Column(CustomJSON(), nullable=False, default={})
218220

219-
bool = db.BooleanProperty(prop_name="profile1")
220-
bool1 = db.BooleanProperty(prop_name="profile2")
221+
bool = db.BooleanProperty()
222+
bool1 = db.BooleanProperty(prop_name="profile1")
221223

222224
await PropsTest.gino.create()
223225
try:

mysql_tests/test_schema.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,3 @@ class Comment2(db.Model):
6161

6262
await db.gino.create_all()
6363
await db.gino.drop_all()
64-
65-
66-
async def test_no_alter(engine, mocker):
67-
engine.dialect.supports_alter = False
68-
warn = mocker.patch("warnings.warn")
69-
await test(engine, define=False)
70-
assert warn.called

src/gino/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def create_engine(*args, **kwargs):
1919
* **Pre-bake** immediately when connected to the database (default).
2020
* No **pre-bake** but create prepared statements lazily when needed for the first
2121
time.
22+
23+
Note: ``prebake`` has no effect in aiomysql
2224
"""
2325

2426
from sqlalchemy import create_engine

src/gino/declarative.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,9 @@ def _init_table(cls, sub_cls):
365365
json_col = getattr(
366366
sub_cls.__dict__.get(v.prop_name), "column", None
367367
)
368-
if not isinstance(json_col, sa.Column) or not isinstance(
369-
json_col.type, sa.JSON
368+
if not (
369+
isinstance(json_col, sa.Column) and
370+
isinstance(json_col.type, sa.JSON)
370371
):
371372
raise AttributeError(
372373
'{} "{}" requires a JSON[B] column "{}" '

src/gino/dialects/aiomysql.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
JSON_COLTYPE = 245
2525

2626
#: Regular expression for :meth:`Cursor.executemany`.
27-
#: executemany only suports simple bulk insert.
27+
#: executemany only supports simple bulk insert.
2828
#: You can use it to load large dataset.
2929
_RE_INSERT_VALUES = re.compile(
3030
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)"
@@ -158,6 +158,10 @@ async def async_execute(self, query, timeout, args, limit=0, many=False):
158158
self._async_executemany(conn, query, args), timeout=timeout
159159
)
160160

161+
async def execute_baked(self, baked_query, timeout, args, one):
162+
# TODO: use prepare when it's supported
163+
return await self.async_execute(baked_query.sql, timeout, args)
164+
161165
async def _async_execute(self, conn, query, timeout, args):
162166
if args is not None:
163167
query = query % _escape_args(args, conn)
@@ -229,12 +233,14 @@ def iterate(self, context):
229233

230234

231235
class Pool(base.Pool):
232-
def __init__(self, url, loop, init=None, **kwargs):
236+
def __init__(self, url, loop, init=None, bakery=None, prebake=True, **kwargs):
233237
self._url = url
234238
self._loop = loop
235239
self._kwargs = kwargs
236240
self._pool = None
237241
self._conn_init = init
242+
self._bakery = bakery
243+
self._prebake = prebake
238244

239245
async def _init(self):
240246
args = self._kwargs.copy()
@@ -366,6 +372,7 @@ class AiomysqlDialect(MySQLDialect, base.AsyncDialectMixin):
366372
cursor_cls = DBAPICursor
367373
init_kwargs = set(
368374
itertools.chain(
375+
("bakery", "prebake"),
369376
*[
370377
inspect.getfullargspec(f).args
371378
for f in [aiomysql.create_pool, aiomysql.connect]
@@ -376,24 +383,30 @@ class AiomysqlDialect(MySQLDialect, base.AsyncDialectMixin):
376383
} # use SQLAlchemy's echo instead
377384
colspecs = util.update_copy(
378385
MySQLDialect.colspecs,
379-
{ENUM: AsyncEnum, sqltypes.Enum: AsyncEnum, sqltypes.NullType: GinoNullType,},
386+
{
387+
ENUM: AsyncEnum,
388+
sqltypes.Enum: AsyncEnum,
389+
sqltypes.NullType: GinoNullType,
390+
},
380391
)
381392
postfetch_lastrowid = False
382393
support_returning = False
383394
support_prepare = False
384395

385-
def __init__(self, *args, **kwargs):
396+
def __init__(self, *args, bakery=None, **kwargs):
386397
self._pool_kwargs = {}
387398
for k in self.init_kwargs:
388399
if k in kwargs:
389400
self._pool_kwargs[k] = kwargs.pop(k)
390401
super().__init__(*args, **kwargs)
391-
self._init_mixin()
402+
self._init_mixin(bakery)
392403

393404
async def init_pool(self, url, loop, pool_class=None):
394405
if pool_class is None:
395406
pool_class = Pool
396-
return await pool_class(url, loop, init=self.on_connect(), **self._pool_kwargs)
407+
return await pool_class(
408+
url, loop, bakery=self._bakery, init=self.on_connect(), **self._pool_kwargs
409+
)
397410

398411
# noinspection PyMethodMayBeStatic
399412
def transaction(self, raw_conn, args, kwargs):

0 commit comments

Comments
 (0)