diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 62e41e1b69..cf3be2b59a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -39,6 +39,8 @@ from ...auth.auth_tool import AuthConfig from ...events.event import Event from ...models.base_llm_connection import BaseLlmConnection +from ...models.google_llm import Gemini +from ...models.google_llm import GoogleLLMVariant from ...models.llm_request import LlmRequest from ...models.llm_response import LlmResponse from ...telemetry import tracing @@ -47,8 +49,8 @@ from ...telemetry.tracing import tracer from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext -from ...utils.context_utils import Aclosing from ...utils import model_name_utils +from ...utils.context_utils import Aclosing from .audio_cache_manager import AudioCacheManager from .functions import build_auth_request_event @@ -515,7 +517,17 @@ async def run_live( llm_request.live_connect_config.session_resumption.handle = ( invocation_context.live_session_resumption_handle ) - llm_request.live_connect_config.session_resumption.transparent = True + # Only set transparent=True for Vertex AI backend, as the Gemini API + # backend explicitly rejects it. + if ( + isinstance(llm, Gemini) + and llm._api_backend == GoogleLLMVariant.VERTEX_AI # pylint: disable=protected-access + ): + session_resumption = ( + llm_request.live_connect_config.session_resumption + ) + if session_resumption.transparent is None: + session_resumption.transparent = True if ( isinstance(llm, Gemini) @@ -527,8 +539,8 @@ async def run_live( if llm_request.live_connect_config is None: llm_request.live_connect_config = types.LiveConnectConfig() if llm_request.live_connect_config.history_config is None: - llm_request.live_connect_config.history_config = types.HistoryConfig( - initial_history_in_client_content=True + llm_request.live_connect_config.history_config = ( + types.HistoryConfig(initial_history_in_client_content=True) ) logger.info( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index aadfd39dec..feb88db7e1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -79,10 +79,15 @@ def _build_basic_request( llm_request.live_connect_config.realtime_input_config = ( invocation_context.run_config.realtime_input_config ) - active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model + active_model_name = ( + getattr(getattr(agent, 'canonical_live_model', None), 'model', None) + or llm_request.model + ) is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name) llm_request.live_connect_config.enable_affective_dialog = ( - None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog + None + if is_gemini_31 + else invocation_context.run_config.enable_affective_dialog ) llm_request.live_connect_config.proactivity = ( None if is_gemini_31 else invocation_context.run_config.proactivity diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index f5d0400b5e..7599de0dad 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -88,11 +88,15 @@ async def send_history(self, history: list[types.Content]): # protocol error (invalid role mid-session), we consolidate previous multi-turn # interactions into a unified contextual preamble on a single user role turn. if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API: - collapsed_text = "Previous conversation history:\n" + collapsed_text = 'Previous conversation history:\n' for c in contents: - text_parts = "".join(p.text for p in c.parts if p.text) + text_parts = ''.join(p.text for p in c.parts if p.text) collapsed_text += f'[{c.role}]: {text_parts}\n' - contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])] + contents = [ + types.Content( + role='user', parts=[types.Part.from_text(text=collapsed_text)] + ) + ] logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send_client_content( @@ -276,8 +280,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: text += part.text is_thought = current_is_thought llm_response.partial = True - # don't yield the merged text event when receiving audio data - if text and not any(p.text for p in content.parts) and not has_inline_data: + if ( + text + and not any(p.text for p in content.parts) + and not has_inline_data + ): yield self.__build_full_text_response(text, is_thought) text = '' yield llm_response diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 59a988c5d1..8a9bd12e19 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -17,12 +17,14 @@ from unittest import mock from unittest.mock import AsyncMock +from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.agents.llm_agent import Agent from google.adk.agents.run_config import RunConfig from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow -from google.adk.models.google_llm import Gemini, GoogleLLMVariant +from google.adk.models.google_llm import Gemini +from google.adk.models.google_llm import GoogleLLMVariant from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin @@ -30,6 +32,7 @@ from google.adk.tools.google_search_tool import GoogleSearchTool from google.genai import types import pytest +from websockets.exceptions import ConnectionClosed from ... import testing_utils @@ -490,8 +493,6 @@ async def call(self, **kwargs): @pytest.mark.asyncio async def test_run_live_reconnects_on_connection_closed(): """Test that run_live reconnects when ConnectionClosed occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -558,7 +559,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_reconnects_on_api_error(): """Test that run_live reconnects when APIError occurs.""" - from google.adk.agents.live_request_queue import LiveRequestQueue from google.genai.errors import APIError real_model = Gemini() @@ -626,7 +626,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_skips_send_history_on_resumption(): """Test that run_live skips send_history when resuming a session.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -684,7 +683,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_live_session_resumption_go_away(): """Test that go_away triggers reconnection.""" - from google.adk.agents.live_request_queue import LiveRequestQueue real_model = Gemini() mock_connection = mock.AsyncMock() @@ -743,8 +741,6 @@ async def mock_receive_2(): @pytest.mark.asyncio async def test_run_live_no_reconnect_without_handle(): """Test that run_live does not reconnect when handle is missing.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() mock_connection = mock.AsyncMock() @@ -786,8 +782,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_limit(): """Test that run_live stops reconnecting after 5 attempts.""" - from google.adk.agents.live_request_queue import LiveRequestQueue - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -843,9 +837,7 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_reconnect_reset_attempt(): """Test that attempt counter is reset on successful communication.""" - from google.adk.agents.live_request_queue import LiveRequestQueue from google.adk.flows.llm_flows.base_llm_flow import DEFAULT_MAX_RECONNECT_ATTEMPTS - from websockets.exceptions import ConnectionClosed real_model = Gemini() @@ -987,7 +979,6 @@ async def mock_receive(): @pytest.mark.asyncio async def test_run_live_clears_resumption_handle_on_transfer(): """Test that run_live clears session resumption handles when transferring to another agent.""" - from google.adk.agents.live_request_queue import LiveRequestQueue agent = Agent(name='test_agent') invocation_context = await testing_utils.create_invocation_context( @@ -1184,21 +1175,27 @@ async def mock_receive_2(): mock_aenter = mock.AsyncMock() mock_aenter.side_effect = [mock_connection, mock_connection_2] - with mock.patch( - 'google.adk.models.google_llm.Gemini.connect' - ) as mock_connect: - mock_connect.return_value.__aenter__ = mock_aenter + with mock.patch.object( + Gemini, '_api_backend', new_callable=mock.PropertyMock + ) as mock_backend: + mock_backend.return_value = GoogleLLMVariant.GEMINI_API + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter - try: - async for _ in flow.run_live(invocation_context): + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: pass - except StopTestError: - pass - assert mock_connect.call_count == 2 - second_call_req = mock_connect.call_args_list[1][0][0] - session_resump = second_call_req.live_connect_config.session_resumption - assert session_resump.transparent is None + assert mock_connect.call_count == 2 + second_call_req = mock_connect.call_args_list[1][0][0] + session_resump = ( + second_call_req.live_connect_config.session_resumption + ) + assert session_resump.transparent is None @pytest.mark.asyncio @@ -1275,7 +1272,7 @@ async def mock_receive_2(): @pytest.mark.asyncio @pytest.mark.parametrize( - "api_backend,should_have_history_config", + 'api_backend,should_have_history_config', [ (GoogleLLMVariant.GEMINI_API, True), (GoogleLLMVariant.VERTEX_AI, False), @@ -1309,8 +1306,12 @@ async def mock_receive(): flow = BaseLlmFlowForTesting() with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + async def mock_preprocess(ctx, req): - req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] + req.model = 'gemini-3.1-flash-live-preview' + req.contents = [ + types.Content(parts=[types.Part.from_text(text='history')]) + ] yield Event(id=Event.new_id(), author='test') with mock.patch.object( diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 7cc7c22290..555c2c2dad 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -1462,7 +1462,9 @@ async def mock_receive_generator(): @pytest.mark.asyncio -async def test_receive_multiplexed_parts(gemini_connection, mock_gemini_session): +async def test_receive_multiplexed_parts( + gemini_connection, mock_gemini_session +): """Test receive with multiplexed inline data and text content.""" mock_content = types.Content( role='model', @@ -1507,6 +1509,7 @@ async def mock_receive_generator(): async def test_send_history_gemini_31_turn_complete(mock_gemini_session): """Verify Gemini 3.1 Live history seeding explicitly appends turn_complete=True.""" from google.adk.models.google_llm import GoogleLLMVariant + conn = GeminiLlmConnection( mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API, @@ -1530,6 +1533,7 @@ async def test_send_history_gemini_31_turn_complete(mock_gemini_session): async def test_send_history_collapse_vertex_ai(mock_gemini_session): """Verify history prompt collapse when seeding Gemini 3.1 Live on Vertex AI backend.""" from google.adk.models.google_llm import GoogleLLMVariant + conn = GeminiLlmConnection( mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI, @@ -1544,10 +1548,15 @@ async def test_send_history_collapse_vertex_ai(mock_gemini_session): await conn.send_history(mock_contents) assert mock_gemini_session.send_client_content.call_count == 1 - called_turns = mock_gemini_session.send_client_content.call_args.kwargs['turns'] + called_turns = mock_gemini_session.send_client_content.call_args.kwargs[ + 'turns' + ] assert len(called_turns) == 1 assert called_turns[0].role == 'user' assert 'Previous conversation history:' in called_turns[0].parts[0].text assert '[user]: hi' in called_turns[0].parts[0].text assert '[model]: hello' in called_turns[0].parts[0].text - assert mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] is True + assert ( + mock_gemini_session.send_client_content.call_args.kwargs['turn_complete'] + is True + )