Skip to content

Commit 9ca8c38

Browse files
haiyuan-eng-googlecopybara-github
authored andcommitted
fix: Resolve BigQuery plugin issues with A2A transfers, spans, and metadata
Fixes #5073, #5310, and #5311 with three targeted updates to the `BigQueryAgentAnalyticsPlugin` (no changes to ADK core): - Classifies `TransferToAgentTool` transfers to `RemoteA2aAgent` as `TRANSFER_A2A` instead of the generic `TRANSFER_AGENT` by resolving the target agent at the call level. - Ensures a self-consistent BigQuery span tree by preferring the plugin's internal span stack over ambient OTel spans, resolving dangling `parent_span_id` references. - Surfaces remote A2A interaction metadata (`a2a:request`, `a2a:response`, etc.) in BigQuery by detecting them in custom metadata and logging new `A2A_INTERACTION` events. Co-authored-by: Haiyuan Cao <haiyuan@google.com> PiperOrigin-RevId: 900224778
1 parent 47fa7b7 commit 9ca8c38

File tree

2 files changed

+655
-86
lines changed

2 files changed

+655
-86
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 167 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,57 @@ def _format_content(
168168
return " | ".join(parts), truncated
169169

170170

171-
def _get_tool_origin(tool: "BaseTool") -> str:
171+
def _find_transfer_target(agent, agent_name: str):
172+
"""Find a transfer target agent by name in the accessible agent tree.
173+
174+
Searches the current agent's sub-agents, parent, and peer agents
175+
to locate the transfer target.
176+
177+
Args:
178+
agent: The current agent executing the transfer.
179+
agent_name: The name of the transfer target to find.
180+
181+
Returns:
182+
The matching agent object, or None if not found.
183+
"""
184+
for sub in getattr(agent, "sub_agents", []):
185+
if sub.name == agent_name:
186+
return sub
187+
parent = getattr(agent, "parent_agent", None)
188+
if parent is not None and parent.name == agent_name:
189+
return parent
190+
if parent is not None:
191+
for peer in getattr(parent, "sub_agents", []):
192+
if peer.name == agent_name and peer.name != agent.name:
193+
return peer
194+
return None
195+
196+
197+
def _get_tool_origin(
198+
tool: "BaseTool",
199+
tool_args: Optional[dict[str, Any]] = None,
200+
tool_context: Optional["ToolContext"] = None,
201+
) -> str:
172202
"""Returns the provenance category of a tool.
173203
174204
Uses lazy imports to avoid circular dependencies.
175205
206+
For ``TransferToAgentTool`` the classification is **call-level**: when
207+
*tool_args* and *tool_context* are supplied the selected
208+
``agent_name`` is resolved against the agent tree so that transfers
209+
to a ``RemoteA2aAgent`` are labelled ``TRANSFER_A2A`` rather than
210+
the generic ``TRANSFER_AGENT``.
211+
176212
Args:
177213
tool: The tool instance.
214+
tool_args: Optional tool arguments, used for call-level classification of
215+
TransferToAgentTool.
216+
tool_context: Optional tool context, used to access the agent tree for
217+
TransferToAgentTool classification.
178218
179219
Returns:
180-
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN.
220+
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT,
221+
TRANSFER_A2A, or UNKNOWN.
181222
"""
182223
# Import lazily to avoid circular dependencies.
183224
# pylint: disable=g-import-not-at-top
@@ -199,6 +240,15 @@ def _get_tool_origin(tool: "BaseTool") -> str:
199240
if McpTool is not None and isinstance(tool, McpTool):
200241
return "MCP"
201242
if isinstance(tool, TransferToAgentTool):
243+
if RemoteA2aAgent is not None and tool_args and tool_context:
244+
agent_name = tool_args.get("agent_name")
245+
if agent_name:
246+
target = _find_transfer_target(
247+
tool_context._invocation_context.agent,
248+
agent_name,
249+
)
250+
if target is not None and isinstance(target, RemoteA2aAgent):
251+
return "TRANSFER_A2A"
202252
return "TRANSFER_AGENT"
203253
if isinstance(tool, AgentTool):
204254
if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent):
@@ -1825,6 +1875,25 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
18251875
"JSON_VALUE(content, '$.tool') AS tool_name",
18261876
"JSON_QUERY(content, '$.args') AS tool_args",
18271877
],
1878+
"A2A_INTERACTION": [
1879+
"content AS response_content",
1880+
(
1881+
"JSON_VALUE(attributes,"
1882+
" '$.a2a_metadata.\"a2a:task_id\"') AS a2a_task_id"
1883+
),
1884+
(
1885+
"JSON_VALUE(attributes,"
1886+
" '$.a2a_metadata.\"a2a:context_id\"') AS a2a_context_id"
1887+
),
1888+
(
1889+
"JSON_QUERY(attributes,"
1890+
" '$.a2a_metadata.\"a2a:request\"') AS a2a_request"
1891+
),
1892+
(
1893+
"JSON_QUERY(attributes,"
1894+
" '$.a2a_metadata.\"a2a:response\"') AS a2a_response"
1895+
),
1896+
],
18281897
}
18291898

