Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@

package com.google.adk.tools;

import com.google.adk.agents.LlmAgent;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.LlmRequest;
import com.google.adk.utils.ModelNameUtils;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Tool;
import com.google.genai.types.ToolCodeExecution;
import io.reactivex.rxjava3.core.Completable;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A built-in code execution tool that is automatically invoked by Gemini 2 models.
Expand All @@ -32,6 +38,7 @@
*/
public final class BuiltInCodeExecutionTool extends BaseTool {
public static final BuiltInCodeExecutionTool INSTANCE = new BuiltInCodeExecutionTool();
private static final Logger LOG = LoggerFactory.getLogger(BuiltInCodeExecutionTool.class);

public BuiltInCodeExecutionTool() {
super("code_execution", "code_execution");
Expand All @@ -41,10 +48,28 @@ public BuiltInCodeExecutionTool() {
public Completable processLlmRequest(
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {

String model = llmRequestBuilder.build().model().get();
if (model.isEmpty() || !model.startsWith("gemini-2")) {
return Completable.error(
new IllegalArgumentException("Code execution tool is not supported for model " + model));
Optional<BaseLlm> model =
Optional.ofNullable(toolContext)
.flatMap(tCtx -> Optional.ofNullable(tCtx.invocationContext()))
.flatMap(
iCtx -> {
if (iCtx.agent() instanceof LlmAgent llmAgent) {
return Optional.of(llmAgent);
} else {
return Optional.empty();
}
})
.flatMap(llmAgent -> llmAgent.resolvedModel().model());

String modelName = llmRequestBuilder.build().model().get();
if (!ModelNameUtils.isGeminiModel(modelName)
|| model.filter(ModelNameUtils::isInstanceOfGemini).isEmpty()) {
// model name is not a gemini model, or the model isn't an instance of Gemini class (eg.
// LangChain case).
LOG.warn(
"Code execution tool is not supported for model: {} ({}).",
modelName,
model.map(Object::getClass).map(Class::toString).orElse("<unknown class>"));
}
GenerateContentConfig.Builder configBuilder =
llmRequestBuilder
Expand Down
31 changes: 31 additions & 0 deletions core/src/main/java/com/google/adk/utils/ModelNameUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@

package com.google.adk.utils;

import com.google.common.base.Strings;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public final class ModelNameUtils {
private static final String GEMINI_PREFIX = "gemini-";
private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*");
private static final String GEMINI_CLASS = "com.google.adk.models.Gemini";
private static final Pattern PATH_PATTERN =
Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$");
private static final Pattern APIGEE_PATTERN =
Pattern.compile("^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$");

public static boolean isGeminiModel(String modelString) {
return extractModelName(Strings.nullToEmpty(modelString)).startsWith(GEMINI_PREFIX);
}

public static boolean isGemini2Model(String modelString) {
if (modelString == null) {
return false;
Expand All @@ -34,6 +42,29 @@ public static boolean isGemini2Model(String modelString) {
return GEMINI_2_PATTERN.matcher(modelName).matches();
}

/**
* Checks whether an object is an instance of {@link com.google.adk.models.Gemini}, by searching
* through its class hierarchy for a class whose name equals the hardcoded String name of Gemini
* class.
*
* <p>This method can be used where the "real" instanceof check is not possible because the Gemini
* type is not known at compile time.
*
* @param o The object to check.
* @return true if object's class is {@link com.google.adk.models.Gemini}, false otherwise.
*/
public static boolean isInstanceOfGemini(Object o) {
if (o == null) {
return false;
}
for (Class<?> clazz = o.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
if (Objects.equals(clazz.getName(), GEMINI_CLASS)) {
return true;
}
}
return false;
}

/**
* Extract the actual model name from either simple or path-based format.
*
Expand Down
45 changes: 43 additions & 2 deletions core/src/test/java/com/google/adk/tools/BaseToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.models.Gemini;
import com.google.adk.models.LlmRequest;
import com.google.adk.sessions.InMemorySessionService;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.FunctionDeclaration;
import com.google.genai.types.GenerateContentConfig;
Expand Down Expand Up @@ -171,13 +175,27 @@ public void processLlmRequestWithUrlContextToolAddsToolToConfig() {
Tool.builder().urlContext(UrlContext.builder().build()).build());
}

private static InvocationContext.Builder testInvocationContext() {
InvocationContext.Builder builder = InvocationContext.builder();
builder.agent(testAgent().build());
InMemorySessionService inMemorySessionService = new InMemorySessionService();
builder.sessionService(inMemorySessionService);
builder.session(inMemorySessionService.createSession("test-app", "test-user-id").blockingGet());
return builder;
}

private static LlmAgent.Builder testAgent() {
return LlmAgent.builder().name("test-agent");
}

@Test
public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() {
public void
processLlmRequestWithBuiltInCodeExecutionToolAndNonGeminiModelAndNullContextAddsToolToConfig() {
BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool();
LlmRequest llmRequest =
LlmRequest.builder()
.config(GenerateContentConfig.builder().build())
.model("gemini-2")
.model("text-bison")
.build();
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
Completable unused =
Expand All @@ -189,6 +207,29 @@ public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() {
.containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build());
}

@Test
public void processLlmRequestWithBuiltInCodeExecutionToolAndGemini2ModelAddsToolToConfig() {
BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool();
LlmRequest llmRequest =
LlmRequest.builder()
.config(GenerateContentConfig.builder().build())
.model("gemini-2")
.build();
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
ToolContext toolContext =
ToolContext.builder(
testInvocationContext()
.agent(testAgent().model(new Gemini("gemini-2", "")).build())
.build())
.build();
Completable unused = builtInCodeExecutionTool.processLlmRequest(llmRequestBuilder, toolContext);
LlmRequest updatedLlmRequest = llmRequestBuilder.build();
assertThat(updatedLlmRequest.config()).isPresent();
assertThat(updatedLlmRequest.config().get().tools()).isPresent();
assertThat(updatedLlmRequest.config().get().tools().get())
.containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build());
}

@Test
public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() {
GoogleMapsTool googleMapsTool = new GoogleMapsTool();
Expand Down
100 changes: 100 additions & 0 deletions core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.adk.models.Gemini;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -69,4 +70,103 @@ public void isGemini2Model_withApigeeProviderV1BetaGemini2Model_returnsTrue() {
public void isGemini2Model_withNullModel_returnsFalse() {
assertThat(ModelNameUtils.isGemini2Model(null)).isFalse();
}

@Test
public void isGeminiModel_withGeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withNonGeminiModel_returnsFalse() {
assertThat(ModelNameUtils.isGeminiModel("text-bison")).isFalse();
}

@Test
public void isGeminiModel_withPathBasedGeminiModel_returnsTrue() {
assertThat(
ModelNameUtils.isGeminiModel(
"projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro"))
.isTrue();
}

@Test
public void isGeminiModel_withPathBasedNonGeminiModel_returnsFalse() {
assertThat(
ModelNameUtils.isGeminiModel(
"projects/test-project/locations/us-central1/publishers/google/models/text-bison"))
.isFalse();
}

@Test
public void isGeminiModel_withApigeeGeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withApigeeV1GeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/v1/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withApigeeProviderGeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withApigeeProviderVertexGeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withApigeeProviderV1GeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/v1/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withApigeeProviderV1BetaGeminiModel_returnsTrue() {
assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/v1beta/gemini-1.5-flash")).isTrue();
}

@Test
public void isGeminiModel_withNullModel_returnsFalse() {
assertThat(ModelNameUtils.isGeminiModel(null)).isFalse();
}

@Test
public void isGeminiModel_withEmptyModel_returnsFalse() {
assertThat(ModelNameUtils.isGeminiModel("")).isFalse();
}

@Test
public void isInstanceOfGemini_withGeminiInstance_returnsTrue() {
assertThat(ModelNameUtils.isInstanceOfGemini(new Gemini("", ""))).isTrue();
}

@Test
public void isInstanceOfGemini_withNonGeminiInstance_returnsFalse() {
assertThat(ModelNameUtils.isInstanceOfGemini(new Object())).isFalse();
}

@Test
public void isInstanceOfGemini_withNullInstance_returnsFalse() {
assertThat(ModelNameUtils.isInstanceOfGemini(null)).isFalse();
}

private static class GeminiSubclass extends Gemini {
GeminiSubclass() {
super("test-model", "test-api-key");
}
}

private static class GeminiSubclassSubclass extends GeminiSubclass {}

@Test
public void isInstanceOfGemini_withGeminiSubclassInstance_returnsTrue() {
assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclass())).isTrue();
}

@Test
public void isInstanceOfGemini_withSubclassOfGeminiSubclassInstance_returnsTrue() {
assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclassSubclass())).isTrue();
}
}