diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 76c7b8e160..532644d5d8 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1020,9 +1020,16 @@ async def _postprocess_live( and not llm_response.input_transcription and not llm_response.output_transcription and not llm_response.usage_metadata + and not llm_response.setup_complete ): return + # Handle setup complete events + if llm_response.setup_complete: + model_response_event.setup_complete = llm_response.setup_complete + yield model_response_event + return + # Handle transcription events ONCE per llm_response, outside the event loop if llm_response.input_transcription: model_response_event.input_transcription = ( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 8f491dc88b..e4ad9c1c89 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -209,6 +209,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # partial content and emit responses as needed. async for message in agen: logger.debug('Got LLM Live message: %s', message) + if message.setup_complete: + yield LlmResponse(setup_complete=True) if message.usage_metadata: # Tracks token usage data per model. yield LlmResponse( diff --git a/src/google/adk/models/llm_response.py b/src/google/adk/models/llm_response.py index 92a687bae3..699135ff8a 100644 --- a/src/google/adk/models/llm_response.py +++ b/src/google/adk/models/llm_response.py @@ -46,6 +46,8 @@ class LlmResponse(BaseModel): output_transcription: Audio transcription of model output. avg_logprobs: Average log probability of the generated tokens. logprobs_result: Detailed log probabilities for chosen and top candidate tokens. + setup_complete: Indicates whether the initial model setup is complete. + Only used for Gemini Live streaming mode. """ model_config = ConfigDict( @@ -80,6 +82,12 @@ class LlmResponse(BaseModel): Only used for streaming mode. """ + setup_complete: Optional[bool] = None + """Indicates whether the initial model setup is complete. + + Only used for Gemini Live streaming mode. + """ + finish_reason: Optional[types.FinishReason] = None """The finish reason of the response.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 0bed24831e..f5c44d500c 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -256,6 +256,7 @@ async def test_receive_transcript_finished_on_interrupt( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -272,6 +273,7 @@ async def test_receive_transcript_finished_on_interrupt( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -288,6 +290,7 @@ async def test_receive_transcript_finished_on_interrupt( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = True @@ -336,6 +339,7 @@ async def test_receive_transcript_finished_on_generation_complete( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -352,6 +356,7 @@ async def test_receive_transcript_finished_on_generation_complete( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -368,6 +373,7 @@ async def test_receive_transcript_finished_on_generation_complete( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -415,6 +421,7 @@ async def test_receive_transcript_finished_on_turn_complete( message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -431,6 +438,7 @@ async def test_receive_transcript_finished_on_turn_complete( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -447,6 +455,7 @@ async def test_receive_transcript_finished_on_turn_complete( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -493,6 +502,7 @@ async def test_receive_handles_input_transcription_fragments( """Test receive handles input transcription fragments correctly.""" message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -509,6 +519,7 @@ async def test_receive_handles_input_transcription_fragments( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -525,6 +536,7 @@ async def test_receive_handles_input_transcription_fragments( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -568,6 +580,7 @@ async def test_receive_handles_output_transcription_fragments( """Test receive handles output transcription fragments correctly.""" message1 = mock.Mock() message1.usage_metadata = None + message1.setup_complete = None message1.server_content = mock.Mock() message1.server_content.model_turn = None message1.server_content.interrupted = False @@ -584,6 +597,7 @@ async def test_receive_handles_output_transcription_fragments( message2 = mock.Mock() message2.usage_metadata = None + message2.setup_complete = None message2.server_content = mock.Mock() message2.server_content.model_turn = None message2.server_content.interrupted = False @@ -600,6 +614,7 @@ async def test_receive_handles_output_transcription_fragments( message3 = mock.Mock() message3.usage_metadata = None + message3.setup_complete = None message3.server_content = mock.Mock() message3.server_content.model_turn = None message3.server_content.interrupted = False @@ -811,6 +826,20 @@ async def test_send_history_filters_various_audio_mime_types( @pytest.mark.asyncio +async def test_receive_setup_complete(gemini_connection, mock_gemini_session): + """Test receive handles setup_complete signal.""" + + # Create a mock message simulating BidiGenerateContentSetupComplete + message = mock.Mock() + message.setup_complete = True + message.usage_metadata = None + message.server_content = None + message.tool_call = None + message.session_resumption_update = None + + async def mock_receive_generator(): + yield message + async def test_receive_grounding_metadata_standalone( gemini_connection, mock_gemini_session ): @@ -847,6 +876,7 @@ async def mock_receive_generator(): responses = [resp async for resp in gemini_connection.receive()] assert len(responses) == 1 + assert responses[0].setup_complete is True assert responses[0].grounding_metadata == grounding_metadata assert responses[0].content is None diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index d77b13e538..8c582f031f 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -54,6 +54,41 @@ def test_streaming(): ), 'Expected at least one response, but got an empty list.' +def test_live_streaming_setup_complete(): + """Test live streaming with setup complete event.""" + # Create LLM responses: setup complete followed by turn completion + response1 = LlmResponse( + setup_complete=True, + ) + response2 = LlmResponse( + turn_complete=True, + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[], + ) + + runner = testing_utils.InMemoryRunner(root_agent=root_agent) + live_request_queue = LiveRequestQueue() + res_events = runner.run_live(live_request_queue) + + assert res_events is not None, 'Expected a list of events, got None.' + assert len(res_events) >= 1, 'Expected at least one event.' + + # Check that we got a setup complete event + setup_complete_found = False + for event in res_events: + if event.setup_complete: + setup_complete_found = True + break + + assert setup_complete_found, 'Expected a setup complete event.' + + def test_live_streaming_function_call_single(): """Test live streaming with a single function call response.""" # Create a function call response