Skip to content
20 changes: 19 additions & 1 deletion src/google/adk/agents/context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from __future__ import annotations

from typing import Optional

from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
Expand All @@ -38,10 +41,12 @@ class ContextCacheConfig(BaseModel):
cache_intervals: Maximum number of invocations to reuse the same cache before refreshing it
ttl_seconds: Time-to-live for cache in seconds
min_tokens: Minimum tokens required to enable caching
create_http_options: HTTP options for cache creation API calls
"""

model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)

cache_intervals: int = Field(
Expand Down Expand Up @@ -72,6 +77,18 @@ class ContextCacheConfig(BaseModel):
),
)

create_http_options: Optional[types.HttpOptions] = Field(
default=None,
description=(
"HTTP options for cache creation API calls. Use this to set a"
" timeout on CachedContent.create() calls (e.g."
" types.HttpOptions(timeout=10000) for a 10-second timeout in"
" milliseconds). When the cache creation call exceeds the timeout,"
" it fails and the request proceeds without caching. None uses the"
" client's default HTTP options."
),
)

@property
def ttl_string(self) -> str:
"""Get TTL as string format for cache creation."""
Expand All @@ -81,5 +98,6 @@ def __str__(self) -> str:
"""String representation for logging."""
return (
f"ContextCacheConfig(cache_intervals={self.cache_intervals}, "
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens})"
f"ttl={self.ttl_seconds}s, min_tokens={self.min_tokens}, "
f"create_http_options={self.create_http_options})"
)
12 changes: 11 additions & 1 deletion src/google/adk/models/gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ async def handle_context_caching(
)
if cache_metadata:
self._apply_cache_to_request(
llm_request, cache_metadata.cache_name, cache_contents_count
llm_request,
cache_metadata.cache_name,
cache_contents_count,
)
return cache_metadata

Expand All @@ -127,6 +129,7 @@ async def handle_context_caching(
fingerprint_for_all = self._generate_cache_fingerprint(
llm_request, total_contents_count
)

return CacheMetadata(
fingerprint=fingerprint_for_all,
contents_count=total_contents_count,
Expand Down Expand Up @@ -386,6 +389,13 @@ async def _create_gemini_cache(
if llm_request.config and llm_request.config.tool_config:
cache_config.tool_config = llm_request.config.tool_config

# Pass through HTTP options (e.g. timeout) from cache config
if (
llm_request.cache_config
and llm_request.cache_config.create_http_options
):
cache_config.http_options = llm_request.cache_config.create_http_options

span.set_attribute("cache_contents_count", cache_contents_count)
span.set_attribute("model", llm_request.model)
span.set_attribute("ttl_seconds", llm_request.cache_config.ttl_seconds)
Expand Down
28 changes: 13 additions & 15 deletions tests/unittests/agents/test_context_cache_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,19 @@ def test_str_representation(self):
)

expected = (
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024)"
"ContextCacheConfig(cache_intervals=15, ttl=3600s, min_tokens=1024,"
" create_http_options=None)"
)
assert str(config) == expected

def test_str_representation_defaults(self):
"""Test string representation with default values."""
config = ContextCacheConfig()

expected = "ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0)"
expected = (
"ContextCacheConfig(cache_intervals=10, ttl=1800s, min_tokens=0,"
" create_http_options=None)"
)
assert str(config) == expected

def test_pydantic_model_validation(self):
Expand All @@ -126,25 +130,19 @@ def test_pydantic_model_validation(self):

def test_field_descriptions(self):
"""Test that fields have proper descriptions."""
config = ContextCacheConfig()
schema = config.model_json_schema()
fields = ContextCacheConfig.model_fields

assert "cache_intervals" in schema["properties"]
assert "cache_intervals" in fields
assert (
"Maximum number of invocations"
in schema["properties"]["cache_intervals"]["description"]
"Maximum number of invocations" in fields["cache_intervals"].description
)

assert "ttl_seconds" in schema["properties"]
assert (
"Time-to-live for cache"
in schema["properties"]["ttl_seconds"]["description"]
)
assert "ttl_seconds" in fields
assert "Time-to-live for cache" in fields["ttl_seconds"].description

assert "min_tokens" in schema["properties"]
assert "min_tokens" in fields
assert (
"Minimum estimated request tokens"
in schema["properties"]["min_tokens"]["description"]
"Minimum estimated request tokens" in fields["min_tokens"].description
)

def test_immutability_config(self):
Expand Down
60 changes: 60 additions & 0 deletions tests/unittests/agents/test_gemini_context_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,63 @@ async def test_cache_creation_without_token_count(self):
assert result.cache_name is None
assert result.fingerprint == "test_fp"
self.manager.genai_client.aio.caches.create.assert_not_called()

async def test_create_http_options_passthrough(self):
"""Test that create_http_options is passed through to cache creation config."""
mock_cached_content = AsyncMock()
mock_cached_content.name = (
"projects/test/locations/us-central1/cachedContents/test123"
)
self.manager.genai_client.aio.caches.create = AsyncMock(
return_value=mock_cached_content
)

# Create config with http_options (e.g. 10s timeout)
http_options = types.HttpOptions(timeout=10000)
cache_config_with_timeout = ContextCacheConfig(
cache_intervals=10,
ttl_seconds=1800,
min_tokens=0,
create_http_options=http_options,
)

llm_request = self.create_llm_request()
llm_request.cache_config = cache_config_with_timeout

cache_contents_count = max(0, len(llm_request.contents) - 1)

with patch.object(
self.manager, "_generate_cache_fingerprint", return_value="test_fp"
):
await self.manager._create_gemini_cache(llm_request, cache_contents_count)

# Verify cache creation call includes http_options
create_call = self.manager.genai_client.aio.caches.create.call_args
assert create_call is not None
cache_config = create_call[1]["config"]
assert cache_config.http_options is not None
assert cache_config.http_options.timeout == 10000

async def test_create_without_http_options(self):
"""Test that cache creation works without create_http_options."""
mock_cached_content = AsyncMock()
mock_cached_content.name = (
"projects/test/locations/us-central1/cachedContents/test123"
)
self.manager.genai_client.aio.caches.create = AsyncMock(
return_value=mock_cached_content
)

llm_request = self.create_llm_request()
cache_contents_count = max(0, len(llm_request.contents) - 1)

with patch.object(
self.manager, "_generate_cache_fingerprint", return_value="test_fp"
):
await self.manager._create_gemini_cache(llm_request, cache_contents_count)

# Verify cache creation call does not include http_options
create_call = self.manager.genai_client.aio.caches.create.call_args
assert create_call is not None
cache_config = create_call[1]["config"]
assert cache_config.http_options is None
3 changes: 2 additions & 1 deletion tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,8 @@ def test_runner_realistic_cache_config_scenario(self):

# Verify string representation
expected_str = (
"ContextCacheConfig(cache_intervals=30, ttl=14400s, min_tokens=4096)"
"ContextCacheConfig(cache_intervals=30, ttl=14400s, min_tokens=4096,"
" create_http_options=None)"
)
assert str(runner.context_cache_config) == expected_str

Expand Down
Loading