Skip to content

Commit f2c68eb

Browse files
wukathcopybara-github
authored andcommitted
feat: Add Auth Provider support to agent registry
Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 899104727
1 parent 547766a commit f2c68eb

2 files changed

Lines changed: 80 additions & 1 deletion

File tree

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
3737
from google.adk.auth.auth_credential import AuthCredential
3838
from google.adk.auth.auth_schemes import AuthScheme
39+
from google.adk.integrations.agent_identity.gcp_auth_provider_scheme import GcpAuthProviderScheme
3940
from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID
4041
from google.adk.tools.base_tool import BaseTool
4142
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
@@ -292,8 +293,24 @@ def get_mcp_toolset(
292293
mcp_server_name: str,
293294
auth_scheme: AuthScheme | None = None,
294295
auth_credential: AuthCredential | None = None,
296+
*,
297+
continue_uri: str | None = None,
295298
) -> McpToolset:
296-
"""Constructs an McpToolset instance from a registered MCP Server."""
299+
"""Constructs an McpToolset from a registered MCP Server.
300+
301+
If `auth_scheme` is omitted, it is automatically resolved from the server's
302+
IAM bindings via `GcpAuthProviderScheme`.
303+
304+
Args:
305+
mcp_server_name: Resource name of the MCP Server.
306+
auth_scheme: Optional auth scheme. Resolved via bindings if omitted.
307+
auth_credential: Optional auth credential.
308+
continue_uri: Optional continue URI to override what is in the auth
309+
provider.
310+
311+
Returns:
312+
An McpToolset for the MCP server.
313+
"""
297314
server_details = self.get_mcp_server(mcp_server_name)
298315
name = self._clean_name(server_details.get("displayName", mcp_server_name))
299316
mcp_server_id = server_details.get("mcpServerId")
@@ -313,6 +330,23 @@ def get_mcp_toolset(
313330
)
314331

315332
headers = self._get_auth_headers() if _is_google_api(endpoint_uri) else None
333+
if mcp_server_id and not auth_scheme:
334+
try:
335+
bindings_data = self._make_request("bindings")
336+
for b in bindings_data.get("bindings", []):
337+
target_id = b.get("target", {}).get("identifier", "")
338+
if target_id.endswith(mcp_server_id):
339+
auth_provider = b.get("authProviderBinding", {}).get("authProvider")
340+
if auth_provider:
341+
auth_scheme = GcpAuthProviderScheme(
342+
name=auth_provider, continue_uri=continue_uri
343+
)
344+
break
345+
except Exception as e:
346+
logger.warning(
347+
f"Failed to fetch bindings for MCP Server {mcp_server_name}: {e}"
348+
)
349+
316350
connection_params = StreamableHTTPConnectionParams(
317351
url=endpoint_uri,
318352
headers=headers,

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,48 @@ def test_get_model_name_raises_value_error_if_no_uri(
637637
mock_get_endpoint.return_value = {}
638638
with pytest.raises(ValueError, match="Connection URI not found"):
639639
registry.get_model_name("test-endpoint")
640+
641+
@patch.object(AgentRegistry, "_make_request")
642+
def test_get_mcp_toolset_with_binding(self, mock_make_request, registry):
643+
def side_effect(*args, **kwargs):
644+
if args[0] == "test-mcp":
645+
return {
646+
"displayName": "TestPrefix",
647+
"mcpServerId": "server-456",
648+
"interfaces": [{
649+
"url": "https://mcp.com",
650+
"protocolBinding": "JSONRPC",
651+
}],
652+
}
653+
if args[0] == "bindings":
654+
return {
655+
"bindings": [{
656+
"target": {
657+
"identifier": (
658+
"urn:mcp:projects-123:projects:123:locations:l:mcpServers:server-456"
659+
)
660+
},
661+
"authProviderBinding": {
662+
"authProvider": (
663+
"projects/123/locations/l/authProviders/ap-789"
664+
)
665+
},
666+
}]
667+
}
668+
return {}
669+
670+
mock_make_request.side_effect = side_effect
671+
672+
registry._credentials.token = "token"
673+
registry._credentials.refresh = MagicMock()
674+
675+
toolset = registry.get_mcp_toolset(
676+
"test-mcp", continue_uri="https://override.com/continue"
677+
)
678+
assert isinstance(toolset, McpToolset)
679+
assert toolset._auth_scheme is not None
680+
assert (
681+
toolset._auth_scheme.name
682+
== "projects/123/locations/l/authProviders/ap-789"
683+
)
684+
assert toolset._auth_scheme.continue_uri == "https://override.com/continue"

0 commit comments

Comments
 (0)