@@ -360,3 +360,44 @@ async def test_connect_and_disconnect(database_url):
360360 assert database .is_connected
361361 await database .disconnect ()
362362 assert not database .is_connected
363+
364+
365+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
366+ @async_adapter
367+ async def test_connection_context (database_url ):
368+ """
369+ Test connection contexts are task-local.
370+ """
371+ async with Database (database_url ) as database :
372+ async with database .connection () as connection_1 :
373+ async with database .connection () as connection_2 :
374+ assert connection_1 is connection_2
375+
376+ async with Database (database_url ) as database :
377+ connection_1 = None
378+ connection_2 = None
379+ test_complete = asyncio .Event ()
380+
381+ async def get_connection_1 ():
382+ nonlocal connection_1
383+
384+ async with database .connection () as connection :
385+ connection_1 = connection
386+ await test_complete .wait ()
387+
388+ async def get_connection_2 ():
389+ nonlocal connection_2
390+
391+ async with database .connection () as connection :
392+ connection_2 = connection
393+ await test_complete .wait ()
394+
395+ loop = asyncio .get_event_loop ()
396+ task_1 = loop .create_task (get_connection_1 ())
397+ task_2 = loop .create_task (get_connection_2 ())
398+ while connection_1 is None or connection_2 is None :
399+ await asyncio .sleep (0.000001 )
400+ assert connection_1 is not connection_2
401+ test_complete .set ()
402+ await task_1
403+ await task_2
0 commit comments