Skip to content

Commit 9b9faa4

Browse files
giulio-leonexuanyang15
authored andcommitted
fix: populate required fields in FunctionDeclaration json_schema fallback
Merge #5000 Fixes #4798 — `required` fields lost in `FunctionDeclaration` when the `parameters_json_schema` fallback path is used. Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#5000 from giulio-leone:fix/required-fields-json-schema-fallback e9783d7 PiperOrigin-RevId: 899243799
1 parent 80a7ecf commit 9b9faa4

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ def from_function_with_options(
368368
parameters_json_schema[name] = types.Schema.model_validate(
369369
json_schema_dict
370370
)
371+
if param.default is not inspect.Parameter.empty:
372+
if param.default is not None:
373+
parameters_json_schema[name].default = param.default
374+
else:
375+
parameters_json_schema[name].nullable = True
371376
except Exception as e:
372377
_function_parameter_parse_util._raise_for_unsupported_param(
373378
param, func.__name__, e
@@ -392,6 +397,11 @@ def from_function_with_options(
392397
type='OBJECT',
393398
properties=parameters_json_schema,
394399
)
400+
declaration.parameters.required = (
401+
_function_parameter_parse_util._get_required_fields(
402+
declaration.parameters
403+
)
404+
)
395405

396406
if variant == GoogleLLMVariant.GEMINI_API:
397407
return declaration

tests/unittests/tools/test_from_function_with_options.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,45 @@ async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]:
319319
# VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator
320320
assert declaration.response is not None
321321
assert declaration.response.type == types.Type.OBJECT
322+
323+
324+
def test_required_fields_set_in_json_schema_fallback():
325+
"""Test that required fields are populated when the json_schema fallback path is used.
326+
327+
When a parameter has a complex union type (e.g. list[str] | None) that
328+
_parse_schema_from_parameter can't handle, from_function_with_options falls
329+
back to the parameters_json_schema branch. This test verifies that the
330+
required fields are correctly populated in that fallback branch.
331+
"""
332+
333+
def complex_tool(
334+
query: str,
335+
mode: str = 'default',
336+
tags: list[str] | None = None,
337+
) -> str:
338+
"""A tool where one param has a complex union type."""
339+
return query
340+
341+
declaration = _automatic_function_calling_util.from_function_with_options(
342+
complex_tool, GoogleLLMVariant.GEMINI_API
343+
)
344+
345+
assert declaration.name == 'complex_tool'
346+
assert declaration.parameters == types.Schema(
347+
type=types.Type.OBJECT,
348+
required=['query'],
349+
properties={
350+
'query': types.Schema(type=types.Type.STRING),
351+
'mode': types.Schema(type=types.Type.STRING, default='default'),
352+
'tags': types.Schema(
353+
any_of=[
354+
types.Schema(
355+
items=types.Schema(type=types.Type.STRING),
356+
type=types.Type.ARRAY,
357+
),
358+
types.Schema(type=types.Type.NULL),
359+
],
360+
nullable=True,
361+
),
362+
},
363+
)

0 commit comments

Comments
 (0)