diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 89d831a884..dda76e9c24 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -16,6 +16,7 @@ import asyncio from typing import Any +from typing import Callable from typing import cast from typing import Optional @@ -51,6 +52,16 @@ class LlmCallsLimitExceededError(Exception): """Error thrown when the number of LLM calls exceed the limit.""" +ToolProgressHandler = Callable[[str, str | None, Any], Any] +"""Callback for UI-only tool progress updates. + +Args: + tool_name: The name of the tool reporting progress. + function_call_id: The function call id if available. + data: The tool-defined progress payload. +""" + + class RealtimeCacheEntry(BaseModel): """Store audio data chunks for caching before flushing.""" @@ -207,6 +218,11 @@ class InvocationContext(BaseModel): live_request_queue: Optional[LiveRequestQueue] = None """The queue to receive live requests.""" + tool_progress_handler: Optional[ToolProgressHandler] = Field( + default=None, exclude=True + ) + """Runtime callback for tool progress updates that should not reach the LLM.""" + active_streaming_tools: Optional[dict[str, ActiveStreamingTool]] = None """The running streaming tools of this invocation.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 5b2f23fec7..f2ad52b8f1 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -243,6 +243,9 @@ def __init__( self.session_service = session_service self.memory_service = memory_service self.credential_service = credential_service + self.on_tool_progress: Optional[Callable[[str, str | None, Any], Any]] = ( + None + ) self.plugin_manager = PluginManager( plugins=app.plugins, close_timeout=plugin_close_timeout ) @@ -2068,6 +2071,7 @@ def _new_invocation_context( session=session, user_content=new_message, live_request_queue=live_request_queue, + tool_progress_handler=self.on_tool_progress, run_config=run_config, resumability_config=self.resumability_config, ) diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 21aa6bfd36..25fb7f34ef 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -86,7 +86,11 @@ def __init__( self.func = func # Detect context parameter by type annotation, fallback to 'tool_context' name self._context_param_name = find_context_parameter(func) or 'tool_context' - self._ignore_params = [self._context_param_name, 'input_stream'] + self._ignore_params = [ + self._context_param_name, + 'input_stream', + 'progress_callback', + ] self._require_confirmation = require_confirmation @override @@ -221,6 +225,10 @@ async def run_async( valid_params = {param for param in signature.parameters} if self._context_param_name in valid_params: args_to_call[self._context_param_name] = tool_context + if 'progress_callback' in valid_params: + args_to_call['progress_callback'] = self._make_progress_callback( + tool_context + ) # Filter args_to_call to only include valid parameters for the function args_to_call = {k: v for k, v in args_to_call.items() if k in valid_params} @@ -297,6 +305,19 @@ async def _invoke_callable( else: return target(**args_to_call) + def _make_progress_callback(self, tool_context: ToolContext) -> Callable: + """Returns a tool-bound progress callback for UI-only status updates.""" + + async def progress_callback(data: Any) -> None: + handler = tool_context._invocation_context.tool_progress_handler + if handler is None: + return + result = handler(self.name, tool_context.function_call_id, data) + if inspect.isawaitable(result): + await result + + return progress_callback + # TODO(hangfei): fix call live for function stream. async def _call_live( self, @@ -319,6 +340,10 @@ async def _call_live( ].stream if self._context_param_name in signature.parameters: args_to_call[self._context_param_name] = tool_context + if 'progress_callback' in signature.parameters: + args_to_call['progress_callback'] = self._make_progress_callback( + tool_context + ) # TODO: support tool confirmation for live mode. async with Aclosing(self.func(**args_to_call)) as agen: diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 2acb254833..601a802b29 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -455,7 +455,7 @@ def my_tool(query: str, ctx: Context) -> str: tool = FunctionTool(my_tool) assert tool._context_param_name == "ctx" - assert tool._ignore_params == ["ctx", "input_stream"] + assert tool._ignore_params == ["ctx", "input_stream", "progress_callback"] def test_context_param_detection_with_tool_context_type(): @@ -466,7 +466,11 @@ def my_tool(query: str, tool_context: ToolContext) -> str: tool = FunctionTool(my_tool) assert tool._context_param_name == "tool_context" - assert tool._ignore_params == ["tool_context", "input_stream"] + assert tool._ignore_params == [ + "tool_context", + "input_stream", + "progress_callback", + ] def test_context_param_detection_with_custom_name(): @@ -477,7 +481,11 @@ def my_tool(query: str, my_custom_context: Context) -> str: tool = FunctionTool(my_tool) assert tool._context_param_name == "my_custom_context" - assert tool._ignore_params == ["my_custom_context", "input_stream"] + assert tool._ignore_params == [ + "my_custom_context", + "input_stream", + "progress_callback", + ] def test_context_param_detection_fallback_to_name(): @@ -488,7 +496,11 @@ def my_tool(query: str, tool_context) -> str: tool = FunctionTool(my_tool) assert tool._context_param_name == "tool_context" - assert tool._ignore_params == ["tool_context", "input_stream"] + assert tool._ignore_params == [ + "tool_context", + "input_stream", + "progress_callback", + ] def test_context_param_detection_no_context(): @@ -499,7 +511,11 @@ def my_tool(query: str, count: int) -> str: tool = FunctionTool(my_tool) assert tool._context_param_name == "tool_context" - assert tool._ignore_params == ["tool_context", "input_stream"] + assert tool._ignore_params == [ + "tool_context", + "input_stream", + "progress_callback", + ] @pytest.mark.asyncio @@ -518,6 +534,96 @@ def my_tool(query: str, ctx: Context) -> dict: assert result == {"query": "test", "has_context": True} +@pytest.mark.asyncio +async def test_run_async_injects_progress_callback(mock_tool_context): + """Test that run_async injects a UI-only progress callback when declared.""" + progress_events = [] + + async def progress_handler(tool_name, function_call_id, data): + progress_events.append((tool_name, function_call_id, data)) + + async def my_tool(query: str, progress_callback) -> dict: + await progress_callback({"step": 1, "message": "working"}) + return {"query": query} + + mock_tool_context.function_call_id = "call-123" + mock_tool_context._invocation_context.tool_progress_handler = progress_handler + + tool = FunctionTool(my_tool) + result = await tool.run_async( + args={"query": "test"}, + tool_context=mock_tool_context, + ) + + assert result == {"query": "test"} + assert progress_events == [ + ("my_tool", "call-123", {"step": 1, "message": "working"}) + ] + + +@pytest.mark.asyncio +async def test_run_async_progress_callback_no_handler_is_noop( + mock_tool_context, +): + """Test that an injected progress callback is a no-op without a handler.""" + + async def my_tool(progress_callback) -> dict: + await progress_callback({"step": 1}) + return {"ok": True} + + mock_tool_context._invocation_context.tool_progress_handler = None + + tool = FunctionTool(my_tool) + result = await tool.run_async(args={}, tool_context=mock_tool_context) + + assert result == {"ok": True} + + +def test_progress_callback_is_hidden_from_declaration(): + """Test that progress_callback is not exposed in the model-facing schema.""" + + def my_tool(query: str, progress_callback) -> str: + """Search with UI-only progress.""" + return query + + declaration = FunctionTool(my_tool)._get_declaration() + + assert declaration.parameters_json_schema is not None + properties = declaration.parameters_json_schema["properties"] + assert "query" in properties + assert "progress_callback" not in properties + + +@pytest.mark.asyncio +async def test_call_live_injects_progress_callback(mock_tool_context): + """Test that live streaming tools receive the progress callback.""" + progress_events = [] + + def progress_handler(tool_name, function_call_id, data): + progress_events.append((tool_name, function_call_id, data)) + + async def my_tool(progress_callback): + await progress_callback({"step": "start"}) + yield {"status": "done"} + + mock_tool_context.function_call_id = "live-call-123" + mock_tool_context._invocation_context.tool_progress_handler = progress_handler + mock_tool_context._invocation_context.active_streaming_tools = {} + + tool = FunctionTool(my_tool) + results = [ + item + async for item in tool._call_live( + args={}, + tool_context=mock_tool_context, + invocation_context=mock_tool_context._invocation_context, + ) + ] + + assert results == [{"status": "done"}] + assert progress_events == [("my_tool", "live-call-123", {"step": "start"})] + + @pytest.mark.asyncio async def test_run_async_with_context_type_annotation(mock_tool_context): """Test that run_async works with Context type annotation."""