diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index d9d049f80..743d569b9 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -21,7 +21,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import io.reactivex.rxjava3.core.Flowable; import java.util.List; -import java.util.Optional; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,7 +34,7 @@ public class LoopAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(LoopAgent.class); - private final Optional maxIterations; + private final @Nullable Integer maxIterations; /** * Constructor for LoopAgent. @@ -50,7 +50,7 @@ private LoopAgent( String name, String description, List subAgents, - Optional maxIterations, + @Nullable Integer maxIterations, List beforeAgentCallback, List afterAgentCallback) { @@ -60,16 +60,10 @@ private LoopAgent( /** Builder for {@link LoopAgent}. */ public static class Builder extends BaseAgent.Builder { - private Optional maxIterations = Optional.empty(); + private @Nullable Integer maxIterations; @CanIgnoreReturnValue - public Builder maxIterations(int maxIterations) { - this.maxIterations = Optional.of(maxIterations); - return this; - } - - @CanIgnoreReturnValue - public Builder maxIterations(Optional maxIterations) { + public Builder maxIterations(@Nullable Integer maxIterations) { this.maxIterations = maxIterations; return this; } @@ -124,7 +118,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { return Flowable.fromIterable(subAgents) .concatMap(subAgent -> subAgent.runAsync(invocationContext)) - .repeat(maxIterations.orElse(Integer.MAX_VALUE)) + .repeat(maxIterations != null ? maxIterations : Integer.MAX_VALUE) .takeUntil(LoopAgent::hasEscalateAction); } @@ -137,4 +131,8 @@ protected Flowable runLiveImpl(InvocationContext invocationContext) { private static boolean hasEscalateAction(Event event) { return event.actions().escalate().orElse(false); } + + public @Nullable Integer maxIterations() { + return maxIterations; + } } diff --git a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java index 5c04ac74b..b2d0778c6 100644 --- a/core/src/test/java/com/google/adk/agents/LoopAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LoopAgentTest.java @@ -33,7 +33,6 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -165,11 +164,7 @@ public void runAsync_withNoMaxIterations_keepsLooping() { Event event2 = createEvent("event2"); TestBaseAgent subAgent = createSubAgent("subAgent", () -> Flowable.just(event1, event2)); LoopAgent loopAgent = - LoopAgent.builder() - .name("loopAgent") - .subAgents(ImmutableList.of(subAgent)) - .maxIterations(Optional.empty()) - .build(); + LoopAgent.builder().name("loopAgent").subAgents(ImmutableList.of(subAgent)).build(); InvocationContext invocationContext = createInvocationContext(loopAgent); Iterable result = loopAgent.runAsync(invocationContext).blockingIterable();