From 6d009430f8bed8c7f90fb28240da7046e0a2278e Mon Sep 17 00:00:00 2001 From: Zeel Date: Mon, 30 Mar 2026 16:54:34 -0400 Subject: [PATCH 1/2] fix(flows): resume long-running tools after matching responses --- src/google/adk/agents/invocation_context.py | 28 +++++++ src/google/adk/agents/llm_agent.py | 2 +- .../adk/flows/llm_flows/base_llm_flow.py | 17 +--- src/google/adk/flows/llm_flows/functions.py | 29 +++++++ .../agents/test_invocation_context.py | 77 +++++++++++++++++++ .../flows/llm_flows/test_base_llm_flow.py | 42 ++++++++++ 6 files changed, 181 insertions(+), 14 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index b2032c5325..10eee3e9a0 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,6 +396,34 @@ def should_pause_invocation(self, event: Event) -> bool: return False + def has_unresolved_long_running_tool_calls( + self, events: list[Event] + ) -> bool: + """Returns whether any long-running tool call in events is unresolved.""" + if not self.is_resumable or not events: + return False + + function_response_ids = { + function_response.id + for event in events + for function_response in event.get_function_responses() + if function_response.id + } + + for event in reversed(events): + if not self.should_pause_invocation(event): + continue + + paused_function_call_ids = { + function_call.id + for function_call in event.get_function_calls() + if function_call.id in event.long_running_tool_ids + } + if paused_function_call_ids - function_response_ids: + return True + + return False + # TODO: Move this method from invocation_context to a dedicated module. def _find_matching_function_call( self, function_response_event: Event diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 96e5043f72..ab6e4ab523 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -496,7 +496,7 @@ async def _run_async_impl( if ctx.is_resumable: events = ctx._get_events(current_invocation=True, current_branch=True) - if events and any(ctx.should_pause_invocation(e) for e in events[-2:]): + if ctx.has_unresolved_long_running_tool_calls(events): return # Only yield an end state if the last event is no longer a long-running # tool call. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..1a660771c8 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -99,6 +99,9 @@ def _finalize_model_response_event( if finalized_event.content: function_calls = finalized_event.get_function_calls() if function_calls: + functions.preserve_existing_function_call_ids( + model_response_event, finalized_event + ) functions.populate_client_function_call_id(finalized_event) finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls( @@ -785,19 +788,7 @@ async def _run_one_step_async( # Long running tool calls should have been handled before this point. # If there are still long running tool calls, it means the agent is paused # before, and its branch hasn't been resumed yet. - if ( - invocation_context.is_resumable - and events - and len(events) > 1 - # TODO: here we are using the last 2 events to decide whether to pause - # the invocation. But this is just being optimistic, we should find a - # way to pause when the long running tool call is followed by more than - # one text responses. - and ( - invocation_context.should_pause_invocation(events[-1]) - or invocation_context.should_pause_invocation(events[-2]) - ) - ): + if invocation_context.has_unresolved_long_running_tool_calls(events): return if ( diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 1f85bee3a8..99a9765f73 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -189,6 +189,35 @@ def populate_client_function_call_id(model_response_event: Event) -> None: function_call.id = generate_client_function_call_id() +def preserve_existing_function_call_ids( + previous_event: Event, model_response_event: Event +) -> None: + """Carries forward function call IDs from a previous streaming event. + + Streaming responses may emit partial and final events for the same function + call sequence. The partial event is sent to clients first, while only the + final event is persisted. Preserving IDs across those events keeps + functionResponse routing stable when the client resumes a long-running tool. + + Args: + previous_event: The in-flight model response event from an earlier chunk. + model_response_event: The newly finalized event for the current chunk. + """ + previous_function_calls = previous_event.get_function_calls() + current_function_calls = model_response_event.get_function_calls() + if not previous_function_calls or not current_function_calls: + return + + for previous_function_call, current_function_call in zip( + previous_function_calls, current_function_calls + ): + if current_function_call.id: + continue + if previous_function_call.name != current_function_call.name: + continue + current_function_call.id = previous_function_call.id + + def remove_client_function_call_id(content: Optional[types.Content]) -> None: """Removes ADK-generated function call IDs from content before sending to LLM. diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 87f78b2869..8a0a38a82d 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -24,6 +24,7 @@ from google.adk.sessions.session import Session from google.genai.types import Content from google.genai.types import FunctionCall +from google.genai.types import FunctionResponse from google.genai.types import Part import pytest @@ -210,6 +211,82 @@ def test_should_not_pause_invocation_with_no_function_calls( nonpausable_event ) + def test_has_unresolved_long_running_tool_calls_with_matching_response(self): + """Tests that matching function responses resolve the pause.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + resolved_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id=function_call.id, + ) + ) + ], + ), + ) + + assert not invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, resolved_event] + ) + + def test_has_unresolved_long_running_tool_calls_without_matching_response( + self, + ): + """Tests that unmatched long-running calls still pause the invocation.""" + invocation_context = self._create_test_invocation_context( + ResumabilityConfig(is_resumable=True) + ) + function_call = FunctionCall( + id='tool_call_id_1', + name='long_running_function_call', + args={}, + ) + paused_event = Event( + invocation_id='inv_1', + author='agent', + content=testing_utils.ModelContent([Part(function_call=function_call)]), + long_running_tool_ids={function_call.id}, + ) + unrelated_response_event = Event( + invocation_id='inv_1', + author='user', + content=Content( + role='user', + parts=[ + Part( + function_response=FunctionResponse( + name='long_running_function_call', + response={'result': 'done'}, + id='different_tool_call_id', + ) + ) + ], + ), + ) + + assert invocation_context.has_unresolved_long_running_tool_calls( + [paused_event, unrelated_response_event] + ) + def test_is_resumable_true(self): """Tests that is_resumable is True when resumability is enabled.""" invocation_context = self._create_test_invocation_context( diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 3dfadbcabf..8fca948550 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -19,6 +19,7 @@ from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini @@ -41,6 +42,47 @@ class BaseLlmFlowForTesting(BaseLlmFlow): pass +def test_finalize_model_response_event_preserves_function_call_ids(): + """Test that streaming finalization keeps function call IDs stable.""" + previous_event = Event( + id=Event.new_id(), + invocation_id='test_invocation', + author='test_agent', + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='track_execution', + args={'call_id': 'partial'}, + id='adk-existing-id', + ) + ) + ], + ), + partial=True, + ) + llm_response = LlmResponse( + content=types.Content( + role='model', + parts=[ + types.Part.from_function_call( + name='track_execution', args={'call_id': 'final'} + ) + ], + ), + partial=False, + ) + + finalized_event = _finalize_model_response_event( + LlmRequest(), llm_response, previous_event + ) + + function_calls = finalized_event.get_function_calls() + assert len(function_calls) == 1 + assert function_calls[0].id == 'adk-existing-id' + + @pytest.mark.asyncio async def test_preprocess_calls_toolset_process_llm_request(): """Test that _preprocess_async calls process_llm_request on toolsets.""" From d361c12b02383d77ee4bd6a6f414ce24008304ee Mon Sep 17 00:00:00 2001 From: Jordan Date: Wed, 1 Apr 2026 22:12:04 +0000 Subject: [PATCH 2/2] fix: 3 fixes for has_unresolved_long_running_tool_calls (PR #5072) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. pyink formatting: collapse method signature to single line 2. Only count author='user' function_responses as resolutions — agent- generated auto-responses from LongRunningFunctionTool should not resolve the pause, only actual user resume responses should 3. Add null guard on event.long_running_tool_ids to fix mypy type error All 5158 unit tests pass with these changes. --- src/google/adk/agents/invocation_context.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 10eee3e9a0..3287664d29 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -396,9 +396,7 @@ def should_pause_invocation(self, event: Event) -> bool: return False - def has_unresolved_long_running_tool_calls( - self, events: list[Event] - ) -> bool: + def has_unresolved_long_running_tool_calls(self, events: list[Event]) -> bool: """Returns whether any long-running tool call in events is unresolved.""" if not self.is_resumable or not events: return False @@ -407,7 +405,7 @@ def has_unresolved_long_running_tool_calls( function_response.id for event in events for function_response in event.get_function_responses() - if function_response.id + if function_response.id and event.author == 'user' } for event in reversed(events): @@ -417,7 +415,7 @@ def has_unresolved_long_running_tool_calls( paused_function_call_ids = { function_call.id for function_call in event.get_function_calls() - if function_call.id in event.long_running_tool_ids + if event.long_running_tool_ids and function_call.id in event.long_running_tool_ids } if paused_function_call_ids - function_response_ids: return True