Skip to content

Commit 454188d

Browse files
GWealecopybara-github
authored andcommitted
fix: execute on_event_callback before append_event to persist plugin modifications
Previously, on_event_callback ran after append_event, so plugin modifications (e.g. custom_metadata) were only visible in the yielded stream but never persisted to the session store. This moves the callback before persistence, matching the documented contract in base_plugin.py. Fixes #3990 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 899845848
1 parent 67dc2eb commit 454188d

File tree

3 files changed

+136
-18
lines changed

3 files changed

+136
-18
lines changed

src/google/adk/plugins/base_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ async def before_run_callback(
155155
async def on_event_callback(
156156
self, *, invocation_context: InvocationContext, event: Event
157157
) -> Optional[Event]:
158-
"""Callback executed after an event is yielded from runner.
158+
"""Callback executed when the runner produces an event.
159159
160-
This is the ideal place to make modification to the event before the event
161-
is handled by the underlying agent app.
160+
This is the ideal place to modify the event before it is persisted to the
161+
session service and yielded to the caller.
162162
163163
Args:
164164
invocation_context: The context for the entire invocation.

src/google/adk/runners.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,34 @@ def _should_append_event(self, event: Event, is_live_call: bool) -> bool:
791791
return False
792792
return True
793793

794+
def _get_output_event(
795+
self,
796+
*,
797+
original_event: Event,
798+
modified_event: Event | None,
799+
run_config: RunConfig | None,
800+
) -> Event:
801+
"""Returns the event that should be persisted and yielded.
802+
803+
Plugins may return a replacement event that only overrides a subset of
804+
fields. Merge those changes onto the original event so the streamed event
805+
and the persisted event stay aligned without losing the original event
806+
identity.
807+
"""
808+
if modified_event is None:
809+
return original_event
810+
811+
_apply_run_config_custom_metadata(modified_event, run_config)
812+
update = {}
813+
for field_name in modified_event.model_fields_set:
814+
if field_name in {'id', 'invocation_id', 'timestamp'}:
815+
continue
816+
update[field_name] = modified_event.__dict__[field_name]
817+
output_event = original_event.model_copy(update=update)
818+
if not output_event.author:
819+
output_event.author = original_event.author
820+
return output_event
821+
794822
async def _exec_with_plugin(
795823
self,
796824
invocation_context: InvocationContext,
@@ -854,13 +882,24 @@ async def _exec_with_plugin(
854882
_apply_run_config_custom_metadata(
855883
event, invocation_context.run_config
856884
)
885+
# Step 3: Run the on_event callbacks before persisting so callback
886+
# changes are stored in the session and match the streamed event.
887+
modified_event = await plugin_manager.run_on_event_callback(
888+
invocation_context=invocation_context, event=event
889+
)
890+
output_event = self._get_output_event(
891+
original_event=event,
892+
modified_event=modified_event,
893+
run_config=invocation_context.run_config,
894+
)
895+
857896
if is_live_call:
858897
if event.partial and _is_transcription(event):
859898
is_transcribing = True
860899
if is_transcribing and _is_tool_call_or_response(event):
861900
# only buffer function call and function response event which is
862901
# non-partial
863-
buffered_events.append(event)
902+
buffered_events.append(output_event)
864903
continue
865904
# Note for live/bidi: for audio response, it's considered as
866905
# non-partial event(event.partial=None)
@@ -881,7 +920,7 @@ async def _exec_with_plugin(
881920
)
882921
if self._should_append_event(event, is_live_call):
883922
await self.session_service.append_event(
884-
session=session, event=event
923+
session=session, event=output_event
885924
)
886925

887926
for buffered_event in buffered_events:
@@ -897,25 +936,15 @@ async def _exec_with_plugin(
897936
if self._should_append_event(event, is_live_call):
898937
logger.debug('Appending non-buffered event: %s', event)
899938
await self.session_service.append_event(
900-
session=session, event=event
939+
session=session, event=output_event
901940
)
902941
else:
903942
if event.partial is not True:
904943
await self.session_service.append_event(
905-
session=session, event=event
944+
session=session, event=output_event
906945
)
907946

908-
# Step 3: Run the on_event callbacks to optionally modify the event.
909-
modified_event = await plugin_manager.run_on_event_callback(
910-
invocation_context=invocation_context, event=event
911-
)
912-
if modified_event:
913-
_apply_run_config_custom_metadata(
914-
modified_event, invocation_context.run_config
915-
)
916-
yield modified_event
917-
else:
918-
yield event
947+
yield output_event
919948

920949
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
921950
# finalizing logs and metrics data.

tests/unittests/test_runners.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class MockPlugin(BasePlugin):
139139
"Modified user message ON_USER_CALLBACK_MSG from MockPlugin"
140140
)
141141
ON_EVENT_CALLBACK_MSG = "Modified event ON_EVENT_CALLBACK_MSG from MockPlugin"
142+
ON_EVENT_CALLBACK_METADATA = {"plugin_key": "plugin_value"}
142143

143144
def __init__(self):
144145
super().__init__(name="mock_plugin")
@@ -184,6 +185,7 @@ async def on_event_callback(
184185
],
185186
role=event.content.role,
186187
),
188+
custom_metadata=self.ON_EVENT_CALLBACK_METADATA,
187189
)
188190

189191

@@ -359,6 +361,60 @@ async def test_run_live_auto_create_session():
359361
assert session is not None
360362

361363

364+
@pytest.mark.asyncio
365+
async def test_run_live_persists_event_callback_modifications():
366+
"""run_live should persist the same event it streams after callback changes."""
367+
session_service = InMemorySessionService()
368+
artifact_service = InMemoryArtifactService()
369+
plugin = MockPlugin()
370+
plugin.enable_event_callback = True
371+
runner = Runner(
372+
app_name="live_app",
373+
agent=MockLiveAgent("live_agent"),
374+
session_service=session_service,
375+
artifact_service=artifact_service,
376+
plugins=[plugin],
377+
)
378+
await session_service.create_session(
379+
app_name="live_app", user_id="user", session_id="live_session"
380+
)
381+
382+
from google.adk.agents.live_request_queue import LiveRequestQueue
383+
384+
live_queue = LiveRequestQueue()
385+
agen = runner.run_live(
386+
user_id="user",
387+
session_id="live_session",
388+
live_request_queue=live_queue,
389+
)
390+
391+
streamed_event = await agen.__anext__()
392+
await agen.aclose()
393+
394+
session = await session_service.get_session(
395+
app_name="live_app", user_id="user", session_id="live_session"
396+
)
397+
persisted_event = session.events[0]
398+
399+
assert streamed_event.author == "live_agent"
400+
assert streamed_event.invocation_id
401+
assert streamed_event.content.parts[0].text == (
402+
MockPlugin.ON_EVENT_CALLBACK_MSG
403+
)
404+
assert streamed_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
405+
406+
assert persisted_event.id == streamed_event.id
407+
assert persisted_event.timestamp == streamed_event.timestamp
408+
assert persisted_event.author == streamed_event.author
409+
assert persisted_event.invocation_id == streamed_event.invocation_id
410+
assert persisted_event.content.parts[0].text == (
411+
MockPlugin.ON_EVENT_CALLBACK_MSG
412+
)
413+
assert (
414+
persisted_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
415+
)
416+
417+
362418
@pytest.mark.asyncio
363419
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
364420
project_root = tmp_path / "workspace"
@@ -747,6 +803,39 @@ async def test_runner_modifies_event_after_execution(self):
747803

748804
assert modified_event_message == MockPlugin.ON_EVENT_CALLBACK_MSG
749805

806+
@pytest.mark.asyncio
807+
async def test_runner_persists_event_callback_modifications(self):
808+
"""Event callback output should be persisted, not only streamed."""
809+
self.plugin.enable_event_callback = True
810+
811+
events = await self.run_test()
812+
streamed_event = events[0]
813+
814+
session = await self.session_service.get_session(
815+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
816+
)
817+
persisted_event = session.events[1]
818+
819+
assert streamed_event.author == "test_agent"
820+
assert streamed_event.invocation_id
821+
assert streamed_event.content.parts[0].text == (
822+
MockPlugin.ON_EVENT_CALLBACK_MSG
823+
)
824+
assert (
825+
streamed_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
826+
)
827+
828+
assert persisted_event.id == streamed_event.id
829+
assert persisted_event.timestamp == streamed_event.timestamp
830+
assert persisted_event.author == streamed_event.author
831+
assert persisted_event.invocation_id == streamed_event.invocation_id
832+
assert persisted_event.content.parts[0].text == (
833+
MockPlugin.ON_EVENT_CALLBACK_MSG
834+
)
835+
assert (
836+
persisted_event.custom_metadata == MockPlugin.ON_EVENT_CALLBACK_METADATA
837+
)
838+
750839
@pytest.mark.asyncio
751840
async def test_runner_close_calls_plugin_close(self):
752841
"""Test that runner.close() calls plugin manager close."""

0 commit comments

Comments
 (0)