|
24 | 24 | import static com.google.adk.testing.TestUtils.createTextLlmResponse; |
25 | 25 | import static com.google.adk.testing.TestUtils.simplifyEvents; |
26 | 26 | import static com.google.common.truth.Truth.assertThat; |
| 27 | +import static java.nio.charset.StandardCharsets.UTF_8; |
| 28 | +import static java.util.Arrays.stream; |
27 | 29 | import static org.mockito.ArgumentMatchers.any; |
28 | 30 | import static org.mockito.Mockito.CALLS_REAL_METHODS; |
29 | 31 | import static org.mockito.Mockito.mock; |
|
36 | 38 | import com.google.adk.agents.LlmAgent; |
37 | 39 | import com.google.adk.agents.RunConfig; |
38 | 40 | import com.google.adk.apps.App; |
| 41 | +import com.google.adk.artifacts.BaseArtifactService; |
39 | 42 | import com.google.adk.events.Event; |
40 | 43 | import com.google.adk.flows.llmflows.Functions; |
41 | 44 | import com.google.adk.models.LlmResponse; |
|
65 | 68 | import io.reactivex.rxjava3.core.Completable; |
66 | 69 | import io.reactivex.rxjava3.core.Flowable; |
67 | 70 | import io.reactivex.rxjava3.core.Maybe; |
| 71 | +import io.reactivex.rxjava3.core.Single; |
68 | 72 | import io.reactivex.rxjava3.subscribers.TestSubscriber; |
69 | 73 | import java.util.List; |
70 | 74 | import java.util.Objects; |
71 | 75 | import java.util.Optional; |
72 | 76 | import java.util.UUID; |
73 | 77 | import java.util.concurrent.ConcurrentHashMap; |
| 78 | +import java.util.concurrent.atomic.AtomicInteger; |
74 | 79 | import org.junit.After; |
75 | 80 | import org.junit.Before; |
76 | 81 | import org.junit.Rule; |
77 | 82 | import org.junit.Test; |
78 | 83 | import org.junit.runner.RunWith; |
79 | 84 | import org.junit.runners.JUnit4; |
80 | 85 | import org.mockito.ArgumentCaptor; |
| 86 | +import org.mockito.Mockito; |
81 | 87 |
|
82 | 88 | @RunWith(JUnit4.class) |
83 | 89 | public final class RunnerTest { |
@@ -849,6 +855,19 @@ private Content createContent(String text) { |
849 | 855 | return Content.builder().parts(Part.builder().text(text).build()).build(); |
850 | 856 | } |
851 | 857 |
|
| 858 | + private static Content createInlineDataContent(byte[]... data) { |
| 859 | + return Content.builder() |
| 860 | + .parts( |
| 861 | + stream(data) |
| 862 | + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) |
| 863 | + .toArray(Part[]::new)) |
| 864 | + .build(); |
| 865 | + } |
| 866 | + |
| 867 | + private static Content createInlineDataContent(String... data) { |
| 868 | + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); |
| 869 | + } |
| 870 | + |
852 | 871 | @Test |
853 | 872 | public void runAsync_createsInvocationSpan() { |
854 | 873 | var unused = |
@@ -1331,4 +1350,40 @@ public static ImmutableMap<String, Object> echoTool(String message) { |
1331 | 1350 | return ImmutableMap.of("message", message); |
1332 | 1351 | } |
1333 | 1352 | } |
| 1353 | + |
| 1354 | + @Test |
| 1355 | + public void runner_executesSaveArtifactFlow() { |
| 1356 | + // arrange |
| 1357 | + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); |
| 1358 | + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); |
| 1359 | + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) |
| 1360 | + .thenReturn( |
| 1361 | + Single.defer( |
| 1362 | + () -> { |
| 1363 | + // we want to assert not only that the saveArtifact method was |
| 1364 | + // called, but also that the flow that it returned was run, so |
| 1365 | + // we need to record the call in a counter |
| 1366 | + artifactsSavedCounter.incrementAndGet(); |
| 1367 | + return Single.just(42); |
| 1368 | + })); |
| 1369 | + Runner runner = |
| 1370 | + Runner.builder() |
| 1371 | + .app(App.builder().name("test").rootAgent(agent).build()) |
| 1372 | + .artifactService(mockArtifactService) |
| 1373 | + .build(); |
| 1374 | + session = runner.sessionService().createSession("test", "user").blockingGet(); |
| 1375 | + // each inline data will be saved using our mock artifact service |
| 1376 | + Content content = createInlineDataContent("test data", "test data 2"); |
| 1377 | + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); |
| 1378 | + |
| 1379 | + // act |
| 1380 | + var events = runner.runAsync("user", session.id(), content, runConfig).test(); |
| 1381 | + |
| 1382 | + // assert |
| 1383 | + events.assertComplete(); |
| 1384 | + // artifacts were saved |
| 1385 | + assertThat(artifactsSavedCounter.get()).isEqualTo(2); |
| 1386 | + // agent was run |
| 1387 | + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); |
| 1388 | + } |
1334 | 1389 | } |
0 commit comments