|
24 | 24 | import static com.google.adk.testing.TestUtils.createTextLlmResponse; |
25 | 25 | import static com.google.adk.testing.TestUtils.simplifyEvents; |
26 | 26 | import static com.google.common.truth.Truth.assertThat; |
| 27 | +import static java.nio.charset.StandardCharsets.UTF_8; |
| 28 | +import static java.util.Arrays.stream; |
27 | 29 | import static org.mockito.ArgumentMatchers.any; |
28 | 30 | import static org.mockito.Mockito.CALLS_REAL_METHODS; |
29 | 31 | import static org.mockito.Mockito.mock; |
|
36 | 38 | import com.google.adk.agents.LlmAgent; |
37 | 39 | import com.google.adk.agents.RunConfig; |
38 | 40 | import com.google.adk.apps.App; |
| 41 | +import com.google.adk.artifacts.BaseArtifactService; |
39 | 42 | import com.google.adk.events.Event; |
40 | 43 | import com.google.adk.flows.llmflows.Functions; |
41 | 44 | import com.google.adk.models.LlmResponse; |
|
62 | 65 | import io.reactivex.rxjava3.core.Completable; |
63 | 66 | import io.reactivex.rxjava3.core.Flowable; |
64 | 67 | import io.reactivex.rxjava3.core.Maybe; |
| 68 | +import io.reactivex.rxjava3.core.Single; |
65 | 69 | import io.reactivex.rxjava3.subscribers.TestSubscriber; |
66 | 70 | import java.util.List; |
67 | 71 | import java.util.Objects; |
68 | 72 | import java.util.Optional; |
69 | 73 | import java.util.UUID; |
70 | 74 | import java.util.concurrent.ConcurrentHashMap; |
| 75 | +import java.util.concurrent.atomic.AtomicInteger; |
71 | 76 | import org.junit.After; |
72 | 77 | import org.junit.Before; |
73 | 78 | import org.junit.Rule; |
74 | 79 | import org.junit.Test; |
75 | 80 | import org.junit.runner.RunWith; |
76 | 81 | import org.junit.runners.JUnit4; |
77 | 82 | import org.mockito.ArgumentCaptor; |
| 83 | +import org.mockito.Mockito; |
78 | 84 |
|
79 | 85 | @RunWith(JUnit4.class) |
80 | 86 | public final class RunnerTest { |
@@ -846,6 +852,19 @@ private Content createContent(String text) { |
846 | 852 | return Content.builder().parts(Part.builder().text(text).build()).build(); |
847 | 853 | } |
848 | 854 |
|
| 855 | + private static Content createInlineDataContent(byte[]... data) { |
| 856 | + return Content.builder() |
| 857 | + .parts( |
| 858 | + stream(data) |
| 859 | + .map(dataBytes -> Part.fromBytes(dataBytes, "example/octet-stream")) |
| 860 | + .toArray(Part[]::new)) |
| 861 | + .build(); |
| 862 | + } |
| 863 | + |
| 864 | + private static Content createInlineDataContent(String... data) { |
| 865 | + return createInlineDataContent(stream(data).map(d -> d.getBytes(UTF_8)).toArray(byte[][]::new)); |
| 866 | + } |
| 867 | + |
849 | 868 | @Test |
850 | 869 | public void runAsync_createsInvocationSpan() { |
851 | 870 | var unused = |
@@ -1203,4 +1222,40 @@ public static ImmutableMap<String, Object> echoTool(String message) { |
1203 | 1222 | return ImmutableMap.of("message", message); |
1204 | 1223 | } |
1205 | 1224 | } |
| 1225 | + |
| 1226 | + @Test |
| 1227 | + public void runner_executesSaveArtifactFlow() { |
| 1228 | + // arrange |
| 1229 | + final AtomicInteger artifactsSavedCounter = new AtomicInteger(); |
| 1230 | + BaseArtifactService mockArtifactService = Mockito.mock(BaseArtifactService.class); |
| 1231 | + when(mockArtifactService.saveArtifact(any(), any(), any(), any(), any())) |
| 1232 | + .thenReturn( |
| 1233 | + Single.defer( |
| 1234 | + () -> { |
| 1235 | + // we want to assert not only that the saveArtifact method was |
| 1236 | + // called, but also that the flow that it returned was run, so |
| 1237 | + // we need to record the call in a counter |
| 1238 | + artifactsSavedCounter.incrementAndGet(); |
| 1239 | + return Single.just(42); |
| 1240 | + })); |
| 1241 | + Runner runner = |
| 1242 | + Runner.builder() |
| 1243 | + .app(App.builder().name("test").rootAgent(agent).build()) |
| 1244 | + .artifactService(mockArtifactService) |
| 1245 | + .build(); |
| 1246 | + session = runner.sessionService().createSession("test", "user").blockingGet(); |
| 1247 | + // each inline data will be saved using our mock artifact service |
| 1248 | + Content content = createInlineDataContent("test data", "test data 2"); |
| 1249 | + RunConfig runConfig = RunConfig.builder().setSaveInputBlobsAsArtifacts(true).build(); |
| 1250 | + |
| 1251 | + // act |
| 1252 | + var events = runner.runAsync("user", session.id(), content, runConfig).test(); |
| 1253 | + |
| 1254 | + // assert |
| 1255 | + events.assertComplete(); |
| 1256 | + // artifacts where saved |
| 1257 | + assertThat(artifactsSavedCounter.get()).isEqualTo(2); |
| 1258 | + // agent was run |
| 1259 | + assertThat(simplifyEvents(events.values())).containsExactly("test agent: from llm"); |
| 1260 | + } |
1206 | 1261 | } |
0 commit comments