Skip to content

Commit ab5a68a

Browse files
authored
Merge pull request #497 from fantix/starlette
fix #486, add Starlette support
2 parents b4ad66a + e384f25 commit ab5a68a

6 files changed

Lines changed: 352 additions & 3 deletions

File tree

gino/ext/aiohttp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def init_app(self, app, config=None, *, db_attr_name='db'):
122122
else:
123123
config = config.copy()
124124

125-
async def before_server_start(app_):
126-
if config.get('dsn'):
125+
async def before_server_start(_):
126+
if 'dsn' in config:
127127
dsn = config['dsn']
128128
else:
129129
dsn = URL(
@@ -144,7 +144,7 @@ async def before_server_start(app_):
144144
**config.setdefault('kwargs', dict()),
145145
)
146146

147-
async def after_server_stop(app_):
147+
async def after_server_stop(_):
148148
await self.pop_bind().close()
149149

150150
app.on_startup.append(before_server_start)

gino/ext/quart.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22

3+
# noinspection PyPackageRequirements
34
from quart import Quart, request
5+
# noinspection PyPackageRequirements
46
from quart.exceptions import NotFound
57
from sqlalchemy.engine.url import URL
68

gino/ext/starlette.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# noinspection PyPackageRequirements
2+
from starlette.applications import Starlette
3+
# noinspection PyPackageRequirements
4+
from starlette.types import Message, Receive, Scope, Send
5+
# noinspection PyPackageRequirements
6+
from starlette.exceptions import HTTPException
7+
# noinspection PyPackageRequirements
8+
from starlette import status
9+
from sqlalchemy.engine.url import URL
10+
11+
from ..api import Gino as _Gino, GinoExecutor as _Executor
12+
from ..engine import GinoConnection as _Connection, GinoEngine as _Engine
13+
from ..strategies import GinoStrategy
14+
15+
16+
class StarletteModelMixin:
17+
@classmethod
18+
async def get_or_404(cls, *args, **kwargs):
19+
# noinspection PyUnresolvedReferences
20+
rv = await cls.get(*args, **kwargs)
21+
if rv is None:
22+
raise HTTPException(status.HTTP_404_NOT_FOUND,
23+
'{} is not found'.format(cls.__name__))
24+
return rv
25+
26+
27+
# noinspection PyClassHasNoInit
28+
class GinoExecutor(_Executor):
29+
async def first_or_404(self, *args, **kwargs):
30+
rv = await self.first(*args, **kwargs)
31+
if rv is None:
32+
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
33+
return rv
34+
35+
36+
# noinspection PyClassHasNoInit
37+
class GinoConnection(_Connection):
38+
async def first_or_404(self, *args, **kwargs):
39+
rv = await self.first(*args, **kwargs)
40+
if rv is None:
41+
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
42+
return rv
43+
44+
45+
# noinspection PyClassHasNoInit
46+
class GinoEngine(_Engine):
47+
connection_cls = GinoConnection
48+
49+
async def first_or_404(self, *args, **kwargs):
50+
rv = await self.first(*args, **kwargs)
51+
if rv is None:
52+
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
53+
return rv
54+
55+
56+
class StarletteStrategy(GinoStrategy):
57+
name = 'starlette'
58+
engine_cls = GinoEngine
59+
60+
61+
StarletteStrategy()
62+
63+
64+
class _Middleware:
65+
def __init__(self, app, db):
66+
self.app = app
67+
self.db = db
68+
69+
async def __call__(self, scope: Scope, receive: Receive,
70+
send: Send) -> None:
71+
if (scope['type'] == 'http' and
72+
self.db.config['use_connection_for_request']):
73+
scope['connection'] = await self.db.acquire(lazy=True)
74+
await self.app(scope, receive, send)
75+
conn = scope.pop('connection', None)
76+
if conn is not None:
77+
await conn.release()
78+
return
79+
80+
if scope['type'] == 'lifespan':
81+
async def receiver() -> Message:
82+
message = await receive()
83+
if message["type"] == "lifespan.startup":
84+
await self.db.set_bind(
85+
self.db.config['dsn'],
86+
echo=self.db.config['echo'],
87+
min_size=self.db.config['min_size'],
88+
max_size=self.db.config['max_size'],
89+
ssl=self.db.config['ssl'],
90+
**self.db.config['kwargs'],
91+
)
92+
elif message["type"] == "lifespan.shutdown":
93+
await self.db.pop_bind().close()
94+
return message
95+
await self.app(scope, receiver, send)
96+
return
97+
98+
await self.app(scope, receive, send)
99+
100+
101+
class Gino(_Gino):
102+
"""Support Starlette server.
103+
104+
The common usage looks like this::
105+
106+
from starlette.applications import Starlette
107+
from gino.ext.starlette import Gino
108+
109+
app = Starlette()
110+
db = Gino(app, **kwargs)
111+
112+
GINO adds a middleware to the Starlette app to setup and cleanup database
113+
according to the configurations that passed in the ``kwargs`` parameter.
114+
115+
The config includes:
116+
117+
* ``driver`` - the database driver, default is ``asyncpg``.
118+
* ``host`` - database server host, default is ``localhost``.
119+
* ``port`` - database server port, default is ``5432``.
120+
* ``user`` - database server user, default is ``postgres``.
121+
* ``password`` - database server password, default is empty.
122+
* ``database`` - database name, default is ``postgres``.
123+
* ``dsn`` - a SQLAlchemy database URL to create the engine, its existence
124+
will replace all previous connect arguments.
125+
* ``pool_min_size`` - the initial number of connections of the db pool.
126+
* ``pool_max_size`` - the maximum number of connections in the db pool.
127+
* ``echo`` - enable SQLAlchemy echo mode.
128+
* ``ssl`` - SSL context passed to ``asyncpg.connect``, default is ``None``.
129+
* ``use_connection_for_request`` - flag to set up lazy connection for
130+
requests.
131+
* ``kwargs`` - other parameters passed to the specified dialects,
132+
like ``asyncpg``. Unrecognized parameters will cause exceptions.
133+
134+
If ``use_connection_for_request`` is set to be True, then a lazy connection
135+
is available at ``request['connection']``. By default, a database
136+
connection is borrowed on the first query, shared in the same execution
137+
context, and returned to the pool on response. If you need to release the
138+
connection early in the middle to do some long-running tasks, you can
139+
simply do this::
140+
141+
await request['connection'].release(permanent=False)
142+
143+
"""
144+
model_base_classes = _Gino.model_base_classes + (StarletteModelMixin,)
145+
query_executor = GinoExecutor
146+
147+
def __init__(self, app: Starlette, *args, **kwargs):
148+
self.config = dict()
149+
if 'dsn' in kwargs:
150+
self.config['dsn'] = kwargs.pop('dsn')
151+
else:
152+
self.config['dsn'] = URL(
153+
drivername=kwargs.pop('driver', 'asyncpg'),
154+
host=kwargs.pop('host', 'localhost'),
155+
port=kwargs.pop('port', 5432),
156+
username=kwargs.pop('user', 'postgres'),
157+
password=kwargs.pop('password', ''),
158+
database=kwargs.pop('database', 'postgres'),
159+
)
160+
self.config['echo'] = kwargs.pop('echo', False)
161+
self.config['min_size'] = kwargs.pop('pool_min_size', 5)
162+
self.config['max_size'] = kwargs.pop('pool_max_size', 10)
163+
self.config['ssl'] = kwargs.pop('ssl', None)
164+
self.config['use_connection_for_request'] = \
165+
kwargs.pop('use_connection_for_request', True)
166+
self.config['kwargs'] = kwargs.pop('kwargs', dict())
167+
168+
super().__init__(*args, **kwargs)
169+
170+
app.add_middleware(_Middleware, db=self)
171+
172+
async def first_or_404(self, *args, **kwargs):
173+
rv = await self.first(*args, **kwargs)
174+
if rv is None:
175+
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
176+
return rv
177+
178+
async def set_bind(self, bind, loop=None, **kwargs):
179+
kwargs.setdefault('strategy', 'starlette')
180+
return await super().set_bind(bind, loop=loop, **kwargs)

