Skip to content

Commit e12b0af

Browse files
wukathcopybara-github
authored andcommitted
fix: Pass in auth headers with header provider instead of connection params
This allows headers to refresh on each request. Also only add auth headers if no auth_scheme or auth_credential is specified Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 900900323
1 parent 7774a0f commit e12b0af

File tree

2 files changed

+75
-14
lines changed

2 files changed

+75
-14
lines changed

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,6 @@ def get_mcp_toolset(
329329
f"MCP Server endpoint URI not found for: {mcp_server_name}"
330330
)
331331

332-
headers = self._get_auth_headers() if _is_google_api(endpoint_uri) else None
333332
if mcp_server_id and not auth_scheme:
334333
try:
335334
bindings_data = self._make_request("bindings")
@@ -349,13 +348,25 @@ def get_mcp_toolset(
349348

350349
connection_params = StreamableHTTPConnectionParams(
351350
url=endpoint_uri,
352-
headers=headers,
353351
)
352+
353+
def combined_header_provider(context: ReadonlyContext) -> Dict[str, str]:
354+
headers = {}
355+
if (
356+
not auth_scheme
357+
and not auth_credential
358+
and _is_google_api(endpoint_uri)
359+
):
360+
headers.update(self._get_auth_headers())
361+
if self._header_provider:
362+
headers.update(self._header_provider(context))
363+
return headers
364+
354365
return AgentRegistrySingleMcpToolset(
355366
destination_resource_id=mcp_server_id,
356367
connection_params=connection_params,
357368
tool_name_prefix=name,
358-
header_provider=self._header_provider,
369+
header_provider=combined_header_provider,
359370
auth_scheme=auth_scheme,
360371
auth_credential=auth_credential,
361372
)

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -321,16 +321,17 @@ def test_get_endpoint(self, mock_httpx, registry):
321321
assert server == {"name": "test-endpoint"}
322322

323323
@pytest.mark.parametrize(
324-
"url, expected_auth",
324+
"url, expected_auth, use_custom_provider",
325325
[
326-
("https://mcp.com", False),
327-
("https://mcp.googleapis.com/v1", True),
328-
("https://example.com/googleapis/v1", False),
326+
("https://mcp.com", False, False),
327+
("https://mcp.googleapis.com/v1", True, False),
328+
("https://example.com/googleapis/v1", False, False),
329+
("https://mcp.googleapis.com/v1", True, True),
329330
],
330331
)
331332
@patch("httpx.Client")
332333
def test_get_mcp_toolset_auth_headers(
333-
self, mock_httpx, registry, url, expected_auth
334+
self, mock_httpx, registry, url, expected_auth, use_custom_provider
334335
):
335336
mock_response = MagicMock()
336337
mock_response.json.return_value = {
@@ -345,19 +346,34 @@ def test_get_mcp_toolset_auth_headers(
345346
mock_response
346347
)
347348

349+
if use_custom_provider:
350+
custom_header_provider = lambda context: {
351+
"Authorization": "Bearer custom_token"
352+
}
353+
with patch(
354+
"google.auth.default", return_value=(MagicMock(), "project-id")
355+
):
356+
registry = AgentRegistry(
357+
project_id="test-project",
358+
location="global",
359+
header_provider=custom_header_provider,
360+
)
361+
348362
registry._credentials.token = "token"
349363
registry._credentials.refresh = MagicMock()
350364

351365
toolset = registry.get_mcp_toolset("test-mcp")
352366
assert isinstance(toolset, McpToolset)
353367
assert toolset.tool_name_prefix == "TestPrefix"
354-
if expected_auth:
355-
assert toolset._connection_params.headers is not None
356-
assert (
357-
toolset._connection_params.headers["Authorization"] == "Bearer token"
358-
)
368+
assert toolset._connection_params.headers is None
369+
headers = toolset._header_provider(MagicMock())
370+
371+
if use_custom_provider:
372+
assert headers.get("Authorization") == "Bearer custom_token"
373+
elif expected_auth:
374+
assert headers.get("Authorization") == "Bearer token"
359375
else:
360-
assert toolset._connection_params.headers is None
376+
assert "Authorization" not in headers
361377

362378
@patch("httpx.Client")
363379
def test_get_mcp_toolset_with_auth(self, mock_httpx, registry):
@@ -392,6 +408,40 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry):
392408
assert auth_config.auth_scheme == auth_scheme
393409
assert auth_config.raw_auth_credential == auth_credential
394410

411+
@patch("httpx.Client")
412+
def test_get_mcp_toolset_with_auth_blocks_gcp_headers(
413+
self, mock_httpx, registry
414+
):
415+
mock_response = MagicMock()
416+
mock_response.json.return_value = {
417+
"displayName": "TestPrefix",
418+
"interfaces": [{
419+
"url": "https://mcp.googleapis.com/v1",
420+
"protocolBinding": "JSONRPC",
421+
}],
422+
}
423+
mock_response.raise_for_status = MagicMock()
424+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
425+
mock_response
426+
)
427+
428+
registry._credentials.token = "token"
429+
registry._credentials.refresh = MagicMock()
430+
431+
auth_scheme = OAuth2(flows={})
432+
auth_credential = AuthCredential(
433+
auth_type="oauth2",
434+
oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"),
435+
)
436+
437+
toolset = registry.get_mcp_toolset(
438+
"test-mcp", auth_scheme=auth_scheme, auth_credential=auth_credential
439+
)
440+
assert isinstance(toolset, McpToolset)
441+
442+
headers = toolset._header_provider(MagicMock())
443+
assert "Authorization" not in headers
444+
395445
@patch("httpx.Client")
396446
def test_get_remote_a2a_agent(self, mock_httpx, registry):
397447
mock_response = MagicMock()

0 commit comments

Comments
 (0)