18301899
_VIEW_SQL_TEMPLATE = """\
@@ -2552,39 +2621,59 @@ def _resolve_ids(
25522621
) -> tuple[Optional[str], Optional[str], Optional[str]]:
25532622
"""Resolves trace_id, span_id, and parent_span_id for a log row.
25542623
2624+
Resolution rules:
2625+
2626+
* **trace_id** — ambient OTel trace wins (the plugin stack already
2627+
shares the ambient trace when initialised from an ambient span,
2628+
so in practice they agree).
2629+
* **span_id / parent_span_id** — the plugin's internal span stack
2630+
(``TraceManager``) is the preferred source. Ambient OTel spans
2631+
are only used as a fallback when the plugin stack has no span.
2632+
This ensures every ``parent_span_id`` in BigQuery references a
2633+
``span_id`` that is also logged to BigQuery, producing a
2634+
self-consistent execution tree.
2635+
* **Explicit overrides** (``EventData``) always win last — they
2636+
are set by post-pop callbacks that have already captured the
2637+
correct plugin-stack values before the pop.
2638+
25552639
Priority order (highest first):
2556-
1. Explicit ``EventData`` overrides (needed for post-pop callbacks).
2557-
2. Ambient OTel span (the framework's ``start_as_current_span``).
2558-
When present this aligns BQ rows with Cloud Trace / o11y.
2559-
3. Plugin's internal span stack (``TraceManager``).
2640+
1. Explicit ``EventData`` overrides.
2641+
2. Plugin's internal span stack (``TraceManager``) for
2642+
``span_id`` / ``parent_span_id``.
2643+
3. Ambient OTel span — always used for ``trace_id``; used for
2644+
``span_id`` / ``parent_span_id`` only when the plugin stack
2645+
has no span.
25602646
4. ``invocation_id`` fallback for trace_id.
25612647
25622648
Returns:
25632649
(trace_id, span_id, parent_span_id)
25642650
"""
2565-
# --- Layer 3: plugin stack baseline ---
2651+
# --- Plugin stack: span_id / parent_span_id baseline ---
25662652
trace_id = TraceManager.get_trace_id(callback_context)
25672653
plugin_span_id, plugin_parent_span_id = (
25682654
TraceManager.get_current_span_and_parent()
25692655
)
25702656
span_id = plugin_span_id
25712657
parent_span_id = plugin_parent_span_id
25722658

2573-
# --- Layer 2: ambient OTel span ---
2659+
# --- Ambient OTel: trace_id always; span fallback only ---
25742660
ambient = trace.get_current_span()
25752661
ambient_ctx = ambient.get_span_context()
25762662
if ambient_ctx.is_valid:
25772663
trace_id = format(ambient_ctx.trace_id, "032x")
2578-
span_id = format(ambient_ctx.span_id, "016x")
2579-
# Reset parent — stale plugin-stack parent must not leak through
2580-
# when the ambient span is a root (no parent).
2581-
parent_span_id = None
2582-
# SDK spans expose .parent; non-recording spans do not.
2583-
parent_ctx = getattr(ambient, "parent", None)
2584-
if parent_ctx is not None and parent_ctx.span_id:
2585-
parent_span_id = format(parent_ctx.span_id, "016x")
2586-
2587-
# --- Layer 1: explicit EventData overrides ---
2664+
# Only use ambient span IDs when the plugin stack has no span.
2665+
# Framework-internal spans (execute_tool, call_llm, etc.) are
2666+
# never written to BQ, so deriving parent_span_id from them
2667+
# creates phantom references. The plugin stack guarantees
2668+
# that both span_id and parent_span_id reference BQ rows.
2669+
if span_id is None:
2670+
span_id = format(ambient_ctx.span_id, "016x")
2671+
parent_span_id = None
2672+
parent_ctx = getattr(ambient, "parent", None)
2673+
if parent_ctx is not None and parent_ctx.span_id:
2674+
parent_span_id = format(parent_ctx.span_id, "016x")
2675+
2676+
# --- Explicit EventData overrides (post-pop callbacks) ---
25882677
if event_data.trace_id_override is not None:
25892678
trace_id = event_data.trace_id_override
25902679
if event_data.span_id_override is not None:
@@ -2813,13 +2902,18 @@ async def on_event_callback(
28132902
invocation_context: InvocationContext,
28142903
event: "Event",
28152904
) -> None:
2816-
"""Logs state changes and HITL events from the event stream.
2905+
"""Logs state changes, HITL events, and A2A interactions.
28172906
28182907
- Checks each event for a non-empty state_delta and logs it as a
28192908
STATE_DELTA event.
28202909
- Detects synthetic ``adk_request_*`` function calls (HITL pause
28212910
events) and their corresponding function responses (HITL
28222911
completions) and emits dedicated HITL event types.
2912+
- Detects events carrying A2A interaction metadata
2913+
(``a2a:request`` / ``a2a:response`` in ``custom_metadata``)
2914+
and logs them as ``A2A_INTERACTION`` events so the remote
2915+
agent's response and cross-reference IDs (``a2a:task_id``,
2916+
``a2a:context_id``) are visible in BigQuery.
28232917
28242918
The HITL detection must happen here (not in tool callbacks) because
28252919
``adk_request_credential``, ``adk_request_confirmation``, and
@@ -2883,6 +2977,45 @@ async def on_event_callback(
28832977
is_truncated=is_truncated,
28842978
)
28852979

2980+
# --- A2A interaction logging ---
2981+
# RemoteA2aAgent attaches cross-reference metadata to events:
2982+
# a2a:task_id, a2a:context_id — correlation keys
2983+
# a2a:request, a2a:response — full interaction payload
2984+
# Log an A2A_INTERACTION event when meaningful payload is present
2985+
# so the supervisor's BQ trace contains the remote agent's
2986+
# response and cross-reference IDs for JOINs.
2987+
meta = getattr(event, "custom_metadata", None)
2988+
if meta and (
2989+
meta.get("a2a:request") is not None
2990+
or meta.get("a2a:response") is not None
2991+
):
2992+
a2a_keys = {k: v for k, v in meta.items() if k.startswith("a2a:")}
2993+
a2a_truncated, is_truncated = _recursive_smart_truncate(
2994+
a2a_keys, self.config.max_content_length
2995+
)
2996+
# Use the a2a:response as the event content when available,
2997+
# so the remote agent's answer is visible in the content
2998+
# column.
2999+
response_payload = a2a_keys.get("a2a:response")
3000+
content_dict = None
3001+
content_truncated = False
3002+
if response_payload is not None:
3003+
content_dict, content_truncated = _recursive_smart_truncate(
3004+
response_payload,
3005+
self.config.max_content_length,
3006+
)
3007+
await self._log_event(
3008+
"A2A_INTERACTION",
3009+
callback_ctx,
3010+
raw_content=content_dict,
3011+
is_truncated=is_truncated or content_truncated,
3012+
event_data=EventData(
3013+
extra_attributes={
3014+
"a2a_metadata": a2a_truncated,
3015+
},
3016+
),
3017+
)
3018+
28863019
return None
28873020

28883021
async def on_state_change_callback(
@@ -2940,19 +3073,14 @@ async def after_run_callback(
29403073
span_id, duration = TraceManager.pop_span()
29413074
parent_span_id = TraceManager.get_current_span_id()
29423075

2943-
# Only override span IDs when no ambient OTel span exists.
2944-
# When ambient exists, _resolve_ids Layer 2 uses the framework's
2945-
# span IDs, keeping STARTING/COMPLETED pairs consistent.
2946-
has_ambient = trace.get_current_span().get_span_context().is_valid
2947-
29483076
await self._log_event(
29493077
"INVOCATION_COMPLETED",
29503078
callback_ctx,
29513079
event_data=EventData(
29523080
trace_id_override=trace_id,
29533081
latency_ms=duration,
2954-
span_id_override=None if has_ambient else span_id,
2955-
parent_span_id_override=None if has_ambient else parent_span_id,
3082+
span_id_override=span_id,
3083+
parent_span_id_override=parent_span_id,
29563084
),
29573085
)
29583086
finally:
@@ -2995,18 +3123,13 @@ async def after_agent_callback(
29953123
span_id, duration = TraceManager.pop_span()
29963124
parent_span_id, _ = TraceManager.get_current_span_and_parent()
29973125

2998-
# Only override span IDs when no ambient OTel span exists.
2999-
# When ambient exists, _resolve_ids Layer 2 uses the framework's
3000-
# span IDs, keeping STARTING/COMPLETED pairs consistent.
3001-
has_ambient = trace.get_current_span().get_span_context().is_valid
3002-
30033126
await self._log_event(
30043127
"AGENT_COMPLETED",
30053128
callback_context,
30063129
event_data=EventData(
30073130
latency_ms=duration,
3008-
span_id_override=None if has_ambient else span_id,
3009-
parent_span_id_override=None if has_ambient else parent_span_id,
3131+
span_id_override=span_id,
3132+
parent_span_id_override=parent_span_id,
30103133
),
30113134
)
30123135

@@ -3156,12 +3279,6 @@ async def after_model_callback(
31563279
# Otherwise log_event will fetch current stack (which is parent).
31573280
span_id = popped_span_id or span_id
31583281

3159-
# Only override span IDs when no ambient OTel span exists.
3160-
# When ambient exists, _resolve_ids Layer 2 uses the framework's
3161-
# span IDs, keeping LLM_REQUEST/LLM_RESPONSE pairs consistent.
3162-
has_ambient = trace.get_current_span().get_span_context().is_valid
3163-
use_override = is_popped and not has_ambient
3164-
31653282
await self._log_event(
31663283
"LLM_RESPONSE",
31673284
callback_context,
@@ -3172,8 +3289,8 @@ async def after_model_callback(
31723289
time_to_first_token_ms=tfft,
31733290
model_version=llm_response.model_version,
31743291
usage_metadata=llm_response.usage_metadata,
3175-
span_id_override=span_id if use_override else None,
3176-
parent_span_id_override=parent_span_id if use_override else None,
3292+
span_id_override=span_id if is_popped else None,
3293+
parent_span_id_override=(parent_span_id if is_popped else None),
31773294
),
31783295
)
31793296

@@ -3195,18 +3312,15 @@ async def on_model_error_callback(
31953312
span_id, duration = TraceManager.pop_span()
31963313
parent_span_id, _ = TraceManager.get_current_span_and_parent()
31973314

3198-
# Only override span IDs when no ambient OTel span exists.
3199-
has_ambient = trace.get_current_span().get_span_context().is_valid
3200-
32013315
await self._log_event(
32023316
"LLM_ERROR",
32033317
callback_context,
32043318
event_data=EventData(
32053319
status="ERROR",
32063320
error_message=str(error),
32073321
latency_ms=duration,
3208-
span_id_override=None if has_ambient else span_id,
3209-
parent_span_id_override=None if has_ambient else parent_span_id,
3322+
span_id_override=span_id,
3323+
parent_span_id_override=parent_span_id,
32103324
),
32113325
)
32123326

@@ -3228,7 +3342,7 @@ async def before_tool_callback(
32283342
args_truncated, is_truncated = _recursive_smart_truncate(
32293343
tool_args, self.config.max_content_length
32303344
)
3231-
tool_origin = _get_tool_origin(tool)
3345+
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
32323346
content_dict = {
32333347
"tool": tool.name,
32343348
"args": args_truncated,
@@ -3262,7 +3376,7 @@ async def after_tool_callback(
32623376
resp_truncated, is_truncated = _recursive_smart_truncate(
32633377
result, self.config.max_content_length
32643378
)
3265-
tool_origin = _get_tool_origin(tool)
3379+
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
32663380
content_dict = {
32673381
"tool": tool.name,
32683382
"result": resp_truncated,
@@ -3271,13 +3385,10 @@ async def after_tool_callback(
32713385
span_id, duration = TraceManager.pop_span()
32723386
parent_span_id, _ = TraceManager.get_current_span_and_parent()
32733387

3274-
# Only override span IDs when no ambient OTel span exists.
3275-
has_ambient = trace.get_current_span().get_span_context().is_valid
3276-
32773388
event_data = EventData(
32783389
latency_ms=duration,
3279-
span_id_override=None if has_ambient else span_id,
3280-
parent_span_id_override=None if has_ambient else parent_span_id,
3390+
span_id_override=span_id,
3391+
parent_span_id_override=parent_span_id,
32813392
)
32823393
await self._log_event(
32833394
"TOOL_COMPLETED",
@@ -3307,7 +3418,7 @@ async def on_tool_error_callback(
33073418
args_truncated, is_truncated = _recursive_smart_truncate(
33083419
tool_args, self.config.max_content_length
33093420
)
3310-
tool_origin = _get_tool_origin(tool)
3421+
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
33113422
content_dict = {
33123423
"tool": tool.name,
33133424
"args": args_truncated,
@@ -3316,9 +3427,6 @@ async def on_tool_error_callback(
33163427
span_id, duration = TraceManager.pop_span()
33173428
parent_span_id, _ = TraceManager.get_current_span_and_parent()
33183429

3319-
# Only override span IDs when no ambient OTel span exists.
3320-
has_ambient = trace.get_current_span().get_span_context().is_valid
3321-
33223430
await self._log_event(
33233431
"TOOL_ERROR",
33243432
tool_context,
@@ -3328,7 +3436,7 @@ async def on_tool_error_callback(
33283436
status="ERROR",
33293437
error_message=str(error),
33303438
latency_ms=duration,
3331-
span_id_override=None if has_ambient else span_id,
3332-
parent_span_id_override=None if has_ambient else parent_span_id,
3439+
span_id_override=span_id,
3440+
parent_span_id_override=parent_span_id,
33333441
),
33343442
)

0 commit comments

Comments
 (0)