diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index e059cd957d..7e4120b43d 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -344,6 +344,14 @@ class RunConfig(BaseModel): ) """ + model_input_context: list[types.Content] | None = None + """Transient context to include in the model input for this invocation. + + The Runner does not persist these contents to the session. They are only + added to the LLM request assembled for the current invocation, which lets + callers provide per-turn context without changing the conversation history. + """ + @model_validator(mode='before') @classmethod def check_for_deprecated_save_live_audio(cls, data: Any) -> Any: diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index feeb8ef972..56e9f5aba9 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -85,6 +85,16 @@ async def run_async( preserve_function_call_ids=preserve_function_call_ids, ) + if ( + invocation_context.run_config + and invocation_context.run_config.model_input_context + ): + _add_model_input_context_to_user_content( + invocation_context, + llm_request, + copy.deepcopy(invocation_context.run_config.model_input_context), + ) + # Add instruction-related contents to proper position in conversation await _add_instructions_to_user_content( invocation_context, llm_request, instruction_related_contents @@ -845,6 +855,26 @@ def _content_contains_function_response(content: types.Content) -> bool: return False +def _add_model_input_context_to_user_content( + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_input_context: list[types.Content], +) -> None: + """Insert transient model input context before the invocation user content.""" + if not model_input_context: + return + + insert_index = 0 + user_content = invocation_context.user_content + if user_content: + for i in range(len(llm_request.contents) - 1, -1, -1): + if llm_request.contents[i] == user_content: + insert_index = i + break + + llm_request.contents[insert_index:insert_index] = model_input_context + + async def _add_instructions_to_user_content( invocation_context: InvocationContext, llm_request: LlmRequest, diff --git a/tests/unittests/agents/test_llm_agent_include_contents.py b/tests/unittests/agents/test_llm_agent_include_contents.py index a196f93553..c93701b743 100644 --- a/tests/unittests/agents/test_llm_agent_include_contents.py +++ b/tests/unittests/agents/test_llm_agent_include_contents.py @@ -15,6 +15,7 @@ """Unit tests for LlmAgent include_contents field behavior.""" from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig from google.adk.agents.sequential_agent import SequentialAgent from google.genai import types import pytest @@ -189,6 +190,153 @@ def simple_tool(message: str) -> dict: assert len(mock_model.requests[0].config.tools) > 0 +def test_model_input_context_is_sent_to_model_without_persisting_to_session(): + mock_model = testing_utils.MockModel.create(responses=["Answer"]) + agent = LlmAgent(name="test_agent", model=mock_model) + runner = testing_utils.InMemoryRunner(agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Question"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ] + assert testing_utils.simplify_events(runner.session.events) == [ + ("user", "Question"), + ("test_agent", "Answer"), + ] + + +def test_model_input_context_stays_before_user_message_after_tool_call(): + def simple_tool(message: str) -> dict: + return {"result": f"Tool processed: {message}"} + + mock_model = testing_utils.MockModel.create( + responses=[ + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + "Answer", + ] + ) + agent = LlmAgent(name="test_agent", model=mock_model, tools=[simple_tool]) + runner = testing_utils.InMemoryRunner(agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Question"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ] + assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Question"), + ( + "model", + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + ), + ( + "user", + types.Part.from_function_response( + name="simple_tool", + response={"result": "Tool processed: payload"}, + ), + ), + ] + assert testing_utils.simplify_events(runner.session.events) == [ + ("user", "Question"), + ( + "test_agent", + types.Part.from_function_call( + name="simple_tool", args={"message": "payload"} + ), + ), + ( + "test_agent", + types.Part.from_function_response( + name="simple_tool", + response={"result": "Tool processed: payload"}, + ), + ), + ("test_agent", "Answer"), + ] + + +def test_model_input_context_with_include_contents_none_sub_agent(): + agent1_model = testing_utils.MockModel.create( + responses=["Agent1 response: XYZ"] + ) + agent1 = LlmAgent(name="agent1", model=agent1_model) + + agent2_model = testing_utils.MockModel.create( + responses=["Agent2 final response"] + ) + agent2 = LlmAgent( + name="agent2", + model=agent2_model, + include_contents="none", + ) + sequential_agent = SequentialAgent( + name="sequential_test_agent", sub_agents=[agent1, agent2] + ) + runner = testing_utils.InMemoryRunner(sequential_agent) + session = runner.session + + list( + runner.runner.run( + user_id=session.user_id, + session_id=session.id, + new_message=testing_utils.get_user_content("Original user request"), + run_config=RunConfig( + model_input_context=[ + types.UserContent("Relevant context for this turn") + ] + ), + ) + ) + + assert testing_utils.simplify_contents(agent1_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ("user", "Original user request"), + ] + assert testing_utils.simplify_contents(agent2_model.requests[0].contents) == [ + ("user", "Relevant context for this turn"), + ( + "user", + [ + types.Part(text="For context:"), + types.Part(text="[agent1] said: Agent1 response: XYZ"), + ], + ), + ] + + @pytest.mark.asyncio async def test_include_contents_none_sequential_agents(): """Test include_contents='none' with sequential agents.""" diff --git a/tests/unittests/agents/test_run_config.py b/tests/unittests/agents/test_run_config.py index cbb82af019..c08a1a52c3 100644 --- a/tests/unittests/agents/test_run_config.py +++ b/tests/unittests/agents/test_run_config.py @@ -97,3 +97,11 @@ def test_avatar_config_with_name(): assert run_config.avatar_config == avatar_config assert run_config.avatar_config.avatar_name == "test_avatar" assert run_config.avatar_config.customized_avatar is None + + +def test_model_input_context_accepts_transient_contents(): + context_content = types.UserContent("Relevant context for this turn") + + run_config = RunConfig(model_input_context=[context_content]) + + assert run_config.model_input_context == [context_content]