diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GroundingTests.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GroundingTests.kt index bfd0f8f67d7..9e19b86dc89 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GroundingTests.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GroundingTests.kt @@ -27,11 +27,13 @@ import kotlinx.coroutines.runBlocking import org.junit.Test class GroundingTests { + private val validator = TypesValidator() @Test fun groundingTests_canRecognizeAreas(): Unit = runBlocking { val model = setupModel(config = ToolConfig()) val response = model.generateContent("Where is a good place to grab a coffee near Alameda, CA?") + validator.validateResponse(response) response.candidates.isEmpty() shouldBe false response.candidates[0].groundingMetadata?.groundingChunks?.any { it.maps != null } shouldBe true @@ -51,11 +53,27 @@ class GroundingTests { ) ) val response = model.generateContent("Find bookstores in my area.") + validator.validateResponse(response) response.candidates.isEmpty() shouldBe false response.candidates[0].groundingMetadata?.groundingChunks?.any { it.maps != null } shouldBe true } + @Test + fun groundingTests_canSearchWeather(): Unit = runBlocking { + val model = + FirebaseAI.getInstance(app(), GenerativeBackend.vertexAI()) + .generativeModel( + modelName = "gemini-2.5-flash", + tools = listOf(Tool.googleSearch()), + ) + val response = model.generateContent("What temperature is it today in Cancún?") + // Grounding indices should be correct + validator.validateResponse(response) + // Search grounding should be used + response.candidates.any { it.groundingMetadata != null } shouldBe true + } + companion object { @JvmStatic diff --git a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt index 93cee71d662..4aa2bd34fb8 100644 --- a/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt +++ b/ai-logic/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt @@ -18,7 +18,11 @@ package com.google.firebase.ai import com.google.firebase.ai.type.Candidate import com.google.firebase.ai.type.Content import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.GroundingSupport import com.google.firebase.ai.type.TextPart +import io.kotest.matchers.ints.shouldBeGreaterThanOrEqual +import io.kotest.matchers.ints.shouldBeLessThan +import io.kotest.matchers.ints.shouldBeLessThanOrEqual import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe @@ -38,6 +42,24 @@ class TypesValidator { fun validateCandidate(candidate: Candidate) { validateContent(candidate.content) + if (candidate.groundingMetadata != null) { + for (grounding in candidate.groundingMetadata.groundingSupports) { + validateGroundingSupport(candidate, grounding) + } + } + } + + fun validateGroundingSupport(candidate: Candidate, grounding: GroundingSupport) { + val segment = grounding.segment + segment.partIndex shouldBeGreaterThanOrEqual 0 + segment.partIndex shouldBeLessThan candidate.content.parts.size + val part = candidate.content.parts[segment.partIndex] + part::class shouldBe TextPart::class + val text = (part as TextPart).text + segment.startIndex shouldBeGreaterThanOrEqual 0 + segment.startIndex shouldBeLessThanOrEqual segment.endIndex + segment.endIndex shouldBeLessThanOrEqual text.length + segment.text shouldBe text.substring(segment.startIndex, segment.endIndex) } fun validateContent(content: Content) { diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt index 29b3062927e..d4059cc66f5 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Candidate.kt @@ -65,14 +65,15 @@ internal constructor( @OptIn(PublicPreviewAPI::class) internal fun toPublic(): Candidate { + val content = this.content?.toPublic() ?: content("model") {} val safetyRatings = safetyRatings?.mapNotNull { it.toPublic() }.orEmpty() - val citations = citationMetadata?.toPublic() + val citations = citationMetadata?.toPublic(content) val finishReason = finishReason?.toPublic() - val groundingMetadata = groundingMetadata?.toPublic() + val groundingMetadata = groundingMetadata?.toPublic(content) val urlContextMetadata = urlContextMetadata?.toPublic() return Candidate( - this.content?.toPublic() ?: content("model") {}, + content, safetyRatings, citations, finishReason, @@ -163,7 +164,8 @@ public class CitationMetadata internal constructor(public val citations: List) { - internal fun toPublic() = CitationMetadata(citationSources.map { it.toPublic() }) + internal fun toPublic(content: Content) = + CitationMetadata(citationSources.map { it.toPublic(content) }) } } @@ -203,7 +205,7 @@ internal constructor( val publicationDate: Date? = null, ) { - internal fun toPublic(): Citation { + internal fun toPublic(content: Content): Citation { val publicationDateAsCalendar = publicationDate?.let { val calendar = Calendar.getInstance() @@ -220,8 +222,8 @@ internal constructor( } return Citation( title = title, - startIndex = startIndex, - endIndex = endIndex, + startIndex = convertUtf8IndexToUtf16(content, startIndex), + endIndex = convertUtf8IndexToUtf16(content, endIndex), uri = uri, license = license, publicationDate = publicationDateAsCalendar @@ -371,14 +373,15 @@ public class GroundingMetadata( val groundingChunks: List?, val groundingSupports: List?, ) { - internal fun toPublic() = + internal fun toPublic(content: Content) = GroundingMetadata( webSearchQueries = webSearchQueries.orEmpty(), searchEntryPoint = searchEntryPoint?.toPublic(), retrievalQueries = retrievalQueries.orEmpty(), - groundingAttribution = groundingAttribution?.map { it.toPublic() }.orEmpty(), + groundingAttribution = groundingAttribution?.map { it.toPublic(content) }.orEmpty(), groundingChunks = groundingChunks?.map { it.toPublic() }.orEmpty(), - groundingSupports = groundingSupports?.map { it.toPublic() }.orEmpty().filterNotNull() + groundingSupports = + groundingSupports?.map { it.toPublic(content) }.orEmpty().filterNotNull() ) } } @@ -491,12 +494,12 @@ public class GroundingSupport( val segment: Segment.Internal?, val groundingChunkIndices: List?, ) { - internal fun toPublic(): GroundingSupport? { + internal fun toPublic(content: Content): GroundingSupport? { if (segment == null) { return null } return GroundingSupport( - segment = segment.toPublic(), + segment = segment.toPublic(content), groundingChunkIndices = groundingChunkIndices.orEmpty(), ) } @@ -514,8 +517,8 @@ public class GroundingAttribution( val segment: Segment.Internal, val confidenceScore: Float?, ) { - internal fun toPublic() = - GroundingAttribution(segment = segment.toPublic(), confidenceScore = confidenceScore) + internal fun toPublic(content: Content) = + GroundingAttribution(segment = segment.toPublic(content), confidenceScore = confidenceScore) } } @@ -546,13 +549,17 @@ public class Segment( val partIndex: Int?, val text: String?, ) { - internal fun toPublic() = - Segment( - startIndex = startIndex ?: 0, - endIndex = endIndex ?: 0, - partIndex = partIndex ?: 0, + internal fun toPublic(content: Content): Segment { + val partIndex = this.partIndex ?: 0 + val part = content.parts.getOrNull(partIndex) + val fakeContent = Content(content.role, if (part == null) emptyList() else listOf(part)) + return Segment( + startIndex = convertUtf8IndexToUtf16(fakeContent, startIndex ?: 0), + endIndex = convertUtf8IndexToUtf16(fakeContent, endIndex ?: 0), + partIndex = partIndex, text = text ?: "" ) + } } } @@ -635,3 +642,36 @@ private constructor(public val name: String, public val ordinal: Int) { @JvmField public val UNSAFE: UrlRetrievalStatus = UrlRetrievalStatus("UNSAFE", 4) } } + +internal fun convertUtf8IndexToUtf16(content: Content, originalIndex: Int): Int { + if (originalIndex == 0) { + return 0 + } + var sumIndex = 0 + var progress = 0 + for (part in content.parts) { + val text = part.asTextOrNull() ?: "" + var i = 0 + while (i < text.length) { + val c = text[i].code + progress += + when { + c < 0x80 -> 1 // ASCII + c < 0x800 -> 2 // Two-byte codepoint + c in 0xD800..0xDBFF -> 4 // High surrogate character + else -> 3 + } + if (c in 0xD800..0xDBFF && i + 1 < text.length) { + i++ // Skip the low surrogate + } + i++ + if (progress >= originalIndex) { + return sumIndex + i + } + } + sumIndex += text.length + } + throw StringIndexOutOfBoundsException( + "Desired index $originalIndex is higher than content size $progress" + ) +} diff --git a/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt new file mode 100644 index 00000000000..6b65b0adf4f --- /dev/null +++ b/ai-logic/firebase-ai/src/test/java/com/google/firebase/ai/EncodingTests.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.type.Candidate +import com.google.firebase.ai.type.Citation +import com.google.firebase.ai.type.CitationMetadata +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.TextPart +import com.google.firebase.ai.type.content +import com.google.firebase.ai.type.convertUtf8IndexToUtf16 +import io.kotest.matchers.shouldBe +import kotlinx.serialization.ExperimentalSerializationApi +import org.junit.Test + +@OptIn(PublicPreviewAPI::class, ExperimentalSerializationApi::class) +class EncodingTests { + val testStrings = + listOf( + "hello world", + "¡Sí! Tengo muchos años.", + "🙂🤝📩", + "速度を上げて", + "", + ) + + @Test + fun `UTF-8 to UFT-16 index mapping matches length`() { + for (string in testStrings) { + val content = content { text(string) } + val ba = string.toByteArray(Charsets.UTF_8) + val index = convertUtf8IndexToUtf16(content, ba.size) + index shouldBe string.length + } + } + + @Test + fun `CitationMetadata gets converted to UTF-16`() { + val internalCandidate = + Candidate.Internal( + content = Content.Internal("", listOf(TextPart.Internal("í abc í"))), + citationMetadata = + CitationMetadata.Internal( + listOf( + Citation.Internal( + startIndex = 3, + endIndex = 6, + ) + ) + ) + ) + val candidate = internalCandidate.toPublic() + val start = candidate.citationMetadata!!.citations.first().startIndex + val end = candidate.citationMetadata.citations.first().endIndex + (candidate.content.parts.first() as TextPart).text.substring(start, end) shouldBe "abc" + start shouldBe 2 + end shouldBe 5 + } +}