Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 0 additions & 86 deletions contrib/spring-ai/DOCUMENT-GEMINI.md

This file was deleted.

26 changes: 13 additions & 13 deletions contrib/spring-ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ To use ADK Java with the Spring AI integration in your application, add the foll
<dependency>
<groupId>com.google.adk</groupId>
<artifactId>google-adk</artifactId>
<version>0.3.1-SNAPSHOT</version>
<version>1.0.1-rc.1-SNAPSHOT</version>
</dependency>

<!-- ADK Spring AI Integration -->
<dependency>
<groupId>com.google.adk</groupId>
<artifactId>google-adk-spring-ai</artifactId>
<version>0.3.1-SNAPSHOT</version>
<version>1.0.1-rc.1-SNAPSHOT</version>
</dependency>

<!-- Spring AI BOM for version management -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>1.1.0-M3</version>
<version>2.0.0-M3</version>
<type>pom</type>
<scope>import</scope>
</dependency>
Expand Down Expand Up @@ -109,14 +109,14 @@ Add the Spring AI provider dependencies for the AI services you want to use:
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.0</version>
<version>4.0.2</version>
<relativePath/>
</parent>

<properties>
<java.version>17</java.version>
<spring-ai.version>1.1.0-M3</spring-ai.version>
<adk.version>0.3.1-SNAPSHOT</adk.version>
<spring-ai.version>2.0.0-M3</spring-ai.version>
<adk.version>1.0.1-rc.1-SNAPSHOT</adk.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -271,7 +271,7 @@ public class MyAdkSpringAiApplication {
.anthropicApi(anthropicApi)
.build();

return new SpringAI(chatModel, "claude-3-5-sonnet-20241022");
return new SpringAI(chatModel, "claude-sonnet-4-6");
}

@Bean
Expand Down Expand Up @@ -312,7 +312,7 @@ spring:
api-key: ${ANTHROPIC_API_KEY}
chat:
options:
model: claude-3-5-sonnet-20241022
model: claude-sonnet-4-6
temperature: 0.7

# ADK Spring AI Configuration
Expand Down Expand Up @@ -365,13 +365,13 @@ The main adapter class that implements `BaseLlm` and wraps Spring AI `ChatModel`
**Usage:**
```java
// With ChatModel only
SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514");
SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6");

// With both ChatModel and StreamingChatModel
SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-20250514");
SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-6");

// With observability configuration
SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514", observabilityConfig);
SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6", observabilityConfig);
```

#### 2. MessageConverter (MessageConverter.java)
Expand Down Expand Up @@ -533,7 +533,7 @@ The library works with any Spring AI provider:
- Features: Chat, streaming, function calling, embeddings

2. **Anthropic** (`spring-ai-anthropic`)
- Models: Claude 3.5 Sonnet, Claude 3 Haiku
- Models: Claude 4.x Sonnet, Claude 4.x Haiku
- Features: Chat, streaming, function calling
- **Note:** Requires proper function schema registration

Expand Down Expand Up @@ -563,7 +563,7 @@ The library works with any Spring AI provider:

#### Anthropic
- **Function Calling:** Requires explicit schema registration using `inputSchema()` method
- **Model Names:** Use full model names like `claude-3-5-sonnet-20241022`
- **Model Names:** Use full model names like `claude-sonnet-4-6`
- **API Key:** Requires `ANTHROPIC_API_KEY` environment variable

#### OpenAI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.tool.ToolCallback;
Expand Down Expand Up @@ -172,6 +173,17 @@ public List<ToolCallback> convertToSpringAiTools(Map<String, BaseTool> tools) {
} catch (Exception e) {
logger.error("Error serializing schema to JSON: {}", e.getMessage(), e);
}
} else if (declaration.parametersJsonSchema().isPresent()) {
callbackBuilder.inputType(Map.class);
try {
String schemaJson =
new com.fasterxml.jackson.databind.ObjectMapper()
.writeValueAsString(declaration.parametersJsonSchema().get());
callbackBuilder.inputSchema(schemaJson);
logger.debug("Set input schema JSON from parametersJsonSchema: {}", schemaJson);
} catch (Exception e) {
logger.error("Error serializing parametersJsonSchema to JSON: {}", e.getMessage(), e);
}
}

