diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 7fc39748ec..4aa3fb829f 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -44,6 +44,7 @@ def __init__( gemini_session: live.AsyncSession, api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI, model_version: str | None = None, + live_config: types.LiveConnectConfig | None = None, ): self._gemini_session = gemini_session self._input_transcription_text: str = '' @@ -51,6 +52,10 @@ def __init__( self._api_backend = api_backend self._model_version = model_version + self._audio_active = False + if live_config and getattr(live_config, 'response_modalities', None): + self._audio_active = 'AUDIO' in live_config.response_modalities + async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -111,25 +116,27 @@ async def send_content(self, content: types.Content): ), ) else: - logger.debug('Sending LLM new content %s', content) is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live( self._model_version ) is_gemini_api = self._api_backend == GoogleLLMVariant.GEMINI_API - # As of now, Gemini 3.1 Flash Live is only available in Gemini API, not - # Vertex AI. - if ( - is_gemini_31 - and is_gemini_api - and len(content.parts) == 1 - and content.parts[0].text - ): - logger.debug('Using send_realtime_input for Gemini 3.1 text input') - await self._gemini_session.send_realtime_input( - text=content.parts[0].text + # Route via send_realtime_input if audio is active OR if targeting 3.1 API + if self._audio_active or (is_gemini_31 and is_gemini_api): + logger.debug( + 'Routing text via send_realtime_input %s', + content, ) + has_text = False + for part in content.parts: + if isinstance(part.text, str): + await self._gemini_session.send_realtime_input(text=part.text) + has_text = True + + if not has_text: + logger.warning('Encountered unsupported content in send_content') else: + logger.debug('Sending LLM new content %s', content) await self._gemini_session.send( input=types.LiveClientContent( turns=[content], @@ -154,9 +161,9 @@ async def send_realtime(self, input: RealtimeInput): # As of now, Gemini 3.1 Flash Live is only available in Gemini API, not # Vertex AI. if is_gemini_31 and is_gemini_api: - if input.mime_type and input.mime_type.startswith('audio/'): + if isinstance(input.mime_type, str) and input.mime_type.startswith('audio/'): await self._gemini_session.send_realtime_input(audio=input) - elif input.mime_type and input.mime_type.startswith('image/'): + elif isinstance(input.mime_type, str) and input.mime_type.startswith('image/'): await self._gemini_session.send_realtime_input(video=input) else: logger.warning( @@ -165,7 +172,12 @@ async def send_realtime(self, input: RealtimeInput): input.mime_type, ) else: - await self._gemini_session.send_realtime_input(media=input) + if isinstance(input.mime_type, str) and input.mime_type.startswith('video/'): + await self._gemini_session.send_realtime_input(video=input) + elif isinstance(input.mime_type, str) and input.mime_type.startswith('audio/'): + await self._gemini_session.send_realtime_input(audio=input) + else: + await self._gemini_session.send_realtime_input(media=input) elif isinstance(input, types.ActivityStart): logger.debug('Sending LLM activity start signal.') diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 0114d73a82..e13a7315ea 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -424,6 +424,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: live_session, api_backend=self._api_backend, model_version=llm_request.model, + live_config=llm_request.live_connect_config, ) async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 09bd537d8e..968ded9783 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -54,16 +54,37 @@ def test_blob(): return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm') +@pytest.fixture +def test_fallback_blob(): + """Test blob for unknown media data.""" + return types.Blob(data=b'\x01\x02', mime_type='application/pdf') + + @pytest.mark.asyncio -async def test_send_realtime_default_behavior( +async def test_send_realtime_audio_routing( gemini_connection, mock_gemini_session, test_blob ): - """Test send_realtime with default automatic_activity_detection value (True).""" + """Test send_realtime explicitly routing audio mimetypes to the audio parameter.""" await gemini_connection.send_realtime(test_blob) # Should call send once mock_gemini_session.send_realtime_input.assert_called_once_with( - media=test_blob + audio=test_blob + ) + # Should not call .send function + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_realtime_media_fallback_routing( + gemini_connection, mock_gemini_session, test_fallback_blob +): + """Test send_realtime falling back to media for non-audio/video mimetypes.""" + await gemini_connection.send_realtime(test_fallback_blob) + + # Should call send once + mock_gemini_session.send_realtime_input.assert_called_once_with( + media=test_fallback_blob ) # Should not call .send function mock_gemini_session.send.assert_not_called() @@ -90,7 +111,12 @@ async def test_send_history(gemini_connection, mock_gemini_session): @pytest.mark.asyncio async def test_send_content_text(gemini_connection, mock_gemini_session): - """Test send_content with text content.""" + """Test send_content with text content when audio is inactive. + + Note: gemini_connection._audio_active is False by default. + """ + assert gemini_connection._audio_active is False + content = types.Content( role='user', parts=[types.Part.from_text(text='Hello')] ) @@ -104,6 +130,21 @@ async def test_send_content_text(gemini_connection, mock_gemini_session): assert call_args['input'].turn_complete is True +@pytest.mark.asyncio +async def test_send_content_text_audio_active(gemini_connection, mock_gemini_session): + """Test send_content routes to send_realtime_input when audio is active.""" + gemini_connection._audio_active = True + + content = types.Content( + role='user', parts=[types.Part.from_text(text='Hello')] + ) + + await gemini_connection.send_content(content) + + mock_gemini_session.send_realtime_input.assert_called_once_with(text='Hello') + mock_gemini_session.send.assert_not_called() + + @pytest.mark.asyncio async def test_send_content_function_response( gemini_connection, mock_gemini_session diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 5ab56bf42b..aec75dd2e1 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -752,6 +752,7 @@ async def __aexit__(self, *args): mock_live_session, api_backend=gemini_llm._api_backend, model_version=llm_request.model, + live_config=llm_request.live_connect_config, )