|
15 | 15 | from aiocontextvars import ContextVar |
16 | 16 |
|
17 | 17 |
|
18 | | -class DatabaseURL: |
19 | | - def __init__(self, url: typing.Union[str, "DatabaseURL"]): |
20 | | - if isinstance(url, DatabaseURL): |
21 | | - self._url = str(url) |
22 | | - else: |
23 | | - self._url = url |
24 | | - |
25 | | - @property |
26 | | - def components(self) -> SplitResult: |
27 | | - if not hasattr(self, "_components"): |
28 | | - self._components = urlsplit(self._url) |
29 | | - return self._components |
30 | | - |
31 | | - @property |
32 | | - def dialect(self) -> str: |
33 | | - return self.components.scheme.split("+")[0] |
34 | | - |
35 | | - @property |
36 | | - def driver(self) -> str: |
37 | | - if "+" not in self.components.scheme: |
38 | | - return "" |
39 | | - return self.components.scheme.split("+", 1)[1] |
40 | | - |
41 | | - @property |
42 | | - def username(self) -> typing.Optional[str]: |
43 | | - return self.components.username |
44 | | - |
45 | | - @property |
46 | | - def password(self) -> typing.Optional[str]: |
47 | | - return self.components.password |
48 | | - |
49 | | - @property |
50 | | - def hostname(self) -> typing.Optional[str]: |
51 | | - return self.components.hostname |
52 | | - |
53 | | - @property |
54 | | - def port(self) -> typing.Optional[int]: |
55 | | - return self.components.port |
56 | | - |
57 | | - @property |
58 | | - def database(self) -> str: |
59 | | - return self.components.path.lstrip("/") |
60 | | - |
61 | | - def replace(self, **kwargs: typing.Any) -> "DatabaseURL": |
62 | | - if ( |
63 | | - "username" in kwargs |
64 | | - or "password" in kwargs |
65 | | - or "hostname" in kwargs |
66 | | - or "port" in kwargs |
67 | | - ): |
68 | | - hostname = kwargs.pop("hostname", self.hostname) |
69 | | - port = kwargs.pop("port", self.port) |
70 | | - username = kwargs.pop("username", self.username) |
71 | | - password = kwargs.pop("password", self.password) |
72 | | - |
73 | | - netloc = hostname |
74 | | - if port is not None: |
75 | | - netloc += f":{port}" |
76 | | - if username is not None: |
77 | | - userpass = username |
78 | | - if password is not None: |
79 | | - userpass += f":{password}" |
80 | | - netloc = f"{userpass}@{netloc}" |
81 | | - |
82 | | - kwargs["netloc"] = netloc |
83 | | - |
84 | | - if "database" in kwargs: |
85 | | - kwargs["path"] = "/" + kwargs.pop("database") |
86 | | - |
87 | | - if "dialect" in kwargs or "driver" in kwargs: |
88 | | - dialect = kwargs.pop("dialect", self.dialect) |
89 | | - driver = kwargs.pop("driver", self.driver) |
90 | | - kwargs["scheme"] = f"{dialect}+{driver}" if driver else dialect |
91 | | - |
92 | | - components = self.components._replace(**kwargs) |
93 | | - return self.__class__(components.geturl()) |
94 | | - |
95 | | - def __str__(self) -> str: |
96 | | - return self._url |
97 | | - |
98 | | - def __repr__(self) -> str: |
99 | | - url = str(self) |
100 | | - if self.password: |
101 | | - url = str(self.replace(password="********")) |
102 | | - return f"{self.__class__.__name__}({repr(url)})" |
103 | | - |
104 | | - |
105 | 18 | class Database: |
106 | 19 | SUPPORTED_BACKENDS = { |
107 | 20 | "postgresql": "databases.backends.postgres:PostgresBackend", |
108 | 21 | "mysql": "databases.backends.mysql:MySQLBackend", |
109 | 22 | } |
110 | 23 |
|
111 | | - def __init__(self, url: typing.Union[str, DatabaseURL]): |
112 | | - self.url = DatabaseURL(url) |
| 24 | + def __init__( |
| 25 | + self, url: typing.Union[str, "DatabaseURL"], force_rollback: bool = False |
| 26 | + ): |
| 27 | + self._url = DatabaseURL(url) |
| 28 | + self._force_rollback = force_rollback |
| 29 | + |
113 | 30 | self.is_connected = False |
114 | 31 |
|
115 | | - backend_str = self.SUPPORTED_BACKENDS[self.url.dialect] |
| 32 | + backend_str = self.SUPPORTED_BACKENDS[self._url.dialect] |
116 | 33 | backend_cls = import_from_string(backend_str) |
117 | 34 | assert issubclass(backend_cls, DatabaseBackend) |
118 | | - self._backend = backend_cls(self.url) |
| 35 | + self._backend = backend_cls(self._url) |
| 36 | + |
| 37 | + # Connections are stored as task-local state. |
119 | 38 | self._connection_context = ContextVar("connection_context") # type: ContextVar |
120 | 39 |
|
| 40 | + # When `force_rollback=True` is used, we use a single global |
| 41 | + # connection, within a transaction that always rolls back. |
| 42 | + self._global_connection = None # type: typing.Optional[Connection] |
| 43 | + self._global_transaction = None # type: typing.Optional[Transaction] |
| 44 | + |
121 | 45 | async def connect(self) -> None: |
122 | | - if not self.is_connected: |
123 | | - await self._backend.connect() |
124 | | - self.is_connected = True |
| 46 | + assert not self.is_connected, "Already connected." |
| 47 | + |
| 48 | + await self._backend.connect() |
| 49 | + self.is_connected = True |
| 50 | + |
| 51 | + if self._force_rollback: |
| 52 | + self._global_connection = Connection(self._backend) |
| 53 | + self._global_transaction = self._global_connection.transaction( |
| 54 | + force_rollback=True |
| 55 | + ) |
| 56 | + await self._global_transaction.__aenter__() |
125 | 57 |
|
126 | 58 | async def disconnect(self) -> None: |
127 | | - if self.is_connected: |
128 | | - await self._backend.disconnect() |
129 | | - self.is_connected = False |
| 59 | + assert self.is_connected, "Already disconnected." |
| 60 | + |
| 61 | + if self._force_rollback: |
| 62 | + assert self._global_transaction is not None |
| 63 | + await self._global_transaction.__aexit__() |
| 64 | + self._global_transaction = None |
| 65 | + self._global_connection = None |
| 66 | + |
| 67 | + await self._backend.disconnect() |
| 68 | + self.is_connected = False |
130 | 69 |
|
131 | 70 | async def __aenter__(self) -> "Database": |
132 | 71 | await self.connect() |
@@ -164,6 +103,9 @@ async def iterate( |
164 | 103 | yield record |
165 | 104 |
|
166 | 105 | def connection(self) -> "Connection": |
| 106 | + if self._global_connection is not None: |
| 107 | + return self._global_connection |
| 108 | + |
167 | 109 | try: |
168 | 110 | return self._connection_context.get() |
169 | 111 | except LookupError: |
@@ -267,3 +209,90 @@ async def rollback(self) -> None: |
267 | 209 | self._connection._transaction_stack.pop() |
268 | 210 | await self._transaction.rollback() |
269 | 211 | await self._connection.__aexit__() |
| 212 | + |
| 213 | + |
| 214 | +class DatabaseURL: |
| 215 | + def __init__(self, url: typing.Union[str, "DatabaseURL"]): |
| 216 | + if isinstance(url, DatabaseURL): |
| 217 | + self._url = str(url) |
| 218 | + else: |
| 219 | + self._url = url |
| 220 | + |
| 221 | + @property |
| 222 | + def components(self) -> SplitResult: |
| 223 | + if not hasattr(self, "_components"): |
| 224 | + self._components = urlsplit(self._url) |
| 225 | + return self._components |
| 226 | + |
| 227 | + @property |
| 228 | + def dialect(self) -> str: |
| 229 | + return self.components.scheme.split("+")[0] |
| 230 | + |
| 231 | + @property |
| 232 | + def driver(self) -> str: |
| 233 | + if "+" not in self.components.scheme: |
| 234 | + return "" |
| 235 | + return self.components.scheme.split("+", 1)[1] |
| 236 | + |
| 237 | + @property |
| 238 | + def username(self) -> typing.Optional[str]: |
| 239 | + return self.components.username |
| 240 | + |
| 241 | + @property |
| 242 | + def password(self) -> typing.Optional[str]: |
| 243 | + return self.components.password |
| 244 | + |
| 245 | + @property |
| 246 | + def hostname(self) -> typing.Optional[str]: |
| 247 | + return self.components.hostname |
| 248 | + |
| 249 | + @property |
| 250 | + def port(self) -> typing.Optional[int]: |
| 251 | + return self.components.port |
| 252 | + |
| 253 | + @property |
| 254 | + def database(self) -> str: |
| 255 | + return self.components.path.lstrip("/") |
| 256 | + |
| 257 | + def replace(self, **kwargs: typing.Any) -> "DatabaseURL": |
| 258 | + if ( |
| 259 | + "username" in kwargs |
| 260 | + or "password" in kwargs |
| 261 | + or "hostname" in kwargs |
| 262 | + or "port" in kwargs |
| 263 | + ): |
| 264 | + hostname = kwargs.pop("hostname", self.hostname) |
| 265 | + port = kwargs.pop("port", self.port) |
| 266 | + username = kwargs.pop("username", self.username) |
| 267 | + password = kwargs.pop("password", self.password) |
| 268 | + |
| 269 | + netloc = hostname |
| 270 | + if port is not None: |
| 271 | + netloc += f":{port}" |
| 272 | + if username is not None: |
| 273 | + userpass = username |
| 274 | + if password is not None: |
| 275 | + userpass += f":{password}" |
| 276 | + netloc = f"{userpass}@{netloc}" |
| 277 | + |
| 278 | + kwargs["netloc"] = netloc |
| 279 | + |
| 280 | + if "database" in kwargs: |
| 281 | + kwargs["path"] = "/" + kwargs.pop("database") |
| 282 | + |
| 283 | + if "dialect" in kwargs or "driver" in kwargs: |
| 284 | + dialect = kwargs.pop("dialect", self.dialect) |
| 285 | + driver = kwargs.pop("driver", self.driver) |
| 286 | + kwargs["scheme"] = f"{dialect}+{driver}" if driver else dialect |
| 287 | + |
| 288 | + components = self.components._replace(**kwargs) |
| 289 | + return self.__class__(components.geturl()) |
| 290 | + |
| 291 | + def __str__(self) -> str: |
| 292 | + return self._url |
| 293 | + |
| 294 | + def __repr__(self) -> str: |
| 295 | + url = str(self) |
| 296 | + if self.password: |
| 297 | + url = str(self.replace(password="********")) |
| 298 | + return f"{self.__class__.__name__}({repr(url)})" |
0 commit comments