diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 6072a5ddcb..cf0ec073b8 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -379,18 +379,18 @@ def _is_remote_response(self, event: Event) -> bool: def _construct_message_parts_from_session( self, ctx: InvocationContext - ) -> tuple[list[A2APart], Optional[str]]: + ) -> tuple[list[A2APart], Optional[str], Optional[str]]: """Construct A2A message parts from session events. Args: ctx: The invocation context Returns: - List of A2A parts extracted from session events, context ID, - request metadata + List of A2A parts extracted from session events, context ID, and task ID """ message_parts: list[A2APart] = [] context_id = None + task_id = None events_to_process = [] for event in reversed(ctx.session.events): @@ -400,6 +400,18 @@ def _construct_message_parts_from_session( if event.custom_metadata: metadata = event.custom_metadata context_id = metadata.get(A2A_METADATA_PREFIX + "context_id") + + # Only set task_id if the task state is input-required or auth-required + response_dict = metadata.get(A2A_METADATA_PREFIX + "response") + if isinstance(response_dict, dict): + status_dict = response_dict.get("status") + if isinstance(status_dict, dict): + task_state_val = status_dict.get("state") + if task_state_val in ( + TaskState.input_required, + TaskState.auth_required, + ): + task_id = metadata.get(A2A_METADATA_PREFIX + "task_id") # Historical note: this behavior originally always applied, regardless # of whether the agent was stateful or stateless. However, only stateful # agents can be expected to have previous events in the remote session. @@ -427,7 +439,7 @@ def _construct_message_parts_from_session( else: logger.warning("Failed to convert part to A2A format: %s", part) - return message_parts, context_id + return message_parts, context_id, task_id async def _handle_a2a_response( self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext @@ -624,7 +636,7 @@ async def _run_async_impl( # Create A2A request for function response or regular message a2a_request = self._create_a2a_request_for_user_function_response(ctx) if not a2a_request: - message_parts, context_id = self._construct_message_parts_from_session( + message_parts, context_id, task_id = self._construct_message_parts_from_session( ctx ) @@ -645,6 +657,7 @@ async def _run_async_impl( parts=message_parts, role="user", context_id=context_id, + task_id=task_id, ) logger.debug(build_a2a_request_log(a2a_request)) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 0f1ce896a3..d741d90c65 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -615,7 +615,7 @@ def test_construct_message_parts_from_session_success(self): mock_a2a_part = Mock() self.mock_genai_part_converter.return_value = mock_a2a_part - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -649,7 +649,7 @@ def test_construct_message_parts_from_session_success_multiple_parts(self): mock_a2a_part2, ] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -660,7 +660,7 @@ def test_construct_message_parts_from_session_empty_events(self): """Test message parts construction with empty events.""" self.mock_session.events = [] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -718,7 +718,7 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 1 @@ -768,7 +768,7 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 3 @@ -823,13 +823,86 @@ def mock_converter(part): "google.adk.agents.remote_a2a_agent._present_other_agent_message" ) as mock_present: mock_present.side_effect = lambda event: event - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) assert len(parts) == 1 assert parts[0].text == "User 2" assert context_id == "ctx-1" + @pytest.mark.parametrize( + "task_state, expected_task_id", + [ + (TaskState.input_required, "task-1"), + (TaskState.auth_required, "task-1"), + (TaskState.completed, None), + ("completed", None), + (None, None), + ], + ) + def test_construct_message_parts_from_session_task_id_filtering( + self, task_state, expected_task_id + ): + """Test task_id is only extracted if task state is input_required or auth_required.""" + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + + status_dict = {} + if task_state is not None: + status_dict["state"] = task_state + + agent1.custom_metadata = { + A2A_METADATA_PREFIX + "response": { + "status": status_dict + }, + A2A_METADATA_PREFIX + "context_id": "ctx-1", + A2A_METADATA_PREFIX + "task_id": "task-1", + } + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id, task_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id == "ctx-1" + assert task_id == expected_task_id + @pytest.mark.asyncio async def test_handle_a2a_response_success_with_message(self): """Test successful A2A response handling with message.""" @@ -954,7 +1027,7 @@ def mock_converter(part): self.mock_genai_part_converter.side_effect = mock_converter - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1378,7 +1451,7 @@ def test_construct_message_parts_from_session_success(self): mock_a2a_part = Mock() mock_convert_part.return_value = mock_a2a_part - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1390,7 +1463,7 @@ def test_construct_message_parts_from_session_empty_events(self): """Test message parts construction with empty events.""" self.mock_session.events = [] - parts, context_id = self.agent._construct_message_parts_from_session( + parts, context_id, _ = self.agent._construct_message_parts_from_session( self.mock_context ) @@ -1969,6 +2042,7 @@ async def test_run_async_impl_no_message_parts(self): mock_construct.return_value = ( [], None, + None, ) # Tuple with empty parts and no context_id events = [] @@ -1999,6 +2073,7 @@ async def test_run_async_impl_successful_request(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client @@ -2071,6 +2146,7 @@ async def test_run_async_impl_a2a_client_error(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client that throws an exception @@ -2138,6 +2214,7 @@ async def test_run_async_impl_with_meta_provider(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client @@ -2245,6 +2322,7 @@ async def test_run_async_impl_no_message_parts(self): mock_construct.return_value = ( [], None, + None, ) # Tuple with empty parts and no context_id events = [] @@ -2275,6 +2353,7 @@ async def test_run_async_impl_successful_request(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client @@ -2349,6 +2428,7 @@ async def test_run_async_impl_a2a_client_error(self): mock_construct.return_value = ( [mock_a2a_part], "context-123", + None, ) # Tuple with parts and context_id # Mock A2A client that throws an exception