diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 392e256b33..0dc5920d7c 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -18,6 +18,7 @@ import inspect from types import FunctionType import typing +from typing import Annotated from typing import Any from typing import Callable from typing import Dict @@ -31,6 +32,7 @@ from pydantic import BaseModel from pydantic import create_model from pydantic import fields as pydantic_fields +from pydantic.fields import FieldInfo from . import _function_parameter_parse_util from . import _function_tool_declarations @@ -39,62 +41,256 @@ from ..utils.variant_utils import GoogleLLMVariant _py_type_2_schema_type = { - 'str': types.Type.STRING, - 'int': types.Type.INTEGER, - 'float': types.Type.NUMBER, - 'bool': types.Type.BOOLEAN, - 'string': types.Type.STRING, - 'integer': types.Type.INTEGER, - 'number': types.Type.NUMBER, - 'boolean': types.Type.BOOLEAN, - 'list': types.Type.ARRAY, - 'array': types.Type.ARRAY, - 'tuple': types.Type.ARRAY, - 'object': types.Type.OBJECT, - 'Dict': types.Type.OBJECT, - 'List': types.Type.ARRAY, - 'Tuple': types.Type.ARRAY, - 'Any': types.Type.TYPE_UNSPECIFIED, + "str": types.Type.STRING, + "int": types.Type.INTEGER, + "float": types.Type.NUMBER, + "bool": types.Type.BOOLEAN, + "string": types.Type.STRING, + "integer": types.Type.INTEGER, + "number": types.Type.NUMBER, + "boolean": types.Type.BOOLEAN, + "list": types.Type.ARRAY, + "array": types.Type.ARRAY, + "tuple": types.Type.ARRAY, + "object": types.Type.OBJECT, + "Dict": types.Type.OBJECT, + "List": types.Type.ARRAY, + "Tuple": types.Type.ARRAY, + "Any": types.Type.TYPE_UNSPECIFIED, } +def _extract_field_info_from_annotated( + annotation: Any, +) -> Optional[FieldInfo]: + """Extract pydantic FieldInfo from Annotated[T, Field(...)] if present. + + Args: + annotation: The type annotation to inspect. + + Returns: + The FieldInfo instance if found in Annotated metadata, None otherwise. + """ + if get_origin(annotation) is Annotated: + for metadata in get_args(annotation)[1:]: + if isinstance(metadata, FieldInfo): + return metadata + return None + + +def _extract_base_type_from_annotated(annotation: Any) -> Any: + """Extract the base type from Annotated[T, ...]. + + Args: + annotation: The type annotation to unwrap. + + Returns: + The base type T if annotation is Annotated[T, ...], otherwise the original + annotation. + """ + if get_origin(annotation) is Annotated: + return get_args(annotation)[0] + return annotation + + +def _resolve_pydantic_refs(schema: Dict[str, Any]) -> Dict[str, Any]: + """Resolve $ref pointers in Pydantic JSON schema and inline nested objects. + + Pydantic generates JSON schemas with $ref pointers to $defs for nested + BaseModel classes. This function resolves these references and inlines + nested properties so that Field descriptions from nested models are + directly accessible in the schema sent to the LLM. + + This is similar to the reference resolution in openapi_spec_parser.py but + optimized for Pydantic v2 schema structure (handles allOf wrappers). + + Args: + schema: Pydantic model_json_schema() output with $defs. + + Returns: + Schema with all $ref resolved and nested properties inlined. The $defs + section is removed as all definitions are now inlined. + + Example: + Input: + { + "properties": { + "user": {"allOf": [{"$ref": "#/$defs/Person"}], "description": "User"} + }, + "$defs": { + "Person": {"properties": {"name": {"description": "Name"}}} + } + } + + Output: + { + "properties": { + "user": { + "type": "object", + "description": "User", + "properties": {"name": {"description": "Name"}} + } + } + } + """ + import copy + + schema = copy.deepcopy(schema) + defs = schema.get("$defs", {}) + + def resolve_ref(ref_string: str) -> Optional[Dict]: + """Resolve a $ref string like '#/$defs/Person'.""" + if not ref_string.startswith("#/$defs/"): + return None + def_name = ref_string.split("/")[-1] + return defs.get(def_name) + + def resolve_property( + prop_schema: Dict, seen_refs: Optional[set] = None + ) -> Dict: + """Recursively resolve $ref in a property schema. + + Args: + prop_schema: A property schema that may contain $ref or allOf with $ref. + seen_refs: Set of already-visited $ref strings to prevent circular refs. + + Returns: + Property schema with all $ref resolved and nested properties inlined. + """ + if seen_refs is None: + seen_refs = set() + + prop_schema = copy.deepcopy(prop_schema) + + # Handle allOf wrapper (Pydantic v2 pattern: {"allOf": [{"$ref": "..."}]}) + if "allOf" in prop_schema and len(prop_schema["allOf"]) == 1: + ref_item = prop_schema["allOf"][0] + if "$ref" in ref_item: + ref_string = ref_item["$ref"] + + # Prevent circular references + if ref_string in seen_refs: + # Return schema without allOf to break the cycle + return {k: v for k, v in prop_schema.items() if k != "allOf"} + + seen_refs_copy = seen_refs.copy() + seen_refs_copy.add(ref_string) + + resolved = resolve_ref(ref_string) + if resolved: + resolved = copy.deepcopy(resolved) + + # Preserve parameter-level description (takes precedence over model docstring) + param_description = prop_schema.get("description") + + # Recursively resolve nested properties within the resolved definition + if "properties" in resolved: + for nested_name, nested_schema in resolved["properties"].items(): + resolved["properties"][nested_name] = resolve_property( + nested_schema, seen_refs_copy + ) + + # If there was a parameter-level description, keep it + # (e.g., "User info" instead of model's docstring "Person model") + if param_description: + resolved["description"] = param_description + + return resolved + + # Handle direct $ref (less common in Pydantic v2, but supported for completeness) + elif "$ref" in prop_schema: + ref_string = prop_schema["$ref"] + if ref_string not in seen_refs: + seen_refs_copy = seen_refs.copy() + seen_refs_copy.add(ref_string) + resolved = resolve_ref(ref_string) + if resolved: + return resolve_property(copy.deepcopy(resolved), seen_refs_copy) + + # Recursively resolve nested properties (for already-inlined objects) + if "properties" in prop_schema: + for nested_name in list(prop_schema["properties"].keys()): + prop_schema["properties"][nested_name] = resolve_property( + prop_schema["properties"][nested_name], seen_refs + ) + + # Handle arrays with items that might have refs + if "items" in prop_schema: + prop_schema["items"] = resolve_property(prop_schema["items"], seen_refs) + + return prop_schema + + # Resolve all top-level properties + if "properties" in schema: + for prop_name in list(schema["properties"].keys()): + schema["properties"][prop_name] = resolve_property( + schema["properties"][prop_name] + ) + + # Clean up $defs since all definitions are now inlined + schema.pop("$defs", None) + + return schema + + def _get_fields_dict(func: Callable) -> Dict: + """Build a dictionary of field definitions for Pydantic model creation. + + This function extracts parameter information from a callable and creates + field definitions compatible with Pydantic's create_model. It supports + parameter descriptions via Annotated[T, Field(description=...)] syntax. + + Args: + func: The callable to extract parameters from. + + Returns: + A dictionary mapping parameter names to (type, FieldInfo) tuples. + """ param_signature = dict(inspect.signature(func).parameters) - fields_dict = { - name: ( - # 1. We infer the argument type here: use Any rather than None so - # it will not try to auto-infer the type based on the default value. - ( - param.annotation - if param.annotation != inspect.Parameter.empty - else Any - ), - pydantic.Field( - # 2. We do not support default values for now. - default=( - param.default - if param.default != inspect.Parameter.empty - # ! Need to use Undefined instead of None - else pydantic_fields.PydanticUndefined - ), - # 3. Do not support parameter description for now. - description=None, - ), - ) - for name, param in param_signature.items() - # We do not support *args or **kwargs - if param.kind - in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_ONLY, - ) - } + fields_dict = {} + + for name, param in param_signature.items(): + # We do not support *args or **kwargs + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ): + continue + + annotation = ( + param.annotation if param.annotation != inspect.Parameter.empty else Any + ) + + # Extract FieldInfo from Annotated[T, Field(...)] if present + field_info = _extract_field_info_from_annotated(annotation) + + # Extract the base type from Annotated[T, ...] for the model field + base_type = _extract_base_type_from_annotated(annotation) + + # Determine the default value + default = ( + param.default + if param.default != inspect.Parameter.empty + else pydantic_fields.PydanticUndefined + ) + + # Get description from FieldInfo if available + description = field_info.description if field_info else None + + fields_dict[name] = ( + base_type, + pydantic.Field( + default=default, + description=description, + ), + ) + return fields_dict def _annotate_nullable_fields(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): + for _, property_schema in schema.get("properties", {}).items(): # for Optional[T], the pydantic schema is: # { # "type": "object", @@ -109,45 +305,45 @@ def _annotate_nullable_fields(schema: Dict): # ] # } # } - for type_ in property_schema.get('anyOf', []): - if type_.get('type') == 'null': - property_schema['nullable'] = True - property_schema['anyOf'].remove(type_) + for type_ in property_schema.get("anyOf", []): + if type_.get("type") == "null": + property_schema["nullable"] = True + property_schema["anyOf"].remove(type_) break def _annotate_required_fields(schema: Dict): required = [ field_name - for field_name, field_schema in schema.get('properties', {}).items() - if not field_schema.get('nullable') and 'default' not in field_schema + for field_name, field_schema in schema.get("properties", {}).items() + if not field_schema.get("nullable") and "default" not in field_schema ] - schema['required'] = required + schema["required"] = required def _remove_any_of(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): - union_types = property_schema.pop('anyOf', None) + for _, property_schema in schema.get("properties", {}).items(): + union_types = property_schema.pop("anyOf", None) # Take the first non-null type. if union_types: for type_ in union_types: - if type_.get('type') != 'null': + if type_.get("type") != "null": property_schema.update(type_) def _remove_default(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): - property_schema.pop('default', None) + for _, property_schema in schema.get("properties", {}).items(): + property_schema.pop("default", None) def _remove_nullable(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): - property_schema.pop('nullable', None) + for _, property_schema in schema.get("properties", {}).items(): + property_schema.pop("nullable", None) def _remove_title(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): - property_schema.pop('title', None) + for _, property_schema in schema.get("properties", {}).items(): + property_schema.pop("title", None) def _get_pydantic_schema(func: Callable) -> Dict: @@ -155,10 +351,18 @@ def _get_pydantic_schema(func: Callable) -> Dict: fields_dict = _get_fields_dict(func) # Remove context parameter (detected by type or fallback to 'tool_context' name) - context_param = find_context_parameter(func) or 'tool_context' + context_param = find_context_parameter(func) or "tool_context" if context_param in fields_dict.keys(): fields_dict.pop(context_param) - return pydantic.create_model(func.__name__, **fields_dict).model_json_schema() + + schema = pydantic.create_model( + func.__name__, **fields_dict + ).model_json_schema() + + # Resolve $ref for nested Pydantic models to inline Field descriptions + schema = _resolve_pydantic_refs(schema) + + return schema def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict: @@ -173,24 +377,24 @@ def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict: def _map_pydantic_type_to_property_schema(property_schema: Dict): - if 'type' in property_schema: - property_schema['type'] = _py_type_2_schema_type.get( - property_schema['type'], 'TYPE_UNSPECIFIED' + if "type" in property_schema: + property_schema["type"] = _py_type_2_schema_type.get( + property_schema["type"], "TYPE_UNSPECIFIED" ) - if property_schema['type'] == 'ARRAY': - _map_pydantic_type_to_property_schema(property_schema['items']) - for type_ in property_schema.get('anyOf', []): - if 'type' in type_: - type_['type'] = _py_type_2_schema_type.get( - type_['type'], 'TYPE_UNSPECIFIED' + if property_schema["type"] == "ARRAY": + _map_pydantic_type_to_property_schema(property_schema["items"]) + for type_ in property_schema.get("anyOf", []): + if "type" in type_: + type_["type"] = _py_type_2_schema_type.get( + type_["type"], "TYPE_UNSPECIFIED" ) # TODO: To investigate. Unclear why a Type is needed with 'anyOf' to # avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT. - property_schema['type'] = type_['type'] + property_schema["type"] = type_["type"] def _map_pydantic_type_to_schema_type(schema: Dict): - for _, property_schema in schema.get('properties', {}).items(): + for _, property_schema in schema.get("properties", {}).items(): _map_pydantic_type_to_property_schema(property_schema) @@ -266,13 +470,13 @@ def build_function_declaration_for_langchain( vertexai: bool, name, description, func, param_pydantic_schema ) -> types.FunctionDeclaration: param_pydantic_schema = _process_pydantic_schema( - vertexai, {'properties': param_pydantic_schema} - )['properties'] + vertexai, {"properties": param_pydantic_schema} + )["properties"] param_copy = param_pydantic_schema.copy() - required_fields = param_copy.pop('required', []) + required_fields = param_copy.pop("required", []) before_param_pydantic_schema = { - 'properties': param_copy, - 'required': required_fields, + "properties": param_copy, + "required": required_fields, } return build_function_declaration_util( vertexai, name, description, func, before_param_pydantic_schema @@ -295,10 +499,10 @@ def build_function_declaration_util( vertexai: bool, name, description, func, before_param_pydantic_schema ) -> types.FunctionDeclaration: _map_pydantic_type_to_schema_type(before_param_pydantic_schema) - properties = before_param_pydantic_schema.get('properties', {}) + properties = before_param_pydantic_schema.get("properties", {}) function_declaration = types.FunctionDeclaration( parameters=types.Schema( - type='OBJECT', + type="OBJECT", properties=properties, ) if properties @@ -317,7 +521,7 @@ def build_function_declaration_util( def from_function_with_options( func: Callable, variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API, -) -> 'types.FunctionDeclaration': +) -> "types.FunctionDeclaration": parameters_properties = {} parameters_json_schema = {} @@ -379,7 +583,7 @@ def from_function_with_options( ) if parameters_properties: declaration.parameters = types.Schema( - type='OBJECT', + type="OBJECT", properties=parameters_properties, ) declaration.parameters.required = ( @@ -389,7 +593,7 @@ def from_function_with_options( ) elif parameters_json_schema: declaration.parameters = types.Schema( - type='OBJECT', + type="OBJECT", properties=parameters_json_schema, ) @@ -416,7 +620,7 @@ def from_function_with_options( if return_annotation is inspect._empty: # Functions with no return annotation can return any type return_value = inspect.Parameter( - 'return_value', + "return_value", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=typing.Any, ) @@ -433,11 +637,11 @@ def from_function_with_options( if ( return_annotation is None or return_annotation is type(None) - or (isinstance(return_annotation, str) and return_annotation == 'None') + or (isinstance(return_annotation, str) and return_annotation == "None") ): # Create a response schema for None/null return return_value = inspect.Parameter( - 'return_value', + "return_value", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=None, ) @@ -451,13 +655,13 @@ def from_function_with_options( return declaration return_value = inspect.Parameter( - 'return_value', + "return_value", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=return_annotation, ) if isinstance(return_value.annotation, str): return_value = return_value.replace( - annotation=typing.get_type_hints(func)['return'] + annotation=typing.get_type_hints(func)["return"] ) response_schema: Optional[types.Schema] = None diff --git a/tests/unittests/tools/test_annotated_parameter_descriptions.py b/tests/unittests/tools/test_annotated_parameter_descriptions.py new file mode 100644 index 0000000000..63ec849df3 --- /dev/null +++ b/tests/unittests/tools/test_annotated_parameter_descriptions.py @@ -0,0 +1,644 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for per-parameter descriptions via Annotated[T, Field(description=...)].""" + +from typing import Annotated +from typing import Optional + +from google.adk.tools._automatic_function_calling_util import _extract_base_type_from_annotated +from google.adk.tools._automatic_function_calling_util import _extract_field_info_from_annotated +from google.adk.tools._automatic_function_calling_util import _get_fields_dict +from google.adk.tools._automatic_function_calling_util import build_function_declaration +from google.adk.utils.variant_utils import GoogleLLMVariant +from pydantic import Field +import pytest + + +class TestExtractFieldInfoFromAnnotated: + """Tests for _extract_field_info_from_annotated helper function.""" + + def test_extract_field_info_with_description(self): + """Test extracting FieldInfo with description from Annotated type.""" + annotation = Annotated[str, Field(description="A test description")] + field_info = _extract_field_info_from_annotated(annotation) + + assert field_info is not None + assert field_info.description == "A test description" + + def test_extract_field_info_without_description(self): + """Test extracting FieldInfo without description from Annotated type.""" + annotation = Annotated[str, Field()] + field_info = _extract_field_info_from_annotated(annotation) + + assert field_info is not None + assert field_info.description is None + + def test_extract_field_info_not_annotated(self): + """Test that non-Annotated types return None.""" + field_info = _extract_field_info_from_annotated(str) + assert field_info is None + + def test_extract_field_info_annotated_without_field(self): + """Test that Annotated without Field returns None.""" + annotation = Annotated[str, "some_metadata"] + field_info = _extract_field_info_from_annotated(annotation) + assert field_info is None + + def test_extract_field_info_with_multiple_metadata(self): + """Test that Field is found even with multiple metadata items.""" + annotation = Annotated[ + str, + "some_string_metadata", + Field(description="Found it!"), + 42, + ] + field_info = _extract_field_info_from_annotated(annotation) + + assert field_info is not None + assert field_info.description == "Found it!" + + +class TestExtractBaseTypeFromAnnotated: + """Tests for _extract_base_type_from_annotated helper function.""" + + def test_extract_base_type_from_annotated(self): + """Test extracting base type from Annotated.""" + annotation = Annotated[str, Field(description="test")] + base_type = _extract_base_type_from_annotated(annotation) + assert base_type is str + + def test_extract_base_type_from_non_annotated(self): + """Test that non-Annotated types are returned as-is.""" + base_type = _extract_base_type_from_annotated(int) + assert base_type is int + + def test_extract_base_type_complex_annotated(self): + """Test extracting complex base types from Annotated.""" + from typing import List + + annotation = Annotated[List[str], Field(description="A list of strings")] + base_type = _extract_base_type_from_annotated(annotation) + assert base_type == List[str] + + +class TestGetFieldsDict: + """Tests for _get_fields_dict with Annotated parameter descriptions.""" + + def test_get_fields_dict_with_annotated_description(self): + """Test that _get_fields_dict extracts descriptions from Annotated.""" + + def sample_func( + repo: Annotated[ + str, + Field(description="Repository URL from get_repository_info"), + ], + branch: Annotated[ + str, + Field(description="Base branch for development"), + ], + ) -> dict: + return {} + + fields = _get_fields_dict(sample_func) + + assert "repo" in fields + assert "branch" in fields + + # Check that descriptions are extracted + repo_type, repo_field = fields["repo"] + branch_type, branch_field = fields["branch"] + + assert repo_type is str + assert repo_field.description == "Repository URL from get_repository_info" + assert branch_type is str + assert branch_field.description == "Base branch for development" + + def test_get_fields_dict_mixed_annotations(self): + """Test _get_fields_dict with mix of Annotated and regular params.""" + + def sample_func( + annotated_param: Annotated[ + str, + Field(description="This has a description"), + ], + regular_param: str, + ) -> None: + pass + + fields = _get_fields_dict(sample_func) + + annotated_type, annotated_field = fields["annotated_param"] + regular_type, regular_field = fields["regular_param"] + + assert annotated_type is str + assert annotated_field.description == "This has a description" + assert regular_type is str + assert regular_field.description is None + + def test_get_fields_dict_with_default_values(self): + """Test that default values are preserved with Annotated types.""" + + def sample_func( + required_param: Annotated[ + str, + Field(description="Required parameter"), + ], + optional_param: Annotated[ + str, + Field(description="Optional parameter"), + ] = "default_value", + ) -> None: + pass + + fields = _get_fields_dict(sample_func) + + _, required_field = fields["required_param"] + _, optional_field = fields["optional_param"] + + assert required_field.description == "Required parameter" + assert optional_field.description == "Optional parameter" + assert optional_field.default == "default_value" + + def test_get_fields_dict_with_optional_annotated(self): + """Test Annotated with Optional type.""" + + def sample_func( + optional_param: Annotated[ + Optional[str], + Field(description="Optional string parameter"), + ] = None, + ) -> None: + pass + + fields = _get_fields_dict(sample_func) + + param_type, param_field = fields["optional_param"] + assert param_field.description == "Optional string parameter" + assert param_field.default is None + + +class TestBuildFunctionDeclaration: + """Tests for build_function_declaration with Annotated descriptions.""" + + def test_build_declaration_with_annotated_params(self): + """Test that build_function_declaration includes parameter descriptions.""" + + def create_task( + repository: Annotated[ + str, + Field( + description=( + "Full GitLab repository URL. " + "MUST be obtained from get_repository_info." + ) + ), + ], + base_branch: Annotated[ + str, + Field( + description=( + "Base branch for development (e.g. 'main', 'develop'). " + "MUST be obtained from get_repository_info." + ) + ), + ], + ) -> dict: + """Create a new task in the repository.""" + return {} + + declaration = build_function_declaration( + create_task, + variant=GoogleLLMVariant.GEMINI_API, + ) + + assert declaration.name == "create_task" + assert declaration.description == "Create a new task in the repository." + assert declaration.parameters is not None + assert declaration.parameters.properties is not None + + # Check that descriptions are in the schema + repo_schema = declaration.parameters.properties.get("repository") + branch_schema = declaration.parameters.properties.get("base_branch") + + assert repo_schema is not None + assert branch_schema is not None + + # The descriptions should be present in the schema + assert repo_schema.description is not None + assert "GitLab repository URL" in repo_schema.description + assert branch_schema.description is not None + assert "Base branch for development" in branch_schema.description + + def test_build_declaration_without_annotated_params(self): + """Test build_function_declaration without Annotated still works.""" + + def simple_func(name: str, count: int) -> str: + """A simple function.""" + return name * count + + declaration = build_function_declaration( + simple_func, + variant=GoogleLLMVariant.GEMINI_API, + ) + + assert declaration.name == "simple_func" + assert declaration.parameters is not None + assert declaration.parameters.properties is not None + assert "name" in declaration.parameters.properties + assert "count" in declaration.parameters.properties + + +class TestIntegrationWithFunctionTool: + """Integration tests for FunctionTool with Annotated descriptions.""" + + def test_function_tool_with_annotated_params(self): + """Test that FunctionTool works with Annotated parameter descriptions.""" + from google.adk.tools.function_tool import FunctionTool + + def search_repos( + query: Annotated[ + str, + Field(description="Search query for repositories"), + ], + limit: Annotated[ + int, + Field(description="Maximum number of results to return"), + ] = 10, + ) -> list: + """Search for repositories matching the query.""" + return [] + + tool = FunctionTool(search_repos) + + assert tool.name == "search_repos" + assert tool.description == "Search for repositories matching the query." + + # Get the function declaration (internal method) + declaration = tool._get_declaration() + assert declaration is not None + assert declaration.parameters is not None + + @pytest.mark.asyncio + async def test_function_tool_execution_with_annotated_params(self): + """Test that FunctionTool executes correctly with Annotated params.""" + from unittest.mock import MagicMock + + from google.adk.agents.invocation_context import InvocationContext + from google.adk.sessions.session import Session + from google.adk.tools.function_tool import FunctionTool + from google.adk.tools.tool_context import ToolContext + + def greet( + name: Annotated[ + str, + Field(description="Name of the person to greet"), + ], + greeting: Annotated[ + str, + Field(description="Greeting to use"), + ] = "Hello", + ) -> str: + """Greet a person.""" + return f"{greeting}, {name}!" + + tool = FunctionTool(greet) + + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + tool_context = ToolContext(invocation_context=mock_invocation_context) + + result = await tool.run_async( + args={"name": "World"}, + tool_context=tool_context, + ) + + assert result == "Hello, World!" + + result_custom = await tool.run_async( + args={"name": "Alice", "greeting": "Hi"}, + tool_context=tool_context, + ) + + assert result_custom == "Hi, Alice!" + + +class TestNestedPydanticModels: + """Tests for nested Pydantic BaseModel support with Field descriptions.""" + + def test_single_level_nested_model(self): + """Test that nested Pydantic model Field descriptions are inlined.""" + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class Address(BaseModel): + """User's address information.""" + + street: str = Field(description="Street name and number") + city: str = Field(description="City name") + zipcode: str = Field(description="Postal code (5 digits)") + + def create_user( + address: Annotated[ + Address, Field(description="User's residential address") + ], + ) -> dict: + """Create a new user with address.""" + return {} + + schema = _get_pydantic_schema(create_user) + + # Check that address parameter exists + assert "properties" in schema + assert "address" in schema["properties"] + + address_schema = schema["properties"]["address"] + + # Check that the parameter-level description is preserved + assert address_schema.get("description") == "User's residential address" + + # Check that nested properties are inlined (not using $ref) + assert ( + "properties" in address_schema + ), "Nested properties should be inlined, not using $ref" + assert "$ref" not in address_schema, "Should not have $ref after resolution" + assert ( + "allOf" not in address_schema + ), "Should not have allOf after resolution" + + # Check that nested Field descriptions are present + nested_props = address_schema["properties"] + assert "street" in nested_props + assert "city" in nested_props + assert "zipcode" in nested_props + + assert nested_props["street"].get("description") == "Street name and number" + assert nested_props["city"].get("description") == "City name" + assert ( + nested_props["zipcode"].get("description") == "Postal code (5 digits)" + ) + + # Verify $defs is removed after inlining + assert "$defs" not in schema, "$defs should be removed after inlining" + + def test_multi_level_nested_model(self): + """Test that doubly-nested Pydantic models preserve all descriptions.""" + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class ContactInfo(BaseModel): + """Contact information.""" + + email: str = Field(description="Email address in format user@domain.com") + phone: str = Field(description="Phone number with country code") + + class Person(BaseModel): + """Person information.""" + + name: str = Field(description="Person's full name") + age: int = Field(description="Person's age in years") + contact: ContactInfo = Field(description="Contact information") + + def create_user( + person: Annotated[ + Person, Field(description="User personal information") + ], + ) -> dict: + """Create a new user.""" + return {} + + schema = _get_pydantic_schema(create_user) + + # Check first level (person parameter) + person_schema = schema["properties"]["person"] + assert person_schema.get("description") == "User personal information" + assert "properties" in person_schema + + # Check second level (name, age, contact) + person_props = person_schema["properties"] + assert person_props["name"].get("description") == "Person's full name" + assert person_props["age"].get("description") == "Person's age in years" + assert person_props["contact"].get("description") == "Contact information" + + # Check third level (email, phone within contact) + assert "properties" in person_props["contact"] + contact_props = person_props["contact"]["properties"] + assert ( + contact_props["email"].get("description") + == "Email address in format user@domain.com" + ) + assert ( + contact_props["phone"].get("description") + == "Phone number with country code" + ) + + def test_nested_model_with_list(self): + """Test that List of nested Pydantic models works correctly.""" + from typing import List + + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class Tag(BaseModel): + """A tag.""" + + name: str = Field(description="Tag name") + color: str = Field(description="Tag color in hex format") + + def create_item( + tags: Annotated[ + List[Tag], Field(description="List of tags for the item") + ], + ) -> dict: + """Create an item with tags.""" + return {} + + schema = _get_pydantic_schema(create_item) + + tags_schema = schema["properties"]["tags"] + assert tags_schema.get("description") == "List of tags for the item" + assert tags_schema.get("type") == "array" + assert "items" in tags_schema + + # Check that items schema has inlined properties + items_schema = tags_schema["items"] + assert "properties" in items_schema + assert items_schema["properties"]["name"].get("description") == "Tag name" + assert ( + items_schema["properties"]["color"].get("description") + == "Tag color in hex format" + ) + + def test_nested_model_with_optional(self): + """Test that Optional nested Pydantic models preserve descriptions.""" + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class Metadata(BaseModel): + """Metadata information.""" + + key: str = Field(description="Metadata key") + value: str = Field(description="Metadata value") + + def create_item( + metadata: Annotated[ + Optional[Metadata], Field(description="Optional metadata") + ] = None, + ) -> dict: + """Create an item with optional metadata.""" + return {} + + schema = _get_pydantic_schema(create_item) + + metadata_schema = schema["properties"]["metadata"] + + # Optional handling might use anyOf, but descriptions should still be there + # Check if properties are accessible (could be in anyOf structure) + if "properties" in metadata_schema: + # Direct properties + assert ( + metadata_schema["properties"]["key"].get("description") + == "Metadata key" + ) + elif "anyOf" in metadata_schema: + # Look for the object definition in anyOf + for variant in metadata_schema["anyOf"]: + if variant.get("type") == "object" and "properties" in variant: + assert ( + variant["properties"]["key"].get("description") == "Metadata key" + ) + break + + def test_mixed_nested_and_simple_params(self): + """Test function with both nested models and simple parameters.""" + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class Config(BaseModel): + """Configuration.""" + + timeout: int = Field(description="Timeout in seconds") + retries: int = Field(description="Number of retries") + + def execute_task( + task_name: Annotated[str, Field(description="Name of the task")], + config: Annotated[Config, Field(description="Task configuration")], + dry_run: Annotated[ + bool, Field(description="Run in dry-run mode") + ] = False, + ) -> dict: + """Execute a task with configuration.""" + return {} + + schema = _get_pydantic_schema(execute_task) + + # Check simple parameters + assert ( + schema["properties"]["task_name"].get("description") + == "Name of the task" + ) + assert ( + schema["properties"]["dry_run"].get("description") + == "Run in dry-run mode" + ) + + # Check nested model + config_schema = schema["properties"]["config"] + assert config_schema.get("description") == "Task configuration" + assert "properties" in config_schema + assert ( + config_schema["properties"]["timeout"].get("description") + == "Timeout in seconds" + ) + assert ( + config_schema["properties"]["retries"].get("description") + == "Number of retries" + ) + + def test_nested_model_circular_reference_handling(self): + """Test that circular references in nested models don't cause infinite loops.""" + from typing import List + + from google.adk.tools._automatic_function_calling_util import _get_pydantic_schema + from pydantic import BaseModel + + class TreeNode(BaseModel): + """A tree node.""" + + value: str = Field(description="Node value") + children: List["TreeNode"] = Field( + default_factory=list, description="Child nodes" + ) + + def create_tree( + root: Annotated[TreeNode, Field(description="Root node of the tree")], + ) -> dict: + """Create a tree structure.""" + return {} + + # This should not raise an error or hang + schema = _get_pydantic_schema(create_tree) + + # Verify schema was generated + assert "properties" in schema + assert "root" in schema["properties"] + + # The function should handle the circular reference gracefully + # (implementation may vary: could inline first level, use ref, or break cycle) + root_schema = schema["properties"]["root"] + assert root_schema.get("description") == "Root node of the tree" + + def test_function_declaration_with_nested_models(self): + """Test that build_function_declaration works with nested Pydantic models.""" + from pydantic import BaseModel + + class Credentials(BaseModel): + """API credentials.""" + + api_key: str = Field(description="API key for authentication") + secret: str = Field(description="API secret") + + def authenticate( + creds: Annotated[ + Credentials, Field(description="Authentication credentials") + ], + ) -> dict: + """Authenticate with API credentials.""" + return {} + + declaration = build_function_declaration( + authenticate, + variant=GoogleLLMVariant.GEMINI_API, + ) + + assert declaration.name == "authenticate" + assert declaration.parameters is not None + assert declaration.parameters.properties is not None + + creds_schema = declaration.parameters.properties.get("creds") + assert creds_schema is not None + assert creds_schema.description == "Authentication credentials" + + # Check that nested properties are accessible + assert creds_schema.properties is not None + assert "api_key" in creds_schema.properties + assert "secret" in creds_schema.properties + + # Verify nested descriptions are present + assert ( + creds_schema.properties["api_key"].description + == "API key for authentication" + ) + assert creds_schema.properties["secret"].description == "API secret"