Skip to content

Commit c5d809e

Browse files
gautamvarmadatlaxuanyang15
authored andcommitted
fix: populate required for Pydantic BaseModel parameters in FunctionTool
Merge #4778 Closes: #4777 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=#4778 from gautamvarmadatla:fix/pydantic-basemodel-required-fields 45db622 PiperOrigin-RevId: 888905953
1 parent 3153e6d commit c5d809e

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,17 @@ def _parse_schema_from_parameter(
399399
),
400400
func_name,
401401
)
402+
403+
required_fields = [
404+
field_name
405+
for field_name, field_info in param.annotation.model_fields.items()
406+
if field_info.is_required()
407+
]
408+
if required_fields:
409+
schema.required = required_fields
402410
_raise_if_schema_unsupported(variant, schema)
403411
return schema
412+
404413
if inspect.isclass(param.annotation) and issubclass(
405414
param.annotation, ToolContext
406415
):

tests/unittests/tools/test_build_function_declaration.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,65 @@ def simple_function(input: CustomInput) -> str:
124124
)
125125

126126

127+
def test_basemodel_required_fields():
128+
class SearchRequest(BaseModel):
129+
query: str
130+
max_results: int
131+
filter: str = ''
132+
133+
def search(request: SearchRequest) -> list:
134+
return []
135+
136+
function_decl = _automatic_function_calling_util.build_function_declaration(
137+
func=search
138+
)
139+
140+
inner = function_decl.parameters.properties['request']
141+
assert set(inner.required) == {'query', 'max_results'}
142+
assert 'filter' not in (inner.required or [])
143+
144+
145+
def test_basemodel_all_optional_fields_no_required():
146+
class Config(BaseModel):
147+
timeout: int = 30
148+
retries: int = 3
149+
150+
def run(config: Config) -> str:
151+
return ''
152+
153+
function_decl = _automatic_function_calling_util.build_function_declaration(
154+
func=run
155+
)
156+
157+
inner = function_decl.parameters.properties['config']
158+
assert not inner.required
159+
160+
161+
def test_nested_basemodel_required_fields():
162+
class Inner(BaseModel):
163+
x: int
164+
y: int = 0
165+
166+
class Outer(BaseModel):
167+
inner: Inner
168+
label: str = ''
169+
170+
def process(data: Outer) -> str:
171+
return ''
172+
173+
function_decl = _automatic_function_calling_util.build_function_declaration(
174+
func=process
175+
)
176+
177+
outer = function_decl.parameters.properties['data']
178+
assert set(outer.required) == {'inner'}
179+
assert 'label' not in (outer.required or [])
180+
181+
inner = outer.properties['inner']
182+
assert set(inner.required) == {'x'}
183+
assert 'y' not in (inner.required or [])
184+
185+
127186
def test_toolcontext_ignored():
128187
def simple_function(input_str: str, tool_context: ToolContext) -> str:
129188
return {'result': input_str}

0 commit comments

Comments
 (0)