Skip to content

Commit f286222

Browse files
committed
Expose Gemini Live API client override
1 parent 9670ce2 commit f286222

2 files changed

Lines changed: 99 additions & 18 deletions

File tree

src/google/adk/models/google_llm.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,19 @@ def api_client(self) -> Client:
110110
111111
Use ``@property`` instead of ``@cached_property`` if you hit asyncio
112112
lock contention in multithreaded code.
113+
114+
Customizing the Live API Client:
115+
The Live API path uses its own client. To set Live API-only options,
116+
subclass ``Gemini`` and override the ``live_api_client`` property::
117+
118+
from functools import cached_property
119+
from google.adk.models import Gemini
120+
from google.genai import Client
121+
122+
class RegionalLiveGemini(Gemini):
123+
@cached_property
124+
def live_api_client(self) -> Client:
125+
return Client(vertexai=True, location="europe-central2")
113126
"""
114127

115128
model: str = 'gemini-2.5-flash'
@@ -376,8 +389,7 @@ def _live_api_version(self) -> str:
376389
# use v1alpha for using API KEY from Google AI Studio
377390
return 'v1alpha'
378391

379-
@cached_property
380-
def _live_api_client(self) -> Client:
392+
def _build_live_api_client(self) -> Client:
381393
from google.genai import Client
382394

383395
base_url, _ = self._base_url_and_api_version
@@ -394,6 +406,27 @@ def _live_api_client(self) -> Client:
394406

395407
return Client(**kwargs)
396408

409+
def _uses_legacy_live_api_client_override(self) -> bool:
410+
for cls in type(self).__mro__:
411+
if '_live_api_client' in cls.__dict__:
412+
return cls is not Gemini
413+
return False
414+
415+
@cached_property
416+
def live_api_client(self) -> Client:
417+
"""Provides the Live API client.
418+
419+
Subclasses can override this property to customize the ``Client`` used by
420+
Live API connections.
421+
"""
422+
if self._uses_legacy_live_api_client_override():
423+
return self._live_api_client
424+
return self._build_live_api_client()
425+
426+
@cached_property
427+
def _live_api_client(self) -> Client:
428+
return self.live_api_client
429+
397430
@contextlib.asynccontextmanager
398431
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
399432
"""Connects to the Gemini model and returns an llm connection.
@@ -455,7 +488,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
455488
llm_request.live_connect_config.tools = llm_request.config.tools
456489
logger.debug('Connecting to live with llm_request:%s', llm_request)
457490
logger.debug('Live connect config: %s', llm_request.live_connect_config)
458-
async with self._live_api_client.aio.live.connect(
491+
async with self.live_api_client.aio.live.connect(
459492
model=llm_request.model, config=llm_request.live_connect_config
460493
) as live_session:
461494
yield GeminiLlmConnection(

tests/unittests/models/test_google_llm.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def test_gemini_live_api_client_creation_with_projects_prefix():
174174
model="projects/test-project/locations/test-location/publishers/google/models/gemini-2.5-pro"
175175
)
176176
with mock.patch("google.genai.Client", autospec=True) as mock_client:
177-
_ = model._live_api_client
177+
_ = model.live_api_client
178178
assert mock_client.call_count == 2
179179

180-
# Second call is for _live_api_client
180+
# Second call is for live_api_client.
181181
_, kwargs = mock_client.call_args_list[1]
182182
assert kwargs["vertexai"] is True
183183

@@ -732,25 +732,40 @@ def test_live_api_version_gemini_api(gemini_llm):
732732
assert gemini_llm._live_api_version == "v1alpha"
733733

734734

735-
def test_live_api_client_uses_api_version_from_google_base_url():
735+
@pytest.mark.parametrize(
736+
"base_url, expected_base_url",
737+
[
738+
(
739+
"https://generativelanguage.googleapis.com/v1alpha",
740+
"https://generativelanguage.googleapis.com/",
741+
),
742+
(
743+
"https://generativelanguage.mtls.googleapis.com/v1alpha",
744+
"https://generativelanguage.mtls.googleapis.com/",
745+
),
746+
],
747+
)
748+
def test_live_api_client_uses_api_version_from_google_base_url(
749+
base_url, expected_base_url
750+
):
736751
gemini_llm = Gemini(
737752
model="gemini-2.5-flash",
738-
base_url="https://generativelanguage.googleapis.com/v1alpha",
753+
base_url=base_url,
739754
)
740755

741-
client = gemini_llm._live_api_client
756+
client = gemini_llm.live_api_client
742757
http_options = client._api_client._http_options
743758

744-
assert http_options.base_url == "https://generativelanguage.googleapis.com/"
759+
assert http_options.base_url == expected_base_url
745760
assert http_options.api_version == "v1alpha"
746761

747762

748763
def test_live_api_client_properties(gemini_llm):
749-
"""Test that _live_api_client is properly configured with tracking headers and API version."""
764+
"""Test that live_api_client is properly configured with tracking headers and API version."""
750765
with mock.patch.object(
751766
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
752767
):
753-
client = gemini_llm._live_api_client
768+
client = gemini_llm.live_api_client
754769

755770
# Verify that the client has the correct headers and API version
756771
http_options = client._api_client._http_options
@@ -763,6 +778,39 @@ def test_live_api_client_properties(gemini_llm):
763778
assert value in http_options.headers[key]
764779

765780

781+
def test_live_api_client_private_alias(gemini_llm):
782+
assert gemini_llm._live_api_client is gemini_llm.live_api_client
783+
784+
785+
def test_live_api_client_public_override():
786+
custom_client = mock.MagicMock()
787+
788+
class CustomGemini(Gemini):
789+
790+
@property
791+
def live_api_client(self):
792+
return custom_client
793+
794+
gemini_llm = CustomGemini(model="gemini-2.5-flash")
795+
796+
assert gemini_llm.live_api_client is custom_client
797+
assert gemini_llm._live_api_client is custom_client
798+
799+
800+
def test_live_api_client_legacy_private_override():
801+
custom_client = mock.MagicMock()
802+
803+
class CustomGemini(Gemini):
804+
805+
@property
806+
def _live_api_client(self):
807+
return custom_client
808+
809+
gemini_llm = CustomGemini(model="gemini-2.5-flash")
810+
811+
assert gemini_llm.live_api_client is custom_client
812+
813+
766814
@pytest.mark.asyncio
767815
async def test_connect_with_custom_headers(gemini_llm, llm_request):
768816
"""Test that connect method updates tracking headers and API version when custom headers are provided."""
@@ -774,8 +822,8 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):
774822

