Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/google/adk/agents/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,25 @@ def _create_branch_ctx_for_sub_agent(
return invocation_context


def _has_escalate_action(event: Event) -> bool:
"""Returns whether the event asks the parent workflow to exit early."""
return bool(event.actions.escalate)


def _cancel_tasks(tasks: list[asyncio.Task[None]]) -> None:
"""Cancels still-running merge tasks."""
for task in tasks:
if not task.done():
task.cancel()


async def _merge_agent_run(
agent_runs: list[AsyncGenerator[Event, None]],
) -> AsyncGenerator[Event, None]:
"""Merges agent runs using asyncio.TaskGroup on Python 3.11+."""
sentinel = object()
queue = asyncio.Queue()
tasks: list[asyncio.Task[None]] = []

# Agents are processed in parallel.
# Events for each agent are put on queue sequentially.
Expand All @@ -70,7 +83,7 @@ async def process_an_agent(events_for_one_agent):

async with asyncio.TaskGroup() as tg:
for events_for_one_agent in agent_runs:
tg.create_task(process_an_agent(events_for_one_agent))
tasks.append(tg.create_task(process_an_agent(events_for_one_agent)))

sentinel_count = 0
# Run until all agents finished processing.
Expand All @@ -81,6 +94,9 @@ async def process_an_agent(events_for_one_agent):
sentinel_count += 1
else:
yield event
if _has_escalate_action(event):
_cancel_tasks(tasks)
return
# Signal to agent that it should generate next event.
resume_signal.set()

Expand Down Expand Up @@ -124,7 +140,7 @@ async def process_an_agent(events_for_one_agent):
# Mark agent as finished.
await queue.put((sentinel, None))

tasks = []
tasks: list[asyncio.Task[None]] = []
try:
for events_for_one_agent in agent_runs:
tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent)))
Expand All @@ -139,12 +155,16 @@ async def process_an_agent(events_for_one_agent):
sentinel_count += 1
else:
yield event
if _has_escalate_action(event):
_cancel_tasks(tasks)
return
# Signal to agent that event has been processed by runner and it can
# continue now.
resume_signal.set()
finally:
for task in tasks:
task.cancel()
_cancel_tasks(tasks)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)


