Skip to content

Commit 991c911

Browse files
tilgalascopybara-github
authored andcommitted
fix: include saveArtifact invocations in event chain
PiperOrigin-RevId: 884451502
1 parent 567fdf0 commit 991c911

2 files changed

Lines changed: 63 additions & 4 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ private Single<Event> appendNewMessageToSession(
312312
throw new IllegalArgumentException("No parts in the new_message.");
313313
}
314314

315+
Completable saveArtifactsFlow = Completable.complete();
315316
if (this.artifactService != null && saveInputBlobsAsArtifacts) {
316317
// The runner directly saves the artifacts (if applicable) in the user message and replaces
317318
// the artifact data with a file name placeholder.
@@ -321,9 +322,11 @@ private Single<Event> appendNewMessageToSession(
321322
continue;
322323
}
323324
String fileName = "artifact_" + invocationContext.invocationId() + "_" + i;
324-
var unused =
325-
this.artifactService.saveArtifact(
326-
this.appName, session.userId(), session.id(), fileName, part);
325+
saveArtifactsFlow =
326+
saveArtifactsFlow.andThen(
327+
this.artifactService
328+
.saveArtifact(this.appName, session.userId(), session.id(), fileName, part)
329+
.ignoreElement());
327330

328331
newMessage
329332
.parts()
@@ -348,7 +351,8 @@ private Single<Event> appendNewMessageToSession(
348351
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
349352
}
350353

351-
return this.sessionService.appendEvent(session, eventBuilder.build());
354+
return saveArtifactsFlow.andThen(
355+
this.sessionService.appendEvent(session, eventBuilder.build()));
352356
}
353357

354358
/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import static com.google.adk.testing.TestUtils.createTextLlmResponse;
2525
import static com.google.adk.testing.TestUtils.simplifyEvents;
2626
import static com.google.common.truth.Truth.assertThat;
27+
import static java.nio.charset.StandardCharsets.UTF_8;
28+
import static java.util.Arrays.stream;
2729
import static org.mockito.ArgumentMatchers.any;
2830
import static org.mockito.Mockito.CALLS_REAL_METHODS;
2931
import static org.mockito.Mockito.mock;
@@ -36,6 +38,7 @@
3638
import com.google.adk.agents.LlmAgent;
3739
import com.google.adk.agents.RunConfig;
3840
import com.google.adk.apps.App;
41+
import com.google.adk.artifacts.BaseArtifactService;
3942
import com.google.adk.events.Event;
4043
import com.google.adk.flows.llmflows.Functions;
4144
import com.google.adk.models.LlmResponse;
@@ -62,19 +65,22 @@
6265
import io.reactivex.rxjava3.core.Completable;
6366
import io.reactivex.rxjava3.core.Flowable;
6467
import io.reactivex.rxjava3.core.Maybe;
68+
import io.reactivex.rxjava3.core.Single;
6569
import io.reactivex.rxjava3.subscribers.TestSubscriber;
6670
import java.util.List;
6771
import java.util.Objects;
6872
import java.util.Optional;
6973
import java.util.UUID;
7074
import java.util.concurrent.ConcurrentHashMap;
75+
import java.util.concurrent.atomic.AtomicInteger;
7176
import org.junit.After;
7277
import org.junit.Before;
7378
import org.junit.Rule;
7479
import org.junit.Test;
7580
import org.junit.runner.RunWith;
7681
import org.junit.runners.JUnit4;
7782
import org.mockito.ArgumentCaptor;
83+
import org.mockito.Mockito;
7884

7985
@RunWith(JUnit4.class)
8086
public final class RunnerTest {
@@ -846,6 +852,19 @@ private Content createContent(String text) {
846852
return Content.builder().parts(Part.builder().text(text).build()).build();
847853
}
848854

855+
private static Content createInlineDataContent(byte[]... data) {
856+
return Content.builder()
857+
.parts(
858+
stream(data)
859+
.map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream"))
860+
.toArray(Part[]::new))
861+
.build();
862+
}
863+
864+
private static Content createInlineDataContent(String... data) {
865+
return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new));
866+
}
867+
849868
@Test
850869
public void runAsync_createsInvocationSpan() {
851870
var unused =
@@ -1203,4 +1222,40 @@ public static ImmutableMap<String, Object> echoTool(String message) {
12031222
return ImmutableMap.of("message", message);
12041223
}
12051224
}
1225+
1226+
@Test
1227+
public void runner_executesSaveArtifactFlow() {
1228+
// arrange
1229+
final AtomicInteger artifactsSavedCounter = new AtomicInteger();
1230+
BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class);
1231+
when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any()))
1232+
.thenReturn(
1233+
Single.defer(
1234+
() -> {
1235+
// we want to assert not only that the saveArtifact method was
1236+
// called, but also that the flow that it returned was run, so
1237+
// we need to record the call in a counter
1238+
artifactsSavedCounter.incrementAndGet();
1239+
return Single.just(42);
1240+
}));
1241+
Runner runner =
1242+
Runner.builder()
1243+
.app(App.builder().name("test").rootAgent(agent).build())
1244+
.artifactService(mockArtifactService)
1245+
.build();
1246+
session = runner.sessionService().createSession("test", "user").blockingGet();
1247+
// each inline data will be saved using our mock artifact service
1248+
Content content = createInlineDataContent("test data", "test data 2");
1249+
RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build();
1250+
1251+
// act
1252+
var events = runner.runAsync("user", session.id(), content, runConfig).test();
1253+
1254+
// assert
1255+
events.assertComplete();
1256+
// artifacts where saved
1257+
assertThat(artifactsSavedCounter.get()).isEqualTo(2);
1258+
// agent was run
1259+
assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm");
1260+
}
12061261
}

0 commit comments

Comments
 (0)