Skip to content

Commit 0f3850f

Browse files
wukathcopybara-github
authored andcommitted
fix: Fixes for initializing RemoteA2aAgent - passing in preferred transport, protocol version, and auth headers
Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 896068855
1 parent 3a374ce commit 0f3850f

2 files changed

Lines changed: 140 additions & 24 deletions

File tree

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
from collections.abc import Generator
1920
from enum import Enum
2021
import logging
2122
import os
@@ -55,6 +56,12 @@
5556

5657
AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha"
5758

59+
_TRANSPORT_MAPPING = {
60+
"HTTP_JSON": A2ATransport.http_json,
61+
"JSONRPC": A2ATransport.jsonrpc,
62+
"GRPC": A2ATransport.grpc,
63+
}
64+
5865

5966
# An MCPToolset for a single registered MCP server. Adds the special
6067
# gcp.mcp.server.destination.id custom_metadata key on each returned tool. This special key is
@@ -232,13 +239,15 @@ def _get_connection_uri(
232239
for p in protocols:
233240
if protocol_type and p.get("type") != protocol_type:
234241
continue
242+
protocol_version = p.get("protocolVersion")
235243
for i in p.get("interfaces", []):
236-
if protocol_binding and i.get("protocolBinding") != protocol_binding:
244+
mapped_binding = _TRANSPORT_MAPPING.get(i.get("protocolBinding"))
245+
if protocol_binding and mapped_binding != protocol_binding:
237246
continue
238247
if url := i.get("url"):
239-
return url
248+
return url, protocol_version, mapped_binding
240249

241-
return None
250+
return None, None, None
242251

243252
def _clean_name(self, name: str) -> str:
244253
"""Cleans a string to be a valid Python identifier for agent names."""
@@ -284,11 +293,13 @@ def get_mcp_toolset(
284293
if not isinstance(mcp_server_id, str):
285294
mcp_server_id = None
286295

287-
endpoint_uri = self._get_connection_uri(
296+
endpoint_uri, _, _ = self._get_connection_uri(
288297
server_details, protocol_binding=A2ATransport.jsonrpc
289-
) or self._get_connection_uri(
290-
server_details, protocol_binding=A2ATransport.http_json
291298
)
299+
if not endpoint_uri:
300+
endpoint_uri, _, _ = self._get_connection_uri(
301+
server_details, protocol_binding=A2ATransport.http_json
302+
)
292303
if not endpoint_uri:
293304
raise ValueError(
294305
f"MCP Server endpoint URI not found for: {mcp_server_name}"
@@ -339,7 +350,7 @@ def get_model_name(self, endpoint_name: str) -> str:
339350
projects/.../locations/.../publishers/google/models/...).
340351
"""
341352
endpoint_details = self.get_endpoint(endpoint_name)
342-
uri = self._get_connection_uri(endpoint_details)
353+
uri, _, _ = self._get_connection_uri(endpoint_details)
343354
if not uri:
344355
raise ValueError(
345356
f"Connection URI not found for endpoint: {endpoint_name}"
@@ -378,7 +389,12 @@ def get_agent_info(self, name: str) -> Dict[str, Any]:
378389
"""Retrieves detailed metadata of a specific A2A Agent."""
379390
return self._make_request(name)
380391

381-
def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent:
392+
def get_remote_a2a_agent(
393+
self,
394+
agent_name: str,
395+
*,
396+
httpx_client: httpx.AsyncClient | None = None,
397+
) -> RemoteA2aAgent:
382398
"""Creates a RemoteA2aAgent instance for a registered A2A Agent."""
383399
agent_info = self.get_agent_info(agent_name)
384400

@@ -389,17 +405,19 @@ def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent:
389405
agent_card = AgentCard(**card_content)
390406
# Clean the name to be a valid identifier
391407
name = self._clean_name(agent_card.name)
408+
392409
return RemoteA2aAgent(
393410
name=name,
394411
agent_card=agent_card,
395412
description=agent_card.description,
413+
httpx_client=httpx_client,
396414
)
397415

398416
name = self._clean_name(agent_info.get("displayName", agent_name))
399417
description = agent_info.get("description", "")
400418
version = agent_info.get("version", "")
401419

402-
url = self._get_connection_uri(
420+
url, protocol_version, protocol_binding = self._get_connection_uri(
403421
agent_info, protocol_type=_ProtocolType.A2A_AGENT
404422
)
405423
if not url:
@@ -421,6 +439,8 @@ def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent:
421439
name=name,
422440
description=description,
423441
version=version,
442+
preferredTransport=protocol_binding or A2ATransport.http_json,
443+
protocolVersion=protocol_version or "0.3.0",
424444
url=url,
425445
skills=skills,
426446
capabilities=AgentCapabilities(streaming=False, polling=False),
@@ -432,4 +452,5 @@ def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent:
432452
name=name,
433453
agent_card=agent_card,
434454
description=description,
455+
httpx_client=httpx_client,
435456
)

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def test_get_mcp_toolset_adds_destination_id(
6060
),
6161
"interfaces": [{
6262
"url": "https://mcp.com",
63-
"protocolBinding": A2ATransport.jsonrpc,
63+
"protocolBinding": "JSONRPC",
6464
}],
6565
}
6666
mock_httpx.return_value.__enter__.return_value.get.return_value = (
@@ -126,7 +126,7 @@ async def test_get_mcp_toolset_handles_missing_destination_id(
126126
# "mcpServerId" is intentionally omitted
127127
"interfaces": [{
128128
"url": "https://mcp.com",
129-
"protocolBinding": A2ATransport.jsonrpc,
129+
"protocolBinding": "JSONRPC",
130130
}],
131131
}
132132
mock_httpx.return_value.__enter__.return_value.get.return_value = (
@@ -176,25 +176,29 @@ def test_get_connection_uri_mcp_interfaces_top_level(self, registry):
176176
{"url": "https://mcp-v1main.com", "protocolBinding": "JSONRPC"}
177177
]
178178
}
179-
uri = registry._get_connection_uri(
179+
uri, version, binding = registry._get_connection_uri(
180180
resource_details, protocol_binding=A2ATransport.jsonrpc
181181
)
182182
assert uri == "https://mcp-v1main.com"
183+
assert version is None
184+
assert binding == "JSONRPC"
183185

184186
def test_get_connection_uri_agent_nested_protocols(self, registry):
185187
resource_details = {
186188
"protocols": [{
187189
"type": _ProtocolType.A2A_AGENT,
188190
"interfaces": [{
189191
"url": "https://my-agent.com",
190-
"protocolBinding": A2ATransport.jsonrpc,
192+
"protocolBinding": "JSONRPC",
191193
}],
192194
}]
193195
}
194-
uri = registry._get_connection_uri(
196+
uri, version, binding = registry._get_connection_uri(
195197
resource_details, protocol_type=_ProtocolType.A2A_AGENT
196198
)
197199
assert uri == "https://my-agent.com"
200+
assert version is None
201+
assert binding == A2ATransport.jsonrpc
198202

199203
def test_get_connection_uri_filtering(self, registry):
200204
resource_details = {
@@ -207,42 +211,52 @@ def test_get_connection_uri_filtering(self, registry):
207211
"type": _ProtocolType.A2A_AGENT,
208212
"interfaces": [{
209213
"url": "https://my-agent.com",
210-
"protocolBinding": A2ATransport.http_json,
214+
"protocolBinding": "HTTP_JSON",
211215
}],
212216
},
213217
]
214218
}
215219
# Filter by type
216-
uri = registry._get_connection_uri(
220+
uri, version, binding = registry._get_connection_uri(
217221
resource_details, protocol_type=_ProtocolType.A2A_AGENT
218222
)
219223
assert uri == "https://my-agent.com"
224+
assert version is None
225+
assert binding == A2ATransport.http_json
220226

221227
# Filter by binding
222-
uri = registry._get_connection_uri(
228+
uri, version, binding = registry._get_connection_uri(
223229
resource_details, protocol_binding=A2ATransport.http_json
224230
)
225231
assert uri == "https://my-agent.com"
232+
assert version is None
233+
assert binding == A2ATransport.http_json
226234

227235
# No match
228-
uri = registry._get_connection_uri(
236+
uri, version, binding = registry._get_connection_uri(
229237
resource_details,
230238
protocol_type=_ProtocolType.A2A_AGENT,
231239
protocol_binding=A2ATransport.jsonrpc,
232240
)
233241
assert uri is None
242+
assert version is None
243+
assert binding is None
234244

235245
def test_get_connection_uri_returns_none_if_no_interfaces(self, registry):
236246
resource_details = {}
237-
uri = registry._get_connection_uri(resource_details)
247+
uri, version, binding = registry._get_connection_uri(resource_details)
238248
assert uri is None
249+
assert version is None
250+
assert binding is None
239251

240252
def test_get_connection_uri_returns_none_if_no_url_in_interfaces(
241253
self, registry
242254
):
243255
resource_details = {"interfaces": [{"protocolBinding": "HTTP"}]}
244-
uri = registry._get_connection_uri(resource_details)
256+
uri, version, binding = registry._get_connection_uri(resource_details)
245257
assert uri is None
258+
assert version is None
259+
assert binding is None
246260

247261
@patch("httpx.Client")
248262
def test_list_agents(self, mock_httpx, registry):
@@ -313,7 +327,7 @@ def test_get_mcp_toolset(self, mock_httpx, registry):
313327
"displayName": "TestPrefix",
314328
"interfaces": [{
315329
"url": "https://mcp.com",
316-
"protocolBinding": A2ATransport.jsonrpc,
330+
"protocolBinding": "JSONRPC",
317331
}],
318332
}
319333
mock_response.raise_for_status = MagicMock()
@@ -335,7 +349,7 @@ def test_get_mcp_toolset_with_auth(self, mock_httpx, registry):
335349
"displayName": "TestPrefix",
336350
"interfaces": [{
337351
"url": "https://mcp.com",
338-
"protocolBinding": A2ATransport.jsonrpc,
352+
"protocolBinding": "JSONRPC",
339353
}],
340354
}
341355
mock_response.raise_for_status = MagicMock()
@@ -370,9 +384,10 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry):
370384
"version": "1.0",
371385
"protocols": [{
372386
"type": _ProtocolType.A2A_AGENT,
387+
"protocolVersion": "0.4.0",
373388
"interfaces": [{
374389
"url": "https://my-agent.com",
375-
"protocolBinding": A2ATransport.jsonrpc,
390+
"protocolBinding": "HTTP_JSON",
376391
}],
377392
}],
378393
"skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}],
@@ -393,6 +408,35 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry):
393408
assert agent._agent_card.version == "1.0"
394409
assert len(agent._agent_card.skills) == 1
395410
assert agent._agent_card.skills[0].name == "Skill 1"
411+
assert agent._agent_card.preferred_transport == A2ATransport.http_json
412+
assert agent._agent_card.protocol_version == "0.4.0"
413+
414+
@patch("httpx.Client")
415+
def test_get_remote_a2a_agent_defaults(self, mock_httpx, registry):
416+
mock_response = MagicMock()
417+
mock_response.json.return_value = {
418+
"displayName": "TestAgent",
419+
"description": "Test Desc",
420+
"version": "1.0",
421+
"protocols": [{
422+
"type": _ProtocolType.A2A_AGENT,
423+
"interfaces": [{
424+
"url": "https://my-agent.com",
425+
}],
426+
}],
427+
}
428+
mock_response.raise_for_status = MagicMock()
429+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
430+
mock_response
431+
)
432+
433+
registry._credentials.token = "token"
434+
registry._credentials.refresh = MagicMock()
435+
436+
agent = registry.get_remote_a2a_agent("test-agent")
437+
assert isinstance(agent, RemoteA2aAgent)
438+
assert agent._agent_card.preferred_transport == A2ATransport.http_json
439+
assert agent._agent_card.protocol_version == "0.3.0"
396440

397441
@patch("httpx.Client")
398442
def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry):
@@ -436,6 +480,57 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry):
436480
assert len(agent._agent_card.skills) == 1
437481
assert agent._agent_card.skills[0].name == "S1"
438482

483+
@patch("httpx.Client")
484+
def test_get_remote_a2a_agent_with_httpx_client(self, mock_httpx, registry):
485+
mock_response = MagicMock()
486+
mock_response.json.return_value = {
487+
"displayName": "TestAgent",
488+
"description": "Test Desc",
489+
"version": "1.0",
490+
"protocols": [{
491+
"type": _ProtocolType.A2A_AGENT,
492+
"interfaces": [{
493+
"url": "https://my-agent.com",
494+
}],
495+
}],
496+
}
497+
mock_response.raise_for_status = MagicMock()
498+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
499+
mock_response
500+
)
501+
502+
custom_client = httpx.AsyncClient()
503+
agent = registry.get_remote_a2a_agent(
504+
"test-agent", httpx_client=custom_client
505+
)
506+
assert agent._httpx_client is custom_client
507+
508+
@patch("httpx.Client")
509+
def test_get_remote_a2a_agent_configures_transports(
510+
self, mock_httpx, registry
511+
):
512+
mock_response = MagicMock()
513+
mock_response.json.return_value = {
514+
"displayName": "TestAgent",
515+
"protocols": [{
516+
"type": _ProtocolType.A2A_AGENT,
517+
"interfaces": [{
518+
"url": "https://my-agent.com",
519+
"protocolBinding": A2ATransport.jsonrpc,
520+
}],
521+
}],
522+
}
523+
mock_response.raise_for_status = MagicMock()
524+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
525+
mock_response
526+
)
527+
528+
registry._credentials.token = "token"
529+
registry._credentials.refresh = MagicMock()
530+
531+
agent = registry.get_remote_a2a_agent("test-agent")
532+
assert agent._agent_card.preferred_transport == A2ATransport.jsonrpc
533+
439534
def test_get_auth_headers(self, registry):
440535
registry._credentials.token = "fake-token"
441536
registry._credentials.refresh = MagicMock()
@@ -472,7 +567,7 @@ def test_make_request_raises_request_error(self, mock_httpx, registry):
472567
registry._credentials.refresh = MagicMock()
473568

474569
with pytest.raises(
475-
RuntimeError, match="API request failed \(network error\)"
570+
RuntimeError, match=r"API request failed \(network error\)"
476571
):
477572
registry._make_request("test-path")
478573

0 commit comments

Comments
 (0)