Skip to content

Commit a64a8e4

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Live avatar support in ADK
Testing plan: Added new unit tests - `test_avatar_config_initialization` - `test_avatar_config_with_name` - `test_receive_video_content` - `test_streaming_with_avatar_config` PiperOrigin-RevId: 899193911
1 parent cbcb5e6 commit a64a8e4

File tree

8 files changed

+170
-21
lines changed

8 files changed

+170
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ dependencies = [
4444
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
4545
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
4646
"google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service
47-
"google-genai>=1.64.0, <2.0.0", # Google GenAI SDK
47+
"google-genai>=1.72.0, <2.0.0", # Google GenAI SDK
4848
"graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering
4949
"httpx>=0.27.0, <1.0.0", # HTTP client library
5050
"jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation

src/google/adk/agents/run_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ class RunConfig(BaseModel):
198198
response_modalities: Optional[list[str]] = None
199199
"""The output modalities. If not set, it's default to AUDIO."""
200200

201+
avatar_config: Optional[types.AvatarConfig] = None
202+
"""Avatar configuration for the live agent."""
203+
201204
save_input_blobs_as_artifacts: bool = Field(
202205
default=False,
203206
deprecated=True,

src/google/adk/flows/llm_flows/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _build_basic_request(
9090
llm_request.live_connect_config.context_window_compression = (
9191
invocation_context.run_config.context_window_compression
9292
)
93+
llm_request.live_connect_config.avatar_config = (
94+
invocation_context.run_config.avatar_config
95+
)
9396

9497

9598
class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,7 @@ async def send_content(self, content: types.Content):
115115
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
116116
self._model_version
117117
)
118-
is_gemini_api = self._api_backend == GoogleLLMVariant.GEMINI_API
119-
120-
# As of now, Gemini 3.1 Flash Live is only available in Gemini API, not
121-
# Vertex AI.
122-
if (
123-
is_gemini_31
124-
and is_gemini_api
125-
and len(content.parts) == 1
126-
and content.parts[0].text
127-
):
118+
if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text:
128119
logger.debug('Using send_realtime_input for Gemini 3.1 text input')
129120
await self._gemini_session.send_realtime_input(
130121
text=content.parts[0].text
@@ -149,11 +140,7 @@ async def send_realtime(self, input: RealtimeInput):
149140
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
150141
self._model_version
151142
)
152-
is_gemini_api = self._api_backend == GoogleLLMVariant.GEMINI_API
153-
154-
# As of now, Gemini 3.1 Flash Live is only available in Gemini API, not
155-
# Vertex AI.
156-
if is_gemini_31 and is_gemini_api:
143+
if is_gemini_31:
157144
if input.mime_type and input.mime_type.startswith('audio/'):
158145
await self._gemini_session.send_realtime_input(audio=input)
159146
elif input.mime_type and input.mime_type.startswith('image/'):

src/google/adk/utils/model_name_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ def is_gemini_2_or_above(model_string: Optional[str]) -> bool:
130130
def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool:
131131
"""Check if the model is a Gemini 3.1 Flash Live model.
132132
133-
Note: This is a very specific model name for live bidi streaming, so we check
134-
for exact match.
135-
136133
Args:
137134
model_string: The model name
138135
@@ -141,5 +138,4 @@ def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool:
141138
"""
142139
if not model_string:
143140
return False
144-
145-
return model_string == 'gemini-3.1-flash-live-preview'
141+
return model_string.startswith('gemini-3.1-flash-live')

tests/unittests/agents/test_run_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest.mock import patch
1818

1919
from google.adk.agents.run_config import RunConfig
20+
from google.genai import types
2021
import pytest
2122

2223

@@ -64,3 +65,35 @@ def test_audio_transcription_configs_are_not_shared_between_instances():
6465
assert (
6566
config1.input_audio_transcription is not config2.input_audio_transcription
6667
)
68+
69+
70+
def test_avatar_config_initialization():
71+
custom_avatar = types.CustomizedAvatar(
72+
image_mime_type="image/jpeg", image_data=b"image_bytes"
73+
)
74+
avatar_config = types.AvatarConfig(
75+
audio_bitrate_bps=128000,
76+
video_bitrate_bps=1000000,
77+
customized_avatar=custom_avatar,
78+
)
79+
run_config = RunConfig(avatar_config=avatar_config)
80+
81+
assert run_config.avatar_config == avatar_config
82+
assert run_config.avatar_config.customized_avatar == custom_avatar
83+
assert (
84+
run_config.avatar_config.customized_avatar.image_mime_type == "image/jpeg"
85+
)
86+
assert run_config.avatar_config.customized_avatar.image_data == b"image_bytes"
87+
88+
89+
def test_avatar_config_with_name():
90+
avatar_config = types.AvatarConfig(
91+
audio_bitrate_bps=128000,
92+
video_bitrate_bps=1000000,
93+
avatar_name="test_avatar",
94+
)
95+
run_config = RunConfig(avatar_config=avatar_config)
96+
97+
assert run_config.avatar_config == avatar_config
98+
assert run_config.avatar_config.avatar_name == "test_avatar"
99+
assert run_config.avatar_config.customized_avatar is None

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,3 +1120,43 @@ async def mock_receive_generator():
11201120

11211121
assert len(responses) == 1
11221122
assert responses[0].go_away == mock_go_away
1123+
1124+
1125+
@pytest.mark.asyncio
1126+
async def test_receive_video_content(gemini_connection, mock_gemini_session):
1127+
"""Test receive with video content."""
1128+
mock_content = types.Content(
1129+
role='model',
1130+
parts=[
1131+
types.Part(
1132+
inline_data=types.Blob(data=b'video_data', mime_type='video/mp4')
1133+
)
1134+
],
1135+
)
1136+
mock_server_content = mock.Mock()
1137+
mock_server_content.model_turn = mock_content
1138+
mock_server_content.interrupted = False
1139+
mock_server_content.input_transcription = None
1140+
mock_server_content.output_transcription = None
1141+
mock_server_content.turn_complete = False
1142+
mock_server_content.grounding_metadata = None
1143+
1144+
mock_message = mock.AsyncMock()
1145+
mock_message.usage_metadata = None
1146+
mock_message.server_content = mock_server_content
1147+
mock_message.tool_call = None
1148+
mock_message.session_resumption_update = None
1149+
mock_message.go_away = None
1150+
1151+
async def mock_receive_generator():
1152+
yield mock_message
1153+
1154+
receive_mock = mock.Mock(return_value=mock_receive_generator())
1155+
mock_gemini_session.receive = receive_mock
1156+
1157+
responses = [resp async for resp in gemini_connection.receive()]
1158+
1159+
assert responses
1160+
content_response = next((r for r in responses if r.content), None)
1161+
assert content_response is not None
1162+
assert content_response.content == mock_content

tests/unittests/streaming/test_live_streaming_configs.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,90 @@ def test_streaming_with_context_window_compression_config():
642642
llm_request_sent_to_mock.live_connect_config.context_window_compression.sliding_window.target_tokens
643643
== 500
644644
)
645+
646+
647+
def test_streaming_with_avatar_config():
648+
"""Test avatar_config propagation and video content through run_live.
649+
650+
Verifies:
651+
1. avatar_config from RunConfig is propagated to live_connect_config.
652+
2. Video inline_data from the model flows through events correctly.
653+
"""
654+
# Mock model returns video content followed by turn_complete.
655+
video_response = LlmResponse(
656+
content=types.Content(
657+
role='model',
658+
parts=[
659+
types.Part(
660+
inline_data=types.Blob(
661+
data=b'video_data', mime_type='video/mp4'
662+
)
663+
)
664+
],
665+
),
666+
)
667+
turn_complete_response = LlmResponse(
668+
turn_complete=True,
669+
)
670+
671+
mock_model = testing_utils.MockModel.create(
672+
[video_response, turn_complete_response]
673+
)
674+
675+
root_agent = Agent(
676+
name='root_agent',
677+
model=mock_model,
678+
tools=[],
679+
)
680+
681+
runner = testing_utils.InMemoryRunner(
682+
root_agent=root_agent, response_modalities=['VIDEO']
683+
)
684+
685+
run_config = RunConfig(
686+
response_modalities=['VIDEO'],
687+
avatar_config=types.AvatarConfig(avatar_name='Kai'),
688+
)
689+
690+
live_request_queue = LiveRequestQueue()
691+
live_request_queue.send_realtime(
692+
blob=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
693+
)
694+
res_events = runner.run_live(live_request_queue, run_config)
695+
696+
assert res_events is not None, 'Expected a list of events, got None.'
697+
assert (
698+
len(res_events) > 0
699+
), 'Expected at least one response, but got an empty list.'
700+
assert len(mock_model.requests) == 1
701+
702+
# 1. Verify avatar_config was propagated to the live_connect_config.
703+
llm_request_sent_to_mock = mock_model.requests[0]
704+
assert llm_request_sent_to_mock.live_connect_config is not None
705+
assert llm_request_sent_to_mock.live_connect_config.avatar_config is not None
706+
assert (
707+
llm_request_sent_to_mock.live_connect_config.avatar_config.avatar_name
708+
== 'Kai'
709+
)
710+
711+
# 2. Verify video content flows through events.
712+
video_events = [
713+
e
714+
for e in res_events
715+
if e.content
716+
and e.content.parts
717+
and any(
718+
p.inline_data
719+
and p.inline_data.mime_type
720+
and p.inline_data.mime_type.startswith('video/')
721+
for p in e.content.parts
722+
)
723+
]
724+
assert video_events, 'Expected at least one event with video inline_data.'
725+
726+
video_event = video_events[0]
727+
assert video_event.content.role == 'model'
728+
video_part = video_event.content.parts[0]
729+
assert video_part.inline_data is not None
730+
assert video_part.inline_data.data == b'video_data'
731+
assert video_part.inline_data.mime_type == 'video/mp4'

0 commit comments

Comments
 (0)