Skip to content

Commit 83393ab

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Propagate context to thread pools
PiperOrigin-RevId: 896063584
1 parent 6a1c90b commit 83393ab

3 files changed

Lines changed: 60 additions & 3 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import base64
2121
import binascii
2222
from concurrent.futures import ThreadPoolExecutor
23+
import contextvars
2324
import copy
2425
import functools
2526
import inspect
@@ -140,6 +141,7 @@ async def _call_tool_in_thread_pool(
140141
"""
141142
from ...tools.function_tool import FunctionTool
142143

144+
ctx = contextvars.copy_context()
143145
loop = asyncio.get_running_loop()
144146
executor = _get_tool_thread_pool(max_workers)
145147

@@ -160,7 +162,9 @@ def run_sync_tool():
160162
# For other sync tool types, we can't easily run them in thread pool
161163
return None
162164

163-
result = await loop.run_in_executor(executor, run_sync_tool)
165+
result = await loop.run_in_executor(
166+
executor, lambda: ctx.run(run_sync_tool)
167+
)
164168
if result is not None:
165169
return result
166170
else:
@@ -171,7 +175,9 @@ def run_async_tool_in_new_loop():
171175
# Create a new event loop for this thread
172176
return asyncio.run(tool.run_async(args=args, tool_context=tool_context))
173177

174-
return await loop.run_in_executor(executor, run_async_tool_in_new_loop)
178+
return await loop.run_in_executor(
179+
executor, lambda: ctx.run(run_async_tool_in_new_loop)
180+
)
175181

176182
# Fall back to normal async execution for non-FunctionTool sync tools
177183
return await tool.run_async(args=args, tool_context=tool_context)

src/google/adk/optimization/gepa_root_agent_prompt_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
import contextvars
1819
import logging
1920
from typing import Any
2021
from typing import Optional
@@ -298,7 +299,8 @@ def run_gepa():
298299

299300
_logger.info("Running the GEPA optimizer...")
300301

301-
gepa_results = await loop.run_in_executor(None, run_gepa)
302+
ctx = contextvars.copy_context()
303+
gepa_results = await loop.run_in_executor(None, lambda: ctx.run(run_gepa))
302304

303305
_logger.info("GEPA optimization finished. Preparing final results...")
304306

tests/unittests/flows/llm_flows/test_functions_thread_pool.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests for thread pool execution of tools in Live API mode."""
1616

1717
import asyncio
18+
import contextvars
1819
import threading
1920
import time
2021

@@ -349,6 +350,54 @@ def sync_func() -> dict:
349350
pool = _get_tool_thread_pool(max_workers=12)
350351
assert pool is not None
351352

353+
@pytest.mark.asyncio
354+
async def test_contextvars_propagation_sync_tool(self):
355+
"""Test that contextvars propagate to sync tools in thread pool."""
356+
test_var = contextvars.ContextVar('test_var', default='default')
357+
test_var.set('main_thread_value')
358+
359+
def sync_func() -> dict[str, str]:
360+
return {'value': test_var.get()}
361+
362+
tool = FunctionTool(sync_func)
363+
model = testing_utils.MockModel.create(responses=[])
364+
agent = Agent(name='test_agent', model=model, tools=[tool])
365+
invocation_context = await testing_utils.create_invocation_context(
366+
agent=agent, user_content=''
367+
)
368+
tool_context = ToolContext(
369+
invocation_context=invocation_context,
370+
function_call_id='test_id',
371+
)
372+
373+
result = await _call_tool_in_thread_pool(tool, {}, tool_context)
374+
375+
assert result == {'value': 'main_thread_value'}
376+
377+
@pytest.mark.asyncio
378+
async def test_contextvars_propagation_async_tool(self):
379+
"""Test that contextvars propagate to async tools in thread pool."""
380+
test_var = contextvars.ContextVar('test_var', default='default')
381+
test_var.set('main_thread_value')
382+
383+
async def async_func() -> dict[str, str]:
384+
return {'value': test_var.get()}
385+
386+
tool = FunctionTool(async_func)
387+
model = testing_utils.MockModel.create(responses=[])
388+
agent = Agent(name='test_agent', model=model, tools=[tool])
389+
invocation_context = await testing_utils.create_invocation_context(
390+
agent=agent, user_content=''
391+
)
392+
tool_context = ToolContext(
393+
invocation_context=invocation_context,
394+
function_call_id='test_id',
395+
)
396+
397+
result = await _call_tool_in_thread_pool(tool, {}, tool_context)
398+
399+
assert result == {'value': 'main_thread_value'}
400+
352401

353402
class TestToolThreadPoolConfig:
354403
"""Tests for the tool_thread_pool_config in RunConfig."""

0 commit comments

Comments
 (0)