Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 4651671

Browse files
committed
[WIP] aiopg support #39
1 parent 9cb2202 commit 4651671

5 files changed

Lines changed: 316 additions & 2 deletions

File tree

databases/backends/aiopg.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import logging
2+
import typing
3+
import uuid
4+
5+
import aiopg
6+
7+
from sqlalchemy.dialects.postgresql import pypostgresql
8+
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
9+
from sqlalchemy.engine.result import ResultMetaData, RowProxy
10+
from sqlalchemy.sql import ClauseElement
11+
from sqlalchemy.types import TypeEngine
12+
13+
from databases.core import DatabaseURL
14+
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
15+
16+
logger = logging.getLogger("databases")
17+
18+
19+
class AiopgBackend(DatabaseBackend):
20+
def __init__(
21+
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
22+
) -> None:
23+
self._database_url = DatabaseURL(database_url)
24+
self._options = options
25+
self._dialect = self._get_dialect()
26+
self._pool = None
27+
28+
def _get_dialect(self) -> Dialect:
29+
dialect = pypostgresql.dialect(paramstyle="pyformat")
30+
31+
dialect.implicit_returning = True
32+
dialect.supports_native_enum = True
33+
dialect.supports_smallserial = True # 9.2+
34+
dialect._backslash_escapes = False
35+
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
36+
dialect._has_native_hstore = True
37+
dialect.supports_native_decimal = True
38+
39+
return dialect
40+
41+
def _get_connection_kwargs(self) -> dict: # TODO move to `core.py`
42+
url_options = self._database_url.options
43+
44+
kwargs = {}
45+
min_size = url_options.get("min_size")
46+
max_size = url_options.get("max_size")
47+
ssl = url_options.get("ssl")
48+
49+
if min_size is not None:
50+
kwargs["minsize"] = int(min_size)
51+
if max_size is not None:
52+
kwargs["maxsize"] = int(max_size)
53+
if ssl is not None:
54+
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]
55+
56+
for key, value in self._options.items():
57+
# Coerce 'min_size' and 'max_size' for consistency.
58+
if key == "min_size":
59+
key = "minsize"
60+
elif key == "max_size":
61+
key = "maxsize"
62+
kwargs[key] = value
63+
64+
return kwargs
65+
66+
async def connect(self) -> None: # TODO as MySQL one?
67+
assert self._pool is None, "DatabaseBackend is already running"
68+
kwargs = self._get_connection_kwargs()
69+
self._pool = await aiopg.create_pool(
70+
host=self._database_url.hostname,
71+
port=self._database_url.port,
72+
user=self._database_url.username or getpass.getuser(),
73+
password=self._database_url.password,
74+
database=self._database_url.database,
75+
# autocommit=True,
76+
**kwargs,
77+
)
78+
79+
async def disconnect(self) -> None:
80+
assert self._pool is not None, "DatabaseBackend is not running"
81+
self._pool.close()
82+
await self._pool.wait_closed()
83+
self._pool = None
84+
85+
def connection(self) -> "AiopgConnection":
86+
return AiopgConnection(self, self._dialect)
87+
88+
89+
class CompilationContext:
90+
def __init__(self, context: ExecutionContext):
91+
self.context = context
92+
93+
94+
class AiopgConnection(ConnectionBackend):
95+
def __init__(self, database: AiopgBackend, dialect: Dialect):
96+
self._database = database
97+
self._dialect = dialect
98+
self._connection = None # type: typing.Optional[aiopg.Connection]
99+
100+
async def acquire(self) -> None:
101+
assert self._connection is None, "Connection is already acquired"
102+
assert self._database._pool is not None, "DatabaseBackend is not running"
103+
self._connection = await self._database._pool.acquire()
104+
105+
async def release(self) -> None:
106+
assert self._connection is not None, "Connection is not acquired"
107+
assert self._database._pool is not None, "DatabaseBackend is not running"
108+
await self._database._pool.release(self._connection)
109+
self._connection = None
110+
111+
async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
112+
assert self._connection is not None, "Connection is not acquired"
113+
query, args, context = self._compile(query)
114+
cursor = await self._connection.cursor()
115+
# TODO
116+
import pdb; pdb.set_trace()
117+
try:
118+
await cursor.execute(query, args)
119+
rows = await cursor.fetchall()
120+
metadata = ResultMetaData(context, cursor.description)
121+
return [
122+
RowProxy(metadata, row, metadata._processors, metadata._keymap)
123+
for row in rows
124+
]
125+
finally:
126+
cursor.close()
127+
128+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
129+
assert self._connection is not None, "Connection is not acquired"
130+
query, args, context = self._compile(query)
131+
cursor = await self._connection.cursor()
132+
# TODO
133+
import pdb; pdb.set_trace()
134+
try:
135+
await cursor.execute(query, args)
136+
row = await cursor.fetchone()
137+
if row is None:
138+
return None
139+
metadata = ResultMetaData(context, cursor.description)
140+
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
141+
finally:
142+
cursor.close()
143+
144+
async def execute(self, query: ClauseElement) -> typing.Any:
145+
assert self._connection is not None, "Connection is not acquired"
146+
query, args, context = self._compile(query)
147+
cursor = await self._connection.cursor()
148+
# TODO
149+
import pdb; pdb.set_trace()
150+
try:
151+
await cursor.execute(query, args)
152+
return cursor.lastrowid
153+
finally:
154+
cursor.close()
155+
156+
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
157+
assert self._connection is not None, "Connection is not acquired"
158+
cursor = await self._connection.cursor()
159+
# TODO
160+
import pdb; pdb.set_trace()
161+
try:
162+
for single_query in queries:
163+
single_query, args, context = self._compile(single_query)
164+
await cursor.execute(single_query, args)
165+
finally:
166+
cursor.close()
167+
168+
async def iterate(
169+
self, query: ClauseElement
170+
) -> typing.AsyncGenerator[typing.Any, None]:
171+
assert self._connection is not None, "Connection is not acquired"
172+
query, args, context = self._compile(query)
173+
cursor = await self._connection.cursor()
174+
# TODO
175+
import pdb; pdb.set_trace()
176+
try:
177+
await cursor.execute(query, args)
178+
metadata = ResultMetaData(context, cursor.description)
179+
async for row in cursor:
180+
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
181+
finally:
182+
cursor.close()
183+
184+
def transaction(self) -> TransactionBackend:
185+
return AiopgTransaction(self)
186+
187+
def _compile(
188+
self, query: ClauseElement
189+
) -> typing.Tuple[str, dict, CompilationContext]:
190+
compiled = query.compile(dialect=self._dialect)
191+
args = compiled.construct_params()
192+
for key, val in args.items():
193+
if key in compiled._bind_processors:
194+
args[key] = compiled._bind_processors[key](val)
195+
196+
execution_context = self._dialect.execution_ctx_cls()
197+
execution_context.dialect = self._dialect
198+
execution_context.result_column_struct = (
199+
compiled._result_columns,
200+
compiled._ordered_columns,
201+
compiled._textual_ordered_columns,
202+
)
203+
204+
logger.debug("Query: %s\nArgs: %s", compiled.string, args)
205+
return compiled.string, args, CompilationContext(execution_context)
206+
207+
@property
208+
def raw_connection(self) -> aiopg.connection.Connection:
209+
assert self._connection is not None, "Connection is not acquired"
210+
return self._connection
211+
212+
213+
class AiopgTransaction(TransactionBackend):
214+
def __init__(self, connection: AiopgConnection):
215+
self._connection = connection
216+
self._is_root = False
217+
self._savepoint_name = ""
218+
219+
async def start(self, is_root: bool) -> None:
220+
import pdb; pdb.set_trace()
221+
assert self._connection._connection is not None, "Connection is not acquired"
222+
self._is_root = is_root
223+
cursor = await self._connection._connection.cursor()
224+
if self._is_root:
225+
# await self._connection._connection.begin()
226+
await cursor.execute("BEGIN")
227+
else:
228+
id = str(uuid.uuid4()).replace("-", "_")
229+
self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
230+
# cursor = await self._connection._connection.cursor()
231+
try:
232+
await cursor.execute(f"SAVEPOINT {self._savepoint_name}")
233+
finally:
234+
cursor.close()
235+
236+
async def commit(self) -> None:
237+
assert self._connection._connection is not None, "Connection is not acquired"
238+
cursor = await self._connection._connection.cursor()
239+
if self._is_root:
240+
# await self._connection._connection.commit()
241+
await cursor.execute("COMMIT")
242+
else:
243+
# cursor = await self._connection._connection.cursor()
244+
try:
245+
await cursor.execute(f"RELEASE SAVEPOINT {self._savepoint_name}")
246+
finally:
247+
cursor.close()
248+
249+
async def rollback(self) -> None:
250+
assert self._connection._connection is not None, "Connection is not acquired"
251+
cursor = await self._connection._connection.cursor()
252+
if self._is_root:
253+
# await self._connection._connection.rollback()
254+
await cursor.execute("ROLLBACK")
255+
else:
256+
# cursor = await self._connection._connection.cursor()
257+
try:
258+
await cursor.execute(f"ROLLBACK TO SAVEPOINT {self._savepoint_name}")
259+
finally:
260+
cursor.close()

