Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/java/com/google/adk/apps/App.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ public ContextCacheConfig contextCacheConfig() {
return contextCacheConfig;
}

public Builder toBuilder() {
return new Builder(this);
}

/** Builder for {@link App}. */
public static class Builder {
private String name;
Expand All @@ -86,6 +90,16 @@ public static class Builder {
@Nullable private EventsCompactionConfig eventsCompactionConfig;
@Nullable private ContextCacheConfig contextCacheConfig;

private Builder() {}

private Builder(App app) {
this.name = app.name;
this.rootAgent = app.rootAgent;
this.plugins = app.plugins;
this.eventsCompactionConfig = app.eventsCompactionConfig;
this.contextCacheConfig = app.contextCacheConfig;
}

@CanIgnoreReturnValue
public Builder name(String name) {
this.name = name;
Expand Down
18 changes: 11 additions & 7 deletions core/src/main/java/com/google/adk/runner/InMemoryRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
package com.google.adk.runner;

import com.google.adk.agents.BaseAgent;
import com.google.adk.apps.App;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.memory.InMemoryMemoryService;
import com.google.adk.plugins.Plugin;
import com.google.adk.sessions.InMemorySessionService;
Expand All @@ -38,12 +40,14 @@ public InMemoryRunner(BaseAgent agent, String appName) {
}

public InMemoryRunner(BaseAgent agent, String appName, List<? extends Plugin> plugins) {
super(
agent,
appName,
new InMemoryArtifactService(),
new InMemorySessionService(),
new InMemoryMemoryService(),
plugins);
this(App.builder().rootAgent(agent).name(appName).plugins(plugins).build());
}

public InMemoryRunner(App app) {
this(app, new InMemoryMemoryService());
}

public InMemoryRunner(App app, BaseMemoryService memoryService) {
super(app, new InMemoryArtifactService(), new InMemorySessionService(), memoryService);
}
}
144 changes: 63 additions & 81 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,43 +68,44 @@

/** The main class for the GenAI Agents runner. */
public class Runner {
private final BaseAgent agent;
private final String appName;
private final App app;
private final BaseArtifactService artifactService;
private final BaseSessionService sessionService;
@Nullable private final BaseMemoryService memoryService;
private final PluginManager pluginManager;
@Nullable private final EventsCompactionConfig eventsCompactionConfig;
@Nullable private final ContextCacheConfig contextCacheConfig;

/** Builder for {@link Runner}. */
public static class Builder {
private App app;
private BaseAgent agent;
private String appName;
private App.Builder appBuilder = App.builder();
private BaseArtifactService artifactService = new InMemoryArtifactService();
private BaseSessionService sessionService = new InMemorySessionService();
@Nullable private BaseMemoryService memoryService = null;
private List<? extends Plugin> plugins = ImmutableList.of();

private Builder() {}

private Builder(Runner runner) {
this.appBuilder = runner.app().toBuilder();
this.artifactService = runner.artifactService();
this.sessionService = runner.sessionService();
this.memoryService = runner.memoryService();
}

@CanIgnoreReturnValue
public Builder app(App app) {
Preconditions.checkState(this.agent == null, "app() cannot be called when agent() is set.");
this.app = app;
this.appBuilder = app.toBuilder();
return this;
}

@CanIgnoreReturnValue
public Builder agent(BaseAgent agent) {
Preconditions.checkState(this.app == null, "agent() cannot be called when app is set.");
this.agent = agent;
this.appBuilder.rootAgent(agent);
return this;
}

@CanIgnoreReturnValue
public Builder appName(String appName) {
Preconditions.checkState(this.app == null, "appName() cannot be called when app is set.");
this.appName = appName;
this.appBuilder.name(appName);
return this;
}

Expand All @@ -128,66 +129,26 @@ public Builder memoryService(BaseMemoryService memoryService) {

@CanIgnoreReturnValue
public Builder plugins(List<? extends Plugin> plugins) {
Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set.");
this.plugins = plugins;
this.appBuilder.plugins(plugins);
return this;
}

@CanIgnoreReturnValue
public Builder plugins(Plugin... plugins) {
Preconditions.checkState(this.app == null, "plugins() cannot be called when app is set.");
this.plugins = ImmutableList.copyOf(plugins);
this.appBuilder.plugins(plugins);
return this;
}

public Runner build() {
BaseAgent buildAgent;
String buildAppName;
List<? extends Plugin> buildPlugins;
EventsCompactionConfig buildEventsCompactionConfig;
ContextCacheConfig buildContextCacheConfig;

if (this.app != null) {
if (this.agent != null) {
throw new IllegalStateException("agent() cannot be called when app() is called.");
}
if (!this.plugins.isEmpty()) {
throw new IllegalStateException("plugins() cannot be called when app() is called.");
}
buildAgent = this.app.rootAgent();
buildPlugins = this.app.plugins();
buildAppName = this.appName == null ? this.app.name() : this.appName;
buildEventsCompactionConfig = this.app.eventsCompactionConfig();
buildContextCacheConfig = this.app.contextCacheConfig();
} else {
buildAgent = this.agent;
buildAppName = this.appName;
buildPlugins = this.plugins;
buildEventsCompactionConfig = null;
buildContextCacheConfig = null;
}
App app = this.appBuilder.build();

if (buildAgent == null) {
throw new IllegalStateException("Agent must be provided via app() or agent().");
}
if (buildAppName == null) {
throw new IllegalStateException("App name must be provided via app() or appName().");
}
if (artifactService == null) {
throw new IllegalStateException("Artifact service must be provided.");
}
if (sessionService == null) {
throw new IllegalStateException("Session service must be provided.");
}
return new Runner(
buildAgent,
buildAppName,
artifactService,
sessionService,
memoryService,
buildPlugins,
buildEventsCompactionConfig,
buildContextCacheConfig);
return new Runner(app, artifactService, sessionService, memoryService);
}
}

Expand All @@ -207,7 +168,11 @@ public Runner(
BaseArtifactService artifactService,
BaseSessionService sessionService,
@Nullable BaseMemoryService memoryService) {
this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of());
this(
App.builder().rootAgent(agent).name(appName).build(),
artifactService,
sessionService,
memoryService);
}

/**
Expand All @@ -223,7 +188,11 @@ public Runner(
BaseSessionService sessionService,
@Nullable BaseMemoryService memoryService,
List<? extends Plugin> plugins) {
this(agent, appName, artifactService, sessionService, memoryService, plugins, null, null);
this(
App.builder().rootAgent(agent).name(appName).plugins(plugins).build(),
artifactService,
sessionService,
memoryService);
}

/**
Expand All @@ -233,22 +202,17 @@ public Runner(
*/
@Deprecated
protected Runner(
BaseAgent agent,
String appName,
App app,
BaseArtifactService artifactService,
BaseSessionService sessionService,
@Nullable BaseMemoryService memoryService,
List<? extends Plugin> plugins,
@Nullable EventsCompactionConfig eventsCompactionConfig,
@Nullable ContextCacheConfig contextCacheConfig) {
this.agent = agent;
this.appName = appName;
@Nullable BaseMemoryService memoryService) {
this.app = app;
this.artifactService = artifactService;
this.sessionService = sessionService;
this.memoryService = memoryService;
this.pluginManager = new PluginManager(plugins);
this.eventsCompactionConfig = createEventsCompactionConfig(agent, eventsCompactionConfig);
this.contextCacheConfig = contextCacheConfig;
this.pluginManager = new PluginManager(app.plugins());
this.eventsCompactionConfig =
createEventsCompactionConfig(app.rootAgent(), app.eventsCompactionConfig());
}

/**
Expand All @@ -265,12 +229,16 @@ public Runner(
this(agent, appName, artifactService, sessionService, null);
}

public App app() {
return this.app;
}

public BaseAgent agent() {
return this.agent;
return this.app.rootAgent();
}

public String appName() {
return this.appName;
return this.app.name();
}

public BaseArtifactService artifactService() {
Expand All @@ -290,10 +258,24 @@ public PluginManager pluginManager() {
return this.pluginManager;
}

@Nullable
public EventsCompactionConfig eventsCompactionConfig() {
return this.eventsCompactionConfig;
}

@Nullable
public ContextCacheConfig contextCacheConfig() {
return this.app.contextCacheConfig();
}

public Builder toBuilder() {
return new Builder(this);
}

/** Closes all plugins, code executors, and releases any resources. */
public Completable close() {
List<Completable> completables = new ArrayList<>();
completables.add(agent.close());
completables.add(app.rootAgent().close());
completables.add(this.pluginManager.close());
return Completable.mergeDelayError(completables);
}
Expand Down Expand Up @@ -326,7 +308,7 @@ private Single<Event> appendNewMessageToSession(
saveArtifactsFlow =
saveArtifactsFlow.andThen(
this.artifactService
.saveArtifact(this.appName, session.userId(), session.id(), fileName, part)
.saveArtifact(this.app.name(), session.userId(), session.id(), fileName, part)
.ignoreElement());

newMessage
Expand Down Expand Up @@ -383,13 +365,13 @@ public Flowable<Event> runAsync(
return Flowable.defer(
() ->
this.sessionService
.getSession(appName, userId, sessionId, Optional.empty())
.getSession(this.app.name(), userId, sessionId, Optional.empty())
.switchIfEmpty(
Single.defer(
() -> {
if (runConfig.autoCreateSession()) {
return this.sessionService.createSession(
appName, userId, (Map<String, Object>) null, sessionId);
this.app.name(), userId, (Map<String, Object>) null, sessionId);
}
return Single.error(
new IllegalArgumentException(
Expand Down Expand Up @@ -475,7 +457,7 @@ protected Flowable<Event> runAsyncImpl(
Context capturedContext = Context.current();
return Flowable.defer(
() -> {
BaseAgent rootAgent = this.agent;
BaseAgent rootAgent = this.app.rootAgent();
String invocationId = InvocationContext.newInvocationContextId();

// Create initial context
Expand Down Expand Up @@ -634,7 +616,7 @@ private InvocationContext newInvocationContextForLive(
}

private InvocationContext.Builder newInvocationContextBuilder(Session session) {
BaseAgent rootAgent = this.agent;
BaseAgent rootAgent = this.app.rootAgent();
return InvocationContext.builder()
.sessionService(this.sessionService)
.artifactService(this.artifactService)
Expand All @@ -643,7 +625,7 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) {
.agent(rootAgent)
.session(session)
.eventsCompactionConfig(this.eventsCompactionConfig)
.contextCacheConfig(this.contextCacheConfig)
.contextCacheConfig(this.app.contextCacheConfig())
.agent(this.findAgentToRun(session, rootAgent));
}

Expand All @@ -663,13 +645,13 @@ public Flowable<Event> runLive(
return Flowable.defer(
() ->
this.sessionService
.getSession(appName, userId, sessionId, Optional.empty())
.getSession(this.app.name(), userId, sessionId, Optional.empty())
.switchIfEmpty(
Single.defer(
() -> {
if (runConfig.autoCreateSession()) {
return this.sessionService.createSession(
appName, userId, (Map<String, Object>) null, sessionId);
this.app.name(), userId, (Map<String, Object>) null, sessionId);
}
return Single.error(
new IllegalArgumentException(
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,19 @@ public void runAsync_withSessionKey_success() {
assertThat(simplifyEvents(events)).containsExactly("test agent: from llm");
}

@Test
public void toBuilder_success() {
Runner newRunner = runner.toBuilder().appName("new_app").build();

assertThat(newRunner.appName()).isEqualTo("new_app");
assertThat(newRunner.agent()).isEqualTo(runner.agent());
assertThat(newRunner.artifactService()).isEqualTo(runner.artifactService());
assertThat(newRunner.sessionService()).isEqualTo(runner.sessionService());
assertThat(newRunner.memoryService()).isEqualTo(runner.memoryService());
assertThat(newRunner.pluginManager().getPlugins())
.containsExactlyElementsIn(runner.pluginManager().getPlugins());
}

@Test
public void runAsync_withStateDelta_mergesStateIntoSession() {
ImmutableMap<String, Object> stateDelta = ImmutableMap.of("key1", "value1", "key2", 42);
Expand Down
Loading