From 9d6cc80660d81fc217d058b7c8edb1ff906e2c30 Mon Sep 17 00:00:00 2001 From: Guillaume Laforge Date: Thu, 2 Apr 2026 23:59:45 -0700 Subject: [PATCH] feat: add support for Gemma models in LlmRegistry - Registered 'gemma-.*' models to use the Gemini builder. - Added GemmaTest.java to verify valid and invalid model names. - Restored main BUILD file and added GemmaTest target with env setup. Fixes b/499032158 PiperOrigin-RevId: 893918083 --- .../com/google/adk/models/LlmRegistry.java | 13 +++++++ .../java/com/google/adk/models/GemmaTest.java | 39 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 core/src/test/java/com/google/adk/models/GemmaTest.java diff --git a/core/src/main/java/com/google/adk/models/LlmRegistry.java b/core/src/main/java/com/google/adk/models/LlmRegistry.java index a73d89430..acc038695 100644 --- a/core/src/main/java/com/google/adk/models/LlmRegistry.java +++ b/core/src/main/java/com/google/adk/models/LlmRegistry.java @@ -16,6 +16,7 @@ package com.google.adk.models; +import com.google.common.annotations.VisibleForTesting; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -38,6 +39,7 @@ public interface LlmFactory { static { registerLlm("gemini-.*", modelName -> Gemini.builder().modelName(modelName).build()); registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build()); + registerLlm("gemma-.*", modelName -> Gemini.builder().modelName(modelName).build()); } /** @@ -50,6 +52,17 @@ public static void registerLlm(String modelNamePattern, LlmFactory factory) { llmFactories.put(modelNamePattern, factory); } + /** + * Checks if the given model name matches any of the registered LLM factory patterns. + * + * @param modelName The model name to check. + * @return {@code true} if the model name matches at least one pattern, {@code false} otherwise. + */ + @VisibleForTesting + static boolean matchesAnyPattern(String modelName) { + return llmFactories.keySet().stream().anyMatch(modelName::matches); + } + /** * Returns an LLM instance for the given model name, using a cached or new factory-created * instance. diff --git a/core/src/test/java/com/google/adk/models/GemmaTest.java b/core/src/test/java/com/google/adk/models/GemmaTest.java new file mode 100644 index 000000000..fe6315dcc --- /dev/null +++ b/core/src/test/java/com/google/adk/models/GemmaTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GemmaTest { + + @Test + public void getLlm_withValidGemmaModels_succeeds() { + assertThat(LlmRegistry.matchesAnyPattern("gemma-4-26b-a4b-it")).isTrue(); + assertThat(LlmRegistry.matchesAnyPattern("gemma-4-31b-it")).isTrue(); + } + + @Test + public void getLlm_withInvalidGemmaModels_throwsException() { + assertThat(LlmRegistry.matchesAnyPattern("not-a-gemma")).isFalse(); + assertThat(LlmRegistry.matchesAnyPattern("gemma")).isFalse(); + } +}