Skip to content

Commit ae1f2e6

Browse files
wuliang229copybara-github
authored andcommitted
fix(live): treat input transcription as user message
During agent transfer, the input transcription is represented as user input. Also, the previous implementation is incorrect in that the LlmResponse could never have content.role == 'user'. Modified the docstring to match the behavior. Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 900980704
1 parent 110aecf commit ae1f2e6

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -720,19 +720,21 @@ async def _receive_from_model(
720720
) -> AsyncGenerator[Event, None]:
721721
"""Receive data from model and process events using BaseLlmConnection."""
722722

723-
def get_author_for_event(llm_response):
723+
def get_author_for_event(llm_response: LlmResponse) -> str:
724724
"""Get the author of the event.
725725
726-
When the model returns transcription, the author is "user". Otherwise, the
727-
author is the agent name(not 'model').
726+
When the model returns input transcription, the author is set to "user".
727+
Otherwise, the author is the agent name (not 'model').
728728
729729
Args:
730730
llm_response: The LLM response from the LLM call.
731+
732+
Returns:
733+
The author of the event as a string, either "user" or the agent's name.
731734
"""
732-
if (
733-
llm_response
734-
and llm_response.content
735-
and llm_response.content.role == 'user'
735+
if llm_response and (
736+
llm_response.input_transcription
737+
or (llm_response.content and llm_response.content.role == 'user')
736738
):
737739
return 'user'
738740
else:

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,3 +925,59 @@ async def test_postprocess_live_session_resumption_update():
925925
assert len(events) == 1
926926
assert events[0].live_session_resumption_update is not None
927927
assert events[0].live_session_resumption_update.new_handle == 'test_handle'
928+
929+
930+
@pytest.mark.asyncio
931+
async def test_receive_from_model_author_attribution():
932+
"""Test that _receive_from_model sets the correct author for events based on LlmResponse."""
933+
agent = Agent(name='test_agent')
934+
invocation_context = await testing_utils.create_invocation_context(
935+
agent=agent
936+
)
937+
flow = BaseLlmFlowForTesting()
938+
939+
mock_connection = mock.AsyncMock()
940+
941+
# Case 1: input_transcription is set -> author should be 'user'
942+
response_1 = LlmResponse(
943+
input_transcription=types.Transcription(text='test', finished=True)
944+
)
945+
946+
# Case 2: default -> author should be agent.name
947+
response_2 = LlmResponse(
948+
content=types.Content(
949+
role='model', parts=[types.Part.from_text(text='hello')]
950+
)
951+
)
952+
953+
# Case 3: content.role is 'user' -> author should be 'user'
954+
response_3 = LlmResponse(
955+
content=types.Content(
956+
role='user', parts=[types.Part.from_text(text='user text')]
957+
)
958+
)
959+
960+
class StopTest(Exception):
961+
pass
962+
963+
async def mock_receive():
964+
yield response_1
965+
yield response_2
966+
yield response_3
967+
raise StopTest()
968+
969+
mock_connection.receive = mock.Mock(side_effect=mock_receive)
970+
971+
events = []
972+
try:
973+
async for event in flow._receive_from_model(
974+
mock_connection, 'event_id', invocation_context, LlmRequest()
975+
):
976+
events.append(event)
977+
except StopTest:
978+
pass
979+
980+
assert len(events) == 3
981+
assert events[0].author == 'user'
982+
assert events[1].author == 'test_agent'
983+
assert events[2].author == 'user'

0 commit comments

Comments
 (0)