diff --git a/core/src/main/java/com/google/adk/apps/App.java b/core/src/main/java/com/google/adk/apps/App.java index 897e24490..b0fec1176 100644 --- a/core/src/main/java/com/google/adk/apps/App.java +++ b/core/src/main/java/com/google/adk/apps/App.java @@ -78,6 +78,10 @@ public ContextCacheConfig contextCacheConfig() { return contextCacheConfig; } + public Builder toBuilder() { + return new Builder(this); + } + /** Builder for {@link App}. */ public static class Builder { private String name; @@ -86,6 +90,16 @@ public static class Builder { @Nullable private EventsCompactionConfig eventsCompactionConfig; @Nullable private ContextCacheConfig contextCacheConfig; + private Builder() {} + + private Builder(App app) { + this.name = app.name; + this.rootAgent = app.rootAgent; + this.plugins = app.plugins; + this.eventsCompactionConfig = app.eventsCompactionConfig; + this.contextCacheConfig = app.contextCacheConfig; + } + @CanIgnoreReturnValue public Builder name(String name) { this.name = name; diff --git a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java index 58741003c..43edf7075 100644 --- a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java +++ b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java @@ -17,7 +17,9 @@ package com.google.adk.runner; import com.google.adk.agents.BaseAgent; +import com.google.adk.apps.App; import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.memory.BaseMemoryService; import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.plugins.Plugin; import com.google.adk.sessions.InMemorySessionService; @@ -38,12 +40,14 @@ public InMemoryRunner(BaseAgent agent, String appName) { } public InMemoryRunner(BaseAgent agent, String appName, List plugins) { - super( - agent, - appName, - new InMemoryArtifactService(), - new InMemorySessionService(), - new InMemoryMemoryService(), - plugins); + this(App.builder().rootAgent(agent).name(appName).plugins(plugins).build()); + } + + public InMemoryRunner(App app) { + this(app, new InMemoryMemoryService()); + } + + public InMemoryRunner(App app, BaseMemoryService memoryService) { + super(app, new InMemoryArtifactService(), new InMemorySessionService(), memoryService); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 1f7d924ab..5a72592c9 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -68,43 +68,44 @@ /** The main class for the GenAI Agents runner. */ public class Runner { - private final BaseAgent agent; - private final String appName; + private final App app; private final BaseArtifactService artifactService; private final BaseSessionService sessionService; @Nullable private final BaseMemoryService memoryService; private final PluginManager pluginManager; @Nullable private final EventsCompactionConfig eventsCompactionConfig; - @Nullable private final ContextCacheConfig contextCacheConfig; /** Builder for {@link Runner}. */ public static class Builder { - private App app; - private BaseAgent agent; - private String appName; + private App.Builder appBuilder = App.builder(); private BaseArtifactService artifactService = new InMemoryArtifactService(); private BaseSessionService sessionService = new InMemorySessionService(); @Nullable private BaseMemoryService memoryService = null; - private List plugins = ImmutableList.of(); + + private Builder() {} + + private Builder(Runner runner) { + this.appBuilder = runner.app().toBuilder(); + this.artifactService = runner.artifactService(); + this.sessionService = runner.sessionService(); + this.memoryService = runner.memoryService(); + } @CanIgnoreReturnValue public Builder app(App app) { - Preconditions.checkState(this.agent == null, "app() cannot be called when agent() is set."); - this.app = app; + this.appBuilder = app.toBuilder(); return this; } @CanIgnoreReturnValue public Builder agent(BaseAgent agent) { - Preconditions.checkState(this.app == null, "agent() cannot be called when app is set."); - this.agent = agent; + this.appBuilder.rootAgent(agent); return this; } @CanIgnoreReturnValue public Builder appName(String appName) { - Preconditions.checkState(this.app == null, "appName() cannot be called when app is set."); - this.appName = appName; + this.appBuilder.name(appName); return this; } @@ -128,66 +129,26 @@ public Builder memoryService(BaseMemoryService memoryService) { @CanIgnoreReturnValue public Builder plugins(List plugins) { - Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set."); - this.plugins = plugins; + this.appBuilder.plugins(plugins); return this; } @CanIgnoreReturnValue public Builder plugins(Plugin... plugins) { - Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set."); - this.plugins = ImmutableList.copyOf(plugins); + this.appBuilder.plugins(plugins); return this; } public Runner build() { - BaseAgent buildAgent; - String buildAppName; - List buildPlugins; - EventsCompactionConfig buildEventsCompactionConfig; - ContextCacheConfig buildContextCacheConfig; - - if (this.app != null) { - if (this.agent != null) { - throw new IllegalStateException("agent() cannot be called when app() is called."); - } - if (!this.plugins.isEmpty()) { - throw new IllegalStateException("plugins() cannot be called when app() is called."); - } - buildAgent = this.app.rootAgent(); - buildPlugins = this.app.plugins(); - buildAppName = this.appName == null ? this.app.name() : this.appName; - buildEventsCompactionConfig = this.app.eventsCompactionConfig(); - buildContextCacheConfig = this.app.contextCacheConfig(); - } else { - buildAgent = this.agent; - buildAppName = this.appName; - buildPlugins = this.plugins; - buildEventsCompactionConfig = null; - buildContextCacheConfig = null; - } + App app = this.appBuilder.build(); - if (buildAgent == null) { - throw new IllegalStateException("Agent must be provided via app() or agent()."); - } - if (buildAppName == null) { - throw new IllegalStateException("App name must be provided via app() or appName()."); - } if (artifactService == null) { throw new IllegalStateException("Artifact service must be provided."); } if (sessionService == null) { throw new IllegalStateException("Session service must be provided."); } - return new Runner( - buildAgent, - buildAppName, - artifactService, - sessionService, - memoryService, - buildPlugins, - buildEventsCompactionConfig, - buildContextCacheConfig); + return new Runner(app, artifactService, sessionService, memoryService); } } @@ -207,7 +168,11 @@ public Runner( BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService) { - this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of()); + this( + App.builder().rootAgent(agent).name(appName).build(), + artifactService, + sessionService, + memoryService); } /** @@ -223,7 +188,11 @@ public Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins) { - this(agent, appName, artifactService, sessionService, memoryService, plugins, null, null); + this( + App.builder().rootAgent(agent).name(appName).plugins(plugins).build(), + artifactService, + sessionService, + memoryService); } /** @@ -233,22 +202,17 @@ public Runner( */ @Deprecated protected Runner( - BaseAgent agent, - String appName, + App app, BaseArtifactService artifactService, BaseSessionService sessionService, - @Nullable BaseMemoryService memoryService, - List plugins, - @Nullable EventsCompactionConfig eventsCompactionConfig, - @Nullable ContextCacheConfig contextCacheConfig) { - this.agent = agent; - this.appName = appName; + @Nullable BaseMemoryService memoryService) { + this.app = app; this.artifactService = artifactService; this.sessionService = sessionService; this.memoryService = memoryService; - this.pluginManager = new PluginManager(plugins); - this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig); - this.contextCacheConfig = contextCacheConfig; + this.pluginManager = new PluginManager(app.plugins()); + this.eventsCompactionConfig = + createEventsCompactionConfig(app.rootAgent(), app.eventsCompactionConfig()); } /** @@ -265,12 +229,16 @@ public Runner( this(agent, appName, artifactService, sessionService, null); } + public App app() { + return this.app; + } + public BaseAgent agent() { - return this.agent; + return this.app.rootAgent(); } public String appName() { - return this.appName; + return this.app.name(); } public BaseArtifactService artifactService() { @@ -290,10 +258,24 @@ public PluginManager pluginManager() { return this.pluginManager; } + @Nullable + public EventsCompactionConfig eventsCompactionConfig() { + return this.eventsCompactionConfig; + } + + @Nullable + public ContextCacheConfig contextCacheConfig() { + return this.app.contextCacheConfig(); + } + + public Builder toBuilder() { + return new Builder(this); + } + /** Closes all plugins, code executors, and releases any resources. */ public Completable close() { List completables = new ArrayList<>(); - completables.add(agent.close()); + completables.add(app.rootAgent().close()); completables.add(this.pluginManager.close()); return Completable.mergeDelayError(completables); } @@ -326,7 +308,7 @@ private Single appendNewMessageToSession( saveArtifactsFlow = saveArtifactsFlow.andThen( this.artifactService - .saveArtifact(this.appName, session.userId(), session.id(), fileName, part) + .saveArtifact(this.app.name(), session.userId(), session.id(), fileName, part) .ignoreElement()); newMessage @@ -383,13 +365,13 @@ public Flowable runAsync( return Flowable.defer( () -> this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) + .getSession(this.app.name(), userId, sessionId, Optional.empty()) .switchIfEmpty( Single.defer( () -> { if (runConfig.autoCreateSession()) { return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); + this.app.name(), userId, (Map) null, sessionId); } return Single.error( new IllegalArgumentException( @@ -475,7 +457,7 @@ protected Flowable runAsyncImpl( Context capturedContext = Context.current(); return Flowable.defer( () -> { - BaseAgent rootAgent = this.agent; + BaseAgent rootAgent = this.app.rootAgent(); String invocationId = InvocationContext.newInvocationContextId(); // Create initial context @@ -634,7 +616,7 @@ private InvocationContext newInvocationContextForLive( } private InvocationContext.Builder newInvocationContextBuilder(Session session) { - BaseAgent rootAgent = this.agent; + BaseAgent rootAgent = this.app.rootAgent(); return InvocationContext.builder() .sessionService(this.sessionService) .artifactService(this.artifactService) @@ -643,7 +625,7 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { .agent(rootAgent) .session(session) .eventsCompactionConfig(this.eventsCompactionConfig) - .contextCacheConfig(this.contextCacheConfig) + .contextCacheConfig(this.app.contextCacheConfig()) .agent(this.findAgentToRun(session, rootAgent)); } @@ -663,13 +645,13 @@ public Flowable runLive( return Flowable.defer( () -> this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) + .getSession(this.app.name(), userId, sessionId, Optional.empty()) .switchIfEmpty( Single.defer( () -> { if (runConfig.autoCreateSession()) { return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); + this.app.name(), userId, (Map) null, sessionId); } return Single.error( new IllegalArgumentException( diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index a3e21cb73..e05a289c2 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -596,6 +596,19 @@ public void runAsync_withSessionKey_success() { assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); } + @Test + public void toBuilder_success() { + Runner newRunner = runner.toBuilder().appName("new_app").build(); + + assertThat(newRunner.appName()).isEqualTo("new_app"); + assertThat(newRunner.agent()).isEqualTo(runner.agent()); + assertThat(newRunner.artifactService()).isEqualTo(runner.artifactService()); + assertThat(newRunner.sessionService()).isEqualTo(runner.sessionService()); + assertThat(newRunner.memoryService()).isEqualTo(runner.memoryService()); + assertThat(newRunner.pluginManager().getPlugins()) + .containsExactlyElementsIn(runner.pluginManager().getPlugins()); + } + @Test public void runAsync_withStateDelta_mergesStateIntoSession() { ImmutableMap stateDelta = ImmutableMap.of("key1", "value1", "key2", 42);