toolCallbacks.add(callbackBuilder.build());
Expand All @@ -187,45 +199,63 @@ public List<ToolCallback> convertToSpringAiTools(Map<String, BaseTool> tools) {
*/
private Map<String, Object> processArguments(
Map<String, Object> args, FunctionDeclaration declaration) {
// If the arguments already match the expected format, return as-is
if (declaration.parameters().isPresent()) {
var schema = declaration.parameters().get();
if (schema.properties().isPresent()) {
var expectedParams = schema.properties().get().keySet();

// Check if all expected parameters are present at the top level
boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey);
if (allParamsPresent) {
return args;
return normalizeArguments(args, schema.properties().get().keySet());
}
} else if (declaration.parametersJsonSchema().isPresent()) {
try {
@SuppressWarnings("unchecked")
Map<String, Object> schemaMap =
new com.fasterxml.jackson.databind.ObjectMapper()
.convertValue(declaration.parametersJsonSchema().get(), Map.class);
Object propertiesObj = schemaMap.get("properties");
if (propertiesObj instanceof Map) {
@SuppressWarnings("unchecked")
Set<String> expectedParams = ((Map<String, Object>) propertiesObj).keySet();
return normalizeArguments(args, expectedParams);
}
} catch (Exception e) {
logger.warn(
"Error processing parametersJsonSchema for argument mapping: {}", e.getMessage());
}
}

// Check if arguments are nested under a single key (common pattern)
if (args.size() == 1) {
var singleValue = args.values().iterator().next();
if (singleValue instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, Object> nestedArgs = (Map<String, Object>) singleValue;
boolean allNestedParamsPresent =
expectedParams.stream().allMatch(nestedArgs::containsKey);
if (allNestedParamsPresent) {
return nestedArgs;
}
}
}
// If no processing worked, return original args and let ADK handle the error
return args;
}

// Check if we have a single parameter function and got a direct value
if (expectedParams.size() == 1) {
String expectedParam = expectedParams.iterator().next();
if (args.size() == 1 && !args.containsKey(expectedParam)) {
// Try to map the single value to the expected parameter name
Object singleValue = args.values().iterator().next();
return Map.of(expectedParam, singleValue);
}
private Map<String, Object> normalizeArguments(
Map<String, Object> args, Set<String> expectedParams) {
// Check if all expected parameters are present at the top level
boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey);
if (allParamsPresent) {
return args;
}

// Check if arguments are nested under a single key (common pattern)
if (args.size() == 1) {
var singleValue = args.values().iterator().next();
if (singleValue instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, Object> nestedArgs = (Map<String, Object>) singleValue;
boolean allNestedParamsPresent = expectedParams.stream().allMatch(nestedArgs::containsKey);
if (allNestedParamsPresent) {
return nestedArgs;
}
}
}

// If no processing worked, return original args and let ADK handle the error
// Check if we have a single parameter function and got a direct value
if (expectedParams.size() == 1) {
String expectedParam = expectedParams.iterator().next();
if (args.size() == 1 && !args.containsKey(expectedParam)) {
Object singleValue = args.values().iterator().next();
return Map.of(expectedParam, singleValue);
}
}

return args;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,90 @@ private Map<String, Object> invokeProcessArguments(
return (Map<String, Object>) method.invoke(converter, args, declaration);
}

@Test
void testArgumentProcessingWithParametersJsonSchema_correctFormat() throws Exception {
ToolConverter converter = new ToolConverter();
Method processArguments = getProcessArgumentsMethod(converter);

com.google.genai.types.FunctionDeclaration declaration =
com.google.genai.types.FunctionDeclaration.builder()
.name("getWeatherInfo")
.description("Get weather information")
.parametersJsonSchema(
Map.of(
"type", "object", "properties", Map.of("location", Map.of("type", "string"))))
.build();

Map<String, Object> correctArgs = Map.of("location", "San Francisco");
Map<String, Object> processedArgs =
invokeProcessArguments(processArguments, converter, correctArgs, declaration);

assertThat(processedArgs).isEqualTo(correctArgs);
}

@Test
void testArgumentProcessingWithParametersJsonSchema_nestedFormat() throws Exception {
ToolConverter converter = new ToolConverter();
Method processArguments = getProcessArgumentsMethod(converter);

com.google.genai.types.FunctionDeclaration declaration =
com.google.genai.types.FunctionDeclaration.builder()
.name("getWeatherInfo")
.description("Get weather information")
.parametersJsonSchema(
Map.of(
"type", "object", "properties", Map.of("location", Map.of("type", "string"))))
.build();

Map<String, Object> nestedArgs = Map.of("args", Map.of("location", "San Francisco"));
Map<String, Object> processedArgs =
invokeProcessArguments(processArguments, converter, nestedArgs, declaration);

assertThat(processedArgs).containsEntry("location", "San Francisco");
}

@Test
void testArgumentProcessingWithParametersJsonSchema_directValue() throws Exception {
ToolConverter converter = new ToolConverter();
Method processArguments = getProcessArgumentsMethod(converter);

com.google.genai.types.FunctionDeclaration declaration =
com.google.genai.types.FunctionDeclaration.builder()
.name("getWeatherInfo")
.description("Get weather information")
.parametersJsonSchema(
Map.of(
"type", "object", "properties", Map.of("location", Map.of("type", "string"))))
.build();

Map<String, Object> directValueArgs = Map.of("value", "San Francisco");
Map<String, Object> processedArgs =
invokeProcessArguments(processArguments, converter, directValueArgs, declaration);

assertThat(processedArgs).containsEntry("location", "San Francisco");
}

@Test
void testArgumentProcessingWithParametersJsonSchema_noMatch() throws Exception {
ToolConverter converter = new ToolConverter();
Method processArguments = getProcessArgumentsMethod(converter);

com.google.genai.types.FunctionDeclaration declaration =
com.google.genai.types.FunctionDeclaration.builder()
.name("getWeatherInfo")
.description("Get weather information")
.parametersJsonSchema(
Map.of(
"type", "object", "properties", Map.of("location", Map.of("type", "string"))))
.build();

Map<String, Object> wrongArgs = Map.of("city", "San Francisco", "country", "USA");
Map<String, Object> processedArgs =
invokeProcessArguments(processArguments, converter, wrongArgs, declaration);

assertThat(processedArgs).isEqualTo(wrongArgs);
}

public static class WeatherTools {
public static Map<String, Object> getWeatherInfo(String location) {
return Map.of(
Expand Down
Loading