Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 1d4896f

Browse files
committed
test: add new tests for upcoming contextvar inheritance/isolation and weakref cleanup
1 parent 8370299 commit 1d4896f

File tree

1 file changed

+208
-1
lines changed

1 file changed

+208
-1
lines changed

tests/test_databases.py

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import itertools
66
import os
77
import re
8+
import gc
89
from unittest.mock import MagicMock, patch
9-
10+
from typing import MutableMapping
1011
import pytest
1112
import sqlalchemy
1213

@@ -478,6 +479,212 @@ async def test_transaction_commit(database_url):
478479
assert len(results) == 1
479480

480481

482+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
483+
@async_adapter
484+
async def test_transaction_context_child_task_interaction(database_url):
485+
"""
486+
Ensure that child tasks may influence inherited transactions.
487+
"""
488+
# This is an practical example of the next test.
489+
async with Database(database_url) as database:
490+
async with database.transaction():
491+
# Create a note
492+
await database.execute(
493+
notes.insert().values(id=1, text="setup", completed=True)
494+
)
495+
496+
# Change the note from the same task
497+
await database.execute(
498+
notes.update().where(notes.c.id == 1).values(text="prior")
499+
)
500+
501+
# Confirm the change
502+
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
503+
assert result.text == "prior"
504+
505+
async def run_update_from_child_task():
506+
# Chage the note from a child task
507+
await database.execute(
508+
notes.update().where(notes.c.id == 1).values(text="test")
509+
)
510+
511+
await asyncio.create_task(run_update_from_child_task())
512+
513+
# Confirm the child's change
514+
result = await database.fetch_one(notes.select().where(notes.c.id == 1))
515+
assert result.text == "test"
516+
517+
518+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
519+
@async_adapter
520+
async def test_transaction_context_child_task_inheritance(database_url):
521+
"""
522+
Ensure that transactions are inherited by child tasks.
523+
"""
524+
async with Database(database_url) as database:
525+
526+
async def check_transaction(transaction, active_transaction):
527+
# Should have inherited the same transaction backend from the parent task
528+
assert transaction._transaction is active_transaction
529+
530+
async with database.transaction() as transaction:
531+
await asyncio.create_task(
532+
check_transaction(transaction, transaction._transaction)
533+
)
534+
535+
536+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
537+
@async_adapter
538+
async def test_transaction_context_sibling_task_isolation(database_url):
539+
"""
540+
Ensure that transactions are isolated between sibling tasks.
541+
"""
542+
start = asyncio.Event()
543+
end = asyncio.Event()
544+
545+
async with Database(database_url) as database:
546+
547+
async def check_transaction(transaction):
548+
await start.wait()
549+
# Parent task is now in a transaction, we should not
550+
# see its transaction backend since this task was
551+
# _started_ in a context where no transaction was active.
552+
assert transaction._transaction is None
553+
end.set()
554+
555+
transaction = database.transaction()
556+
assert transaction._transaction is None
557+
task = asyncio.create_task(check_transaction(transaction))
558+
559+
async with transaction:
560+
start.set()
561+
assert transaction._transaction is not None
562+
await end.wait()
563+
564+
# Cleanup for "Task not awaited" warning
565+
await task
566+
567+
568+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
569+
@async_adapter
570+
async def test_connection_context_cleanup_contextmanager(database_url):
571+
"""
572+
Ensure that contextvar connections are not persisted unecessarily.
573+
"""
574+
from databases.core import _ACTIVE_CONNECTIONS
575+
576+
assert _ACTIVE_CONNECTIONS.get() is None
577+
578+
async with Database(database_url) as database:
579+
# .connect is lazy, it doesn't create a Connection, but .connection does
580+
connection = database.connection()
581+
582+
open_connections = _ACTIVE_CONNECTIONS.get()
583+
assert isinstance(open_connections, MutableMapping)
584+
assert open_connections.get(database) is connection
585+
586+
# Context manager closes, open_connections is cleaned up
587+
open_connections = _ACTIVE_CONNECTIONS.get()
588+
assert isinstance(open_connections, MutableMapping)
589+
assert open_connections.get(database, None) is None
590+
591+
592+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
593+
@async_adapter
594+
async def test_connection_context_cleanup_garbagecollector(database_url):
595+
"""
596+
Ensure that contextvar connections are not persisted unecessarily, even
597+
if exit handlers are not called.
598+
"""
599+
from databases.core import _ACTIVE_CONNECTIONS
600+
601+
assert _ACTIVE_CONNECTIONS.get() is None
602+
603+
database = Database(database_url)
604+
await database.connect()
605+
connection = database.connection()
606+
607+
# Should be tracking the connection
608+
open_connections = _ACTIVE_CONNECTIONS.get()
609+
assert isinstance(open_connections, MutableMapping)
610+
assert open_connections.get(database) is connection
611+
612+
# neither .disconnect nor .__aexit__ are called before deleting the reference
613+
del database
614+
gc.collect()
615+
616+
# Should have dropped reference to connection, even without proper cleanup
617+
open_connections = _ACTIVE_CONNECTIONS.get()
618+
assert isinstance(open_connections, MutableMapping)
619+
assert len(open_connections) == 0
620+
621+
622+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
623+
@async_adapter
624+
async def test_transaction_context_cleanup_contextmanager(database_url):
625+
"""
626+
Ensure that contextvar transactions are not persisted unecessarily.
627+
"""
628+
from databases.core import _ACTIVE_TRANSACTIONS
629+
630+
assert _ACTIVE_TRANSACTIONS.get() is None
631+
632+
async with Database(database_url) as database:
633+
async with database.transaction() as transaction:
634+
635+
open_transactions = _ACTIVE_TRANSACTIONS.get()
636+
assert isinstance(open_transactions, MutableMapping)
637+
assert open_transactions.get(transaction) is transaction._transaction
638+
639+
# Context manager closes, open_transactions is cleaned up
640+
open_transactions = _ACTIVE_TRANSACTIONS.get()
641+
assert isinstance(open_transactions, MutableMapping)
642+
assert open_transactions.get(transaction, None) is None
643+
644+
645+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
646+
@async_adapter
647+
async def test_transaction_context_cleanup_garbagecollector(database_url):
648+
"""
649+
Ensure that contextvar transactions are not persisted unecessarily, even
650+
if exit handlers are not called.
651+
652+
This test should be an XFAIL, but cannot be due to the way that is hangs
653+
during teardown.
654+
"""
655+
from databases.core import _ACTIVE_TRANSACTIONS
656+
657+
assert _ACTIVE_TRANSACTIONS.get() is None
658+
659+
async with Database(database_url) as database:
660+
transaction = database.transaction()
661+
await transaction.start()
662+
663+
# Should be tracking the transaction
664+
open_transactions = _ACTIVE_TRANSACTIONS.get()
665+
assert isinstance(open_transactions, MutableMapping)
666+
assert open_transactions.get(transaction) is transaction._transaction
667+
668+
# neither .commit, .rollback, nor .__aexit__ are called
669+
del transaction
670+
gc.collect()
671+
672+
# TODO(zevisert,review): Could skip instead of using the logic below
673+
# A strong reference to the transaction is kept alive by the connection's
674+
# ._transaction_stack, so it is still be tracked at this point.
675+
assert len(open_transactions) == 1
676+
677+
# If that were magically cleared, the transaction would be cleaned up,
678+
# but as it stands this always causes a hang during teardown at
679+
# `Database(...).disconnect()` if the transaction is not closed.
680+
transaction = database.connection()._transaction_stack[-1]
681+
await transaction.rollback()
682+
del transaction
683+
684+
# Now with the transaction rolled-back, it should be cleaned up.
685+
assert len(open_transactions) == 0
686+
687+
481688
@pytest.mark.parametrize("database_url", DATABASE_URLS)
482689
@async_adapter
483690
async def test_transaction_commit_serializable(database_url):

0 commit comments

Comments
 (0)