databases/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,17 @@
1818

1919

2020
class Database:
21+
# TODO Nested schema?
22+
# {
23+
# "postgresql": {
24+
# "asyncpg": "...", # Default
25+
# "aiopg": "..."
26+
# }
27+
# }
2128
SUPPORTED_BACKENDS = {
29+
# TODO `postgresql+asyncpg`?
2230
"postgresql": "databases.backends.postgres:PostgresBackend",
31+
"postgresql+psycopg2": "databases.backends.aiopg:AiopgBackend",
2332
"mysql": "databases.backends.mysql:MySQLBackend",
2433
"sqlite": "databases.backends.sqlite:SQLiteBackend",
2534
}
@@ -37,7 +46,7 @@ def __init__(
3746

3847
self._force_rollback = force_rollback
3948

40-
backend_str = self.SUPPORTED_BACKENDS[self.url.dialect]
49+
backend_str = self.SUPPORTED_BACKENDS[self.url.scheme]
4150
backend_cls = import_from_string(backend_str)
4251
assert issubclass(backend_cls, DatabaseBackend)
4352
self._backend = backend_cls(self.url, **self.options)
@@ -330,6 +339,10 @@ def components(self) -> SplitResult:
330339
self._components = urlsplit(self._url)
331340
return self._components
332341

342+
@property
343+
def scheme(self) -> str:
344+
return self.components.scheme
345+
333346
@property
334347
def dialect(self) -> str:
335348
return self.components.scheme.split("+")[0]

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
import pytest
4+
import sqlalchemy
5+
6+
from databases import DatabaseURL
7+
8+
assert "TEST_DATABASE_URLS" in os.environ, "TEST_DATABASE_URLS is not set."
9+
10+
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]
11+
12+
13+
# @pytest.fixture(autouse=True, scope="module")
14+
# def metadata():
15+
# yield sqlalchemy.MetaData()
16+
17+
18+
19+
# @pytest.fixture(autouse=True, scope="module")
20+
# def create_test_database():
21+
# # Create test databases
22+
# import pdb; pdb.set_trace()
23+
# for url in DATABASE_URLS:
24+
# database_url = DatabaseURL(url)
25+
# if database_url.dialect == "mysql":
26+
# url = str(database_url.replace(driver="pymysql"))
27+
# engine = sqlalchemy.create_engine(url)
28+
# metadata.create_all(engine)
29+
30+
# # Run the test suite
31+
# yield
32+
33+
# # Drop test databases
34+
# for url in DATABASE_URLS:
35+
# database_url = DatabaseURL(url)
36+
# if database_url.dialect == "mysql":
37+
# url = str(database_url.replace(driver="pymysql"))
38+
# engine = sqlalchemy.create_engine(url)
39+
# metadata.drop_all(engine)

tests/test_databases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def process_result_value(self, value, dialect):
7171
)
7272

7373

74+
# TODO Move to `conftest.py`
7475
@pytest.fixture(autouse=True, scope="module")
7576
def create_test_database():
76-
# Create test databases
77+
# Create test databases with tables creation
7778
for url in DATABASE_URLS:
7879
database_url = DatabaseURL(url)
7980
if database_url.dialect == "mysql":

tests/test_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424

2525

26+
# TODO Move to `conftest.py` with tables creation
2627
@pytest.fixture(autouse=True, scope="module")
2728
def create_test_database():
2829
# Create test databases

0 commit comments

Comments
 (0)