Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ def api_client(self) -> Client:

Use ``@property`` instead of ``@cached_property`` if you hit asyncio
lock contention in multithreaded code.

Customizing the Live API Client:
The Live API path uses its own client. To set Live API-only options,
subclass ``Gemini`` and override the ``live_api_client`` property::

from functools import cached_property
from google.adk.models import Gemini
from google.genai import Client

class RegionalLiveGemini(Gemini):
@cached_property
def live_api_client(self) -> Client:
return Client(vertexai=True, location="europe-central2")
"""

model: str = 'gemini-2.5-flash'
Expand Down Expand Up @@ -376,8 +389,7 @@ def _live_api_version(self) -> str:
# use v1alpha for using API KEY from Google AI Studio
return 'v1alpha'

@cached_property
def _live_api_client(self) -> Client:
def _build_live_api_client(self) -> Client:
from google.genai import Client

base_url, _ = self._base_url_and_api_version
Expand All @@ -394,6 +406,27 @@ def _live_api_client(self) -> Client:

return Client(**kwargs)

def _uses_legacy_live_api_client_override(self) -> bool:
for cls in type(self).__mro__:
if '_live_api_client' in cls.__dict__:
return cls is not Gemini
return False

@cached_property
def live_api_client(self) -> Client:
"""Provides the Live API client.

Subclasses can override this property to customize the ``Client`` used by
Live API connections.
"""
if self._uses_legacy_live_api_client_override():
return self._live_api_client
return self._build_live_api_client()

@cached_property
def _live_api_client(self) -> Client:
return self.live_api_client

@contextlib.asynccontextmanager
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
"""Connects to the Gemini model and returns an llm connection.
Expand Down Expand Up @@ -455,7 +488,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
llm_request.live_connect_config.tools = llm_request.config.tools
logger.debug('Connecting to live with llm_request:%s', llm_request)
logger.debug('Live connect config: %s', llm_request.live_connect_config)
async with self._live_api_client.aio.live.connect(
async with self.live_api_client.aio.live.connect(
model=llm_request.model, config=llm_request.live_connect_config
) as live_session:
yield GeminiLlmConnection(
Expand Down
78 changes: 63 additions & 15 deletions tests/unittests/models/test_google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ def test_gemini_live_api_client_creation_with_projects_prefix():
model="projects/test-project/locations/test-location/publishers/google/models/gemini-2.5-pro"
)
with mock.patch("google.genai.Client", autospec=True) as mock_client:
_ = model._live_api_client
_ = model.live_api_client
assert mock_client.call_count == 2

# Second call is for _live_api_client
# Second call is for live_api_client.
_, kwargs = mock_client.call_args_list[1]
assert kwargs["vertexai"] is True

Expand Down Expand Up @@ -732,25 +732,40 @@ def test_live_api_version_gemini_api(gemini_llm):
assert gemini_llm._live_api_version == "v1alpha"


def test_live_api_client_uses_api_version_from_google_base_url():
@pytest.mark.parametrize(
"base_url, expected_base_url",
[
(
"https://generativelanguage.googleapis.com/v1alpha",
"https://generativelanguage.googleapis.com/",
),
(
"https://generativelanguage.mtls.googleapis.com/v1alpha",
"https://generativelanguage.mtls.googleapis.com/",
),
],
)
def test_live_api_client_uses_api_version_from_google_base_url(
base_url, expected_base_url
):
gemini_llm = Gemini(
model="gemini-2.5-flash",
base_url="https://generativelanguage.googleapis.com/v1alpha",
base_url=base_url,
)

client = gemini_llm._live_api_client
client = gemini_llm.live_api_client
http_options = client._api_client._http_options

assert http_options.base_url == "https://generativelanguage.googleapis.com/"
assert http_options.base_url == expected_base_url
assert http_options.api_version == "v1alpha"


def test_live_api_client_properties(gemini_llm):
"""Test that _live_api_client is properly configured with tracking headers and API version."""
"""Test that live_api_client is properly configured with tracking headers and API version."""
with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
):
client = gemini_llm._live_api_client
client = gemini_llm.live_api_client

# Verify that the client has the correct headers and API version
http_options = client._api_client._http_options
Expand All @@ -763,6 +778,39 @@ def test_live_api_client_properties(gemini_llm):
assert value in http_options.headers[key]


def test_live_api_client_private_alias(gemini_llm):
assert gemini_llm._live_api_client is gemini_llm.live_api_client


def test_live_api_client_public_override():
custom_client = mock.MagicMock()

class CustomGemini(Gemini):

@property
def live_api_client(self):
return custom_client

gemini_llm = CustomGemini(model="gemini-2.5-flash")

assert gemini_llm.live_api_client is custom_client
assert gemini_llm._live_api_client is custom_client


def test_live_api_client_legacy_private_override():
custom_client = mock.MagicMock()

class CustomGemini(Gemini):

@property
def _live_api_client(self):
return custom_client

gemini_llm = CustomGemini(model="gemini-2.5-flash")

assert gemini_llm.live_api_client is custom_client


@pytest.mark.asyncio
async def test_connect_with_custom_headers(gemini_llm, llm_request):
"""Test that connect method updates tracking headers and API version when custom headers are provided."""
Expand All @@ -774,8 +822,8 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):

mock_live_session = mock.AsyncMock()

# Mock the _live_api_client to return a mock client
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
# Mock the live_api_client to return a mock client
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
# Create a mock context manager
class MockLiveConnect:

Expand Down Expand Up @@ -817,7 +865,7 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request):

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:

class MockLiveConnect:

Expand Down Expand Up @@ -2099,7 +2147,7 @@ async def test_connect_uses_gemini_speech_config_when_request_is_none(

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:

class MockLiveConnect:

Expand Down Expand Up @@ -2147,7 +2195,7 @@ async def test_connect_uses_request_speech_config_when_gemini_is_none(

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:

class MockLiveConnect:

Expand Down Expand Up @@ -2201,7 +2249,7 @@ async def test_connect_request_gemini_config_overrides_speech_config(

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:

class MockLiveConnect:

Expand Down Expand Up @@ -2242,7 +2290,7 @@ async def test_connect_speech_config_remains_none_when_both_are_none(

mock_live_session = mock.AsyncMock()

with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:

class MockLiveConnect:

Expand Down