1- from typing import Type , Optional , Union , Dict
1+ from typing import Hashable , MutableMapping , Type , Optional , Union , Dict
22from itertools import zip_longest
33from contextlib import suppress
4+ import weakref
45import dsnparse
56import toml
67
78from runtype import dataclass
89from typing_extensions import Self
910
1011from ..abcs .mixins import AbstractMixin
11- from ..utils import WeakCache
1212from .base import Database , ThreadedDatabase
1313from .postgresql import PostgreSQL
1414from .mysql import MySQL
@@ -94,11 +94,12 @@ def match_path(self, dsn):
9494
9595class Connect :
9696 """Provides methods for connecting to a supported database using a URL or connection dict."""
97+ conn_cache : MutableMapping [Hashable , Database ]
9798
9899 def __init__ (self , database_by_scheme : Dict [str , Database ] = DATABASE_BY_SCHEME ):
99100 self .database_by_scheme = database_by_scheme
100101 self .match_uri_path = {name : MatchUriPath (cls ) for name , cls in database_by_scheme .items ()}
101- self .conn_cache = WeakCache ()
102+ self .conn_cache = weakref . WeakValueDictionary ()
102103
103104 def for_databases (self , * dbs ) -> Self :
104105 database_by_scheme = {k : db for k , db in self .database_by_scheme .items () if k in dbs }
@@ -263,9 +264,10 @@ def __call__(
263264 >>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
264265 <data_diff.sqeleton.databases.mysql.MySQL object at ...>
265266 """
267+ cache_key = self .__make_cache_key (db_conf )
266268 if shared :
267269 with suppress (KeyError ):
268- conn = self .conn_cache . get ( db_conf )
270+ conn = self .conn_cache [ cache_key ]
269271 if not conn .is_closed :
270272 return conn
271273
@@ -277,5 +279,10 @@ def __call__(
277279 raise TypeError (f"db configuration must be a URI string or a dictionary. Instead got '{ db_conf } '." )
278280
279281 if shared :
280- self .conn_cache . add ( db_conf , conn )
282+ self .conn_cache [ cache_key ] = conn
281283 return conn
284+
285+ def __make_cache_key (self , db_conf : Union [str , dict ]) -> Hashable :
286+ if isinstance (db_conf , dict ):
287+ return tuple (db_conf .items ())
288+ return db_conf
0 commit comments