Skip to content

Commit 13645cf

Browse files
Jinni GuJinni Gu
authored andcommitted
feat: Support Sub-agent Escalation event in Parallel Agent (Issue #561)
1 parent 3dee126 commit 13645cf

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

core/src/main/java/com/google/adk/agents/ParallelAgent.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
148148
for (BaseAgent subAgent : currentSubAgents) {
149149
agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler));
150150
}
151-
return Flowable.merge(agentFlowables);
151+
return Flowable.merge(agentFlowables)
152+
.takeUntil((Event event) -> event.actions().escalate().orElse(false));
152153
}
153154

154155
/**
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.agents;
18+
19+
import static com.google.adk.testing.TestUtils.createInvocationContext;
20+
import static com.google.common.truth.Truth.assertThat;
21+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
22+
23+
import com.google.adk.events.Event;
24+
import com.google.adk.events.EventActions;
25+
import com.google.common.collect.ImmutableList;
26+
import com.google.genai.types.Content;
27+
import com.google.genai.types.Part;
28+
import io.reactivex.rxjava3.core.Flowable;
29+
import io.reactivex.rxjava3.core.Scheduler;
30+
import io.reactivex.rxjava3.schedulers.TestScheduler;
31+
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
35+
@RunWith(JUnit4.class)
36+
public final class ParallelAgentEscalationTest {
37+
38+
static class EscalatingAgent extends BaseAgent {
39+
private final long delayMillis;
40+
private final Scheduler scheduler;
41+
42+
private EscalatingAgent(String name, long delayMillis, Scheduler scheduler) {
43+
super(name, "Escalating Agent", ImmutableList.of(), null, null);
44+
this.delayMillis = delayMillis;
45+
this.scheduler = scheduler;
46+
}
47+
48+
@Override
49+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
50+
Flowable<Event> event =
51+
Flowable.fromCallable(
52+
() ->
53+
Event.builder()
54+
.author(name())
55+
.branch(invocationContext.branch().orElse(null))
56+
.invocationId(invocationContext.invocationId())
57+
.content(Content.fromParts(Part.fromText("Escalating!")))
58+
.actions(EventActions.builder().escalate(true).build())
59+
.build());
60+
61+
if (delayMillis > 0) {
62+
return event.delay(delayMillis, MILLISECONDS, scheduler);
63+
}
64+
return event;
65+
}
66+
67+
@Override
68+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
69+
throw new UnsupportedOperationException("Not implemented");
70+
}
71+
}
72+
73+
static class SlowAgent extends BaseAgent {
74+
private final long delayMillis;
75+
private final Scheduler scheduler;
76+
77+
private SlowAgent(String name, long delayMillis, Scheduler scheduler) {
78+
super(name, "Slow Agent", ImmutableList.of(), null, null);
79+
this.delayMillis = delayMillis;
80+
this.scheduler = scheduler;
81+
}
82+
83+
@Override
84+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
85+
Flowable<Event> event =
86+
Flowable.fromCallable(
87+
() ->
88+
Event.builder()
89+
.author(name())
90+
.branch(invocationContext.branch().orElse(null))
91+
.invocationId(invocationContext.invocationId())
92+
.content(Content.fromParts(Part.fromText("Finished")))
93+
.build());
94+
95+
if (delayMillis > 0) {
96+
return event.delay(delayMillis, MILLISECONDS, scheduler);
97+
}
98+
return event;
99+
}
100+
101+
@Override
102+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
103+
throw new UnsupportedOperationException("Not implemented");
104+
}
105+
}
106+
107+
@Test
108+
public void runAsync_escalationEvent_shortCircuitsOtherAgents() {
109+
TestScheduler testScheduler = new TestScheduler();
110+
111+
EscalatingAgent agent1 = new EscalatingAgent("agent1", 100, testScheduler);
112+
SlowAgent agent2 = new SlowAgent("agent2", 500, testScheduler);
113+
114+
ParallelAgent parallelAgent =
115+
ParallelAgent.builder()
116+
.name("parallel_agent")
117+
.subAgents(agent1, agent2)
118+
.scheduler(testScheduler)
119+
.build();
120+
121+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
122+
123+
var subscriber = parallelAgent.runAsync(invocationContext).test();
124+
125+
testScheduler.advanceTimeBy(200, MILLISECONDS);
126+
127+
subscriber.assertValueCount(1);
128+
Event event = subscriber.values().get(0);
129+
assertThat(event.author()).isEqualTo("agent1");
130+
assertThat(event.actions().escalate()).hasValue(true);
131+
132+
subscriber.assertComplete();
133+
testScheduler.advanceTimeBy(1000, MILLISECONDS);
134+
135+
// Test RxJava Disposal behavior: SlowAgent won't emit anything since sequence was forcibly
136+
// terminated!
137+
subscriber.assertValueCount(1);
138+
}
139+
}

0 commit comments

Comments
 (0)