Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(
"""
self.tool_filter = tool_filter
self.tool_name_prefix = tool_name_prefix
self._cached_invocation_id: Optional[str] = None
self._cached_prefixed_tools: Optional[list[BaseTool]] = None

@abstractmethod
async def get_tools(
Expand Down Expand Up @@ -112,9 +114,19 @@ async def get_tools_with_prefix(
Returns:
list[BaseTool]: A list of tools with prefixed names if tool_name_prefix is provided.
"""
invocation_id = readonly_context.invocation_id if readonly_context else None

if (
self._cached_prefixed_tools is not None
and self._cached_invocation_id == invocation_id
):
return self._cached_prefixed_tools

tools = await self.get_tools(readonly_context)

if not self.tool_name_prefix:
self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = tools
return tools

prefix = self.tool_name_prefix
Expand Down Expand Up @@ -147,6 +159,8 @@ def _get_prefixed_declaration():
tool_copy._get_declaration = _create_prefixed_declaration()
prefixed_tools.append(tool_copy)

self._cached_invocation_id = invocation_id
self._cached_prefixed_tools = prefixed_tools
return prefixed_tools

async def close(self) -> None:
Expand Down
56 changes: 54 additions & 2 deletions tests/unittests/tools/test_base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,58 @@ async def test_no_duplicate_prefixing():
original_tools = await toolset.get_tools()
assert original_tools[0].name == 'original'

# The prefixed tools should be different instances
assert prefixed_tools_1[0] is not prefixed_tools_2[0]
# The prefixed tools should be the same instance when cached
assert prefixed_tools_1[0] is prefixed_tools_2[0]
assert prefixed_tools_1[0] is not original_tools[0]


@pytest.mark.asyncio
async def test_get_tools_with_prefix_caching():
"""Test that get_tools_with_prefix caches results within the same invocation."""
tool1 = _TestingTool(name='tool1', description='Test tool 1')
toolset = _TestingToolset(tools=[tool1], tool_name_prefix='test')

session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
agent = SequentialAgent(name='test_agent')
invocation_context1 = InvocationContext(
invocation_id='inv-1',
agent=agent,
session=session,
session_service=session_service,
)
readonly_context1 = ReadonlyContext(invocation_context1)

# First call
tools1 = await toolset.get_tools_with_prefix(
readonly_context=readonly_context1
)
assert len(tools1) == 1
assert tools1[0].name == 'test_tool1'

# Second call with same context/invocation_id
tools2 = await toolset.get_tools_with_prefix(
readonly_context=readonly_context1
)
assert len(tools2) == 1
assert (
tools2 is tools1
) # Should return the exact same list instance (from cache)

# Third call with different invocation_id
invocation_context2 = InvocationContext(
invocation_id='inv-2',
agent=agent,
session=session,
session_service=session_service,
)
readonly_context2 = ReadonlyContext(invocation_context2)

tools3 = await toolset.get_tools_with_prefix(
readonly_context=readonly_context2
)
assert len(tools3) == 1
assert tools3 is not tools1 # Should be a new list instance
assert tools3[0].name == 'test_tool1'
Loading