Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/google/adk/agents/remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,18 +379,18 @@ def _is_remote_response(self, event: Event) -> bool:

def _construct_message_parts_from_session(
self, ctx: InvocationContext
) -> tuple[list[A2APart], Optional[str]]:
) -> tuple[list[A2APart], Optional[str], Optional[str]]:
"""Construct A2A message parts from session events.

Args:
ctx: The invocation context

Returns:
List of A2A parts extracted from session events, context ID,
request metadata
List of A2A parts extracted from session events, context ID, and task ID
"""
message_parts: list[A2APart] = []
context_id = None
task_id = None

events_to_process = []
for event in reversed(ctx.session.events):
Expand All @@ -400,6 +400,18 @@ def _construct_message_parts_from_session(
if event.custom_metadata:
metadata = event.custom_metadata
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")

# Only set task_id if the task state is input-required or auth-required
response_dict = metadata.get(A2A_METADATA_PREFIX + "response")
if isinstance(response_dict, dict):
status_dict = response_dict.get("status")
if isinstance(status_dict, dict):
task_state_val = status_dict.get("state")
if task_state_val in (
TaskState.input_required,
TaskState.auth_required,
):
task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
# Historical note: this behavior originally always applied, regardless
# of whether the agent was stateful or stateless. However, only stateful
# agents can be expected to have previous events in the remote session.
Expand Down Expand Up @@ -427,7 +439,7 @@ def _construct_message_parts_from_session(
else:
logger.warning("Failed to convert part to A2A format: %s", part)

return message_parts, context_id
return message_parts, context_id, task_id

async def _handle_a2a_response(
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
Expand Down Expand Up @@ -624,7 +636,7 @@ async def _run_async_impl(
# Create A2A request for function response or regular message
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
if not a2a_request:
message_parts, context_id = self._construct_message_parts_from_session(
message_parts, context_id, task_id = self._construct_message_parts_from_session(
ctx
)

Expand All @@ -645,6 +657,7 @@ async def _run_async_impl(
parts=message_parts,
role="user",
context_id=context_id,
task_id=task_id,
)

logger.debug(build_a2a_request_log(a2a_request))
Expand Down
98 changes: 89 additions & 9 deletions tests/unittests/agents/test_remote_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_construct_message_parts_from_session_success(self):
mock_a2a_part = Mock()
self.mock_genai_part_converter.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -649,7 +649,7 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
mock_a2a_part2,
]

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -660,7 +660,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -718,7 +718,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
Expand Down Expand Up @@ -768,7 +768,7 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 3
Expand Down Expand Up @@ -823,13 +823,86 @@ def mock_converter(part):
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
assert parts[0].text == "User 2"
assert context_id == "ctx-1"

@pytest.mark.parametrize(
"task_state, expected_task_id",
[
(TaskState.input_required, "task-1"),
(TaskState.auth_required, "task-1"),
(TaskState.completed, None),
("completed", None),
(None, None),
],
)
def test_construct_message_parts_from_session_task_id_filtering(
self, task_state, expected_task_id
):
"""Test task_id is only extracted if task state is input_required or auth_required."""
part1 = Mock()
part1.text = "User 1"
content1 = Mock()
content1.parts = [part1]
user1 = Mock()
user1.content = content1
user1.author = "user"
user1.custom_metadata = None

part2 = Mock()
part2.text = "Agent 1"
content2 = Mock()
content2.parts = [part2]
agent1 = Mock()
agent1.content = content2
agent1.author = self.agent.name

status_dict = {}
if task_state is not None:
status_dict["state"] = task_state

agent1.custom_metadata = {
A2A_METADATA_PREFIX + "response": {
"status": status_dict
},
A2A_METADATA_PREFIX + "context_id": "ctx-1",
A2A_METADATA_PREFIX + "task_id": "task-1",
}

part3 = Mock()
part3.text = "User 2"
content3 = Mock()
content3.parts = [part3]
user2 = Mock()
user2.content = content3
user2.author = "user"
user2.custom_metadata = None

self.mock_session.events = [user1, agent1, user2]

def mock_converter(part):
mock_a2a_part = Mock()
mock_a2a_part.text = part.text
return mock_a2a_part

self.mock_genai_part_converter.side_effect = mock_converter

with patch(
"google.adk.agents.remote_a2a_agent._present_other_agent_message"
) as mock_present:
mock_present.side_effect = lambda event: event
parts, context_id, task_id = self.agent._construct_message_parts_from_session(
self.mock_context
)
assert len(parts) == 1
assert parts[0].text == "User 2"
assert context_id == "ctx-1"
assert task_id == expected_task_id

@pytest.mark.asyncio
async def test_handle_a2a_response_success_with_message(self):
"""Test successful A2A response handling with message."""
Expand Down Expand Up @@ -954,7 +1027,7 @@ def mock_converter(part):

self.mock_genai_part_converter.side_effect = mock_converter

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1378,7 +1451,7 @@ def test_construct_message_parts_from_session_success(self):
mock_a2a_part = Mock()
mock_convert_part.return_value = mock_a2a_part

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand All @@ -1390,7 +1463,7 @@ def test_construct_message_parts_from_session_empty_events(self):
"""Test message parts construction with empty events."""
self.mock_session.events = []

parts, context_id = self.agent._construct_message_parts_from_session(
parts, context_id, _ = self.agent._construct_message_parts_from_session(
self.mock_context
)

Expand Down Expand Up @@ -1969,6 +2042,7 @@ async def test_run_async_impl_no_message_parts(self):
mock_construct.return_value = (
[],
None,
None,
) # Tuple with empty parts and no context_id

events = []
Expand Down Expand Up @@ -1999,6 +2073,7 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2071,6 +2146,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down Expand Up @@ -2138,6 +2214,7 @@ async def test_run_async_impl_with_meta_provider(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2245,6 +2322,7 @@ async def test_run_async_impl_no_message_parts(self):
mock_construct.return_value = (
[],
None,
None,
) # Tuple with empty parts and no context_id

events = []
Expand Down Expand Up @@ -2275,6 +2353,7 @@ async def test_run_async_impl_successful_request(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client
Expand Down Expand Up @@ -2349,6 +2428,7 @@ async def test_run_async_impl_a2a_client_error(self):
mock_construct.return_value = (
[mock_a2a_part],
"context-123",
None,
) # Tuple with parts and context_id

# Mock A2A client that throws an exception
Expand Down
Loading