Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from ...auth.auth_tool import AuthConfig
from ...events.event import Event
from ...models.base_llm_connection import BaseLlmConnection
from ...models.google_llm import Gemini
from ...models.google_llm import GoogleLLMVariant
from ...models.llm_request import LlmRequest
from ...models.llm_response import LlmResponse
from ...telemetry import tracing
Expand All @@ -47,8 +49,8 @@
from ...telemetry.tracing import tracer
from ...tools.base_toolset import BaseToolset
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
from ...utils import model_name_utils
from ...utils.context_utils import Aclosing
from .audio_cache_manager import AudioCacheManager
from .functions import build_auth_request_event

Expand Down Expand Up @@ -515,7 +517,17 @@ async def run_live(
llm_request.live_connect_config.session_resumption.handle = (
invocation_context.live_session_resumption_handle
)
llm_request.live_connect_config.session_resumption.transparent = True
# Only set transparent=True for Vertex AI backend, as the Gemini API
# backend explicitly rejects it.
if (
isinstance(llm, Gemini)
and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access
):
session_resumption = (
llm_request.live_connect_config.session_resumption
)
if session_resumption.transparent is None:
session_resumption.transparent = True

if (
isinstance(llm, Gemini)
Expand All @@ -527,8 +539,8 @@ async def run_live(
if llm_request.live_connect_config is None:
llm_request.live_connect_config = types.LiveConnectConfig()
if llm_request.live_connect_config.history_config is None:
llm_request.live_connect_config.history_config = types.HistoryConfig(
initial_history_in_client_content=True
llm_request.live_connect_config.history_config = (
types.HistoryConfig(initial_history_in_client_content=True)
)

logger.info(
Expand Down
9 changes: 7 additions & 2 deletions src/google/adk/flows/llm_flows/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ def _build_basic_request(
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
active_model_name = (
getattr(getattr(agent, 'canonical_live_model', None), 'model', None)
or llm_request.model
)
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
llm_request.live_connect_config.enable_affective_dialog = (
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
None
if is_gemini_31
else invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
None if is_gemini_31 else invocation_context.run_config.proactivity
Expand Down
17 changes: 12 additions & 5 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ async def send_history(self, history: list[types.Content]):
# protocol error (invalid role mid-session), we consolidate previous multi-turn
# interactions into a unified contextual preamble on a single user role turn.
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
collapsed_text = "Previous conversation history:\n"
collapsed_text = 'Previous conversation history:\n'
for c in contents:
text_parts = "".join(p.text for p in c.parts if p.text)
text_parts = ''.join(p.text for p in c.parts if p.text)
collapsed_text += f'[{c.role}]: {text_parts}\n'
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]
contents = [
types.Content(
role='user', parts=[types.Part.from_text(text=collapsed_text)]
)
]

logger.debug('Sending history to live connection: %s', contents)
await self._gemini_session.send_client_content(
Expand Down Expand Up @@ -276,8 +280,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
text += part.text
is_thought = current_is_thought
llm_response.partial = True
# don't yield the merged text event when receiving audio data
if text and not any(p.text for p in content.parts) and not has_inline_data:
if (
text
and not any(p.text for p in content.parts)
and not has_inline_data
):
yield self.__build_full_text_response(text, is_thought)
text = ''
yield llm_response
Expand Down
55 changes: 28 additions & 27 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@
from unittest import mock
from unittest.mock import AsyncMock

from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.agents.llm_agent import Agent
from google.adk.agents.run_config import RunConfig
from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.models.google_llm import Gemini, GoogleLLMVariant
from google.adk.models.google_llm import Gemini
from google.adk.models.google_llm import GoogleLLMVariant
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.base_toolset import BaseToolset
from google.adk.tools.google_search_tool import GoogleSearchTool
from google.genai import types
import pytest
from websockets.exceptions import ConnectionClosed

from ... import testing_utils

Expand Down Expand Up @@ -490,8 +493,6 @@ async def call(self, **kwargs):
@pytest.mark.asyncio
async def test_run_live_reconnects_on_connection_closed():
"""Test that run_live reconnects when ConnectionClosed occurs."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()
mock_connection = mock.AsyncMock()
Expand Down Expand Up @@ -558,7 +559,6 @@ async def mock_receive_2():
@pytest.mark.asyncio
async def test_run_live_reconnects_on_api_error():
"""Test that run_live reconnects when APIError occurs."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.genai.errors import APIError

real_model = Gemini()
Expand Down Expand Up @@ -626,7 +626,6 @@ async def mock_receive_2():
@pytest.mark.asyncio
async def test_run_live_skips_send_history_on_resumption():
"""Test that run_live skips send_history when resuming a session."""
from google.adk.agents.live_request_queue import LiveRequestQueue

real_model = Gemini()
mock_connection = mock.AsyncMock()
Expand Down Expand Up @@ -684,7 +683,6 @@ async def mock_receive():
@pytest.mark.asyncio
async def test_live_session_resumption_go_away():
"""Test that go_away triggers reconnection."""
from google.adk.agents.live_request_queue import LiveRequestQueue

real_model = Gemini()
mock_connection = mock.AsyncMock()
Expand Down Expand Up @@ -743,8 +741,6 @@ async def mock_receive_2():
@pytest.mark.asyncio
async def test_run_live_no_reconnect_without_handle():
"""Test that run_live does not reconnect when handle is missing."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()
mock_connection = mock.AsyncMock()
Expand Down Expand Up @@ -786,8 +782,6 @@ async def mock_receive():
@pytest.mark.asyncio
async def test_run_live_reconnect_limit():
"""Test that run_live stops reconnecting after 5 attempts."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()

Expand Down Expand Up @@ -843,9 +837,7 @@ async def mock_receive():
@pytest.mark.asyncio
async def test_run_live_reconnect_reset_attempt():
"""Test that attempt counter is reset on successful communication."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS
from websockets.exceptions import ConnectionClosed

real_model = Gemini()

Expand Down Expand Up @@ -987,7 +979,6 @@ async def mock_receive():
@pytest.mark.asyncio
async def test_run_live_clears_resumption_handle_on_transfer():
"""Test that run_live clears session resumption handles when transferring to another agent."""
from google.adk.agents.live_request_queue import LiveRequestQueue

agent = Agent(name='test_agent')
invocation_context = await testing_utils.create_invocation_context(
Expand Down Expand Up @@ -1184,21 +1175,27 @@ async def mock_receive_2():
mock_aenter = mock.AsyncMock()
mock_aenter.side_effect = [mock_connection, mock_connection_2]

with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__ = mock_aenter
with mock.patch.object(
Gemini, '_api_backend', new_callable=mock.PropertyMock
) as mock_backend:
mock_backend.return_value = GoogleLLMVariant.GEMINI_API
with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__ = mock_aenter

try:
async for _ in flow.run_live(invocation_context):
try:
async for _ in flow.run_live(invocation_context):
pass
except StopTestError:
pass
except StopTestError:
pass

assert mock_connect.call_count == 2
second_call_req = mock_connect.call_args_list[1][0][0]
session_resump = second_call_req.live_connect_config.session_resumption
assert session_resump.transparent is None
assert mock_connect.call_count == 2
second_call_req = mock_connect.call_args_list[1][0][0]
session_resump = (
second_call_req.live_connect_config.session_resumption
)
assert session_resump.transparent is None


@pytest.mark.asyncio
Expand Down Expand Up @@ -1275,7 +1272,7 @@ async def mock_receive_2():

@pytest.mark.asyncio
@pytest.mark.parametrize(
"api_backend,should_have_history_config",
'api_backend,should_have_history_config',
[
(GoogleLLMVariant.GEMINI_API, True),
(GoogleLLMVariant.VERTEX_AI, False),
Expand Down Expand Up @@ -1309,8 +1306,12 @@ async def mock_receive():
flow = BaseLlmFlowForTesting()

with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):

async def mock_preprocess(ctx, req):
req.contents = [types.Content(parts=[types.Part.from_text(text='history')])]
req.model = 'gemini-3.1-flash-live-preview'
req.contents = [
types.Content(parts=[types.Part.from_text(text='history')])
]
yield Event(id=Event.new_id(), author='test')

with mock.patch.object(
Expand Down
15 changes: 12 additions & 3 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,9 @@ async def mock_receive_generator():


@pytest.mark.asyncio
async def test_receive_multiplexed_parts(gemini_connection, mock_gemini_session):
async def test_receive_multiplexed_parts(
gemini_connection, mock_gemini_session
):
"""Test receive with multiplexed inline data and text content."""
mock_content = types.Content(
role='model',
Expand Down Expand Up @@ -1507,6 +1509,7 @@ async def mock_receive_generator():
async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
"""Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True."""
from google.adk.models.google_llm import GoogleLLMVariant

conn = GeminiLlmConnection(
mock_gemini_session,
api_backend=GoogleLLMVariant.GEMINI_API,
Expand All @@ -1530,6 +1533,7 @@ async def test_send_history_gemini_31_turn_complete(mock_gemini_session):
async def test_send_history_collapse_vertex_ai(mock_gemini_session):
"""Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend."""
from google.adk.models.google_llm import GoogleLLMVariant

conn = GeminiLlmConnection(
mock_gemini_session,
api_backend=GoogleLLMVariant.VERTEX_AI,
Expand All @@ -1544,10 +1548,15 @@ async def test_send_history_collapse_vertex_ai(mock_gemini_session):
await conn.send_history(mock_contents)

assert mock_gemini_session.send_client_content.call_count == 1
called_turns = mock_gemini_session.send_client_content.call_args.kwargs['turns']
called_turns = mock_gemini_session.send_client_content.call_args.kwargs[
'turns'
]
assert len(called_turns) == 1
assert called_turns[0].role == 'user'
assert 'Previous conversation history:' in called_turns[0].parts[0].text
assert '[user]: hi' in called_turns[0].parts[0].text
assert '[model]: hello' in called_turns[0].parts[0].text
assert mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] is True
assert (
mock_gemini_session.send_client_content.call_args.kwargs['turn_complete']
is True
)
Loading