From a18f48eed24b566b1f31f554d46673b79cf90401 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Tue, 2 Jun 2026 18:03:58 +0000 Subject: [PATCH] fix: Support generalized history config injection for Gemini 3.1 Live on Vertex AI - Exposed history_config in RunConfig. - Mapped history_config to LLM live connect request configuration. - Generalized history connection logic to automatically inject `initial_history_in_client_content = True` when seeding history on a fresh connection for both Gemini API and Vertex AI backends. - Updated and added comprehensive unit tests to verify history configuration behaviour. TAG=agy CONV=822f8c76-9099-4f01-a2b8-10a7de0d61a2 Change-Id: Ib532626d5d7d887b17664567aed94ba09ad90b33 --- src/google/adk/agents/run_config.py | 3 + .../adk/flows/llm_flows/base_llm_flow.py | 22 +++-- src/google/adk/flows/llm_flows/basic.py | 3 + .../flows/llm_flows/test_base_llm_flow.py | 92 +++++++++++++++---- 4 files changed, 96 insertions(+), 24 deletions(-) diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e059cd957d..8126ac5bf3 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -247,6 +247,9 @@ class RunConfig(BaseModel): session_resumption: Optional[types.SessionResumptionConfig] = None """Configures session resumption mechanism. Only support transparent session resumption mode now.""" + history_config: Optional[types.HistoryConfig] = None + """Configures the exchange of history between the client and the server.""" + context_window_compression: Optional[types.ContextWindowCompressionConfig] = ( None ) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index cf3be2b59a..20093237d3 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -529,18 +529,26 @@ async def run_live( if session_resumption.transparent is None: session_resumption.transparent = True + # When seeding a fresh connection with prior conversation history, set + # initial_history_in_client_content to True. This tells the Live server + # that the provided history already includes the model's past responses, + # preventing the server from generating duplicate responses for those replayed turns. if ( - isinstance(llm, Gemini) - and llm._api_backend == GoogleLLMVariant.GEMINI_API - and model_name_utils.is_gemini_3_1_flash_live(llm_request.model) - and llm_request.contents + llm_request.contents and not invocation_context.live_session_resumption_handle ): - if llm_request.live_connect_config is None: + if not llm_request.live_connect_config: llm_request.live_connect_config = types.LiveConnectConfig() - if llm_request.live_connect_config.history_config is None: + if not llm_request.live_connect_config.history_config: llm_request.live_connect_config.history_config = ( - types.HistoryConfig(initial_history_in_client_content=True) + types.HistoryConfig() + ) + if ( + llm_request.live_connect_config.history_config.initial_history_in_client_content + is None + ): + llm_request.live_connect_config.history_config.initial_history_in_client_content = ( + True ) logger.info( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index feb88db7e1..50f03d0bf1 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -95,6 +95,9 @@ def _build_basic_request( llm_request.live_connect_config.session_resumption = ( invocation_context.run_config.session_resumption ) + llm_request.live_connect_config.history_config = ( + invocation_context.run_config.history_config + ) llm_request.live_connect_config.context_window_compression = ( invocation_context.run_config.context_window_compression ) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 8a9bd12e19..7de544b4f1 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -1272,16 +1272,14 @@ async def mock_receive_2(): @pytest.mark.asyncio @pytest.mark.parametrize( - 'api_backend,should_have_history_config', + 'api_backend', [ - (GoogleLLMVariant.GEMINI_API, True), - (GoogleLLMVariant.VERTEX_AI, False), + GoogleLLMVariant.GEMINI_API, + GoogleLLMVariant.VERTEX_AI, ], ) -async def test_run_live_history_config_gated_by_backend( - api_backend, should_have_history_config -): - """Test that run_live only sets history_config for Gemini API backend.""" +async def test_run_live_history_config_set_for_all_backends(api_backend): + """Test that run_live sets history_config for all backends.""" real_model = Gemini(model='gemini-3.1-flash-live-preview') mock_connection = mock.AsyncMock() @@ -1334,13 +1332,73 @@ async def mock_preprocess(ctx, req): assert mock_connect.call_count == 1 called_req = mock_connect.call_args[0][0] - if should_have_history_config: - assert called_req.live_connect_config is not None - assert called_req.live_connect_config.history_config is not None - assert ( - called_req.live_connect_config.history_config.initial_history_in_client_content - is True - ) - else: - if called_req.live_connect_config: - assert called_req.live_connect_config.history_config is None + assert called_req.live_connect_config is not None + assert called_req.live_connect_config.history_config is not None + assert ( + called_req.live_connect_config.history_config.initial_history_in_client_content + is True + ) + + +@pytest.mark.asyncio +async def test_run_live_respects_explicit_initial_history_in_client_content_false(): + """Test that run_live respects explicit initial_history_in_client_content=False in RunConfig.""" + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + run_config = RunConfig( + history_config=types.HistoryConfig( + initial_history_in_client_content=False + ) + ) + invocation_context.run_config = run_config + + flow = BaseLlmFlowForTesting() + + async def mock_preprocess(ctx, req): + req.contents = [types.Content(parts=[types.Part.from_text(text='history')])] + from google.adk.flows.llm_flows.basic import _build_basic_request + + _build_basic_request(ctx, req) + yield Event(id=Event.new_id(), author='test') + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + + class StopTestError(Exception): + pass + + async def mock_receive(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopTestError('stop') + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__.return_value = mock_connection + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopTestError: + pass + + assert mock_connect.call_count == 1 + call_req = mock_connect.call_args[0][0] + assert call_req.live_connect_config.history_config is not None + assert ( + call_req.live_connect_config.history_config.initial_history_in_client_content + is False + )