|
15 | 15 | """Testings for the SequentialAgent.""" |
16 | 16 |
|
17 | 17 | from typing import AsyncGenerator |
| 18 | +from unittest.mock import patch |
18 | 19 |
|
19 | 20 | from google.adk.agents.base_agent import BaseAgent |
20 | 21 | from google.adk.agents.invocation_context import InvocationContext |
@@ -249,3 +250,38 @@ async def test_run_async_with_escalate_action( |
249 | 250 | ), |
250 | 251 | ] |
251 | 252 | assert simplified_events == expected_events |
| 253 | + |
| 254 | + |
| 255 | +@pytest.mark.asyncio |
| 256 | +async def test_run_async_with_pause_preserves_sub_agent_state( |
| 257 | + request: pytest.FixtureRequest, |
| 258 | +): |
| 259 | + """Test that the sub-agent state is preserved when the loop agent pauses.""" |
| 260 | + agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') |
| 261 | + loop_agent = LoopAgent( |
| 262 | + name=f'{request.function.__name__}_test_loop_agent', |
| 263 | + max_iterations=2, |
| 264 | + sub_agents=[agent], |
| 265 | + ) |
| 266 | + parent_ctx = await _create_parent_invocation_context( |
| 267 | + request.function.__name__, loop_agent, resumable=True |
| 268 | + ) |
| 269 | + |
| 270 | + # Set some dummy state for the sub-agent |
| 271 | + parent_ctx.agent_states[agent.name] = {'some_key': 'some_value'} |
| 272 | + |
| 273 | + # Mock should_pause_invocation to return True for the agent's event |
| 274 | + def mock_should_pause(event): |
| 275 | + return event.author == agent.name |
| 276 | + |
| 277 | + with patch.object( |
| 278 | + InvocationContext, |
| 279 | + 'should_pause_invocation', |
| 280 | + side_effect=mock_should_pause, |
| 281 | + ): |
| 282 | + async for _ in loop_agent.run_async(parent_ctx): |
| 283 | + pass # Consume the async generator |
| 284 | + |
| 285 | + # Verify that the sub-agent state was NOT reset |
| 286 | + assert agent.name in parent_ctx.agent_states |
| 287 | + assert parent_ctx.agent_states[agent.name] == {'some_key': 'some_value'} |
0 commit comments