775823
mock_live_session = mock.AsyncMock()
776824

777-
# Mock the _live_api_client to return a mock client
778-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
825+
# Mock the live_api_client to return a mock client
826+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
779827
# Create a mock context manager
780828
class MockLiveConnect:
781829

@@ -817,7 +865,7 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request):
817865

818866
mock_live_session = mock.AsyncMock()
819867

820-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
868+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
821869

822870
class MockLiveConnect:
823871

@@ -2099,7 +2147,7 @@ async def test_connect_uses_gemini_speech_config_when_request_is_none(
20992147

21002148
mock_live_session = mock.AsyncMock()
21012149

2102-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
2150+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
21032151

21042152
class MockLiveConnect:
21052153

@@ -2147,7 +2195,7 @@ async def test_connect_uses_request_speech_config_when_gemini_is_none(
21472195

21482196
mock_live_session = mock.AsyncMock()
21492197

2150-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
2198+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
21512199

21522200
class MockLiveConnect:
21532201

@@ -2201,7 +2249,7 @@ async def test_connect_request_gemini_config_overrides_speech_config(
22012249

22022250
mock_live_session = mock.AsyncMock()
22032251

2204-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
2252+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
22052253

22062254
class MockLiveConnect:
22072255

@@ -2242,7 +2290,7 @@ async def test_connect_speech_config_remains_none_when_both_are_none(
22422290

22432291
mock_live_session = mock.AsyncMock()
22442292

2245-
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
2293+
with mock.patch.object(gemini_llm, "live_api_client") as mock_live_client:
22462294

22472295
class MockLiveConnect:
22482296

0 commit comments

Comments
 (0)