diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index cb8b09f655..5e03169d78 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -48,12 +48,25 @@ def _create_branch_ctx_for_sub_agent( return invocation_context +def _has_escalate_action(event: Event) -> bool: + """Returns whether the event asks the parent workflow to exit early.""" + return bool(event.actions.escalate) + + +def _cancel_tasks(tasks: list[asyncio.Task[None]]) -> None: + """Cancels still-running merge tasks.""" + for task in tasks: + if not task.done(): + task.cancel() + + async def _merge_agent_run( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: """Merges agent runs using asyncio.TaskGroup on Python 3.11+.""" sentinel = object() queue = asyncio.Queue() + tasks: list[asyncio.Task[None]] = [] # Agents are processed in parallel. # Events for each agent are put on queue sequentially. @@ -70,7 +83,7 @@ async def process_an_agent(events_for_one_agent): async with asyncio.TaskGroup() as tg: for events_for_one_agent in agent_runs: - tg.create_task(process_an_agent(events_for_one_agent)) + tasks.append(tg.create_task(process_an_agent(events_for_one_agent))) sentinel_count = 0 # Run until all agents finished processing. @@ -81,6 +94,9 @@ async def process_an_agent(events_for_one_agent): sentinel_count += 1 else: yield event + if _has_escalate_action(event): + _cancel_tasks(tasks) + return # Signal to agent that it should generate next event. resume_signal.set() @@ -124,7 +140,7 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - tasks = [] + tasks: list[asyncio.Task[None]] = [] try: for events_for_one_agent in agent_runs: tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) @@ -139,12 +155,16 @@ async def process_an_agent(events_for_one_agent): sentinel_count += 1 else: yield event + if _has_escalate_action(event): + _cancel_tasks(tasks) + return # Signal to agent that event has been processed by runner and it can # continue now. resume_signal.set() finally: - for task in tasks: - task.cancel() + _cancel_tasks(tasks) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) class ParallelAgent(BaseAgent): @@ -181,6 +201,7 @@ async def _run_async_impl( if not sub_agent_ctx.end_of_agents.get(sub_agent.name): agent_runs.append(sub_agent.run_async(sub_agent_ctx)) + escalated = False pause_invocation = False try: merge_func = ( @@ -191,15 +212,18 @@ async def _run_async_impl( async with Aclosing(merge_func(agent_runs)) as agen: async for event in agen: yield event - if ctx.should_pause_invocation(event): + if _has_escalate_action(event): + escalated = True + elif ctx.should_pause_invocation(event): pause_invocation = True - if pause_invocation: + if pause_invocation and not escalated: return # Once all sub-agents are done, mark the ParallelAgent as final. - if ctx.is_resumable and all( - ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents + if ctx.is_resumable and ( + escalated + or all(ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents) ): ctx.set_agent_state(self.name, end_of_agent=True) yield self._create_agent_state_event(ctx) diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index cad1ce3a83..ec31acc12b 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -15,8 +15,10 @@ """Tests for the ParallelAgent.""" import asyncio +from types import SimpleNamespace from typing import AsyncGenerator +from google.adk.agents import parallel_agent as parallel_agent_module from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgentState from google.adk.agents.invocation_context import InvocationContext @@ -25,6 +27,7 @@ from google.adk.agents.sequential_agent import SequentialAgentState from google.adk.apps.app import ResumabilityConfig from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.genai import types import pytest @@ -36,14 +39,21 @@ class _TestingAgent(BaseAgent): delay: float = 0 """The delay before the agent generates an event.""" - def event(self, ctx: InvocationContext): + def event( + self, + ctx: InvocationContext, + *, + text: str | None = None, + actions: EventActions | None = None, + ): return Event( author=self.name, branch=ctx.branch, invocation_id=ctx.invocation_id, content=types.Content( - parts=[types.Part(text=f'Hello, async {self.name}!')] + parts=[types.Part(text=text or f'Hello, async {self.name}!')] ), + actions=actions if actions is not None else EventActions(), ) @override @@ -342,6 +352,22 @@ async def _run_async_impl( yield self.event(ctx) +class _TestingAgentWithEscalateAction(_TestingAgent): + """Mock agent for testing escalation short-circuit behavior.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + await asyncio.sleep(self.delay) + yield self.event( + ctx, + text=f'Escalating from {self.name}!', + actions=EventActions(escalate=True), + ) + yield self.event(ctx, text='This event should be cancelled after escalation.') + + @pytest.mark.asyncio async def test_stop_agent_if_sub_agent_fails( request: pytest.FixtureRequest, @@ -373,3 +399,84 @@ async def test_stop_agent_if_sub_agent_fails( async for _ in agen: # The infinite agent could iterate a few times depending on scheduling. pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize('is_resumable', [True, False]) +@pytest.mark.parametrize('use_pre_3_11_merge', [False, True]) +async def test_run_async_short_circuits_other_agents_on_escalate_action( + request: pytest.FixtureRequest, + monkeypatch: pytest.MonkeyPatch, + is_resumable: bool, + use_pre_3_11_merge: bool, +): + if use_pre_3_11_merge: + monkeypatch.setattr( + parallel_agent_module, + 'sys', + SimpleNamespace(version_info=(3, 10)), + ) + + fast_agent = _TestingAgent( + name=f'{request.function.__name__}_test_fast_agent', + delay=0.05, + ) + escalating_agent = _TestingAgentWithEscalateAction( + name=f'{request.function.__name__}_test_escalating_agent', + delay=0.1, + ) + slow_agent = _TestingAgent( + name=f'{request.function.__name__}_test_slow_agent', + delay=0.5, + ) + parallel_agent = ParallelAgent( + name=f'{request.function.__name__}_test_parallel_agent', + sub_agents=[fast_agent, escalating_agent, slow_agent], + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, parallel_agent, is_resumable=is_resumable + ) + + events = [e async for e in parallel_agent.run_async(parent_ctx)] + + assert all(event.author != slow_agent.name for event in events) + assert all( + not event.content + or not event.content.parts + or event.content.parts[0].text + != 'This event should be cancelled after escalation.' + for event in events + ) + + if is_resumable: + assert len(events) == 4 + + assert events[0].author == parallel_agent.name + assert not events[0].actions.end_of_agent + + assert events[1].author == fast_agent.name + assert events[1].branch == f'{parallel_agent.name}.{fast_agent.name}' + assert events[1].content.parts[0].text == f'Hello, async {fast_agent.name}!' + + assert events[2].author == escalating_agent.name + assert events[2].branch == f'{parallel_agent.name}.{escalating_agent.name}' + assert events[2].content.parts[0].text == ( + f'Escalating from {escalating_agent.name}!' + ) + assert events[2].actions.escalate + + assert events[3].author == parallel_agent.name + assert events[3].actions.end_of_agent + else: + assert len(events) == 2 + + assert events[0].author == fast_agent.name + assert events[0].branch == f'{parallel_agent.name}.{fast_agent.name}' + assert events[0].content.parts[0].text == f'Hello, async {fast_agent.name}!' + + assert events[1].author == escalating_agent.name + assert events[1].branch == f'{parallel_agent.name}.{escalating_agent.name}' + assert events[1].content.parts[0].text == ( + f'Escalating from {escalating_agent.name}!' + ) + assert events[1].actions.escalate diff --git a/tests/unittests/runners/test_resume_invocation.py b/tests/unittests/runners/test_resume_invocation.py index 0db3f23bec..8f3931f215 100644 --- a/tests/unittests/runners/test_resume_invocation.py +++ b/tests/unittests/runners/test_resume_invocation.py @@ -13,12 +13,20 @@ # limitations under the License. """Tests for edge cases of resuming invocations.""" +import asyncio import copy +from typing import AsyncGenerator +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.parallel_agent import ParallelAgent from google.adk.apps.app import App from google.adk.apps.app import ResumabilityConfig +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.tools.long_running_tool import LongRunningFunctionTool +from google.genai import types from google.genai.types import FunctionResponse from google.genai.types import Part import pytest @@ -41,6 +49,43 @@ def test_tool() -> dict[str, str]: return {"result": "test tool result"} +test_tool.__test__ = False + + +class _ParallelEscalationTestingAgent(BaseAgent): + """A testing agent that emits a single event after a delay.""" + + delay: float = 0 + response_text: str = "" + escalate: bool = False + emit_follow_up_after_first_event: bool = False + + def _create_event( + self, + ctx: InvocationContext, + text: str, + *, + escalate: bool = False, + ) -> Event: + return Event( + author=self.name, + branch=ctx.branch, + invocation_id=ctx.invocation_id, + content=types.Content(role="model", parts=[types.Part(text=text)]), + actions=EventActions(escalate=True) if escalate else EventActions(), + ) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + await asyncio.sleep(self.delay) + yield self._create_event( + ctx, self.response_text, escalate=self.escalate + ) + if self.emit_follow_up_after_first_event: + yield self._create_event(ctx, "This event should not be emitted.") + + @pytest.mark.asyncio async def test_resume_invocation_from_sub_agent(): """A test case for an edge case, where an invocation-to-resume starts from a sub-agent. @@ -252,3 +297,68 @@ async def test_resume_any_invocation(): ), (root_agent.name, testing_utils.END_OF_AGENT), ] + + +@pytest.mark.asyncio +async def test_resumable_parallel_agent_escalation_short_circuits_persisted_run(): + """Runner persists fast+escalating events and marks the parent run complete.""" + fast_agent = _ParallelEscalationTestingAgent( + name="fast_agent", + delay=0.05, + response_text="fast response", + ) + escalating_agent = _ParallelEscalationTestingAgent( + name="escalating_agent", + delay=0.1, + response_text="escalating response", + escalate=True, + emit_follow_up_after_first_event=True, + ) + slow_agent = _ParallelEscalationTestingAgent( + name="slow_agent", + delay=0.5, + response_text="slow response", + ) + runner = testing_utils.InMemoryRunner( + app=App( + name="test_app", + root_agent=ParallelAgent( + name="root_agent", + sub_agents=[fast_agent, escalating_agent, slow_agent], + ), + resumability_config=ResumabilityConfig(is_resumable=True), + ) + ) + + invocation_events = await runner.run_async("test user query") + simplified_events = testing_utils.simplify_resumable_app_events( + copy.deepcopy(invocation_events) + ) + + assert simplified_events == [ + ("root_agent", {}), + ("fast_agent", "fast response"), + ("escalating_agent", "escalating response"), + ("root_agent", testing_utils.END_OF_AGENT), + ] + + session = await runner.runner.session_service.get_session( + app_name=runner.app_name, + user_id="test_user", + session_id=runner.session_id, + ) + persisted_events = [ + event + for event in session.events + if event.invocation_id == invocation_events[0].invocation_id + and event.author != "user" + ] + assert testing_utils.simplify_resumable_app_events( + copy.deepcopy(persisted_events) + ) == simplified_events + assert all(event.author != "slow_agent" for event in persisted_events) + + # A completed resumable invocation should not restart cancelled siblings. + assert not await runner.run_async( + invocation_id=invocation_events[0].invocation_id + )