class ParallelAgent(BaseAgent):
Expand Down Expand Up @@ -181,6 +201,7 @@ async def _run_async_impl(
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
agent_runs.append(sub_agent.run_async(sub_agent_ctx))

escalated = False
pause_invocation = False
try:
merge_func = (
Expand All @@ -191,15 +212,18 @@ async def _run_async_impl(
async with Aclosing(merge_func(agent_runs)) as agen:
async for event in agen:
yield event
if ctx.should_pause_invocation(event):
if _has_escalate_action(event):
escalated = True
elif ctx.should_pause_invocation(event):
pause_invocation = True

if pause_invocation:
if pause_invocation and not escalated:
return

# Once all sub-agents are done, mark the ParallelAgent as final.
if ctx.is_resumable and all(
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
if ctx.is_resumable and (
escalated
or all(ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents)
):
ctx.set_agent_state(self.name, end_of_agent=True)
yield self._create_agent_state_event(ctx)
Expand Down
111 changes: 109 additions & 2 deletions tests/unittests/agents/test_parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Tests for the ParallelAgent."""

import asyncio
from types import SimpleNamespace
from typing import AsyncGenerator

from google.adk.agents import parallel_agent as parallel_agent_module
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.base_agent import BaseAgentState
from google.adk.agents.invocation_context import InvocationContext
Expand All @@ -25,6 +27,7 @@
from google.adk.agents.sequential_agent import SequentialAgentState
from google.adk.apps.app import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
import pytest
Expand All @@ -36,14 +39,21 @@ class _TestingAgent(BaseAgent):
delay: float = 0
"""The delay before the agent generates an event."""

def event(self, ctx: InvocationContext):
def event(
self,
ctx: InvocationContext,
*,
text: str | None = None,
actions: EventActions | None = None,
):
return Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=types.Content(
parts=[types.Part(text=f'Hello, async {self.name}!')]
parts=[types.Part(text=text or f'Hello, async {self.name}!')]
),
actions=actions if actions is not None else EventActions(),
)

@override
Expand Down Expand Up @@ -342,6 +352,22 @@ async def _run_async_impl(
yield self.event(ctx)


class _TestingAgentWithEscalateAction(_TestingAgent):
"""Mock agent for testing escalation short-circuit behavior."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self.event(
ctx,
text=f'Escalating from {self.name}!',
actions=EventActions(escalate=True),
)
yield self.event(ctx, text='This event should be cancelled after escalation.')


@pytest.mark.asyncio
async def test_stop_agent_if_sub_agent_fails(
request: pytest.FixtureRequest,
Expand Down Expand Up @@ -373,3 +399,84 @@ async def test_stop_agent_if_sub_agent_fails(
async for _ in agen:
# The infinite agent could iterate a few times depending on scheduling.
pass


@pytest.mark.asyncio
@pytest.mark.parametrize('is_resumable', [True, False])
@pytest.mark.parametrize('use_pre_3_11_merge', [False, True])
async def test_run_async_short_circuits_other_agents_on_escalate_action(
request: pytest.FixtureRequest,
monkeypatch: pytest.MonkeyPatch,
is_resumable: bool,
use_pre_3_11_merge: bool,
):
if use_pre_3_11_merge:
monkeypatch.setattr(
parallel_agent_module,
'sys',
SimpleNamespace(version_info=(3, 10)),
)

fast_agent = _TestingAgent(
name=f'{request.function.__name__}_test_fast_agent',
delay=0.05,
)
escalating_agent = _TestingAgentWithEscalateAction(
name=f'{request.function.__name__}_test_escalating_agent',
delay=0.1,
)
slow_agent = _TestingAgent(
name=f'{request.function.__name__}_test_slow_agent',
delay=0.5,
)
parallel_agent = ParallelAgent(
name=f'{request.function.__name__}_test_parallel_agent',
sub_agents=[fast_agent, escalating_agent, slow_agent],
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, parallel_agent, is_resumable=is_resumable
)

events = [e async for e in parallel_agent.run_async(parent_ctx)]

assert all(event.author != slow_agent.name for event in events)
assert all(
not event.content
or not event.content.parts
or event.content.parts[0].text
!= 'This event should be cancelled after escalation.'
for event in events
)

if is_resumable:
assert len(events) == 4

assert events[0].author == parallel_agent.name
assert not events[0].actions.end_of_agent

assert events[1].author == fast_agent.name
assert events[1].branch == f'{parallel_agent.name}.{fast_agent.name}'
assert events[1].content.parts[0].text == f'Hello, async {fast_agent.name}!'

assert events[2].author == escalating_agent.name
assert events[2].branch == f'{parallel_agent.name}.{escalating_agent.name}'
assert events[2].content.parts[0].text == (
f'Escalating from {escalating_agent.name}!'
)
assert events[2].actions.escalate

assert events[3].author == parallel_agent.name
assert events[3].actions.end_of_agent
else:
assert len(events) == 2

assert events[0].author == fast_agent.name
assert events[0].branch == f'{parallel_agent.name}.{fast_agent.name}'
assert events[0].content.parts[0].text == f'Hello, async {fast_agent.name}!'

assert events[1].author == escalating_agent.name
assert events[1].branch == f'{parallel_agent.name}.{escalating_agent.name}'
assert events[1].content.parts[0].text == (
f'Escalating from {escalating_agent.name}!'
)
assert events[1].actions.escalate
110 changes: 110 additions & 0 deletions tests/unittests/runners/test_resume_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
# limitations under the License.
"""Tests for edge cases of resuming invocations."""

import asyncio
import copy
from typing import AsyncGenerator

from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.tools.long_running_tool import LongRunningFunctionTool
from google.genai import types
from google.genai.types import FunctionResponse
from google.genai.types import Part
import pytest
Expand All @@ -41,6 +49,43 @@ def test_tool() -> dict[str, str]:
return {"result": "test tool result"}


test_tool.__test__ = False


class _ParallelEscalationTestingAgent(BaseAgent):
"""A testing agent that emits a single event after a delay."""

delay: float = 0
response_text: str = ""
escalate: bool = False
emit_follow_up_after_first_event: bool = False

def _create_event(
self,
ctx: InvocationContext,
text: str,
*,
escalate: bool = False,
) -> Event:
return Event(
author=self.name,
branch=ctx.branch,
invocation_id=ctx.invocation_id,
content=types.Content(role="model", parts=[types.Part(text=text)]),
actions=EventActions(escalate=True) if escalate else EventActions(),
)

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
await asyncio.sleep(self.delay)
yield self._create_event(
ctx, self.response_text, escalate=self.escalate
)
if self.emit_follow_up_after_first_event:
yield self._create_event(ctx, "This event should not be emitted.")


@pytest.mark.asyncio
async def test_resume_invocation_from_sub_agent():
"""A test case for an edge case, where an invocation-to-resume starts from a sub-agent.
Expand Down Expand Up @@ -252,3 +297,68 @@ async def test_resume_any_invocation():
),
(root_agent.name, testing_utils.END_OF_AGENT),
]


@pytest.mark.asyncio
async def test_resumable_parallel_agent_escalation_short_circuits_persisted_run():
"""Runner persists fast+escalating events and marks the parent run complete."""
fast_agent = _ParallelEscalationTestingAgent(
name="fast_agent",
delay=0.05,
response_text="fast response",
)
escalating_agent = _ParallelEscalationTestingAgent(
name="escalating_agent",
delay=0.1,
response_text="escalating response",
escalate=True,
emit_follow_up_after_first_event=True,
)
slow_agent = _ParallelEscalationTestingAgent(
name="slow_agent",
delay=0.5,
response_text="slow response",
)
runner = testing_utils.InMemoryRunner(
app=App(
name="test_app",
root_agent=ParallelAgent(
name="root_agent",
sub_agents=[fast_agent, escalating_agent, slow_agent],
),
resumability_config=ResumabilityConfig(is_resumable=True),
)
)

invocation_events = await runner.run_async("test user query")
simplified_events = testing_utils.simplify_resumable_app_events(
copy.deepcopy(invocation_events)
)

assert simplified_events == [
("root_agent", {}),
("fast_agent", "fast response"),
("escalating_agent", "escalating response"),
("root_agent", testing_utils.END_OF_AGENT),
]

session = await runner.runner.session_service.get_session(
app_name=runner.app_name,
user_id="test_user",
session_id=runner.session_id,
)
persisted_events = [
event
for event in session.events
if event.invocation_id == invocation_events[0].invocation_id
and event.author != "user"
]
assert testing_utils.simplify_resumable_app_events(
copy.deepcopy(persisted_events)
) == simplified_events
assert all(event.author != "slow_agent" for event in persisted_events)

# A completed resumable invocation should not restart cancelled siblings.
assert not await runner.run_async(
invocation_id=invocation_events[0].invocation_id
)
Loading