requirements_dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ aiohttp==3.5.0 # pyup: update minor
1010
tornado==6.0 # pyup: update minor
1111
async_generator==1.10 # pyup: update minor
1212
quart==0.9.1;python_version>="3.7" # pyup: update minor
13+
starlette==0.12.0;python_version>="3.6" # pyup: update minor
1314

1415
# tests
1516
coverage==4.5.1 # pyup: update minor

tests/test_sanic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
from async_generator import yield_, async_generator
24
import pytest
35
import sanic
@@ -11,6 +13,11 @@
1113
_MAX_INACTIVE_CONNECTION_LIFETIME = 59.0
1214

1315

16+
def teardown_module():
17+
# sanic server will close the loop during shutdown
18+
asyncio.set_event_loop(asyncio.new_event_loop())
19+
20+
1421
# noinspection PyShadowingNames
1522
async def _app(config):
1623
app = sanic.Sanic()

tests/test_starlette.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import sys
2+
3+
import pytest
4+
5+
# Starlette only supports Python 3.6 or later
6+
if sys.version_info < (3, 6):
7+
raise pytest.skip(allow_module_level=True)
8+
9+
from async_generator import yield_, async_generator
10+
import pytest
11+
from starlette.applications import Starlette
12+
from starlette.responses import JSONResponse, PlainTextResponse
13+
from starlette.testclient import TestClient
14+
15+
import gino
16+
from gino.ext.starlette import Gino
17+
18+
from .models import DB_ARGS, PG_URL
19+
20+
_MAX_INACTIVE_CONNECTION_LIFETIME = 59.0
21+
22+
23+
# noinspection PyShadowingNames
24+
async def _app(**kwargs):
25+
app = Starlette()
26+
kwargs.update({
27+
'kwargs': dict(
28+
max_inactive_connection_lifetime=_MAX_INACTIVE_CONNECTION_LIFETIME,
29+
),
30+
})
31+
db = Gino(app, **kwargs)
32+
33+
class User(db.Model):
34+
__tablename__ = 'gino_users'
35+
36+
id = db.Column(db.BigInteger(), primary_key=True)
37+
nickname = db.Column(db.Unicode(), default='noname')
38+
39+
@app.route('/')
40+
async def root(request):
41+
conn = await request['connection'].get_raw_connection()
42+
# noinspection PyProtectedMember
43+
assert conn._holder._max_inactive_time == \
44+
_MAX_INACTIVE_CONNECTION_LIFETIME
45+
return PlainTextResponse('Hello, world!')
46+
47+
@app.route('/users/{uid:int}')
48+
async def get_user(request):
49+
uid = request.path_params.get('uid')
50+
method = request.query_params.get('method')
51+
q = User.query.where(User.id == uid)
52+
if method == '1':
53+
return JSONResponse((await q.gino.first_or_404()).to_dict())
54+
elif method == '2':
55+
return JSONResponse(
56+
(await request['connection'].first_or_404(q)).to_dict())
57+
elif method == '3':
58+
return JSONResponse(
59+
(await db.bind.first_or_404(q)).to_dict())
60+
elif method == '4':
61+
return JSONResponse(
62+
(await db.first_or_404(q)).to_dict())
63+
else:
64+
return JSONResponse((await User.get_or_404(uid)).to_dict())
65+
66+
@app.route('/users', methods=['POST'])
67+
async def add_user(request):
68+
u = await User.create(nickname=(await request.json()).get('name'))
69+
await u.query.gino.first_or_404()
70+
await db.first_or_404(u.query)
71+
await db.bind.first_or_404(u.query)
72+
await request['connection'].first_or_404(u.query)
73+
return JSONResponse(u.to_dict())
74+
75+
e = await gino.create_engine(PG_URL)
76+
try:
77+
try:
78+
await db.gino.create_all(e)
79+
await yield_(app)
80+
finally:
81+
await db.gino.drop_all(e)
82+
finally:
83+
await e.close()
84+
85+
86+
@pytest.fixture
87+
@async_generator
88+
async def app():
89+
await _app(
90+
host=DB_ARGS['host'],
91+
port=DB_ARGS['port'],
92+
user=DB_ARGS['user'],
93+
password=DB_ARGS['password'],
94+
database=DB_ARGS['database'],
95+
)
96+
97+
98+
@pytest.fixture
99+
@async_generator
100+
async def app_ssl(ssl_ctx):
101+
await _app(
102+
host=DB_ARGS['host'],
103+
port=DB_ARGS['port'],
104+
user=DB_ARGS['user'],
105+
password=DB_ARGS['password'],
106+
database=DB_ARGS['database'],
107+
ssl=ssl_ctx,
108+
)
109+
110+
111+
@pytest.fixture
112+
@async_generator
113+
async def app_dsn():
114+
await _app(dsn=PG_URL)
115+
116+
117+
def _test_index_returns_200(app):
118+
client = TestClient(app)
119+
with client:
120+
response = client.get('/')
121+
assert response.status_code == 200
122+
assert response.text == 'Hello, world!'
123+
124+
125+
def test_index_returns_200(app):
126+
_test_index_returns_200(app)
127+
128+
129+
def test_index_returns_200_dsn(app_dsn):
130+
_test_index_returns_200(app_dsn)
131+
132+
133+
def _test(app):
134+
client = TestClient(app)
135+
with client:
136+
for method in '01234':
137+
response = client.get('/users/1?method=' + method)
138+
assert response.status_code == 404
139+
140+
response = client.post('/users', json=dict(name='fantix'))
141+
assert response.status_code == 200
142+
assert response.json() == dict(id=1, nickname='fantix')
143+
144+
for method in '01234':
145+
response = client.get('/users/1?method=' + method)
146+
assert response.status_code == 200
147+
assert response.json() == dict(id=1, nickname='fantix')
148+
149+
150+
def test(app):
151+
_test(app)
152+
153+
154+
def test_ssl(app_ssl):
155+
_test(app_ssl)
156+
157+
158+
def test_dsn(app_dsn):
159+
_test(app_dsn)

0 commit comments

Comments
 (0)