diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 53c3b309d8..0fe06e7363 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -80,6 +80,8 @@ def __init__( """ self.tool_filter = tool_filter self.tool_name_prefix = tool_name_prefix + self._cached_invocation_id: Optional[str] = None + self._cached_prefixed_tools: Optional[list[BaseTool]] = None @abstractmethod async def get_tools( @@ -112,9 +114,19 @@ async def get_tools_with_prefix( Returns: list[BaseTool]: A list of tools with prefixed names if tool_name_prefix is provided. """ + invocation_id = readonly_context.invocation_id if readonly_context else None + + if ( + self._cached_prefixed_tools is not None + and self._cached_invocation_id == invocation_id + ): + return self._cached_prefixed_tools + tools = await self.get_tools(readonly_context) if not self.tool_name_prefix: + self._cached_invocation_id = invocation_id + self._cached_prefixed_tools = tools return tools prefix = self.tool_name_prefix @@ -147,6 +159,8 @@ def _get_prefixed_declaration(): tool_copy._get_declaration = _create_prefixed_declaration() prefixed_tools.append(tool_copy) + self._cached_invocation_id = invocation_id + self._cached_prefixed_tools = prefixed_tools return prefixed_tools async def close(self) -> None: diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py index 7c4ef3cfdd..d41dce63d3 100644 --- a/tests/unittests/tools/test_base_toolset.py +++ b/tests/unittests/tools/test_base_toolset.py @@ -383,6 +383,58 @@ async def test_no_duplicate_prefixing(): original_tools = await toolset.get_tools() assert original_tools[0].name == 'original' - # The prefixed tools should be different instances - assert prefixed_tools_1[0] is not prefixed_tools_2[0] + # The prefixed tools should be the same instance when cached + assert prefixed_tools_1[0] is prefixed_tools_2[0] assert prefixed_tools_1[0] is not original_tools[0] + + +@pytest.mark.asyncio +async def test_get_tools_with_prefix_caching(): + """Test that get_tools_with_prefix caches results within the same invocation.""" + tool1 = _TestingTool(name='tool1', description='Test tool 1') + toolset = _TestingToolset(tools=[tool1], tool_name_prefix='test') + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context1 = InvocationContext( + invocation_id='inv-1', + agent=agent, + session=session, + session_service=session_service, + ) + readonly_context1 = ReadonlyContext(invocation_context1) + + # First call + tools1 = await toolset.get_tools_with_prefix( + readonly_context=readonly_context1 + ) + assert len(tools1) == 1 + assert tools1[0].name == 'test_tool1' + + # Second call with same context/invocation_id + tools2 = await toolset.get_tools_with_prefix( + readonly_context=readonly_context1 + ) + assert len(tools2) == 1 + assert ( + tools2 is tools1 + ) # Should return the exact same list instance (from cache) + + # Third call with different invocation_id + invocation_context2 = InvocationContext( + invocation_id='inv-2', + agent=agent, + session=session, + session_service=session_service, + ) + readonly_context2 = ReadonlyContext(invocation_context2) + + tools3 = await toolset.get_tools_with_prefix( + readonly_context=readonly_context2 + ) + assert len(tools3) == 1 + assert tools3 is not tools1 # Should be a new list instance + assert tools3[0].name == 'test_tool1'