Skip to content

Commit bcac136

Browse files
tilgalascopybara-github
authored andcommitted
fix: include saveArtifact invocations in event chain
PiperOrigin-RevId: 884451502
1 parent c8ab0f9 commit bcac136

2 files changed

Lines changed: 63 additions & 4 deletions

File tree

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ private Single<Event> appendNewMessageToSession(
313313
throw new IllegalArgumentException("No parts in the new_message.");
314314
}
315315

316+
Completable saveArtifactsFlow = Completable.complete();
316317
if (this.artifactService != null && saveInputBlobsAsArtifacts) {
317318
// The runner directly saves the artifacts (if applicable) in the user message and replaces
318319
// the artifact data with a file name placeholder.
@@ -322,9 +323,11 @@ private Single<Event> appendNewMessageToSession(
322323
continue;
323324
}
324325
String fileName = "artifact_" + invocationContext.invocationId() + "_" + i;
325-
var unused =
326-
this.artifactService.saveArtifact(
327-
this.appName, session.userId(), session.id(), fileName, part);
326+
saveArtifactsFlow =
327+
saveArtifactsFlow.andThen(
328+
this.artifactService
329+
.saveArtifact(this.appName, session.userId(), session.id(), fileName, part)
330+
.ignoreElement());
328331

329332
newMessage
330333
.parts()
@@ -349,7 +352,8 @@ private Single<Event> appendNewMessageToSession(
349352
EventActions.builder().stateDelta(new ConcurrentHashMap<>(stateDelta)).build());
350353
}
351354

352-
return this.sessionService.appendEvent(session, eventBuilder.build());
355+
return saveArtifactsFlow.andThen(
356+
this.sessionService.appendEvent(session, eventBuilder.build()));
353357
}
354358

355359
/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import static com.google.adk.testing.TestUtils.createTextLlmResponse;
2525
import static com.google.adk.testing.TestUtils.simplifyEvents;
2626
import static com.google.common.truth.Truth.assertThat;
27+
import static java.nio.charset.StandardCharsets.UTF_8;
28+
import static java.util.Arrays.stream;
2729
import static org.mockito.ArgumentMatchers.any;
2830
import static org.mockito.Mockito.CALLS_REAL_METHODS;
2931
import static org.mockito.Mockito.mock;
@@ -36,6 +38,7 @@
3638
import com.google.adk.agents.LlmAgent;
3739
import com.google.adk.agents.RunConfig;
3840
import com.google.adk.apps.App;
41+
import com.google.adk.artifacts.BaseArtifactService;
3942
import com.google.adk.events.Event;
4043
import com.google.adk.flows.llmflows.Functions;
4144
import com.google.adk.models.LlmResponse;
@@ -65,19 +68,22 @@
6568
import io.reactivex.rxjava3.core.Completable;
6669
import io.reactivex.rxjava3.core.Flowable;
6770
import io.reactivex.rxjava3.core.Maybe;
71+
import io.reactivex.rxjava3.core.Single;
6872
import io.reactivex.rxjava3.subscribers.TestSubscriber;
6973
import java.util.List;
7074
import java.util.Objects;
7175
import java.util.Optional;
7276
import java.util.UUID;
7377
import java.util.concurrent.ConcurrentHashMap;
78+
import java.util.concurrent.atomic.AtomicInteger;
7479
import org.junit.After;
7580
import org.junit.Before;
7681
import org.junit.Rule;
7782
import org.junit.Test;
7883
import org.junit.runner.RunWith;
7984
import org.junit.runners.JUnit4;
8085
import org.mockito.ArgumentCaptor;
86+
import org.mockito.Mockito;
8187

8288
@RunWith(JUnit4.class)
8389
public final class RunnerTest {
@@ -849,6 +855,19 @@ private Content createContent(String text) {
849855
return Content.builder().parts(Part.builder().text(text).build()).build();
850856
}
851857

858+
private static Content createInlineDataContent(byte[]... data) {
859+
return Content.builder()
860+
.parts(
861+
stream(data)
862+
.map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream"))
863+
.toArray(Part[]::new))
864+
.build();
865+
}
866+
867+
private static Content createInlineDataContent(String... data) {
868+
return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new));
869+
}
870+
852871
@Test
853872
public void runAsync_createsInvocationSpan() {
854873
var unused =
@@ -1331,4 +1350,40 @@ public static ImmutableMap<String, Object> echoTool(String message) {
13311350
return ImmutableMap.of("message", message);
13321351
}
13331352
}
1353+
1354+
@Test
1355+
public void runner_executesSaveArtifactFlow() {
1356+
// arrange
1357+
final AtomicInteger artifactsSavedCounter = new AtomicInteger();
1358+
BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class);
1359+
when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any()))
1360+
.thenReturn(
1361+
Single.defer(
1362+
() -> {
1363+
// we want to assert not only that the saveArtifact method was
1364+
// called, but also that the flow that it returned was run, so
1365+
// we need to record the call in a counter
1366+
artifactsSavedCounter.incrementAndGet();
1367+
return Single.just(42);
1368+
}));
1369+
Runner runner =
1370+
Runner.builder()
1371+
.app(App.builder().name("test").rootAgent(agent).build())
1372+
.artifactService(mockArtifactService)
1373+
.build();
1374+
session = runner.sessionService().createSession("test", "user").blockingGet();
1375+
// each inline data will be saved using our mock artifact service
1376+
Content content = createInlineDataContent("test data", "test data 2");
1377+
RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build();
1378+
1379+
// act
1380+
var events = runner.runAsync("user", session.id(), content, runConfig).test();
1381+
1382+
// assert
1383+
events.assertComplete();
1384+
// artifacts were saved
1385+
assertThat(artifactsSavedCounter.get()).isEqualTo(2);
1386+
// agent was run
1387+
assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm");
1388+
}
13341389
}

0 commit comments

Comments
 (0)