diff --git a/contrib/spring-ai/DOCUMENT-GEMINI.md b/contrib/spring-ai/DOCUMENT-GEMINI.md deleted file mode 100644 index 393562528..000000000 --- a/contrib/spring-ai/DOCUMENT-GEMINI.md +++ /dev/null @@ -1,86 +0,0 @@ -# Documentation for the ADK Spring AI Library - -## 📖 Overview -The `google-adk-spring-ai` library provides an integration layer between the Google Agent Development Kit (ADK) and the Spring AI project. It allows developers to use Spring AI's `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` as `BaseLlm` and `Embedding` implementations within the ADK framework. - -The library handles the conversion between ADK's request/response formats and Spring AI's prompt/chat response formats. It also includes auto-configuration to automatically expose Spring AI models as ADK `SpringAI` and `SpringAIEmbedding` beans in a Spring Boot application. - -## 🛠️ Building -To include this library in your project, use the following Maven coordinates: - -```xml - - com.google.adk - google-adk-spring-ai - 0.3.1-SNAPSHOT - -``` - -You will also need to include a dependency for the specific Spring AI model you want to use, for example: -```xml - - org.springframework.ai - spring-ai-openai - -``` - -## 🚀 Usage -The primary way to use this library is through Spring Boot auto-configuration. By including the `google-adk-spring-ai` dependency and a Spring AI model dependency (e.g., `spring-ai-openai`), the library will automatically create a `SpringAI` bean. This bean can then be injected and used as a `BaseLlm` in the ADK. - -**Example `application.properties`:** -```properties -# OpenAI configuration -spring.ai.openai.api-key=${OPENAI_API_KEY} -spring.ai.openai.chat.options.model=gpt-4o-mini -spring.ai.openai.chat.options.temperature=0.7 - -# ADK Spring AI configuration -adk.spring-ai.model=gpt-4o-mini -adk.spring-ai.validation.enabled=true -``` - -**Example usage in a Spring service:** -```java -import com.google.adk.models.BaseLlm; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import reactor.core.publisher.Mono; - -@Service -public class MyAgentService { - - private final BaseLlm llm; - - @Autowired - public MyAgentService(BaseLlm llm) { - this.llm = llm; - } - - public Mono generateResponse(String prompt) { - LlmRequest request = LlmRequest.builder() - .addText(prompt) - .build(); - return Mono.from(llm.generateContent(request)) - .map(llmResponse -> llmResponse.content().get().parts().get(0).text().get()); - } -} -``` - -## 📚 API Reference -### Key Classes -- **`SpringAI`**: The main class that wraps a Spring AI `ChatModel` and/or `StreamingChatModel` and implements the ADK `BaseLlm` interface. - - **Methods**: - - `generateContent(LlmRequest llmRequest, boolean stream)`: Generates content, either streaming or non-streaming, by calling the underlying Spring AI model. It converts the ADK `LlmRequest` to a Spring AI `Prompt` and the `ChatResponse` back to an ADK `LlmResponse`. - -- **`SpringAIEmbedding`**: Wraps a Spring AI `EmbeddingModel` to be used for generating embeddings within the ADK framework. - -- **`SpringAIAutoConfiguration`**: The Spring Boot auto-configuration class that automatically discovers and configures `SpringAI` and `SpringAIEmbedding` beans based on the `ChatModel`, `StreamingChatModel`, and `EmbeddingModel` beans present in the application context. - -- **`SpringAIProperties`**: A configuration properties class (`@ConfigurationProperties("adk.spring-ai")`) that allows for customization of the Spring AI integration. - - **Properties**: - - `model`: The model name to use. - - `validation.enabled`: Whether to enable configuration validation. - - `validation.fail-fast`: Whether to fail fast on validation errors. - - `observability.enabled`: Whether to enable observability features. diff --git a/contrib/spring-ai/README.md b/contrib/spring-ai/README.md index c45f0e033..0ce7de4fe 100644 --- a/contrib/spring-ai/README.md +++ b/contrib/spring-ai/README.md @@ -18,21 +18,21 @@ To use ADK Java with the Spring AI integration in your application, add the foll com.google.adk google-adk - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT com.google.adk google-adk-spring-ai - 0.3.1-SNAPSHOT + 1.0.1-rc.1-SNAPSHOT org.springframework.ai spring-ai-bom - 1.1.0-M3 + 2.0.0-M3 pom import @@ -109,14 +109,14 @@ Add the Spring AI provider dependencies for the AI services you want to use: org.springframework.boot spring-boot-starter-parent - 3.2.0 + 4.0.2 17 - 1.1.0-M3 - 0.3.1-SNAPSHOT + 2.0.0-M3 + 1.0.1-rc.1-SNAPSHOT @@ -271,7 +271,7 @@ public class MyAdkSpringAiApplication { .anthropicApi(anthropicApi) .build(); - return new SpringAI(chatModel, "claude-3-5-sonnet-20241022"); + return new SpringAI(chatModel, "claude-sonnet-4-6"); } @Bean @@ -312,7 +312,7 @@ spring: api-key: ${ANTHROPIC_API_KEY} chat: options: - model: claude-3-5-sonnet-20241022 + model: claude-sonnet-4-6 temperature: 0.7 # ADK Spring AI Configuration @@ -365,13 +365,13 @@ The main adapter class that implements `BaseLlm` and wraps Spring AI `ChatModel` **Usage:** ```java // With ChatModel only -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6"); // With both ChatModel and StreamingChatModel -SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-20250514"); +SpringAI springAI = new SpringAI(chatModel, streamingChatModel, "claude-sonnet-4-6"); // With observability configuration -SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-20250514", observabilityConfig); +SpringAI springAI = new SpringAI(chatModel, "claude-sonnet-4-6", observabilityConfig); ``` #### 2. MessageConverter (MessageConverter.java) @@ -533,7 +533,7 @@ The library works with any Spring AI provider: - Features: Chat, streaming, function calling, embeddings 2. **Anthropic** (`spring-ai-anthropic`) - - Models: Claude 3.5 Sonnet, Claude 3 Haiku + - Models: Claude 4.x Sonnet, Claude 4.x Haiku - Features: Chat, streaming, function calling - **Note:** Requires proper function schema registration @@ -563,7 +563,7 @@ The library works with any Spring AI provider: #### Anthropic - **Function Calling:** Requires explicit schema registration using `inputSchema()` method -- **Model Names:** Use full model names like `claude-3-5-sonnet-20241022` +- **Model Names:** Use full model names like `claude-sonnet-4-6` - **API Key:** Requires `ANTHROPIC_API_KEY` environment variable #### OpenAI diff --git a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java index 95dafadb4..4012ee5d6 100644 --- a/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java +++ b/contrib/spring-ai/src/main/java/com/google/adk/models/springai/ToolConverter.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.tool.ToolCallback; @@ -172,6 +173,17 @@ public List convertToSpringAiTools(Map tools) { } catch (Exception e) { logger.error("Error serializing schema to JSON: {}", e.getMessage(), e); } + } else if (declaration.parametersJsonSchema().isPresent()) { + callbackBuilder.inputType(Map.class); + try { + String schemaJson = + new com.fasterxml.jackson.databind.ObjectMapper() + .writeValueAsString(declaration.parametersJsonSchema().get()); + callbackBuilder.inputSchema(schemaJson); + logger.debug("Set input schema JSON from parametersJsonSchema: {}", schemaJson); + } catch (Exception e) { + logger.error("Error serializing parametersJsonSchema to JSON: {}", e.getMessage(), e); + } } toolCallbacks.add(callbackBuilder.build()); @@ -187,45 +199,63 @@ public List convertToSpringAiTools(Map tools) { */ private Map processArguments( Map args, FunctionDeclaration declaration) { - // If the arguments already match the expected format, return as-is if (declaration.parameters().isPresent()) { var schema = declaration.parameters().get(); if (schema.properties().isPresent()) { - var expectedParams = schema.properties().get().keySet(); - - // Check if all expected parameters are present at the top level - boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); - if (allParamsPresent) { - return args; + return normalizeArguments(args, schema.properties().get().keySet()); + } + } else if (declaration.parametersJsonSchema().isPresent()) { + try { + @SuppressWarnings("unchecked") + Map schemaMap = + new com.fasterxml.jackson.databind.ObjectMapper() + .convertValue(declaration.parametersJsonSchema().get(), Map.class); + Object propertiesObj = schemaMap.get("properties"); + if (propertiesObj instanceof Map) { + @SuppressWarnings("unchecked") + Set expectedParams = ((Map) propertiesObj).keySet(); + return normalizeArguments(args, expectedParams); } + } catch (Exception e) { + logger.warn( + "Error processing parametersJsonSchema for argument mapping: {}", e.getMessage()); + } + } - // Check if arguments are nested under a single key (common pattern) - if (args.size() == 1) { - var singleValue = args.values().iterator().next(); - if (singleValue instanceof Map) { - @SuppressWarnings("unchecked") - Map nestedArgs = (Map) singleValue; - boolean allNestedParamsPresent = - expectedParams.stream().allMatch(nestedArgs::containsKey); - if (allNestedParamsPresent) { - return nestedArgs; - } - } - } + // If no processing worked, return original args and let ADK handle the error + return args; + } - // Check if we have a single parameter function and got a direct value - if (expectedParams.size() == 1) { - String expectedParam = expectedParams.iterator().next(); - if (args.size() == 1 && !args.containsKey(expectedParam)) { - // Try to map the single value to the expected parameter name - Object singleValue = args.values().iterator().next(); - return Map.of(expectedParam, singleValue); - } + private Map normalizeArguments( + Map args, Set expectedParams) { + // Check if all expected parameters are present at the top level + boolean allParamsPresent = expectedParams.stream().allMatch(args::containsKey); + if (allParamsPresent) { + return args; + } + + // Check if arguments are nested under a single key (common pattern) + if (args.size() == 1) { + var singleValue = args.values().iterator().next(); + if (singleValue instanceof Map) { + @SuppressWarnings("unchecked") + Map nestedArgs = (Map) singleValue; + boolean allNestedParamsPresent = expectedParams.stream().allMatch(nestedArgs::containsKey); + if (allNestedParamsPresent) { + return nestedArgs; } } } - // If no processing worked, return original args and let ADK handle the error + // Check if we have a single parameter function and got a direct value + if (expectedParams.size() == 1) { + String expectedParam = expectedParams.iterator().next(); + if (args.size() == 1 && !args.containsKey(expectedParam)) { + Object singleValue = args.values().iterator().next(); + return Map.of(expectedParam, singleValue); + } + } + return args; } diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java index 301a145e0..77b988837 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterArgumentProcessingTest.java @@ -115,6 +115,90 @@ private Map invokeProcessArguments( return (Map) method.invoke(converter, args, declaration); } + @Test + void testArgumentProcessingWithParametersJsonSchema_correctFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map correctArgs = Map.of("location", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, correctArgs, declaration); + + assertThat(processedArgs).isEqualTo(correctArgs); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_nestedFormat() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map nestedArgs = Map.of("args", Map.of("location", "San Francisco")); + Map processedArgs = + invokeProcessArguments(processArguments, converter, nestedArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_directValue() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map directValueArgs = Map.of("value", "San Francisco"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, directValueArgs, declaration); + + assertThat(processedArgs).containsEntry("location", "San Francisco"); + } + + @Test + void testArgumentProcessingWithParametersJsonSchema_noMatch() throws Exception { + ToolConverter converter = new ToolConverter(); + Method processArguments = getProcessArgumentsMethod(converter); + + com.google.genai.types.FunctionDeclaration declaration = + com.google.genai.types.FunctionDeclaration.builder() + .name("getWeatherInfo") + .description("Get weather information") + .parametersJsonSchema( + Map.of( + "type", "object", "properties", Map.of("location", Map.of("type", "string")))) + .build(); + + Map wrongArgs = Map.of("city", "San Francisco", "country", "USA"); + Map processedArgs = + invokeProcessArguments(processArguments, converter, wrongArgs, declaration); + + assertThat(processedArgs).isEqualTo(wrongArgs); + } + public static class WeatherTools { public static Map getWeatherInfo(String location) { return Map.of( diff --git a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java index 231c8e1fe..1f3044159 100644 --- a/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java +++ b/contrib/spring-ai/src/test/java/com/google/adk/models/springai/ToolConverterTest.java @@ -26,6 +26,7 @@ import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.ToolCallback; class ToolConverterTest { @@ -178,4 +179,37 @@ void testToolMetadata() { assertThat(metadata.getDescription()).isEqualTo("Test description"); assertThat(metadata.getDeclaration()).isEqualTo(function); } + + @Test + void testConvertToSpringAiToolsWithParametersJsonSchema() { + Map jsonSchema = + Map.of( + "type", + "object", + "properties", + Map.of("location", Map.of("type", "string", "description", "City name")), + "required", + List.of("location")); + + FunctionDeclaration function = + FunctionDeclaration.builder() + .name("get_weather") + .description("Get weather for a location") + .parametersJsonSchema(jsonSchema) + .build(); + + BaseTool testTool = + new BaseTool("get_weather", "Get weather for a location") { + @Override + public Optional declaration() { + return Optional.of(function); + } + }; + + Map tools = Map.of("get_weather", testTool); + List toolCallbacks = toolConverter.convertToSpringAiTools(tools); + + assertThat(toolCallbacks).hasSize(1); + assertThat(toolCallbacks.get(0).getToolDefinition().name()).isEqualTo("get_weather"); + } }