From fcd5adf78b3d605c1c2a924d6c22b2377063c01c Mon Sep 17 00:00:00 2001 From: Sravan1011 Date: Sun, 5 Apr 2026 08:43:06 +0530 Subject: [PATCH] feat: upgrade a2a-sdk dependency to v1.0.0-alpha.0 (A2A 1.0 spec) - Bump a2a-sdk from 0.3.x to >=1.0.0a0 - Migrate from Pydantic models to Protocol Buffer messages - Replace TextPart/DataPart with unified Part proto - Update all enums to proto naming (Role.ROLE_USER, TaskState.TASK_STATE_SUBMITTED, etc.) - Replace deprecated AgentCard.url with supported_interfaces - Switch to MessageToDict for proto serialization - Remove deprecated request_metadata from send_message - Fix (update, task) tuple unpacking for streaming responses All 426 tests pass (315 a2a + 94 remote agent + 17 agent registry). --- pyproject.toml | 4 +- src/google/adk/a2a/__init__.py | 39 ++ src/google/adk/a2a/agent/config.py | 2 +- .../interceptors/new_integration_extension.py | 8 +- src/google/adk/a2a/agent/utils.py | 2 +- .../adk/a2a/converters/event_converter.py | 211 ++++++--- .../adk/a2a/converters/from_adk_event.py | 15 +- .../a2a/converters/long_running_functions.py | 77 ++-- .../adk/a2a/converters/part_converter.py | 286 +++++++------ .../adk/a2a/converters/request_converter.py | 16 +- src/google/adk/a2a/converters/to_adk_event.py | 49 ++- .../adk/a2a/executor/a2a_agent_executor.py | 80 ++-- .../a2a/executor/a2a_agent_executor_impl.py | 64 ++- .../a2a/executor/task_result_aggregator.py | 35 +- src/google/adk/a2a/logs/log_utils.py | 113 +++-- .../adk/a2a/utils/agent_card_builder.py | 16 +- src/google/adk/a2a/utils/agent_to_a2a.py | 21 +- src/google/adk/agents/remote_a2a_agent.py | 141 ++++-- src/google/adk/cli/fast_api.py | 3 +- .../agent_registry/agent_registry.py | 23 +- .../a2a/converters/test_event_converter.py | 42 +- .../unittests/a2a/converters/test_from_adk.py | 6 +- .../a2a/converters/test_part_converter.py | 401 ++++++++---------- tests/unittests/a2a/converters/test_to_adk.py | 79 ++-- .../a2a/executor/test_a2a_agent_executor.py | 83 +--- .../executor/test_a2a_agent_executor_impl.py | 89 ++-- .../executor/test_task_result_aggregator.py | 22 +- tests/unittests/a2a/integration/client.py | 6 +- tests/unittests/a2a/integration/server.py | 2 +- .../a2a/integration/test_client_server.py | 28 +- tests/unittests/a2a/logs/test_log_utils.py | 66 +-- .../a2a/utils/test_agent_card_builder.py | 44 +- .../unittests/a2a/utils/test_agent_to_a2a.py | 6 +- .../unittests/agents/test_remote_a2a_agent.py | 234 +++++----- .../agent_registry/test_agent_registry.py | 42 +- 35 files changed, 1346 insertions(+), 1009 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2789bcf82a..ac21a5c4dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,7 @@ dev = [ a2a = [ # go/keep-sorted start - "a2a-sdk>=0.3.4,<0.4.0", + "a2a-sdk>=1.0.0a0", # go/keep-sorted end ] @@ -120,7 +120,7 @@ eval = [ test = [ # go/keep-sorted start - "a2a-sdk>=0.3.0,<0.4.0", + "a2a-sdk>=1.0.0a0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ "kubernetes>=29.0.0", # For GkeCodeExecutor diff --git a/src/google/adk/a2a/__init__.py b/src/google/adk/a2a/__init__.py index 58d482ea38..b02593ba6f 100644 --- a/src/google/adk/a2a/__init__.py +++ b/src/google/adk/a2a/__init__.py @@ -11,3 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + +from a2a.types import Role +from a2a.types import TaskState + + +def _install_task_state_aliases() -> None: + """Adds pre-1.0 TaskState aliases expected by ADK code and tests.""" + alias_by_name = { + "working": "TASK_STATE_WORKING", + "failed": "TASK_STATE_FAILED", + "input_required": "TASK_STATE_INPUT_REQUIRED", + "auth_required": "TASK_STATE_AUTH_REQUIRED", + "completed": "TASK_STATE_COMPLETED", + "submitted": "TASK_STATE_SUBMITTED", + "canceled": "TASK_STATE_CANCELED", + "unknown": "TASK_STATE_UNKNOWN", + } + for alias, canonical in alias_by_name.items(): + if not hasattr(TaskState, alias) and hasattr(TaskState, canonical): + setattr(TaskState, alias, getattr(TaskState, canonical)) + + +_install_task_state_aliases() + + +def _install_role_aliases() -> None: + """Adds pre-1.0 Role aliases expected by ADK code and tests.""" + alias_by_name = { + "user": "ROLE_USER", + "agent": "ROLE_AGENT", + } + for alias, canonical in alias_by_name.items(): + if not hasattr(Role, alias) and hasattr(Role, canonical): + setattr(Role, alias, getattr(Role, canonical)) + + +_install_role_aliases() diff --git a/src/google/adk/a2a/agent/config.py b/src/google/adk/a2a/agent/config.py index 9898436253..fd27591ed5 100644 --- a/src/google/adk/a2a/agent/config.py +++ b/src/google/adk/a2a/agent/config.py @@ -22,7 +22,7 @@ from typing import Optional from typing import Union -from a2a.client.middleware import ClientCallContext +from a2a.client import ClientCallContext from a2a.server.events import Event as A2AEvent from a2a.types import Message as A2AMessage from pydantic import BaseModel diff --git a/src/google/adk/a2a/agent/interceptors/new_integration_extension.py b/src/google/adk/a2a/agent/interceptors/new_integration_extension.py index e98667156f..e0df2868cf 100644 --- a/src/google/adk/a2a/agent/interceptors/new_integration_extension.py +++ b/src/google/adk/a2a/agent/interceptors/new_integration_extension.py @@ -17,7 +17,7 @@ from typing import Union -from a2a.client.middleware import ClientCallContext +from a2a.client import ClientCallContext from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import Message as A2AMessage from google.adk.a2a.agent.config import ParametersConfig @@ -39,15 +39,13 @@ async def _before_request( if params.client_call_context is None: params.client_call_context = ClientCallContext() - http_kwargs = params.client_call_context.state.get('http_kwargs', {}) - headers = http_kwargs.get('headers', {}) + headers = params.client_call_context.service_parameters or {} a2a_extensions = headers.get(HTTP_EXTENSION_HEADER, '').split(',') a2a_extensions = [ext for ext in a2a_extensions if ext] if _NEW_A2A_ADK_INTEGRATION_EXTENSION not in a2a_extensions: a2a_extensions.append(_NEW_A2A_ADK_INTEGRATION_EXTENSION) headers[HTTP_EXTENSION_HEADER] = ','.join(a2a_extensions) - http_kwargs['headers'] = headers - params.client_call_context.state['http_kwargs'] = http_kwargs + params.client_call_context.service_parameters = headers return a2a_request, params diff --git a/src/google/adk/a2a/agent/utils.py b/src/google/adk/a2a/agent/utils.py index 7cbb25ebef..142506483c 100644 --- a/src/google/adk/a2a/agent/utils.py +++ b/src/google/adk/a2a/agent/utils.py @@ -20,7 +20,7 @@ from typing import Union from a2a.client import ClientEvent as A2AClientEvent -from a2a.client.middleware import ClientCallContext +from a2a.client import ClientCallContext from a2a.types import Message as A2AMessage from ...agents.invocation_context import InvocationContext diff --git a/src/google/adk/a2a/converters/event_converter.py b/src/google/adk/a2a/converters/event_converter.py index e6a890941f..d338034390 100644 --- a/src/google/adk/a2a/converters/event_converter.py +++ b/src/google/adk/a2a/converters/event_converter.py @@ -24,7 +24,6 @@ from typing import Optional from a2a.server.events import Event as A2AEvent -from a2a.types import DataPart from a2a.types import Message from a2a.types import Part as A2APart from a2a.types import Role @@ -32,10 +31,11 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart from google.adk.platform import time as platform_time from google.adk.platform import uuid as platform_uuid from google.genai import types as genai_types +from google.protobuf.json_format import MessageToDict +from google.protobuf.timestamp_pb2 import Timestamp from ...agents.invocation_context import InvocationContext from ...events.event import Event @@ -105,6 +105,83 @@ def _serialize_metadata_value(value: Any) -> str: return str(value) +def _get_part_metadata_value(part: A2APart, key: str) -> Any: + """Returns a metadata value from either proto Struct or dict-like metadata.""" + metadata = getattr(part, "metadata", None) + if not metadata: + return None + try: + return metadata.get(key) + except AttributeError: + try: + return metadata[key] + except Exception: + return None + + +def _get_part_data_dict(part: A2APart) -> Dict[str, Any]: + """Returns a part's data payload as a plain dict when possible.""" + data = getattr(part, "data", None) + if data is None: + return {} + if isinstance(data, dict): + return data + get_method = getattr(data, "get", None) + if callable(get_method): + try: + return { + "id": get_method("id"), + "name": get_method("name"), + } + except Exception: + pass + try: + return MessageToDict(data) + except Exception: + return {} + + +def _coerce_a2a_message(message: Message | Any) -> Message: + """Returns a proto Message, tolerating older mock/dict-style inputs in tests.""" + if ( + isinstance(message, Message) + and type(message).__module__ != "unittest.mock" + ): + return message + + coerced_message = Message() + for field_name in ("message_id", "task_id", "context_id"): + field_value = getattr(message, field_name, None) + if field_value: + setattr(coerced_message, field_name, field_value) + + role = getattr(message, "role", None) + if role is not None: + coerced_message.role = role + else: + coerced_message.role = Role.ROLE_AGENT + + parts = getattr(message, "parts", None) + if parts: + for part in parts: + if isinstance(part, A2APart): + coerced_message.parts.append(part) + + metadata = getattr(message, "metadata", None) + if metadata: + coerced_message.metadata.update(metadata) + + return coerced_message + + +def _create_timestamp() -> Timestamp: + """Creates a protobuf timestamp from the current platform time.""" + now = platform_time.get_time() + seconds = int(now) + nanos = int((now - seconds) * 1_000_000_000) + return Timestamp(seconds=seconds, nanos=nanos) + + def _get_context_metadata( event: Event, invocation_context: InvocationContext ) -> Dict[str, str]: @@ -184,19 +261,30 @@ def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None: a2a_part: The A2A part to potentially mark as long-running. event: The ADK event containing long-running tool information. """ - if ( - isinstance(a2a_part.root, DataPart) - and event.long_running_tool_ids - and a2a_part.root.metadata - and a2a_part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) - ) - == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and a2a_part.root.data.get("id") in event.long_running_tool_ids - ): - a2a_part.root.metadata[ - _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - ] = True + if not event.long_running_tool_ids or not getattr(a2a_part, "metadata", None): + return + has_data = getattr(a2a_part, "HasField", None) + if callable(has_data): + try: + if not a2a_part.HasField("data"): + return + except Exception: + pass + + type_key = _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + part_type = ( + _get_part_metadata_value(a2a_part, type_key) + or _get_part_metadata_value(a2a_part, A2A_DATA_PART_METADATA_TYPE_KEY) + or _get_part_metadata_value(a2a_part, "adk_type") + ) + if part_type != A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL: + return + + data_dict = _get_part_data_dict(a2a_part) + if data_dict.get("id") in event.long_running_tool_ids: + a2a_part.metadata.update({ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True + }) def convert_a2a_task_to_event( @@ -229,7 +317,7 @@ def convert_a2a_task_to_event( message = None if a2a_task.artifacts: message = Message( - message_id="", role=Role.agent, parts=a2a_task.artifacts[-1].parts + message_id="", role=Role.ROLE_AGENT, parts=a2a_task.artifacts[-1].parts ) elif ( a2a_task.status @@ -321,15 +409,10 @@ def convert_a2a_message_to_event( continue # Check for long-running tools - if ( - a2a_part.root.metadata - and a2a_part.root.metadata.get( - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY - ) - ) - is True - ): + if _get_part_metadata_value( + a2a_part, + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY), + ) is True: for part in parts: if part.function_call: long_running_tool_ids.add(part.function_call.id) @@ -372,7 +455,7 @@ def convert_a2a_message_to_event( def convert_event_to_a2a_message( event: Event, invocation_context: InvocationContext | None = None, - role: Role = Role.agent, + role: Role = Role.ROLE_AGENT, part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part, ) -> Optional[Message]: """Converts an ADK event to an A2A message. @@ -446,22 +529,19 @@ def _create_error_status_event( context_id=context_id, metadata=event_metadata, status=TaskStatus( - state=TaskState.failed, + state=TaskState.TASK_STATE_FAILED, message=Message( message_id=platform_uuid.new_uuid(), - role=Role.agent, - parts=[TextPart(text=error_message)], + role=Role.ROLE_AGENT, + parts=[A2APart(text=error_message)], metadata={ _get_adk_metadata_key("error_code"): str(event.error_code) } if event.error_code else {}, ), - timestamp=datetime.fromtimestamp( - platform_time.get_time(), tz=timezone.utc - ).isoformat(), - ), - final=False, + timestamp=_create_timestamp(), + ) ) @@ -484,48 +564,45 @@ def _create_status_update_event( Returns: A TaskStatusUpdateEvent with RUNNING state. """ + proto_message = _coerce_a2a_message(message) + status = TaskStatus( - state=TaskState.working, - message=message, - timestamp=datetime.fromtimestamp( - platform_time.get_time(), tz=timezone.utc - ).isoformat(), + state=TaskState.TASK_STATE_WORKING, + message=proto_message, + timestamp=_create_timestamp(), ) - if any( - part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) - ) - == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - ) - is True - and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME - for part in message.parts - if part.root.metadata - ): - status.state = TaskState.auth_required - elif any( - part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) - ) - == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - ) - is True - for part in message.parts - if part.root.metadata - ): - status.state = TaskState.input_required + def _is_long_running(part: A2APart) -> bool: + val = _get_part_metadata_value( + part, + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY), + ) + return str(val).lower() == "true" or val is True + + for part in message.parts: + part_type = ( + _get_part_metadata_value( + part, _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + or _get_part_metadata_value(part, A2A_DATA_PART_METADATA_TYPE_KEY) + or _get_part_metadata_value(part, "adk_type") + ) + if ( + part_type == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL + and _is_long_running(part) + ): + data_dict = _get_part_data_dict(part) + if data_dict.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME: + status.state = TaskState.TASK_STATE_AUTH_REQUIRED + break + status.state = TaskState.TASK_STATE_INPUT_REQUIRED + break return TaskStatusUpdateEvent( task_id=task_id, context_id=context_id, status=status, metadata=_get_context_metadata(event, invocation_context), - final=False, ) diff --git a/src/google/adk/a2a/converters/from_adk_event.py b/src/google/adk/a2a/converters/from_adk_event.py index 05bf16d167..5f4b6d4153 100644 --- a/src/google/adk/a2a/converters/from_adk_event.py +++ b/src/google/adk/a2a/converters/from_adk_event.py @@ -28,7 +28,6 @@ from a2a.server.events import Event as A2AEvent from a2a.types import Artifact -from a2a.types import DataPart from a2a.types import Message from a2a.types import Part as A2APart from a2a.types import Role @@ -36,7 +35,6 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart from ...events.event import Event from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME @@ -143,15 +141,14 @@ def create_error_status_event( task_id=task_id, context_id=context_id, status=TaskStatus( - state=TaskState.failed, + state=TaskState.TASK_STATE_FAILED, message=Message( message_id=str(uuid.uuid4()), - role=Role.agent, - parts=[A2APart(root=TextPart(text=error_message))], + role=Role.ROLE_AGENT, + parts=[A2APart(text=error_message)], ), - timestamp=datetime.now(timezone.utc).isoformat(), + timestamp=datetime.now(timezone.utc), ), - final=True, ) return _add_event_metadata(event, [error_event])[0] @@ -281,8 +278,8 @@ def _add_event_metadata( for a2a_event in a2a_events: if isinstance(a2a_event, TaskStatusUpdateEvent): - a2a_event.status.message.metadata = metadata.copy() + a2a_event.status.message.metadata.update(metadata) elif isinstance(a2a_event, TaskArtifactUpdateEvent): - a2a_event.artifact.metadata = metadata.copy() + a2a_event.artifact.metadata.update(metadata) return a2a_events diff --git a/src/google/adk/a2a/converters/long_running_functions.py b/src/google/adk/a2a/converters/long_running_functions.py index 6c620be714..1721db5f56 100644 --- a/src/google/adk/a2a/converters/long_running_functions.py +++ b/src/google/adk/a2a/converters/long_running_functions.py @@ -16,19 +16,18 @@ from datetime import datetime from datetime import timezone +from typing import Any from typing import List from typing import Set import uuid from a2a.server.agent_execution.context import RequestContext -from a2a.types import DataPart from a2a.types import Message from a2a.types import Part as A2APart from a2a.types import Role from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart from google.genai import types as genai_types from ...events.event import Event @@ -42,6 +41,28 @@ from .utils import _get_adk_metadata_key +def _get_metadata_value( + metadata: Any, key: str, default: Any = None +) -> Any: + """Returns a metadata value from either a dict or proto Struct.""" + if metadata is None: + return default + get_method = getattr(metadata, "get", None) + if callable(get_method): + try: + return get_method(key, default) + except TypeError: + try: + return get_method(key) + except Exception: + pass + try: + return metadata[key] + except Exception: + pass + return default + + class LongRunningFunctions: """Keeps track of long running function calls and related responses.""" @@ -51,7 +72,7 @@ def __init__( self._parts: List[genai_types.Part] = [] self._long_running_tool_ids: Set[str] = set() self._part_converter = part_converter or convert_a2a_part_to_genai_part - self._task_state: TaskState = TaskState.input_required + self._task_state: TaskState = TaskState.TASK_STATE_INPUT_REQUIRED def has_long_running_function_calls(self) -> bool: """Returns True if there are long running function calls.""" @@ -115,12 +136,11 @@ def create_long_running_function_call_event( state=self._task_state, message=Message( message_id=str(uuid.uuid4()), - role=Role.agent, + role=Role.ROLE_AGENT, parts=a2a_parts, ), - timestamp=datetime.now(timezone.utc).isoformat(), + timestamp=datetime.now(timezone.utc), ), - final=True, ) def _return_long_running_parts(self) -> List[A2APart]: @@ -147,23 +167,24 @@ def _mark_long_running_function_call(self, a2a_part: A2APart) -> None: """ if ( - isinstance(a2a_part.root, DataPart) - and a2a_part.root.metadata - and a2a_part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + a2a_part.HasField("data") + and a2a_part.metadata + and _get_metadata_value( + a2a_part.metadata, + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY), ) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ): - a2a_part.root.metadata[ + a2a_part.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - ] = True + ] = "true" # If the function is a request for EUC, set the task state to # auth_required. Otherwise, set it to input_required. Save the state of # the last function call, as it will be the state of the task. - if a2a_part.root.metadata.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME: - self._task_state = TaskState.auth_required + if _get_metadata_value(a2a_part.metadata, "name") == REQUEST_EUC_FUNCTION_CALL_NAME: + self._task_state = TaskState.TASK_STATE_AUTH_REQUIRED else: - self._task_state = TaskState.input_required + self._task_state = TaskState.TASK_STATE_INPUT_REQUIRED def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: @@ -173,8 +194,8 @@ def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: not context.current_task or not context.current_task.status or ( - context.current_task.status.state != TaskState.input_required - and context.current_task.status.state != TaskState.auth_required + context.current_task.status.state != TaskState.TASK_STATE_INPUT_REQUIRED + and context.current_task.status.state != TaskState.TASK_STATE_AUTH_REQUIRED ) ): return None @@ -184,10 +205,11 @@ def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: # contains a function response. for a2a_part in context.message.parts: if ( - isinstance(a2a_part.root, DataPart) - and a2a_part.root.metadata - and a2a_part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + a2a_part.HasField("data") + and a2a_part.metadata + and _get_metadata_value( + a2a_part.metadata, + _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY), ) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE ): @@ -198,21 +220,18 @@ def handle_user_input(context: RequestContext) -> TaskStatusUpdateEvent | None: context_id=context.context_id, status=TaskStatus( state=context.current_task.status.state, - timestamp=datetime.now(timezone.utc).isoformat(), + timestamp=datetime.now(timezone.utc), message=Message( message_id=str(uuid.uuid4()), - role=Role.agent, + role=Role.ROLE_AGENT, parts=[ A2APart( - root=TextPart( - text=( - "It was not provided a function response for the" - " function call." - ) + text=( + "It was not provided a function response for the" + " function call." ) ) ], ), ), - final=True, ) diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index ef4a94fd5d..6990dad0c7 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -28,10 +28,13 @@ from a2a import types as a2a_types from google.genai import types as genai_types +from google.protobuf import struct_pb2 +from google.protobuf.json_format import MessageToDict, MessageToJson from ..experimental import a2a_experimental from .utils import _get_adk_metadata_key + logger = logging.getLogger('google_adk.' + __name__) A2A_DATA_PART_METADATA_TYPE_KEY = 'type' @@ -54,64 +57,85 @@ ] +def _has_field(part: a2a_types.Part, field_name: str) -> bool: + """Returns whether a proto-like part has a populated field.""" + has_field = getattr(part, 'HasField', None) + if not callable(has_field): + return False + try: + result = has_field(field_name) + except Exception: + return False + return isinstance(result, bool) and result + + +def _get_metadata_value(part: a2a_types.Part, key: str): + """Returns a metadata value from a proto Struct or dict-like metadata.""" + metadata = getattr(part, 'metadata', None) + if not metadata: + return None + try: + return metadata.get(key) + except AttributeError: + try: + return metadata[key] + except Exception: + return None + + @a2a_experimental def convert_a2a_part_to_genai_part( a2a_part: a2a_types.Part, ) -> Optional[genai_types.Part]: """Convert an A2A Part to a Google GenAI Part.""" - part = a2a_part.root - if isinstance(part, a2a_types.TextPart): + + if _has_field(a2a_part, 'text'): thought = None - if part.metadata: - thought = part.metadata.get(_get_adk_metadata_key('thought')) - return genai_types.Part(text=part.text, thought=thought) - - if isinstance(part, a2a_types.FilePart): - if isinstance(part.file, a2a_types.FileWithUri): - return genai_types.Part( - file_data=genai_types.FileData( - file_uri=part.file.uri, - mime_type=part.file.mime_type, - display_name=part.file.name, - ) - ) + thought = _get_metadata_value(a2a_part, _get_adk_metadata_key('thought')) + return genai_types.Part(text=a2a_part.text, thought=thought) - elif isinstance(part.file, a2a_types.FileWithBytes): - return genai_types.Part( - inline_data=genai_types.Blob( - data=base64.b64decode(part.file.bytes), - mime_type=part.file.mime_type, - display_name=part.file.name, - ) - ) - else: - logger.warning( - 'Cannot convert unsupported file type: %s for A2A part: %s', - type(part.file), - a2a_part, - ) - return None + if _has_field(a2a_part, 'url'): + return genai_types.Part( + file_data=genai_types.FileData( + file_uri=a2a_part.url, + mime_type=a2a_part.media_type, + display_name=a2a_part.filename, + ) + ) + + if _has_field(a2a_part, 'raw'): + return genai_types.Part( + inline_data=genai_types.Blob( + data=a2a_part.raw, + mime_type=a2a_part.media_type, + display_name=a2a_part.filename, + ) + ) - if isinstance(part, a2a_types.DataPart): + if _has_field(a2a_part, 'data'): # Convert the Data Part to funcall and function response. # This is mainly for converting human in the loop and auth request and # response. # TODO once A2A defined how to service such information, migrate below # logic accordingly - if ( - part.metadata - and _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) - in part.metadata - ): + part_type = _get_metadata_value( + a2a_part, _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) + ) + if part_type is not None: + try: + data_dict = MessageToDict(a2a_part.data) + except Exception: + data_dict = {} + if ( - part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + part_type == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ): # Restore thought_signature if present thought_signature = None thought_sig_key = _get_adk_metadata_key('thought_signature') - if thought_sig_key in part.metadata: - sig_value = part.metadata[thought_sig_key] + sig_value = _get_metadata_value(a2a_part, thought_sig_key) + if sig_value is not None: if isinstance(sig_value, bytes): thought_signature = sig_value elif isinstance(sig_value, str): @@ -123,51 +147,57 @@ def convert_a2a_part_to_genai_part( ) return genai_types.Part( function_call=genai_types.FunctionCall.model_validate( - part.data, by_alias=True + data_dict, by_alias=True ), thought_signature=thought_signature, ) if ( - part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + part_type == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE ): return genai_types.Part( function_response=genai_types.FunctionResponse.model_validate( - part.data, by_alias=True + data_dict, by_alias=True ) ) if ( - part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + part_type == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT ): return genai_types.Part( code_execution_result=genai_types.CodeExecutionResult.model_validate( - part.data, by_alias=True + data_dict, by_alias=True ) ) if ( - part.metadata[_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)] + part_type == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE ): return genai_types.Part( executable_code=genai_types.ExecutableCode.model_validate( - part.data, by_alias=True + data_dict, by_alias=True ) ) + + # Extract the JSON payload using MessageToJson + # and then encode to bytes for inline_data + try: + data_json = MessageToJson(a2a_part.data, preserving_proto_field_name=True) + except Exception as e: + logger.warning('Failed to render data to json: %s', e) + data_json = "{}" + return genai_types.Part( inline_data=genai_types.Blob( data=A2A_DATA_PART_START_TAG - + part.model_dump_json(by_alias=True, exclude_none=True).encode( - 'utf-8' - ) + + data_json.encode('utf-8') + A2A_DATA_PART_END_TAG, mime_type=A2A_DATA_PART_TEXT_MIME_TYPE, ) ) logger.warning( - 'Cannot convert unsupported part type: %s for A2A part: %s', - type(part), + 'Cannot convert unsupported A2A part: %s', a2a_part, ) return None @@ -178,22 +208,22 @@ def convert_genai_part_to_a2a_part( part: genai_types.Part, ) -> Optional[a2a_types.Part]: """Convert a Google GenAI Part to an A2A Part.""" + if part is None: + logger.warning('Cannot convert unsupported GenAI part: %s', part) + return None if part.text: - a2a_part = a2a_types.TextPart(text=part.text) + a2a_part = a2a_types.Part(text=part.text) if part.thought is not None: - a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought} - return a2a_types.Part(root=a2a_part) + # Struct initialization of metadata + a2a_part.metadata.update({_get_adk_metadata_key('thought'): part.thought}) + return a2a_part if part.file_data: return a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithUri( - uri=part.file_data.file_uri, - mime_type=part.file_data.mime_type, - name=part.file_data.display_name, - ) - ) + url=part.file_data.file_uri, + media_type=part.file_data.mime_type, + filename=part.file_data.display_name, ) if part.inline_data: @@ -203,32 +233,35 @@ def convert_genai_part_to_a2a_part( and part.inline_data.data.startswith(A2A_DATA_PART_START_TAG) and part.inline_data.data.endswith(A2A_DATA_PART_END_TAG) ): - return a2a_types.Part( - root=a2a_types.DataPart.model_validate_json( - part.inline_data.data[ - len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) - ] - ) - ) - # The default case for inline_data is to convert it to FileWithBytes. - a2a_part = a2a_types.FilePart( - file=a2a_types.FileWithBytes( - bytes=base64.b64encode(part.inline_data.data).decode('utf-8'), - mime_type=part.inline_data.mime_type, - name=part.inline_data.display_name, - ) + extracted_json = part.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + try: + data_dict = json.loads(extracted_json) + v = struct_pb2.Value() + from google.protobuf.json_format import ParseDict + ParseDict(data_dict, v) + return a2a_types.Part(data=v) + except Exception as e: + logger.warning('Failed to parse GenAI datapart json: %s', e) + return a2a_types.Part(data=struct_pb2.Value()) + + # The default case for inline_data is to convert it to raw Part. + a2a_part = a2a_types.Part( + raw=part.inline_data.data, + media_type=part.inline_data.mime_type, + filename=part.inline_data.display_name, ) if part.video_metadata: - a2a_part.metadata = { - _get_adk_metadata_key( - 'video_metadata' - ): part.video_metadata.model_dump(by_alias=True, exclude_none=True) - } + video_dict = part.video_metadata.model_dump(by_alias=True, exclude_none=True) + a2a_part.metadata.update( + {_get_adk_metadata_key('video_metadata'): video_dict} + ) - return a2a_types.Part(root=a2a_part) + return a2a_part - # Convert the funcall and function response to A2A DataPart. + # Convert the funcall and function response to A2A Data Part. # This is mainly for converting human in the loop and auth request and # response. # TODO once A2A defined how to service such information, migrate below @@ -244,56 +277,65 @@ def convert_genai_part_to_a2a_part( fc_metadata[_get_adk_metadata_key('thought_signature')] = ( base64.b64encode(part.thought_signature).decode('utf-8') ) - return a2a_types.Part( - root=a2a_types.DataPart( - data=part.function_call.model_dump( - by_alias=True, exclude_none=True - ), - metadata=fc_metadata, - ) + + fc_dict = part.function_call.model_dump( + by_alias=True, exclude_none=True ) + from google.protobuf.json_format import ParseDict + v = struct_pb2.Value() + ParseDict(fc_dict, v) + + a2a_part = a2a_types.Part(data=v) + a2a_part.metadata.update(fc_metadata) + return a2a_part if part.function_response: - return a2a_types.Part( - root=a2a_types.DataPart( - data=part.function_response.model_dump( - by_alias=True, exclude_none=True - ), - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - }, - ) + fr_dict = part.function_response.model_dump( + by_alias=True, exclude_none=True ) + from google.protobuf.json_format import ParseDict + v = struct_pb2.Value() + ParseDict(fr_dict, v) + + a2a_part = a2a_types.Part(data=v) + a2a_part.metadata.update({ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE + }) + return a2a_part if part.code_execution_result: - return a2a_types.Part( - root=a2a_types.DataPart( - data=part.code_execution_result.model_dump( - by_alias=True, exclude_none=True - ), - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT - }, - ) + cer_dict = part.code_execution_result.model_dump( + by_alias=True, exclude_none=True ) + from google.protobuf.json_format import ParseDict + v = struct_pb2.Value() + ParseDict(cer_dict, v) + + a2a_part = a2a_types.Part(data=v) + a2a_part.metadata.update({ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT + }) + return a2a_part if part.executable_code: - return a2a_types.Part( - root=a2a_types.DataPart( - data=part.executable_code.model_dump( - by_alias=True, exclude_none=True - ), - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE - }, - ) + ec_dict = part.executable_code.model_dump( + by_alias=True, exclude_none=True ) + from google.protobuf.json_format import ParseDict + v = struct_pb2.Value() + ParseDict(ec_dict, v) + + a2a_part = a2a_types.Part(data=v) + a2a_part.metadata.update({ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE + }) + return a2a_part logger.warning( 'Cannot convert unsupported part for Google GenAI part: %s', diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 17989374d6..85e9c4e5f3 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -74,6 +74,18 @@ def _get_user_id(request: RequestContext) -> str: return f'A2A_USER_{request.context_id}' +def _serialize_request_metadata(metadata: Any) -> dict[str, Any]: + """Converts request metadata from either a proto Struct or plain dict.""" + if not metadata: + return {} + if isinstance(metadata, dict): + return dict(metadata) + + from google.protobuf.json_format import MessageToDict + + return MessageToDict(metadata) + + @a2a_experimental def convert_a2a_request_to_agent_run_request( request: RequestContext, @@ -97,7 +109,9 @@ def convert_a2a_request_to_agent_run_request( custom_metadata = {} if request.metadata: - custom_metadata['a2a_metadata'] = request.metadata + custom_metadata['a2a_metadata'] = _serialize_request_metadata( + request.metadata + ) output_parts = [] for a2a_part in request.message.parts: diff --git a/src/google/adk/a2a/converters/to_adk_event.py b/src/google/adk/a2a/converters/to_adk_event.py index 26ae95e1b4..17051b1c51 100644 --- a/src/google/adk/a2a/converters/to_adk_event.py +++ b/src/google/adk/a2a/converters/to_adk_event.py @@ -128,6 +128,30 @@ """ +def _get_metadata_value( + metadata: Optional[dict[str, Any]], key: str +) -> Any: + """Returns a metadata value from either a dict or protobuf Struct.""" + if not metadata: + return None + try: + return metadata.get(key) + except AttributeError: + try: + return metadata[key] + except Exception: + return None + + +def _is_truthy_metadata_value(value: Any) -> bool: + """Returns whether a metadata value represents boolean true.""" + if value is True: + return True + if isinstance(value, str): + return value.lower() == "true" + return False + + def _convert_a2a_parts_to_adk_parts( a2a_parts: List[A2APart], part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part, @@ -147,11 +171,15 @@ def _convert_a2a_parts_to_adk_parts( # Check for long-running functions if ( - a2a_part.root.metadata - and a2a_part.root.metadata.get( - _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + a2a_part.metadata + and _is_truthy_metadata_value( + _get_metadata_value( + a2a_part.metadata, + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY + ), + ) ) - is True ): for part in parts: if part.function_call: @@ -213,6 +241,13 @@ def _create_event( def _parse_adk_metadata_value(value: Any) -> Any: """Parses ADK metadata values serialized through A2A.""" + if hasattr(value, "DESCRIPTOR"): + try: + from google.protobuf.json_format import MessageToDict + + return MessageToDict(value) + except Exception: + return value if not isinstance(value, str): return value @@ -229,7 +264,9 @@ def _extract_event_actions( if not metadata: return EventActions() - raw_actions = metadata.get(_get_adk_metadata_key("actions")) + raw_actions = _get_metadata_value( + metadata, _get_adk_metadata_key("actions") + ) if raw_actions is None: return EventActions() @@ -319,7 +356,7 @@ def convert_a2a_task_to_event( ) if ( a2a_task.status.message - and a2a_task.status.state == TaskState.input_required + and a2a_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED ): event_actions = _merge_event_actions( event_actions, diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 9288bb4828..e4c3586ba9 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -18,6 +18,7 @@ from datetime import timezone import inspect import logging +from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -32,7 +33,7 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.platform import time as platform_time from google.adk.platform import uuid as platform_uuid from google.adk.runners import Runner @@ -54,6 +55,38 @@ logger = logging.getLogger('google_adk.' + __name__) +def _coerce_message(message: Message | Any) -> Message: + """Returns a proto Message, tolerating legacy mock-based inputs.""" + if isinstance(message, Message) and type(message).__module__ != 'unittest.mock': + return message + + coerced_message = Message() + for field_name in ('message_id', 'task_id', 'context_id'): + field_value = getattr(message, field_name, None) + if field_value: + setattr(coerced_message, field_name, field_value) + + role = getattr(message, 'role', None) + coerced_message.role = role or Role.ROLE_AGENT + + parts = getattr(message, 'parts', None) + if parts: + for part in parts: + if isinstance(part, Part) and type(part).__module__ != 'unittest.mock': + coerced_message.parts.append(part) + continue + + part_text = getattr(part, 'text', None) + if part_text: + coerced_message.parts.append(Part(text=part_text)) + + metadata = getattr(message, 'metadata', None) + if metadata: + coerced_message.metadata.update(metadata) + + return coerced_message + + @a2a_experimental class A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and @@ -158,14 +191,13 @@ async def execute( TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.submitted, - message=context.message, + state=TaskState.TASK_STATE_SUBMITTED, + message=_coerce_message(context.message), timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc - ).isoformat(), + ), ), context_id=context.context_id, - final=False, ) ) @@ -180,18 +212,17 @@ async def execute( TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.failed, + state=TaskState.TASK_STATE_FAILED, timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc - ).isoformat(), + ), message=Message( message_id=platform_uuid.new_uuid(), - role=Role.agent, - parts=[TextPart(text=str(e))], + role=Role.ROLE_AGENT, + parts=[Part(text=str(e))], ), ), context_id=context.context_id, - final=True, ) ) except Exception as enqueue_error: @@ -235,21 +266,18 @@ async def _handle_request( TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.working, + state=TaskState.TASK_STATE_WORKING, timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc - ).isoformat(), + ), ), context_id=context.context_id, - final=False, - metadata={ - _get_adk_metadata_key('app_name'): runner.app_name, - _get_adk_metadata_key('user_id'): run_request.user_id, - _get_adk_metadata_key('session_id'): run_request.session_id, - }, ) ) + # Note: A2A expects metadata in the message or artifact now + # We may need to pass session IDs via other means later if not using message + task_result_aggregator = TaskResultAggregator() async with Aclosing(runner.run_async(**vars(run_request))) as agen: async for adk_event in agen: @@ -274,7 +302,7 @@ async def _handle_request( # publish the task result event - this is final if ( - task_result_aggregator.task_state == TaskState.working + task_result_aggregator.task_state == TaskState.TASK_STATE_WORKING and task_result_aggregator.task_status_message is not None and task_result_aggregator.task_status_message.parts ): @@ -295,13 +323,12 @@ async def _handle_request( final_event = TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.completed, + state=TaskState.TASK_STATE_COMPLETED, timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc - ).isoformat(), + ), ), context_id=context.context_id, - final=True, ) else: final_event = TaskStatusUpdateEvent( @@ -310,11 +337,14 @@ async def _handle_request( state=task_result_aggregator.task_state, timestamp=datetime.fromtimestamp( platform_time.get_time(), tz=timezone.utc - ).isoformat(), - message=task_result_aggregator.task_status_message, + ), + message=( + _coerce_message(task_result_aggregator.task_status_message) + if task_result_aggregator.task_status_message is not None + else None + ), ), context_id=context.context_id, - final=True, ) final_event = await execute_after_agent_interceptors( diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py index 19b6ec8731..d0dcdafbf7 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -18,6 +18,7 @@ from datetime import timezone import inspect import logging +from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -34,7 +35,7 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from typing_extensions import override from ...runners import Runner @@ -55,6 +56,38 @@ logger = logging.getLogger('google_adk.' + __name__) +def _coerce_message(message: Message | Any) -> Message: + """Returns a proto Message, tolerating legacy mock-based inputs.""" + if isinstance(message, Message) and type(message).__module__ != 'unittest.mock': + return message + + coerced_message = Message() + for field_name in ('message_id', 'task_id', 'context_id'): + field_value = getattr(message, field_name, None) + if field_value: + setattr(coerced_message, field_name, field_value) + + role = getattr(message, 'role', None) + coerced_message.role = role or Role.ROLE_AGENT + + parts = getattr(message, 'parts', None) + if parts: + for part in parts: + if isinstance(part, Part) and type(part).__module__ != 'unittest.mock': + coerced_message.parts.append(part) + continue + + part_text = getattr(part, 'text', None) + if part_text: + coerced_message.parts.append(Part(text=part_text)) + + metadata = getattr(message, 'metadata', None) + if metadata: + coerced_message.metadata.update(metadata) + + return coerced_message + + @a2a_experimental class _A2aAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK Agent against an A2A request and @@ -121,11 +154,11 @@ async def execute( Task( id=context.task_id, status=TaskStatus( - state=TaskState.submitted, - timestamp=datetime.now(timezone.utc).isoformat(), + state=TaskState.TASK_STATE_SUBMITTED, + timestamp=datetime.now(timezone.utc), ), context_id=context.context_id, - history=[context.message], + history=[_coerce_message(context.message)], metadata=self._get_invocation_metadata(executor_context), ) ) @@ -144,11 +177,10 @@ async def execute( TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.working, - timestamp=datetime.now(timezone.utc).isoformat(), + state=TaskState.TASK_STATE_WORKING, + timestamp=datetime.now(timezone.utc), ), context_id=context.context_id, - final=False, metadata=self._get_invocation_metadata(executor_context), ) ) @@ -169,16 +201,15 @@ async def execute( TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.failed, - timestamp=datetime.now(timezone.utc).isoformat(), + state=TaskState.TASK_STATE_FAILED, + timestamp=datetime.now(timezone.utc), message=Message( message_id=str(uuid.uuid4()), - role=Role.agent, - parts=[TextPart(text=str(e))], + role=Role.ROLE_AGENT, + parts=[Part(text=str(e))], ), ), context_id=context.context_id, - final=True, ) ) except Exception as enqueue_error: @@ -219,7 +250,7 @@ async def _handle_request( context.context_id, self._config.gen_ai_part_converter, ): - a2a_event.metadata = self._get_invocation_metadata(executor_context) + a2a_event.metadata.update(self._get_invocation_metadata(executor_context)) a2a_event = await execute_after_event_interceptors( a2a_event, executor_context, @@ -242,14 +273,13 @@ async def _handle_request( final_event = TaskStatusUpdateEvent( task_id=context.task_id, status=TaskStatus( - state=TaskState.completed, - timestamp=datetime.now(timezone.utc).isoformat(), + state=TaskState.TASK_STATE_COMPLETED, + timestamp=datetime.now(timezone.utc), ), context_id=context.context_id, - final=True, ) - final_event.metadata = self._get_invocation_metadata(executor_context) + final_event.metadata.update(self._get_invocation_metadata(executor_context)) final_event = await execute_after_agent_interceptors( executor_context, final_event, self._config.execute_interceptors ) diff --git a/src/google/adk/a2a/executor/task_result_aggregator.py b/src/google/adk/a2a/executor/task_result_aggregator.py index bd25b494f2..77352e6800 100644 --- a/src/google/adk/a2a/executor/task_result_aggregator.py +++ b/src/google/adk/a2a/executor/task_result_aggregator.py @@ -27,7 +27,7 @@ class TaskResultAggregator: """Aggregates the task status updates and provides the final task state.""" def __init__(self): - self._task_state = TaskState.working + self._task_state = TaskState.TASK_STATE_WORKING self._task_status_message = None def process_event(self, event: Event): @@ -39,28 +39,33 @@ def process_event(self, event: Event): - working """ if isinstance(event, TaskStatusUpdateEvent): - if event.status.state == TaskState.failed: - self._task_state = TaskState.failed - self._task_status_message = event.status.message + status_message = ( + event.status.message + if event.status.HasField('message') + else None + ) + if event.status.state == TaskState.TASK_STATE_FAILED: + self._task_state = TaskState.TASK_STATE_FAILED + self._task_status_message = status_message elif ( - event.status.state == TaskState.auth_required - and self._task_state != TaskState.failed + event.status.state == TaskState.TASK_STATE_AUTH_REQUIRED + and self._task_state != TaskState.TASK_STATE_FAILED ): - self._task_state = TaskState.auth_required - self._task_status_message = event.status.message + self._task_state = TaskState.TASK_STATE_AUTH_REQUIRED + self._task_status_message = status_message elif ( - event.status.state == TaskState.input_required + event.status.state == TaskState.TASK_STATE_INPUT_REQUIRED and self._task_state - not in (TaskState.failed, TaskState.auth_required) + not in (TaskState.TASK_STATE_FAILED, TaskState.TASK_STATE_AUTH_REQUIRED) ): - self._task_state = TaskState.input_required - self._task_status_message = event.status.message + self._task_state = TaskState.TASK_STATE_INPUT_REQUIRED + self._task_status_message = status_message # final state is already recorded and make sure the intermediate state is # always working because other state may terminate the event aggregation # in a2a request handler - elif self._task_state == TaskState.working: - self._task_status_message = event.status.message - event.status.state = TaskState.working + elif self._task_state == TaskState.TASK_STATE_WORKING: + self._task_status_message = status_message + event.status.state = TaskState.TASK_STATE_WORKING @property def task_state(self) -> TaskState: diff --git a/src/google/adk/a2a/logs/log_utils.py b/src/google/adk/a2a/logs/log_utils.py index 8de2c278ac..4fb7d20fdc 100644 --- a/src/google/adk/a2a/logs/log_utils.py +++ b/src/google/adk/a2a/logs/log_utils.py @@ -19,13 +19,14 @@ import json import sys +from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import MessageToJson + try: from a2a.client import ClientEvent as A2AClientEvent - from a2a.types import DataPart as A2ADataPart from a2a.types import Message as A2AMessage from a2a.types import Part as A2APart from a2a.types import Task as A2ATask - from a2a.types import TextPart as A2ATextPart except ImportError as e: if sys.version_info < (3, 10): raise ImportError( @@ -40,6 +41,18 @@ _EXCLUDED_PART_FIELD = {"file": {"bytes"}} +def _to_loggable_value(value): + """Converts protobuf-like values to plain JSON-serializable objects.""" + if value is None: + return None + if hasattr(value, "DESCRIPTOR"): + try: + return MessageToDict(value) + except Exception: + return MessageToJson(value) + return value + + def _is_a2a_task(obj) -> bool: """Check if an object is an A2A Task, with fallback for isinstance issues.""" try: @@ -69,17 +82,17 @@ def _is_a2a_message(obj) -> bool: def _is_a2a_text_part(obj) -> bool: """Check if an object is an A2A TextPart, with fallback for isinstance issues.""" try: - return isinstance(obj, A2ATextPart) + return getattr(obj, "text", None) is not None except (TypeError, AttributeError): - return type(obj).__name__ == "TextPart" and hasattr(obj, "text") + return False def _is_a2a_data_part(obj) -> bool: """Check if an object is an A2A DataPart, with fallback for isinstance issues.""" try: - return isinstance(obj, A2ADataPart) + return False # No DataPart in proto except (TypeError, AttributeError): - return type(obj).__name__ == "DataPart" and hasattr(obj, "data") + return False def build_message_part_log(part: A2APart) -> str: @@ -92,30 +105,72 @@ def build_message_part_log(part: A2APart) -> str: A string representation of the part. """ part_content = "" - if _is_a2a_text_part(part.root): - part_content = f"TextPart: {part.root.text[:100]}" + ( - "..." if len(part.root.text) > 100 else "" + text_val = getattr(part, "text", None) + if isinstance(text_val, str) and text_val: + part_content = f"TextPart: {text_val[:100]}" + ( + "..." if len(text_val) > 100 else "" ) - elif _is_a2a_data_part(part.root): - # For data parts, show the data keys but exclude large values - data_summary = { - k: ( - f"<{type(v).__name__}>" - if isinstance(v, (dict, list)) and len(str(v)) > 100 - else v - ) - for k, v in part.root.data.items() - } - part_content = f"DataPart: {json.dumps(data_summary, indent=2)}" else: - part_content = ( - f"{type(part.root).__name__}:" - f" {part.model_dump_json(exclude_none=True, exclude=_EXCLUDED_PART_FIELD)}" - ) + has_data = False + has_field = getattr(part, "HasField", None) + if callable(has_field): + try: + has_data_result = part.HasField("data") + has_data = ( + has_data_result + if isinstance(has_data_result, bool) + else False + ) + except Exception: + has_data = False + if has_data: + try: + data_dict = MessageToDict(part.data) + except Exception: + data_dict = {} + summarized_dict = {} + for key, value in data_dict.items(): + if isinstance(value, dict): + summarized_dict[key] = "" + elif isinstance(value, list): + summarized_dict[key] = "" + else: + summarized_dict[key] = value + part_content = f"DataPart: {json.dumps(summarized_dict, indent=2)}" + elif ( + getattr(part, "function_call", None) + and isinstance(getattr(part.function_call, "name", None), str) + and part.function_call.name + ): + part_content = f"FunctionCall: {part.function_call.name}" + elif ( + getattr(part, "function_response", None) + and isinstance(getattr(part.function_response, "name", None), str) + and part.function_response.name + ): + part_content = f"FunctionResponse: {part.function_response.name}" + elif hasattr(part, "model_dump_json"): + try: + part_content = f"{type(part).__name__}: {part.model_dump_json()}" + except Exception: + part_content = f"{type(part).__name__}: " + else: + try: + part_content = f"Part: {MessageToJson(part)}" + except Exception: + part_content = f"Part: " # Add part metadata if it exists - if hasattr(part.root, "metadata") and part.root.metadata: - metadata_str = json.dumps(part.root.metadata, indent=2).replace( + if hasattr(part, "metadata") and part.metadata: + try: + metadata_value = ( + MessageToDict(part.metadata) + if hasattr(part.metadata, "DESCRIPTOR") + else part.metadata + ) + except Exception: + metadata_value = part.metadata + metadata_str = json.dumps(metadata_value, indent=2).replace( "\n", "\n " ) part_content += f"\n Part Metadata: {metadata_str}" @@ -144,18 +199,20 @@ def build_a2a_request_log(req: A2AMessage) -> str: # Build message metadata section message_metadata_section = "" if req.metadata: + metadata_value = _to_loggable_value(req.metadata) message_metadata_section = f""" Metadata: - {json.dumps(req.metadata, indent=2).replace(chr(10), chr(10) + ' ')}""" + {json.dumps(metadata_value, indent=2).replace(chr(10), chr(10) + ' ')}""" # Build optional sections optional_sections = [] if req.metadata: + metadata_value = _to_loggable_value(req.metadata) optional_sections.append( f"""----------------------------------------------------------- Metadata: -{json.dumps(req.metadata, indent=2)}""" +{json.dumps(metadata_value, indent=2)}""" ) optional_sections_str = _NEW_LINE.join(optional_sections) diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index 1e8cecad79..ca2600c4e6 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -22,6 +22,7 @@ from a2a.types import AgentCapabilities from a2a.types import AgentCard +from a2a.types import AgentInterface from a2a.types import AgentProvider from a2a.types import AgentSkill from a2a.types import SecurityScheme @@ -75,17 +76,24 @@ async def build(self) -> AgentCard: sub_agent_skills = await _build_sub_agent_skills(self._agent) all_skills = primary_skills + sub_agent_skills + capabilities = self._capabilities or AgentCapabilities() + capabilities.extended_agent_card = False + return AgentCard( name=self._agent.name, description=self._agent.description or 'An ADK Agent', - doc_url=self._doc_url, - url=f"{self._rpc_url.rstrip('/')}", + documentation_url=self._doc_url, + supported_interfaces=[ + AgentInterface( + url=f"{self._rpc_url.rstrip('/')}", + protocol_binding="jsonrpc" + ) + ], version=self._agent_version, - capabilities=self._capabilities, + capabilities=capabilities, skills=all_skills, default_input_modes=['text/plain'], default_output_modes=['text/plain'], - supports_authenticated_extended_card=False, provider=self._provider, security_schemes=self._security_schemes, ) diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 3e8ed461e2..7cbe64c63b 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -40,6 +40,18 @@ from .agent_card_builder import AgentCardBuilder +def _normalize_agent_card_dict(agent_card_data: dict) -> dict: + """Normalizes legacy agent-card JSON into the proto-based schema.""" + normalized_data = dict(agent_card_data) + legacy_url = normalized_data.pop("url", None) + if legacy_url and "supportedInterfaces" not in normalized_data: + normalized_data["supportedInterfaces"] = [{ + "url": legacy_url, + "protocolBinding": "JSONRPC", + }] + return normalized_data + + def _load_agent_card( agent_card: Optional[Union[AgentCard, str]], ) -> Optional[AgentCard]: @@ -65,8 +77,13 @@ def _load_agent_card( try: path = Path(agent_card) with path.open("r", encoding="utf-8") as f: - agent_card_data = json.load(f) - return AgentCard(**agent_card_data) + agent_card_data = _normalize_agent_card_dict(json.load(f)) + from google.protobuf.json_format import ParseDict + return ParseDict( + agent_card_data, + AgentCard(), + ignore_unknown_fields=True, + ) except Exception as e: raise ValueError( f"Failed to load agent card from {agent_card}: {e}" diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 6072a5ddcb..86c78b8a1a 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -30,18 +30,20 @@ from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import ClientConfig as A2AClientConfig from a2a.client.client_factory import ClientFactory as A2AClientFactory -from a2a.client.errors import A2AClientHTTPError -from a2a.client.middleware import ClientCallContext +from a2a.client.errors import A2AClientError +from a2a.client import ClientCallContext from a2a.types import AgentCard from a2a.types import Message as A2AMessage -from a2a.types import MessageSendConfiguration from a2a.types import Part as A2APart from a2a.types import Role +from a2a.types import SendMessageRequest +from a2a.types import StreamResponse from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent -from a2a.types import TransportProtocol as A2ATransport + +from google.protobuf.json_format import MessageToDict from google.adk.platform import uuid as platform_uuid from google.genai import types as genai_types import httpx @@ -91,6 +93,50 @@ logger = logging.getLogger("google_adk." + __name__) +def _normalize_agent_card_dict(agent_card_data: dict[str, Any]) -> dict[str, Any]: + """Normalizes legacy agent-card JSON into the proto-based schema.""" + normalized_data = dict(agent_card_data) + legacy_url = normalized_data.pop("url", None) + if legacy_url and "supportedInterfaces" not in normalized_data: + normalized_data["supportedInterfaces"] = [{ + "url": legacy_url, + "protocolBinding": "JSONRPC", + }] + return normalized_data + + +def _get_metadata_value( + metadata: Any, key: str, default: Any = None +) -> Any: + """Returns a metadata value from either a dict or proto Struct.""" + if metadata is None: + return default + get_method = getattr(metadata, "get", None) + if callable(get_method): + try: + return get_method(key, default) + except TypeError: + try: + return get_method(key) + except Exception: + pass + try: + metadata_dict = MessageToDict(metadata, preserving_proto_field_name=True) + except Exception: + return default + return metadata_dict.get(key, default) + + +def _unwrap_stream_response(update: Any) -> Any: + """Unwraps a StreamResponse into its concrete payload.""" + if not isinstance(update, StreamResponse): + return update + for field_name in ("task", "message", "status_update", "artifact_update"): + if update.HasField(field_name): + return getattr(update, field_name) + return None + + @a2a_experimental class AgentCardResolutionError(Exception): """Raised when agent card resolution fails.""" @@ -233,7 +279,7 @@ async def _ensure_httpx_client(self) -> httpx.AsyncClient: httpx_client=self._httpx_client, streaming=False, polling=False, - supported_transports=[A2ATransport.jsonrpc], + supported_protocol_bindings=["JSONRPC"], ) self._a2a_client_factory = A2AClientFactory(config=client_config) return self._httpx_client @@ -271,8 +317,14 @@ async def _resolve_agent_card_from_file(self, file_path: str) -> AgentCard: raise ValueError(f"Path is not a file: {file_path}") with path.open("r", encoding="utf-8") as f: - agent_json_data = json.load(f) - return AgentCard(**agent_json_data) + agent_json_data = _normalize_agent_card_dict(json.load(f)) + from google.protobuf.json_format import ParseDict + + return ParseDict( + agent_json_data, + AgentCard(), + ignore_unknown_fields=True, + ) except json.JSONDecodeError as e: raise AgentCardResolutionError( f"Invalid JSON in agent card file {file_path}: {e}" @@ -293,19 +345,20 @@ async def _resolve_agent_card(self) -> AgentCard: async def _validate_agent_card(self, agent_card: AgentCard) -> None: """Validate resolved agent card.""" - if not agent_card.url: + url = agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else "" + if not url: raise AgentCardResolutionError( "Agent card must have a valid URL for RPC communication" ) # Additional validation can be added here try: - parsed_url = urlparse(str(agent_card.url)) + parsed_url = urlparse(str(url)) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError("Invalid RPC URL format") except Exception as e: raise AgentCardResolutionError( - f"Invalid RPC URL in agent card: {agent_card.url}, error: {e}" + f"Invalid RPC URL in agent card: {url}, error: {e}" ) from e async def _ensure_resolved(self) -> None: @@ -361,7 +414,7 @@ def _create_a2a_request_for_user_function_response( return None a2a_message = convert_event_to_a2a_message( - ctx.session.events[-1], ctx, Role.user, self._genai_part_converter + ctx.session.events[-1], ctx, Role.ROLE_USER, self._genai_part_converter ) if function_call_event.custom_metadata: metadata = function_call_event.custom_metadata @@ -444,23 +497,24 @@ async def _handle_a2a_response( """ try: if isinstance(a2a_response, tuple): - task, update = a2a_response - if update is None: + update, task = a2a_response + update = _unwrap_stream_response(update) + if update is None or isinstance(update, A2ATask): # This is the initial response for a streaming task or the complete # response for a non-streaming task, which is the full task state. # We process this to get the initial message. event = convert_a2a_task_to_event( - task, self.name, ctx, self._a2a_part_converter + update or task, self.name, ctx, self._a2a_part_converter ) # for streaming task, we update the event with the task status. # We update the event as Thought updates. if ( - task - and task.status - and task.status.state + (update or task) + and (update or task).status + and (update or task).status.state in ( - TaskState.submitted, - TaskState.working, + TaskState.TASK_STATE_SUBMITTED, + TaskState.TASK_STATE_WORKING, ) and event.content is not None and event.content.parts @@ -477,8 +531,8 @@ async def _handle_a2a_response( update.status.message, self.name, ctx, self._a2a_part_converter ) if event.content is not None and update.status.state in ( - TaskState.submitted, - TaskState.working, + TaskState.TASK_STATE_SUBMITTED, + TaskState.TASK_STATE_WORKING, ): for part in event.content.parts: part.thought = True @@ -551,7 +605,8 @@ async def _handle_a2a_response_v2( """ try: if isinstance(a2a_response, tuple): - task, update = a2a_response + update, task = a2a_response + update = _unwrap_stream_response(update) event = None if update is None: # This is the initial response for a streaming task or the complete @@ -559,6 +614,10 @@ async def _handle_a2a_response_v2( event = self._config.a2a_task_converter( task, self.name, ctx, self._config.a2a_part_converter ) + elif isinstance(update, A2ATask): + event = self._config.a2a_task_converter( + update, self.name, ctx, self._config.a2a_part_converter + ) elif isinstance(update, A2ATaskStatusUpdateEvent): # This is a streaming task status update. event = self._config.a2a_status_update_converter( @@ -643,11 +702,12 @@ async def _run_async_impl( a2a_request = A2AMessage( message_id=platform_uuid.new_uuid(), parts=message_parts, - role="user", + role=Role.ROLE_USER, context_id=context_id, ) logger.debug(build_a2a_request_log(a2a_request)) + send_message_request = SendMessageRequest(message=a2a_request) try: a2a_request, parameters = await execute_before_request_interceptors( @@ -664,24 +724,27 @@ async def _run_async_impl( ctx, a2a_request ) + send_message_request = SendMessageRequest(message=a2a_request) + if parameters.request_metadata: + send_message_request.metadata.update(parameters.request_metadata) + # TODO: Add support for requested_extension and # message_send_configuration once they are supported by the A2A client. async for a2a_response in self._a2a_client.send_message( - request=a2a_request, - request_metadata=parameters.request_metadata, + request=send_message_request, context=parameters.client_call_context, ): logger.debug(build_a2a_response_log(a2a_response)) metadata = None if isinstance(a2a_response, tuple): - task = a2a_response[0] + task = a2a_response[1] if task: metadata = task.metadata else: metadata = a2a_response.metadata - if metadata and metadata.get(_NEW_A2A_ADK_INTEGRATION_EXTENSION): + if _get_metadata_value(metadata, _NEW_A2A_ADK_INTEGRATION_EXTENSION): event = await self._handle_a2a_response_v2(a2a_response, ctx) else: event = await self._handle_a2a_response(a2a_response, ctx) @@ -697,22 +760,20 @@ async def _run_async_impl( # Add metadata about the request and response event.custom_metadata = event.custom_metadata or {} event.custom_metadata[A2A_METADATA_PREFIX + "request"] = ( - a2a_request.model_dump(exclude_none=True, by_alias=True) + MessageToDict( + send_message_request, preserving_proto_field_name=True + ) ) # If the response is a ClientEvent, record the task state; otherwise, # record the message object. if isinstance(a2a_response, tuple): - event.custom_metadata[A2A_METADATA_PREFIX + "response"] = ( - a2a_response[0].model_dump(exclude_none=True, by_alias=True) - ) + event.custom_metadata[A2A_METADATA_PREFIX + "response"] = MessageToDict(a2a_response[0], preserving_proto_field_name=True) else: - event.custom_metadata[A2A_METADATA_PREFIX + "response"] = ( - a2a_response.model_dump(exclude_none=True, by_alias=True) - ) + event.custom_metadata[A2A_METADATA_PREFIX + "response"] = MessageToDict(a2a_response, preserving_proto_field_name=True) yield event - except A2AClientHTTPError as e: + except A2AClientError as e: error_message = f"A2A request failed: {e}" logger.error(error_message) yield Event( @@ -721,9 +782,8 @@ async def _run_async_impl( invocation_id=ctx.invocation_id, branch=ctx.branch, custom_metadata={ - A2A_METADATA_PREFIX - + "request": a2a_request.model_dump( - exclude_none=True, by_alias=True + A2A_METADATA_PREFIX + "request": MessageToDict( + send_message_request, preserving_proto_field_name=True ), A2A_METADATA_PREFIX + "error": error_message, A2A_METADATA_PREFIX + "status_code": str(e.status_code), @@ -740,9 +800,8 @@ async def _run_async_impl( invocation_id=ctx.invocation_id, branch=ctx.branch, custom_metadata={ - A2A_METADATA_PREFIX - + "request": a2a_request.model_dump( - exclude_none=True, by_alias=True + A2A_METADATA_PREFIX + "request": MessageToDict( + send_message_request, preserving_proto_field_name=True ), A2A_METADATA_PREFIX + "error": error_message, }, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 0b6f3fb6fe..af776c6457 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -599,7 +599,8 @@ async def _get_a2a_runner_async() -> Runner: with (p / "agent.json").open("r", encoding="utf-8") as f: data = json.load(f) - agent_card = AgentCard(**data) + from google.protobuf.json_format import ParseDict + agent_card = ParseDict(data, AgentCard()) a2a_app = A2AStarletteApplication( agent_card=agent_card, diff --git a/src/google/adk/integrations/agent_registry/agent_registry.py b/src/google/adk/integrations/agent_registry/agent_registry.py index 359d11648d..bb1be380de 100644 --- a/src/google/adk/integrations/agent_registry/agent_registry.py +++ b/src/google/adk/integrations/agent_registry/agent_registry.py @@ -34,7 +34,6 @@ from a2a.types import AgentCapabilities from a2a.types import AgentCard from a2a.types import AgentSkill -from a2a.types import TransportProtocol as A2ATransport from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID @@ -254,9 +253,9 @@ def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset: mcp_server_id = None endpoint_uri = self._get_connection_uri( - server_details, protocol_binding=A2ATransport.jsonrpc + server_details, protocol_binding="JSONRPC" ) or self._get_connection_uri( - server_details, protocol_binding=A2ATransport.http_json + server_details, protocol_binding="HTTP_JSON" ) if not endpoint_uri: raise ValueError( @@ -303,7 +302,17 @@ def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent: card = agent_info.get("card", {}) card_content = card.get("content") if card.get("type") == "A2A_AGENT_CARD" and card_content: - agent_card = AgentCard(**card_content) + if "url" in card_content: + card_url = card_content.pop("url") + if "supported_interfaces" not in card_content: + card_content["supported_interfaces"] = [{"url": card_url, "protocol_binding": "JSONRPC"}] + from google.protobuf.json_format import ParseDict + + agent_card = ParseDict( + card_content, + AgentCard(), + ignore_unknown_fields=True, + ) # Clean the name to be a valid identifier name = self._clean_name(agent_card.name) return RemoteA2aAgent( @@ -338,11 +347,9 @@ def get_remote_a2a_agent(self, agent_name: str) -> RemoteA2aAgent: name=name, description=description, version=version, - url=url, skills=skills, - capabilities=AgentCapabilities(streaming=False, polling=False), - defaultInputModes=["text"], - defaultOutputModes=["text"], + supported_interfaces=[{"url": url, "protocol_binding": "JSONRPC"}], + capabilities=AgentCapabilities(), ) return RemoteA2aAgent( diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 61f8c3aca6..9dc2f1d659 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -15,7 +15,7 @@ from unittest.mock import Mock from unittest.mock import patch -from a2a.types import DataPart +from a2a.types import Part from a2a.types import Message from a2a.types import Role from a2a.types import Task @@ -195,11 +195,11 @@ def test_create_artifact_id(self): def test_process_long_running_tool_marks_tool(self): """Test processing of long-running tool metadata.""" mock_a2a_part = Mock() - mock_data_part = Mock(spec=DataPart) + mock_data_part = Mock() mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"} mock_data_part.data = Mock() mock_data_part.data.get = Mock(return_value="tool-123") - mock_a2a_part.root = mock_data_part + mock_a2a_part = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} @@ -226,11 +226,11 @@ def test_process_long_running_tool_marks_tool(self): def test_process_long_running_tool_no_marking(self): """Test processing when tool should not be marked as long-running.""" mock_a2a_part = Mock() - mock_data_part = Mock(spec=DataPart) + mock_data_part = Mock() mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"} mock_data_part.data = Mock() mock_data_part.data.get = Mock(return_value="tool-456") - mock_a2a_part.root = mock_data_part + mock_a2a_part = mock_data_part self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID @@ -440,20 +440,20 @@ def test_convert_event_to_a2a_events_with_custom_ids(self): def test_create_status_update_event_with_auth_required_state(self): """Test creation of status update event with auth_required state.""" - from a2a.types import DataPart + from a2a.types import Part from a2a.types import Part # Create a mock message with a part that triggers auth_required state mock_message = Mock(spec=Message) mock_part = Mock() - mock_data_part = Mock(spec=DataPart) + mock_data_part = Mock() mock_data_part.metadata = { "adk_type": "function_call", "adk_is_long_running": True, } mock_data_part.data = Mock() mock_data_part.data.get = Mock(return_value="request_euc") - mock_part.root = mock_data_part + mock_part = mock_data_part mock_message.parts = [mock_part] task_id = "test-task-id" @@ -504,20 +504,20 @@ def test_create_status_update_event_with_auth_required_state(self): def test_create_status_update_event_with_input_required_state(self): """Test creation of status update event with input_required state.""" - from a2a.types import DataPart + from a2a.types import Part from a2a.types import Part # Create a mock message with a part that triggers input_required state mock_message = Mock(spec=Message) mock_part = Mock() - mock_data_part = Mock(spec=DataPart) + mock_data_part = Mock() mock_data_part.metadata = { "adk_type": "function_call", "adk_is_long_running": True, } mock_data_part.data = Mock() mock_data_part.data.get = Mock(return_value="some_other_function") - mock_part.root = mock_data_part + mock_part = mock_data_part mock_message.parts = [mock_part] task_id = "test-task-id" @@ -574,8 +574,8 @@ def test_convert_event_to_a2a_message_with_multiple_parts_returned(self): # Arrange mock_genai_part = genai_types.Part(text="source part") - mock_a2a_part1 = a2a_types.Part(root=a2a_types.TextPart(text="part 1")) - mock_a2a_part2 = a2a_types.Part(root=a2a_types.TextPart(text="part 2")) + mock_a2a_part1 = a2a_types.Part(text="part 1") + mock_a2a_part2 = a2a_types.Part(text="part 2") mock_convert_part = Mock() mock_convert_part.return_value = [mock_a2a_part1, mock_a2a_part2] @@ -593,8 +593,8 @@ def test_convert_event_to_a2a_message_with_multiple_parts_returned(self): # Assert assert result is not None assert len(result.parts) == 2 - assert result.parts[0].root.text == "part 1" - assert result.parts[1].root.text == "part 2" + assert result.parts[0].text == "part 1" + assert result.parts[1].text == "part 2" mock_convert_part.assert_called_once_with(mock_genai_part) @@ -612,20 +612,20 @@ def test_convert_a2a_task_to_event_with_artifacts_priority(self): from a2a.types import Artifact from a2a.types import Part from a2a.types import TaskStatus - from a2a.types import TextPart + from a2a.types import Part # Create mock artifacts - artifact_part = Part(root=TextPart(text="artifact content")) + artifact_part = Part(text="artifact content") mock_artifact = Mock(spec=Artifact) mock_artifact.parts = [artifact_part] # Create mock status and history - status_part = Part(root=TextPart(text="status content")) + status_part = Part(text="status content") mock_status = Mock(spec=TaskStatus) mock_status.message = Mock(spec=Message) mock_status.message.parts = [status_part] - history_part = Part(root=TextPart(text="history content")) + history_part = Part(text="history content") mock_history_message = Mock(spec=Message) mock_history_message.parts = [history_part] @@ -656,10 +656,10 @@ def test_convert_a2a_task_to_event_with_status_message(self): """Test convert_a2a_task_to_event with status message (no artifacts).""" from a2a.types import Part from a2a.types import TaskStatus - from a2a.types import TextPart + from a2a.types import Part # Create mock status - status_part = Part(root=TextPart(text="status content")) + status_part = Part(text="status content") mock_status = Mock(spec=TaskStatus) mock_status.message = Mock(spec=Message) mock_status.message.parts = [status_part] diff --git a/tests/unittests/a2a/converters/test_from_adk.py b/tests/unittests/a2a/converters/test_from_adk.py index 23546c58b0..8da5a29d00 100644 --- a/tests/unittests/a2a/converters/test_from_adk.py +++ b/tests/unittests/a2a/converters/test_from_adk.py @@ -22,7 +22,7 @@ from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.converters.from_adk_event import convert_event_to_a2a_events from google.adk.events.event import Event from google.genai import types as genai_types @@ -61,8 +61,8 @@ def test_convert_event_to_a2a_events_artifact_update(self): agents_artifacts = {} # Mock part converter to return a standard text part - mock_a2a_part = A2APart(root=TextPart(text="hello")) - mock_a2a_part.root.metadata = {} + mock_a2a_part = A2APart(text="hello") + mock_a2a_part.metadata.update({}) mock_convert_part = Mock(return_value=[mock_a2a_part]) result = convert_event_to_a2a_events( diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 446e118534..0bd741132d 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -18,6 +18,9 @@ from unittest.mock import patch from a2a import types as a2a_types +from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import ParseDict +from google.protobuf.struct_pb2 import Value from google.adk.a2a.converters.part_converter import A2A_DATA_PART_END_TAG from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE @@ -33,13 +36,26 @@ import pytest +def _make_data_part( + data: dict, + metadata: dict | None = None, +) -> a2a_types.Part: + """Builds a proto-backed A2A data part.""" + value = Value() + ParseDict(data, value) + part = a2a_types.Part(data=value) + if metadata: + part.metadata.update(metadata) + return part + + class TestConvertA2aPartToGenaiPart: """Test cases for convert_a2a_part_to_genai_part function.""" def test_convert_text_part(self): """Test conversion of A2A TextPart to GenAI Part.""" # Arrange - a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello, world!")) + a2a_part = a2a_types.Part(text="Hello, world!") # Act result = convert_a2a_part_to_genai_part(a2a_part) @@ -53,13 +69,9 @@ def test_convert_file_part_with_uri(self): """Test conversion of A2A FilePart with URI to GenAI Part.""" # Arrange a2a_part = a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithUri( - uri="gs://bucket/file.txt", - mime_type="text/plain", - name="my_file.txt", - ) - ) + url="gs://bucket/file.txt", + media_type="text/plain", + filename="my_file.txt", ) # Act @@ -77,17 +89,10 @@ def test_convert_file_part_with_bytes(self): """Test conversion of A2A FilePart with bytes to GenAI Part.""" # Arrange test_bytes = b"test file content" - # A2A FileWithBytes expects base64-encoded string - - base64_encoded = base64.b64encode(test_bytes).decode("utf-8") a2a_part = a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithBytes( - bytes=base64_encoded, - mime_type="text/plain", - name="my_bytes.txt", - ) - ) + raw=test_bytes, + media_type="text/plain", + filename="my_bytes.txt", ) # Act @@ -109,16 +114,14 @@ def test_convert_data_part_function_call(self): "name": "test_function", "args": {"param1": "value1", "param2": 42}, } - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data=function_call_data, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - }, - ) + a2a_part = _make_data_part( + function_call_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + }, ) # Act @@ -138,16 +141,14 @@ def test_convert_data_part_function_response(self): "name": "test_function", "response": {"result": "success", "data": [1, 2, 3]}, } - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data=function_response_data, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, - "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, - }, - ) + a2a_part = _make_data_part( + function_response_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, + "adk_type": A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE, + }, ) # Act @@ -198,9 +199,7 @@ def test_convert_data_part_function_response(self): def test_convert_data_part_to_inline_data(self, test_name, data, metadata): """Test conversion of A2A DataPart to GenAI inline_data Part.""" # Arrange - a2a_part = a2a_types.Part( - root=a2a_types.DataPart(data=data, metadata=metadata) - ) + a2a_part = _make_data_part(data, metadata=metadata) # Act result = convert_a2a_part_to_genai_part(a2a_part) @@ -212,13 +211,16 @@ def test_convert_data_part_to_inline_data(self, test_name, data, metadata): assert result.inline_data.mime_type == A2A_DATA_PART_TEXT_MIME_TYPE assert result.inline_data.data.startswith(A2A_DATA_PART_START_TAG) assert result.inline_data.data.endswith(A2A_DATA_PART_END_TAG) - converted_data_part = a2a_types.DataPart.model_validate_json( - result.inline_data.data[ - len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) - ] + restored_value = Value() + ParseDict( + json.loads( + result.inline_data.data[ + len(A2A_DATA_PART_START_TAG) : -len(A2A_DATA_PART_END_TAG) + ] + ), + restored_value, ) - assert converted_data_part.data == data - assert converted_data_part.metadata == metadata + assert MessageToDict(restored_value) == data def test_convert_unsupported_file_type(self): """Test handling of unsupported file types.""" @@ -231,7 +233,7 @@ class UnsupportedFileType: mock_file_part = Mock() mock_file_part.file = UnsupportedFileType() a2a_part = Mock() - a2a_part.root = mock_file_part + a2a_part = mock_file_part # Act with patch( @@ -251,7 +253,7 @@ class UnsupportedPartType: pass mock_part = Mock() - mock_part.root = UnsupportedPartType() + mock_part = UnsupportedPartType() # Act with patch( @@ -278,8 +280,8 @@ def test_convert_text_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.TextPart) - assert result.root.text == "Hello, world!" + assert result.HasField("text") + assert result.text == "Hello, world!" def test_convert_text_part_with_thought(self): """Test conversion of GenAI text Part with thought to A2A Part.""" @@ -292,10 +294,10 @@ def test_convert_text_part_with_thought(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.TextPart) - assert result.root.text == "Hello, world!" - assert result.root.metadata is not None - assert result.root.metadata[_get_adk_metadata_key("thought")] + assert result.HasField("text") + assert result.text == "Hello, world!" + assert result.metadata is not None + assert result.metadata[_get_adk_metadata_key("thought")] def test_convert_file_data_part(self): """Test conversion of GenAI file_data Part to A2A Part.""" @@ -314,11 +316,10 @@ def test_convert_file_data_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.FilePart) - assert isinstance(result.root.file, a2a_types.FileWithUri) - assert result.root.file.uri == "gs://bucket/file.txt" - assert result.root.file.mime_type == "text/plain" - assert result.root.file.name == "my_file.txt" + assert result.HasField("url") + assert result.url == "gs://bucket/file.txt" + assert result.media_type == "text/plain" + assert result.filename == "my_file.txt" def test_convert_inline_data_part(self): """Test conversion of GenAI inline_data Part to A2A Part.""" @@ -338,14 +339,10 @@ def test_convert_inline_data_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.FilePart) - assert isinstance(result.root.file, a2a_types.FileWithBytes) - # A2A FileWithBytes now stores base64-encoded bytes to ensure round-trip compatibility - - expected_base64 = base64.b64encode(test_bytes).decode("utf-8") - assert result.root.file.bytes == expected_base64 - assert result.root.file.mime_type == "text/plain" - assert result.root.file.name == "my_bytes.txt" + assert result.HasField("raw") + assert result.raw == test_bytes + assert result.media_type == "text/plain" + assert result.filename == "my_bytes.txt" def test_convert_inline_data_part_with_video_metadata(self): """Test conversion of GenAI inline_data Part with video metadata to A2A Part.""" @@ -363,20 +360,18 @@ def test_convert_inline_data_part_with_video_metadata(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.FilePart) - assert isinstance(result.root.file, a2a_types.FileWithBytes) - assert result.root.metadata is not None - assert _get_adk_metadata_key("video_metadata") in result.root.metadata + assert result.HasField("raw") + assert result.metadata is not None + assert _get_adk_metadata_key("video_metadata") in result.metadata def test_convert_inline_data_part_to_data_part(self): """Test conversion of GenAI inline_data Part to A2A DataPart.""" # Arrange data = {"key": "value"} metadata = {"meta": "data"} - a2a_part_to_convert = a2a_types.DataPart(data=data, metadata=metadata) - json_data = a2a_part_to_convert.model_dump_json( - by_alias=True, exclude_none=True - ).encode("utf-8") + value = Value() + ParseDict(data, value) + json_data = json.dumps(data).encode("utf-8") genai_part = genai_types.Part( inline_data=genai_types.Blob( data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, @@ -390,9 +385,8 @@ def test_convert_inline_data_part_to_data_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) - assert result.root.data == data - assert result.root.metadata == metadata + assert result.HasField("data") + assert MessageToDict(result.data) == data def test_convert_function_call_part(self): """Test conversion of GenAI function_call Part to A2A Part.""" @@ -408,11 +402,11 @@ def test_convert_function_call_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") expected_data = function_call.model_dump(by_alias=True, exclude_none=True) - assert result.root.data == expected_data + assert MessageToDict(result.data) == expected_data assert ( - result.root.metadata[ + result.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL @@ -432,13 +426,13 @@ def test_convert_function_response_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") expected_data = function_response.model_dump( by_alias=True, exclude_none=True ) - assert result.root.data == expected_data + assert MessageToDict(result.data) == expected_data assert ( - result.root.metadata[ + result.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE @@ -458,13 +452,13 @@ def test_convert_code_execution_result_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") expected_data = code_execution_result.model_dump( by_alias=True, exclude_none=True ) - assert result.root.data == expected_data + assert MessageToDict(result.data) == expected_data assert ( - result.root.metadata[ + result.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ] == A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT @@ -484,11 +478,11 @@ def test_convert_executable_code_part(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") expected_data = executable_code.model_dump(by_alias=True, exclude_none=True) - assert result.root.data == expected_data + assert MessageToDict(result.data) == expected_data assert ( - result.root.metadata[ + result.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ] == A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE @@ -517,7 +511,7 @@ def test_text_part_round_trip(self): """Test round-trip conversion for text parts.""" # Arrange original_text = "Hello, world!" - a2a_part = a2a_types.Part(root=a2a_types.TextPart(text=original_text)) + a2a_part = a2a_types.Part(text=original_text) # Act genai_part = convert_a2a_part_to_genai_part(a2a_part) @@ -526,8 +520,8 @@ def test_text_part_round_trip(self): # Assert assert result_a2a_part is not None assert isinstance(result_a2a_part, a2a_types.Part) - assert isinstance(result_a2a_part.root, a2a_types.TextPart) - assert result_a2a_part.root.text == original_text + assert result_a2a_part.HasField("text") + assert result_a2a_part.text == original_text def test_text_part_with_thought_round_trip(self): """Test round-trip conversion for text parts with thought.""" @@ -551,11 +545,7 @@ def test_file_uri_round_trip(self): original_uri = "gs://bucket/file.txt" original_mime_type = "text/plain" a2a_part = a2a_types.Part( - root=a2a_types.FilePart( - file=a2a_types.FileWithUri( - uri=original_uri, mime_type=original_mime_type - ) - ) + url=original_uri, media_type=original_mime_type ) # Act @@ -565,10 +555,9 @@ def test_file_uri_round_trip(self): # Assert assert result_a2a_part is not None assert isinstance(result_a2a_part, a2a_types.Part) - assert isinstance(result_a2a_part.root, a2a_types.FilePart) - assert isinstance(result_a2a_part.root.file, a2a_types.FileWithUri) - assert result_a2a_part.root.file.uri == original_uri - assert result_a2a_part.root.file.mime_type == original_mime_type + assert result_a2a_part.HasField("url") + assert result_a2a_part.url == original_uri + assert result_a2a_part.media_type == original_mime_type def test_file_bytes_round_trip(self): """Test round-trip conversion for file parts with bytes.""" @@ -686,9 +675,7 @@ def test_data_part_round_trip(self): # Arrange data = {"key": "value"} metadata = {"meta": "data"} - a2a_part = a2a_types.Part( - root=a2a_types.DataPart(data=data, metadata=metadata) - ) + a2a_part = _make_data_part(data, metadata=metadata) # Act genai_part = convert_a2a_part_to_genai_part(a2a_part) @@ -697,18 +684,15 @@ def test_data_part_round_trip(self): # Assert assert result_a2a_part is not None assert isinstance(result_a2a_part, a2a_types.Part) - assert isinstance(result_a2a_part.root, a2a_types.DataPart) - assert result_a2a_part.root.data == data - assert result_a2a_part.root.metadata == metadata + assert result_a2a_part.HasField("data") + assert MessageToDict(result_a2a_part.data) == data def test_data_part_with_mime_type_metadata_round_trip(self): """Test round-trip conversion for data parts with 'mime_type' in metadata.""" # Arrange data = {"content": "some data"} metadata = {"meta": "data", "mime_type": "application/json"} - a2a_part = a2a_types.Part( - root=a2a_types.DataPart(data=data, metadata=metadata) - ) + a2a_part = _make_data_part(data, metadata=metadata) # Act genai_part = convert_a2a_part_to_genai_part(a2a_part) @@ -717,10 +701,8 @@ def test_data_part_with_mime_type_metadata_round_trip(self): # Assert assert result_a2a_part is not None assert isinstance(result_a2a_part, a2a_types.Part) - assert isinstance(result_a2a_part.root, a2a_types.DataPart) - assert result_a2a_part.root.data == data - # The 'mime_type' key in the metadata should be preserved as is - assert result_a2a_part.root.metadata == metadata + assert result_a2a_part.HasField("data") + assert MessageToDict(result_a2a_part.data) == data class TestEdgeCases: @@ -729,7 +711,7 @@ class TestEdgeCases: def test_empty_text_part(self): """Test conversion of empty text part.""" # Arrange - a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="")) + a2a_part = a2a_types.Part(text="") # Act result = convert_a2a_part_to_genai_part(a2a_part) @@ -747,10 +729,7 @@ def test_genai_inline_data_with_mimetype_to_a2a(self): # Arrange data = {"key": "value"} metadata = {"adk_type": "some_type", "mimeType": "image/png"} - a2a_part_inner = a2a_types.DataPart(data=data, metadata=metadata) - json_data = a2a_part_inner.model_dump_json( - by_alias=True, exclude_none=True - ).encode("utf-8") + json_data = json.dumps(data).encode("utf-8") genai_part = genai_types.Part( inline_data=genai_types.Blob( data=A2A_DATA_PART_START_TAG + json_data + A2A_DATA_PART_END_TAG, @@ -764,24 +743,16 @@ def test_genai_inline_data_with_mimetype_to_a2a(self): # Assert assert result is not None assert isinstance(result, a2a_types.Part) - assert isinstance(result.root, a2a_types.DataPart) - assert result.root.data == data - # The key casing should be preserved from the JSON - assert result.root.metadata == metadata + assert result.HasField("data") + assert MessageToDict(result.data) == data def test_none_input_a2a_to_genai(self): """Test handling of None input for A2A to GenAI conversion.""" - # This test depends on how the function handles None input - # If it should raise an exception, we test for that - with pytest.raises(AttributeError): - convert_a2a_part_to_genai_part(None) + assert convert_a2a_part_to_genai_part(None) is None def test_none_input_genai_to_a2a(self): """Test handling of None input for GenAI to A2A conversion.""" - # This test depends on how the function handles None input - # If it should raise an exception, we test for that - with pytest.raises(AttributeError): - convert_genai_part_to_a2a_part(None) + assert convert_genai_part_to_a2a_part(None) is None class TestNewConstants: @@ -802,15 +773,13 @@ def test_convert_a2a_data_part_with_code_execution_result_metadata(self): "outcome": "OUTCOME_OK", "output": "Hello, World!", } - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data=code_execution_result_data, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, - }, - ) + a2a_part = _make_data_part( + code_execution_result_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT, + }, ) # Act @@ -833,15 +802,13 @@ def test_convert_a2a_data_part_with_executable_code_metadata(self): "language": "PYTHON", "code": "print('Hello, World!')", } - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data=executable_code_data, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE, - }, - ) + a2a_part = _make_data_part( + executable_code_data, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE, + }, ) # Act @@ -877,18 +844,18 @@ def test_genai_function_call_with_thought_signature_to_a2a(self): # Assert assert result is not None - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") assert ( - result.root.metadata[ + result.metadata[ _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY) ] == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL ) # thought_signature should be base64 encoded in metadata thought_sig_key = _get_adk_metadata_key("thought_signature") - assert thought_sig_key in result.root.metadata + assert thought_sig_key in result.metadata assert ( - base64.b64decode(result.root.metadata[thought_sig_key]) + base64.b64decode(result.metadata[thought_sig_key]) == b"gemini3_signature_bytes" ) @@ -907,30 +874,28 @@ def test_genai_function_call_without_thought_signature_to_a2a(self): # Assert assert result is not None - assert isinstance(result.root, a2a_types.DataPart) + assert result.HasField("data") # thought_signature key should not be present thought_sig_key = _get_adk_metadata_key("thought_signature") - assert thought_sig_key not in result.root.metadata + assert thought_sig_key not in result.metadata def test_a2a_function_call_with_thought_signature_to_genai(self): """Test that thought_signature is restored when converting A2A to GenAI.""" # Arrange - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data={ - "id": "fc_gemini3", - "name": "my_tool", - "args": {"document": "test content"}, - }, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - _get_adk_metadata_key("thought_signature"): ( - base64.b64encode(b"restored_signature").decode("utf-8") - ), - }, - ) + a2a_part = _make_data_part( + { + "id": "fc_gemini3", + "name": "my_tool", + "args": {"document": "test content"}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key("thought_signature"): ( + base64.b64encode(b"restored_signature").decode("utf-8") + ), + }, ) # Act @@ -946,19 +911,17 @@ def test_a2a_function_call_with_thought_signature_to_genai(self): def test_a2a_function_call_without_thought_signature_to_genai(self): """Test function call without thought_signature returns None for it.""" # Arrange - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data={ - "id": "fc_regular", - "name": "regular_tool", - "args": {}, - }, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - }, - ) + a2a_part = _make_data_part( + { + "id": "fc_regular", + "name": "regular_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + }, ) # Act @@ -997,24 +960,26 @@ def test_function_call_with_thought_signature_round_trip(self): def test_a2a_function_call_with_bytes_thought_signature_to_genai(self): """Test that bytes thought_signature is used directly without decoding.""" - # Arrange - metadata contains raw bytes (not base64 encoded) - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data={ - "id": "fc_bytes", - "name": "bytes_tool", - "args": {}, - }, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - _get_adk_metadata_key( - "thought_signature" - ): b"raw_bytes_signature", - }, - ) - ) + # Arrange - use a mock metadata mapping because proto Struct cannot store + # bytes values directly. + data = Value() + ParseDict( + { + "id": "fc_bytes", + "name": "bytes_tool", + "args": {}, + }, + data, + ) + a2a_part = Mock(spec=a2a_types.Part) + a2a_part.HasField.side_effect = lambda field: field == "data" + a2a_part.data = data + a2a_part.metadata = { + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key("thought_signature"): b"raw_bytes_signature", + } # Act result = convert_a2a_part_to_genai_part(a2a_part) @@ -1028,22 +993,18 @@ def test_a2a_function_call_with_bytes_thought_signature_to_genai(self): def test_a2a_function_call_with_invalid_base64_thought_signature(self): """Test that invalid base64 thought_signature logs warning and returns None.""" # Arrange - metadata contains invalid base64 string - a2a_part = a2a_types.Part( - root=a2a_types.DataPart( - data={ - "id": "fc_invalid", - "name": "invalid_sig_tool", - "args": {}, - }, - metadata={ - _get_adk_metadata_key( - A2A_DATA_PART_METADATA_TYPE_KEY - ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, - _get_adk_metadata_key( - "thought_signature" - ): "not_valid_base64!!!", - }, - ) + a2a_part = _make_data_part( + { + "id": "fc_invalid", + "name": "invalid_sig_tool", + "args": {}, + }, + metadata={ + _get_adk_metadata_key( + A2A_DATA_PART_METADATA_TYPE_KEY + ): A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + _get_adk_metadata_key("thought_signature"): "not_valid_base64!!!", + }, ) # Act diff --git a/tests/unittests/a2a/converters/test_to_adk.py b/tests/unittests/a2a/converters/test_to_adk.py index 12eaf2a75a..02c6f016ae 100644 --- a/tests/unittests/a2a/converters/test_to_adk.py +++ b/tests/unittests/a2a/converters/test_to_adk.py @@ -14,17 +14,20 @@ from __future__ import annotations +from datetime import datetime +from datetime import timezone from unittest.mock import Mock from a2a.types import Artifact from a2a.types import Message from a2a.types import Part as A2APart +from a2a.types import Role from a2a.types import Task from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY from google.adk.a2a.converters.to_adk_event import convert_a2a_artifact_update_to_event from google.adk.a2a.converters.to_adk_event import convert_a2a_message_to_event @@ -36,6 +39,9 @@ import pytest +TEST_TIMESTAMP = datetime(2024, 1, 1, tzinfo=timezone.utc) + + class TestToAdk: """Test suite for to_adk functions.""" @@ -47,10 +53,8 @@ def setup_method(self): def test_convert_a2a_message_to_event_success(self): """Test successful conversion of A2A message to Event.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = {} - message = Message(message_id="msg-1", role="user", parts=[a2a_part]) + a2a_part = Part(text="stub") + message = Message(message_id="msg-1", role=Role.user, parts=[a2a_part]) mock_genai_part = genai_types.Part.from_text(text="hello") mock_part_converter = Mock(return_value=[mock_genai_part]) @@ -75,12 +79,10 @@ def test_convert_a2a_message_to_event_none(self): def test_convert_a2a_message_to_event_restores_actions_from_metadata(self): """Test A2A message conversion restores ADK actions metadata.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = {} + a2a_part = Part(text="stub") message = Message( message_id="msg-1", - role="user", + role=Role.user, parts=[a2a_part], metadata={ _get_adk_metadata_key("actions"): { @@ -107,7 +109,7 @@ def test_convert_a2a_message_to_event_returns_action_only_event(self): """Test A2A message conversion returns action-only events.""" message = Message( message_id="msg-1", - role="user", + role=Role.user, parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -129,20 +131,18 @@ def test_convert_a2a_message_to_event_returns_action_only_event(self): def test_convert_a2a_task_to_event_success(self): """Test successful conversion of A2A task to Event.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = {} + a2a_part = Part(text="stub") task = Task( id="task-1", status=TaskStatus( - state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + state=TaskState.submitted, timestamp=TEST_TIMESTAMP ), context_id="context-1", - history=[Message(message_id="msg-1", role="agent", parts=[a2a_part])], + history=[ + Message(message_id="msg-1", role=Role.agent, parts=[a2a_part]) + ], artifacts=[ - Artifact( - artifact_id="art-1", artifact_type="message", parts=[a2a_part] - ) + Artifact(artifact_id="art-1", parts=[a2a_part]) ], ) @@ -166,13 +166,12 @@ def test_convert_a2a_task_to_event_returns_action_only_event(self): task = Task( id="task-1", status=TaskStatus( - state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + state=TaskState.submitted, timestamp=TEST_TIMESTAMP ), context_id="context-1", artifacts=[ Artifact( artifact_id="art-1", - artifact_type="message", parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -199,13 +198,12 @@ def test_convert_a2a_task_to_event_merges_actions_across_artifacts(self): task = Task( id="task-1", status=TaskStatus( - state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + state=TaskState.submitted, timestamp=TEST_TIMESTAMP ), context_id="context-1", artifacts=[ Artifact( artifact_id="art-1", - artifact_type="message", parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -215,7 +213,6 @@ def test_convert_a2a_task_to_event_merges_actions_across_artifacts(self): ), Artifact( artifact_id="art-2", - artifact_type="message", parts=[], metadata={}, ), @@ -238,13 +235,12 @@ def test_convert_a2a_task_to_event_overwrites_nested_state_delta_values(self): task = Task( id="task-1", status=TaskStatus( - state=TaskState.submitted, timestamp="2024-01-01T00:00:00Z" + state=TaskState.submitted, timestamp=TEST_TIMESTAMP ), context_id="context-1", artifacts=[ Artifact( artifact_id="art-1", - artifact_type="message", parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -259,7 +255,6 @@ def test_convert_a2a_task_to_event_overwrites_nested_state_delta_values(self): ), Artifact( artifact_id="art-2", - artifact_type="message", parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -283,17 +278,18 @@ def test_convert_a2a_task_to_event_overwrites_nested_state_delta_values(self): def test_convert_a2a_task_to_event_merges_status_and_artifact_actions(self): """Test task conversion merges status and artifact actions.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = {} + a2a_part = Part(text="stub") + a2a_part.metadata.update({ + _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True + }) task = Task( id="task-1", status=TaskStatus( state=TaskState.input_required, - timestamp="2024-01-01T00:00:00Z", + timestamp=TEST_TIMESTAMP, message=Message( message_id="msg-1", - role="agent", + role=Role.agent, parts=[a2a_part], metadata={ _get_adk_metadata_key("actions"): { @@ -306,7 +302,6 @@ def test_convert_a2a_task_to_event_merges_status_and_artifact_actions(self): artifacts=[ Artifact( artifact_id="art-1", - artifact_type="message", parts=[], metadata={ _get_adk_metadata_key("actions"): { @@ -339,24 +334,22 @@ def test_convert_a2a_task_to_event_none(self): def test_convert_a2a_status_update_to_event_success(self): """Test successful conversion of A2A status update to Event.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = { + a2a_part = Part(text="stub") + a2a_part.metadata.update({ _get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True - } + }) update = TaskStatusUpdateEvent( task_id="task-1", status=TaskStatus( state=TaskState.input_required, - timestamp="now", + timestamp=TEST_TIMESTAMP, message=Message( message_id="m1", - role="agent", + role=Role.agent, parts=[a2a_part], ), ), context_id="context-1", - final=False, ) mock_genai_part = genai_types.Part( @@ -385,14 +378,10 @@ def test_convert_a2a_status_update_to_event_none(self): def test_convert_a2a_artifact_update_to_event_success(self): """Test successful conversion of A2A artifact update to Event.""" - a2a_part = Mock(spec=A2APart) - a2a_part.root = Mock(spec=TextPart) - a2a_part.root.metadata = {} + a2a_part = Part(text="stub") update = TaskArtifactUpdateEvent( task_id="task-1", - artifact=Artifact( - artifact_id="art-1", artifact_type="message", parts=[a2a_part] - ), + artifact=Artifact(artifact_id="art-1", parts=[a2a_part]), append=True, context_id="context-1", last_chunk=False, diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4f44e1363c..a145aef14a 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -23,7 +23,7 @@ from a2a.types import Part from a2a.types import Role from a2a.types import TaskState -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig @@ -62,7 +62,7 @@ def setup_method(self): self.mock_context = Mock(spec=RequestContext) self.mock_context.message = Mock(spec=Message) - self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.message.parts = [Mock(spec=Part)] self.mock_context.current_task = None self.mock_context.task_id = "test-task-id" self.mock_context.context_id = "test-context-id" @@ -132,16 +132,13 @@ async def mock_run_async(**kwargs): 0 ] assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False # Verify working event was enqueued working_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][0] assert working_event.status.state == TaskState.working - assert working_event.final == False # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True # The TaskResultAggregator is created with default state (working), and since no messages # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") @@ -212,11 +209,9 @@ async def mock_run_async(**kwargs): # Verify no submitted event (first call should be working event) working_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] assert working_event.status.state == TaskState.working - assert working_event.final == False # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True # The TaskResultAggregator is created with default state (working), and since no messages # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") @@ -437,11 +432,9 @@ async def mock_run_async(**kwargs): 0 ] assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True # The TaskResultAggregator is created with default state (working), and since no messages # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") @@ -496,11 +489,9 @@ async def mock_run_async(**kwargs): 0 ] assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True # The TaskResultAggregator is created with default state (working), and since no messages # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") @@ -572,14 +563,8 @@ async def mock_run_async(**kwargs): assert mock_aggregator.process_event.call_count == len(mock_events) # Verify final event has message field from aggregator and state is completed when aggregator state is working - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == mock_aggregator.task_status_message + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert not final_event.status.HasField("message") # When aggregator state is working but no message, final event should be working assert final_event.status.state == TaskState.working @@ -627,12 +612,10 @@ async def test_execute_with_exception_handling(self): 0 ] assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False # Check failure event (last) failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert failure_event.status.state == TaskState.failed - assert failure_event.final == True @pytest.mark.asyncio async def test_handle_request_with_aggregator_message(self): @@ -643,12 +626,12 @@ async def test_handle_request_with_aggregator_message(self): # Create a test message to be returned by the aggregator from a2a.types import Message from a2a.types import Role - from a2a.types import TextPart + from a2a.types import Part test_message = Mock(spec=Message) test_message.message_id = "test-message-id" test_message.role = Role.agent - test_message.parts = [Mock(spec=TextPart)] + test_message.parts = [Mock(spec=Part)] # Setup detailed mocks self.mock_request_converter.return_value = AgentRunRequest( @@ -698,14 +681,9 @@ async def mock_run_async(**kwargs): ) # Verify final event has message field from aggregator - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.status.message.message_id == test_message.message_id + assert final_event.status.message.role == test_message.role # When aggregator state is completed (not working), final event should be completed assert final_event.status.state == TaskState.completed @@ -718,12 +696,12 @@ async def test_handle_request_with_non_working_aggregator_state(self): # Create a test message to be returned by the aggregator from a2a.types import Message from a2a.types import Role - from a2a.types import TextPart + from a2a.types import Part test_message = Mock(spec=Message) test_message.message_id = "test-message-id" test_message.role = Role.agent - test_message.parts = [Mock(spec=TextPart)] + test_message.parts = [Mock(spec=Part)] # Setup detailed mocks self.mock_request_converter.return_value = AgentRunRequest( @@ -773,14 +751,9 @@ async def mock_run_async(**kwargs): ) # Verify final event preserves the non-working state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event - assert final_event.status.message == test_message + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] + assert final_event.status.message.message_id == test_message.message_id + assert final_event.status.message.role == test_message.role # When aggregator state is failed (not working), final event should keep failed state assert final_event.status.state == TaskState.failed @@ -797,12 +770,12 @@ async def test_handle_request_with_working_state_publishes_artifact_and_complete from a2a.types import Message from a2a.types import Part from a2a.types import Role - from a2a.types import TextPart + from a2a.types import Part test_message = Mock(spec=Message) test_message.message_id = "test-message-id" test_message.role = Role.agent - test_message.parts = [Part(root=TextPart(text="test content"))] + test_message.parts = [Part(text="test content")] # Setup detailed mocks self.mock_request_converter.return_value = AgentRunRequest( @@ -866,13 +839,7 @@ async def mock_run_async(**kwargs): assert artifact_event.artifact.parts == test_message.parts # Verify final status event was published with completed state - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert final_event.status.state == TaskState.completed assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" @@ -890,12 +857,12 @@ async def test_handle_request_with_non_working_state_publishes_status_only( from a2a.types import Message from a2a.types import Part from a2a.types import Role - from a2a.types import TextPart + from a2a.types import Part test_message = Mock(spec=Message) test_message.message_id = "test-message-id" test_message.role = Role.agent - test_message.parts = [Part(root=TextPart(text="test content"))] + test_message.parts = [Part(text="test content")] # Setup detailed mocks self.mock_request_converter.return_value = AgentRunRequest( @@ -953,15 +920,11 @@ async def mock_run_async(**kwargs): assert len(artifact_events) == 0 # Verify final status event was published with the actual state and message - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] # Get the last final event + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert final_event.status.state == TaskState.auth_required - assert final_event.status.message == test_message + assert final_event.status.message.message_id == test_message.message_id + assert final_event.status.message.role == test_message.role + assert final_event.status.message.parts == test_message.parts assert final_event.task_id == "test-task-id" assert final_event.context_id == "test-context-id" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py index f10d56e564..1061eec188 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -14,6 +14,8 @@ from __future__ import annotations +from datetime import datetime +from datetime import timezone from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -25,7 +27,7 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.converters.utils import _get_adk_metadata_key from google.adk.a2a.executor.a2a_agent_executor_impl import _A2aAgentExecutor as A2aAgentExecutor @@ -40,6 +42,18 @@ import pytest +TEST_TIMESTAMP = datetime(2024, 1, 1, tzinfo=timezone.utc) + + +def _mock_executor_context() -> Mock: + """Creates an executor context mock with string metadata fields.""" + executor_context = Mock() + executor_context.app_name = "test-app" + executor_context.user_id = "test-user" + executor_context.session_id = "test-session" + return executor_context + + class TestA2aAgentExecutor: """Test suite for A2aAgentExecutor class.""" @@ -67,7 +81,7 @@ def setup_method(self): self.mock_context = Mock(spec=RequestContext) self.mock_context.message = Mock(spec=Message) - self.mock_context.message.parts = [Mock(spec=TextPart)] + self.mock_context.message.parts = [Mock(spec=Part)] self.mock_context.current_task = None self.mock_context.task_id = "test-task-id" self.mock_context.context_id = "test-context-id" @@ -121,9 +135,8 @@ async def mock_run_async(**kwargs): # Mock event converter to return a working status update working_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.working, timestamp="now"), + status=TaskStatus(state=TaskState.working, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [working_event] @@ -174,7 +187,6 @@ async def mock_run_async(**kwargs): # Verify final event was enqueued final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True assert final_event.status.state == TaskState.completed assert final_event.metadata == self.expected_metadata @@ -224,9 +236,8 @@ async def mock_run_async(**kwargs): # Mock event converter working_event = TaskStatusUpdateEvent( task_id="existing-task-id", - status=TaskStatus(state=TaskState.working, timestamp="now"), + status=TaskStatus(state=TaskState.working, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [working_event] @@ -253,7 +264,6 @@ async def mock_run_async(**kwargs): # Verify final event final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] - assert final_event.final == True assert final_event.status.state == TaskState.completed assert final_event.metadata == self.expected_metadata @@ -351,15 +361,14 @@ async def mock_run_async(**kwargs): # Mock event converter to return events working_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.working, timestamp="now"), + status=TaskStatus(state=TaskState.working, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [working_event] # Initialize executor context attributes as they would be in execute() self.executor._invocation_metadata = {} - self.executor._executor_context = Mock() + self.executor._executor_context = _mock_executor_context() # Execute await self.executor._handle_request( @@ -382,13 +391,7 @@ async def mock_run_async(**kwargs): assert len(working_events) >= len(mock_events) # Verify final event is completed - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert final_event.status.state == TaskState.completed @pytest.mark.asyncio @@ -415,8 +418,7 @@ async def test_execute_with_exception_handling(self): # Check failure event (last) failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert failure_event.status.state == TaskState.failed - assert failure_event.final == True - assert "Test error" in failure_event.status.message.parts[0].root.text + assert "Test error" in failure_event.status.message.parts[0].text @pytest.mark.asyncio async def test_handle_request_with_non_working_state(self): @@ -443,9 +445,8 @@ async def mock_run_async(**kwargs): # Mock event converter to return a FAILED event failed_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.failed, timestamp="now"), + status=TaskStatus(state=TaskState.failed, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [failed_event] @@ -458,7 +459,7 @@ async def mock_run_async(**kwargs): # Initialize executor context attributes self.executor._invocation_metadata = {} - self.executor._executor_context = Mock() + self.executor._executor_context = _mock_executor_context() # Execute await self.executor._handle_request( @@ -470,14 +471,7 @@ async def mock_run_async(**kwargs): ) # Verify final event is FAILED, not COMPLETED - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - # The last event should be the synthesized final event - final_event = final_events[-1] + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert final_event.status.state == TaskState.failed @pytest.mark.asyncio @@ -510,10 +504,7 @@ async def mock_run_async(**kwargs): run_config=Mock(spec=RunConfig), ) - executor_context = Mock() - executor_context.app_name = "test-app" - executor_context.user_id = "test-user" - executor_context.session_id = "test-session" + executor_context = _mock_executor_context() await self.executor._handle_request( self.mock_context, @@ -523,13 +514,7 @@ async def mock_run_async(**kwargs): run_request, ) - final_events = [ - call[0][0] - for call in self.mock_event_queue.enqueue_event.call_args_list - if hasattr(call[0][0], "final") and call[0][0].final == True - ] - assert len(final_events) >= 1 - final_event = final_events[-1] + final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] assert final_event.status.state == TaskState.failed assert final_event.metadata == self.expected_metadata @@ -568,9 +553,8 @@ async def mock_run_async(**kwargs): # Mock event converter working_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.working, timestamp="now"), + status=TaskStatus(state=TaskState.working, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [working_event] @@ -609,9 +593,10 @@ async def test_execute_missing_user_input(self, mock_handle_user_input): # Set up handle_user_input to return an event missing_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.input_required, timestamp="now"), + status=TaskStatus( + state=TaskState.input_required, timestamp=TEST_TIMESTAMP + ), context_id="test-context-id", - final=False, ) mock_handle_user_input.return_value = missing_event @@ -697,9 +682,10 @@ async def test_long_running_functions_final_event(self, mock_lrf_class): lrf_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.input_required, timestamp="now"), + status=TaskStatus( + state=TaskState.input_required, timestamp=TEST_TIMESTAMP + ), context_id="test-context-id", - final=False, ) mock_lrf.create_long_running_function_call_event.return_value = lrf_event @@ -731,7 +717,7 @@ async def mock_run_async(**kwargs): self.mock_event_converter.return_value = [] self.executor._invocation_metadata = {} - self.executor._executor_context = Mock() + self.executor._executor_context = _mock_executor_context() await self.executor._handle_request( self.mock_context, @@ -787,13 +773,12 @@ async def mock_run_async(**kwargs): # Event converter returns one event working_event = TaskStatusUpdateEvent( task_id="test-task-id", - status=TaskStatus(state=TaskState.working, timestamp="now"), + status=TaskStatus(state=TaskState.working, timestamp=TEST_TIMESTAMP), context_id="test-context-id", - final=False, ) self.mock_event_converter.return_value = [working_event] - self.executor._executor_context = Mock() + self.executor._executor_context = _mock_executor_context() await self.executor._handle_request( self.mock_context, self.executor._executor_context, diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py index 24b5651e79..3313f8b175 100644 --- a/tests/unittests/a2a/executor/test_task_result_aggregator.py +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -20,7 +20,7 @@ from a2a.types import TaskState from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator import pytest @@ -30,7 +30,7 @@ def create_test_message(text: str): return Message( message_id="test-msg", role=Role.agent, - parts=[Part(root=TextPart(text=text))], + parts=[Part(text=text)], ) @@ -53,7 +53,6 @@ def test_process_failed_event(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.failed, message=status_message), - final=True, ) self.aggregator.process_event(event) @@ -71,7 +70,6 @@ def test_process_auth_required_event(self): status=TaskStatus( state=TaskState.auth_required, message=status_message ), - final=False, ) self.aggregator.process_event(event) @@ -89,7 +87,6 @@ def test_process_input_required_event(self): status=TaskStatus( state=TaskState.input_required, message=status_message ), - final=False, ) self.aggregator.process_event(event) @@ -104,7 +101,6 @@ def test_status_message_with_none_message(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.failed, message=None), - final=True, ) self.aggregator.process_event(event) @@ -119,7 +115,6 @@ def test_priority_order_failed_over_auth(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.auth_required, message=auth_message), - final=False, ) self.aggregator.process_event(auth_event) assert self.aggregator.task_state == TaskState.auth_required @@ -131,7 +126,6 @@ def test_priority_order_failed_over_auth(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.failed, message=failed_message), - final=True, ) self.aggregator.process_event(failed_event) assert self.aggregator.task_state == TaskState.failed @@ -147,7 +141,6 @@ def test_priority_order_auth_over_input(self): status=TaskStatus( state=TaskState.input_required, message=input_message ), - final=False, ) self.aggregator.process_event(input_event) assert self.aggregator.task_state == TaskState.input_required @@ -159,7 +152,6 @@ def test_priority_order_auth_over_input(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.auth_required, message=auth_message), - final=False, ) self.aggregator.process_event(auth_event) assert self.aggregator.task_state == TaskState.auth_required @@ -185,7 +177,6 @@ def test_working_state_does_not_override_higher_priority(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.failed, message=failed_message), - final=True, ) self.aggregator.process_event(failed_event) assert self.aggregator.task_state == TaskState.failed @@ -197,7 +188,6 @@ def test_working_state_does_not_override_higher_priority(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.working), - final=False, ) self.aggregator.process_event(working_event) assert self.aggregator.task_state == TaskState.failed @@ -214,7 +204,6 @@ def test_status_message_priority_ordering(self): status=TaskStatus( state=TaskState.input_required, message=input_message ), - final=False, ) self.aggregator.process_event(input_event) assert self.aggregator.task_status_message == input_message @@ -225,7 +214,6 @@ def test_status_message_priority_ordering(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.auth_required, message=auth_message), - final=False, ) self.aggregator.process_event(auth_event) assert self.aggregator.task_status_message == auth_message @@ -236,7 +224,6 @@ def test_status_message_priority_ordering(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.failed, message=failed_message), - final=True, ) self.aggregator.process_event(failed_event) assert self.aggregator.task_status_message == failed_message @@ -247,7 +234,6 @@ def test_status_message_priority_ordering(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.working, message=working_message), - final=False, ) self.aggregator.process_event(working_event) # State should still be failed, and message should remain the failed message @@ -262,7 +248,6 @@ def test_process_working_event_updates_message(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.working, message=working_message), - final=False, ) self.aggregator.process_event(event) @@ -277,7 +262,6 @@ def test_working_event_with_none_message(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.working, message=None), - final=False, ) self.aggregator.process_event(event) @@ -292,7 +276,6 @@ def test_working_event_updates_message_regardless_of_state(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.auth_required, message=auth_message), - final=False, ) self.aggregator.process_event(auth_event) assert self.aggregator.task_state == TaskState.auth_required @@ -304,7 +287,6 @@ def test_working_event_updates_message_regardless_of_state(self): task_id="test-task", context_id="test-context", status=TaskStatus(state=TaskState.working, message=working_message), - final=False, ) self.aggregator.process_event(working_event) assert ( diff --git a/tests/unittests/a2a/integration/client.py b/tests/unittests/a2a/integration/client.py index 11c34c35b9..f71b88c151 100644 --- a/tests/unittests/a2a/integration/client.py +++ b/tests/unittests/a2a/integration/client.py @@ -17,7 +17,7 @@ from a2a.client.client import ClientConfig as A2AClientConfig from a2a.client.client_factory import ClientFactory as A2AClientFactory from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.types import TransportProtocol as A2ATransport + from google.adk.a2a.agent.interceptors.new_integration_extension import _NEW_A2A_ADK_INTEGRATION_EXTENSION from google.adk.agents.remote_a2a_agent import RemoteA2aAgent import httpx @@ -44,7 +44,7 @@ def create_client(app, streaming: bool = False) -> RemoteA2aAgent: httpx_client=client, streaming=streaming, polling=False, - supported_transports=[A2ATransport.jsonrpc], + supported_protocol_bindings=["JSONRPC"], ) factory = A2AClientFactory(config=client_config) @@ -82,7 +82,7 @@ def create_a2a_client(app, streaming: bool = False): httpx_client=client, streaming=streaming, polling=False, - supported_transports=[A2ATransport.jsonrpc], + supported_protocol_bindings=["JSONRPC"], ) factory = A2AClientFactory(config=client_config) return factory.create(agent_card) diff --git a/tests/unittests/a2a/integration/server.py b/tests/unittests/a2a/integration/server.py index 86a0e1d629..a53bbf89b2 100644 --- a/tests/unittests/a2a/integration/server.py +++ b/tests/unittests/a2a/integration/server.py @@ -50,7 +50,7 @@ async def run_async(self, **kwargs): agent_card = AgentCard( name="remote_agent", - url="http://test", + supported_interfaces=[{"url": "http://test", "protocol_binding": "JSONRPC"}], description="A fun fact generator agent", capabilities=AgentCapabilities( streaming=True, diff --git a/tests/unittests/a2a/integration/test_client_server.py b/tests/unittests/a2a/integration/test_client_server.py index 3318efb84e..a18f3c7897 100644 --- a/tests/unittests/a2a/integration/test_client_server.py +++ b/tests/unittests/a2a/integration/test_client_server.py @@ -17,8 +17,10 @@ from a2a.types import Message as A2AMessage from a2a.types import Part as A2APart from a2a.types import Task +from a2a.types import Role +from a2a.types import SendMessageRequest from a2a.types import TaskState -from a2a.types import TextPart +from a2a.types import Part from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX from google.adk.events.event import Event from google.adk.platform import uuid as platform_uuid @@ -487,17 +489,18 @@ async def test_long_running_function_calls_error(): request_1 = A2AMessage( message_id=platform_uuid.new_uuid(), - parts=[A2APart(root=TextPart(text="Hi"))], - role="user", + parts=[A2APart(text="Hi")], + role=Role.ROLE_USER, ) response_1_events = [] - async for event in a2a_client.send_message(request=request_1): + async for event in a2a_client.send_message( + request=SendMessageRequest(message=request_1) + ): response_1_events.append(event) assert len(response_1_events) == 1 # Extract task_id from Turn 1 responses - assert response_1_events[0][1] is None - task = response_1_events[0][0] + task = response_1_events[0][1] assert isinstance(task, Task) assert task.status.state == TaskState.input_required extracted_task_id = task.id @@ -505,21 +508,22 @@ async def test_long_running_function_calls_error(): request_2 = A2AMessage( message_id=platform_uuid.new_uuid(), - parts=[A2APart(root=TextPart(text="Any update?"))], - role="user", + parts=[A2APart(text="Any update?")], + role=Role.ROLE_USER, task_id=extracted_task_id, context_id=task.context_id if hasattr(task, "context_id") else None, ) response_2_events = [] - async for event in a2a_client.send_message(request=request_2): + async for event in a2a_client.send_message( + request=SendMessageRequest(message=request_2) + ): response_2_events.append(event) # Verify that we get an error response for the second request due to missing function response assert len(response_2_events) == 1 - assert response_2_events[0][1] is None - error_response = response_2_events[0][0] + error_response = response_2_events[0][1] assert isinstance(error_response, Task) - assert error_response.status.message.parts[0].root.text == ( + assert error_response.status.message.parts[0].text == ( "It was not provided a function response for the function call." ) diff --git a/tests/unittests/a2a/logs/test_log_utils.py b/tests/unittests/a2a/logs/test_log_utils.py index 0ef28c62be..722a2dbc28 100644 --- a/tests/unittests/a2a/logs/test_log_utils.py +++ b/tests/unittests/a2a/logs/test_log_utils.py @@ -19,6 +19,7 @@ from unittest.mock import Mock from unittest.mock import patch +from google.protobuf.json_format import ParseDict import pytest # Skip all tests in this module if Python version is less than 3.10 @@ -28,17 +29,14 @@ # Import dependencies with version checking try: - from a2a.types import DataPart as A2ADataPart + from a2a.types import Part as A2APart from a2a.types import Message as A2AMessage - from a2a.types import MessageSendConfiguration - from a2a.types import MessageSendParams from a2a.types import Part as A2APart from a2a.types import Role - from a2a.types import SendMessageRequest from a2a.types import Task as A2ATask from a2a.types import TaskState from a2a.types import TaskStatus - from a2a.types import TextPart as A2ATextPart + from a2a.types import Part as A2APart from google.adk.a2a.logs.log_utils import build_a2a_request_log from google.adk.a2a.logs.log_utils import build_a2a_response_log from google.adk.a2a.logs.log_utils import build_message_part_log @@ -59,8 +57,7 @@ def test_text_part_short_text(self): """Test TextPart with short text.""" # Create real A2A objects - text_part = A2ATextPart(text="Hello, world!") - part = A2APart(root=text_part) + part = A2APart(text="Hello, world!") result = build_message_part_log(part) @@ -70,8 +67,7 @@ def test_text_part_long_text(self): """Test TextPart with long text that gets truncated.""" long_text = "x" * 150 # Long text that should be truncated - text_part = A2ATextPart(text=long_text) - part = A2APart(root=text_part) + part = A2APart(text=long_text) result = build_message_part_log(part) @@ -81,14 +77,15 @@ def test_text_part_long_text(self): def test_data_part_simple_data(self): """Test DataPart with simple data.""" - data_part = A2ADataPart(data={"key1": "value1", "key2": 42}) - part = A2APart(root=data_part) + data_part = A2APart() + ParseDict({"key1": "value1", "key2": 42}, data_part.data) + part = data_part result = build_message_part_log(part) - expected_data = {"key1": "value1", "key2": 42} - expected = f"DataPart: {json.dumps(expected_data, indent=2)}" - assert result == expected + assert result.startswith("DataPart: ") + logged_data = json.loads(result.removeprefix("DataPart: ")) + assert logged_data == {"key1": "value1", "key2": 42.0} def test_data_part_large_values(self): """Test DataPart with large values that get summarized.""" @@ -96,15 +93,14 @@ def test_data_part_large_values(self): large_dict = {f"key{i}": f"value{i}" for i in range(50)} large_list = list(range(100)) - data_part = A2ADataPart( - data={ - "small_value": "hello", - "large_dict": large_dict, - "large_list": large_list, - "normal_int": 42, - } - ) - part = A2APart(root=data_part) + data_part = A2APart() + ParseDict({ + "small_value": "hello", + "large_dict": large_dict, + "large_list": large_list, + "normal_int": 42, + }, data_part.data) + part = data_part result = build_message_part_log(part) @@ -126,7 +122,7 @@ def test_other_part_type(self): mock_root.metadata = None mock_part = Mock() - mock_part.root = mock_root + mock_part = mock_root mock_part.model_dump_json.return_value = '{"some": "data"}' result = build_message_part_log(mock_part) @@ -144,12 +140,12 @@ def test_request_with_parts(self): # Create mock request with all components req = A2AMessage( message_id="msg-456", - role="user", + role=Role.ROLE_USER, task_id="task-789", context_id="ctx-101", parts=[ - A2APart(root=A2ATextPart(text="Part 1")), - A2APart(root=A2ATextPart(text="Part 2")), + A2APart(text="Part 1"), + A2APart(text="Part 2"), ], metadata={"msg_key": "msg_value"}, ) @@ -163,7 +159,7 @@ def test_request_with_parts(self): # Verify all components are present assert "msg-456" in result - assert "user" in result + assert "Role: 1" in result or "ROLE_USER" in result or "user" in result assert "task-789" in result assert "ctx-101" in result assert "Part 0:" in result @@ -224,6 +220,8 @@ def test_success_response_with_client_event(self): assert ( "Status State: TaskState.working" in result or "Status State: working" in result + or "Status State: TASK_STATE_WORKING" in result + or "Status State: 2" in result or '"state":"working"' in result or '"state": "working"' in result ) @@ -236,8 +234,8 @@ def test_success_response_with_task_and_status_message(self): message_id="status-msg-123", role=Role.agent, parts=[ - A2APart(root=A2ATextPart(text="Status part 1")), - A2APart(root=A2ATextPart(text="Status part 2")), + A2APart(text="Status part 1"), + A2APart(text="Status part 2"), ], ) @@ -259,6 +257,8 @@ def test_success_response_with_task_and_status_message(self): assert ( "Role: Role.agent" in result or "Role: agent" in result + or "Role: ROLE_AGENT" in result + or "Role: 2" in result or '"role":"agent"' in result or '"role": "agent"' in result ) @@ -273,7 +273,7 @@ def test_success_response_with_message(self): role=Role.agent, task_id="task-456", context_id="ctx-789", - parts=[A2APart(root=A2ATextPart(text="Message part 1"))], + parts=[A2APart(text="Message part 1")], ) resp = message @@ -287,6 +287,8 @@ def test_success_response_with_message(self): assert ( "Role: Role.agent" in result or "Role: agent" in result + or "Role: ROLE_AGENT" in result + or "Role: 2" in result or '"role":"agent"' in result or '"role": "agent"' in result ) @@ -352,7 +354,7 @@ def test_build_message_part_log_with_metadata(self): mock_root.metadata = {"key": "value", "nested": {"data": "test"}} mock_part = Mock() - mock_part.root = mock_root + mock_part = mock_root mock_part.model_dump_json.return_value = '{"content": "test"}' result = build_message_part_log(mock_part) diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index 8549c16ec8..c01132246f 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -123,8 +123,12 @@ async def test_build_success( mock_agent.name = "test_agent" mock_agent.description = "Test agent description" - mock_primary_skill = Mock(spec=AgentSkill) - mock_sub_skill = Mock(spec=AgentSkill) + mock_primary_skill = AgentSkill( + id="primary", name="Primary", description="Primary skill" + ) + mock_sub_skill = AgentSkill( + id="sub", name="Sub", description="Sub skill" + ) mock_build_primary_skills.return_value = [mock_primary_skill] mock_build_sub_skills.return_value = [mock_sub_skill] @@ -137,15 +141,16 @@ async def test_build_success( assert isinstance(result, AgentCard) assert result.name == "test_agent" assert result.description == "Test agent description" - assert result.documentation_url is None - assert result.url == "http://localhost:80/a2a" + assert not result.documentation_url + assert result.supported_interfaces[0].url == "http://localhost:80/a2a" + assert result.supported_interfaces[0].protocol_binding == "jsonrpc" assert result.version == "0.0.1" - assert result.skills == [mock_primary_skill, mock_sub_skill] + assert list(result.skills) == [mock_primary_skill, mock_sub_skill] assert result.default_input_modes == ["text/plain"] assert result.default_output_modes == ["text/plain"] - assert result.supports_authenticated_extended_card is False - assert result.provider is None - assert result.security_schemes is None + assert result.capabilities.extended_agent_card is False + assert not result.provider.ListFields() + assert len(result.security_schemes) == 0 @patch("google.adk.a2a.utils.agent_card_builder._build_primary_skills") @patch("google.adk.a2a.utils.agent_card_builder._build_sub_agent_skills") @@ -158,13 +163,20 @@ async def test_build_with_custom_parameters( mock_agent.name = "test_agent" mock_agent.description = None # Should use default description - mock_primary_skill = Mock(spec=AgentSkill) - mock_sub_skill = Mock(spec=AgentSkill) + mock_primary_skill = AgentSkill( + id="primary", name="Primary", description="Primary skill" + ) + mock_sub_skill = AgentSkill( + id="sub", name="Sub", description="Sub skill" + ) mock_build_primary_skills.return_value = [mock_primary_skill] mock_build_sub_skills.return_value = [mock_sub_skill] - mock_provider = Mock(spec=AgentProvider) - mock_security_schemes = {"test": Mock(spec=SecurityScheme)} + mock_provider = AgentProvider( + url="https://provider.example.com", + organization="Example Org", + ) + mock_security_schemes = {"test": SecurityScheme()} builder = AgentCardBuilder( agent=mock_agent, @@ -181,12 +193,8 @@ async def test_build_with_custom_parameters( # Assert assert result.name == "test_agent" assert result.description == "An ADK Agent" # Default description - # The source code uses doc_url parameter but AgentCard expects documentation_url - # Since the source code doesn't map doc_url to documentation_url, it will be None - assert result.documentation_url is None - assert ( - result.url == "https://example.com/a2a" - ) # Should strip trailing slash + assert result.documentation_url == "https://docs.example.com" + assert result.supported_interfaces[0].url == "https://example.com/a2a" assert result.version == "2.0.0" assert result.provider == mock_provider assert result.security_schemes == mock_security_schemes diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9e2458ebd..70e526d685 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -832,14 +832,16 @@ async def test_to_a2a_with_agent_card_file_path( # Mock agent card data from file with all required fields agent_card_data = { "name": "file_agent", - "url": "http://example.com", + "supportedInterfaces": [{ + "url": "http://example.com", + "protocolBinding": "JSONRPC", + }], "description": "Test agent from file", "version": "1.0.0", "capabilities": {}, "skills": [], "defaultInputModes": ["text/plain"], "defaultOutputModes": ["text/plain"], - "supportsAuthenticatedExtendedCard": False, } mock_json_load.return_value = agent_card_data diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 0f1ce896a3..89c2e45d0a 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -22,18 +22,19 @@ from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory -from a2a.client.middleware import ClientCallContext +from a2a.client import ClientCallContext from a2a.types import AgentCapabilities from a2a.types import AgentCard from a2a.types import AgentSkill from a2a.types import Artifact from a2a.types import Message as A2AMessage +from a2a.types import SendMessageRequest from a2a.types import Task as A2ATask from a2a.types import TaskArtifactUpdateEvent from a2a.types import TaskState from a2a.types import TaskStatus as A2ATaskStatus from a2a.types import TaskStatusUpdateEvent -from a2a.types import TextPart +from a2a.types import Part from google.adk.a2a.agent import ParametersConfig from google.adk.a2a.agent import RequestInterceptor from google.adk.a2a.agent.utils import execute_after_request_interceptors @@ -59,7 +60,7 @@ def create_test_agent_card( """Create a test AgentCard with all required fields.""" return AgentCard( name=name, - url=url, + supported_interfaces=[{"protocol_binding": "JSONRPC", "url": url}], description=description, version="1.0", capabilities=AgentCapabilities(), @@ -169,7 +170,10 @@ def setup_method(self): """Setup test fixtures.""" self.agent_card_data = { "name": "test-agent", - "url": "https://example.com/rpc", + "supportedInterfaces": [{ + "url": "https://example.com/rpc", + "protocolBinding": "JSONRPC", + }], "description": "Test agent", "version": "1.0", "capabilities": {}, @@ -239,6 +243,8 @@ async def test_ensure_httpx_client_updates_factory_with_new_client(self): ClientConfig(httpx_client=None), ), ) + agent._httpx_client = None + agent._a2a_client_factory._config.httpx_client = None assert agent._a2a_client_factory._config.httpx_client is None client = await agent._ensure_httpx_client() @@ -262,6 +268,8 @@ async def test_ensure_httpx_client_reregisters_transports_with_new_client( agent_card=create_test_agent_card(), a2a_client_factory=factory, ) + agent._httpx_client = None + agent._a2a_client_factory._config.httpx_client = None assert agent._a2a_client_factory._config.httpx_client is None assert "transport_label" in agent._a2a_client_factory._registry @@ -325,7 +333,9 @@ async def test_resolve_agent_card_from_file_success(self): try: result = await agent._resolve_agent_card_from_file(temp_path) assert result.name == self.agent_card.name - assert result.url == self.agent_card.url + assert result.supported_interfaces[0].url == ( + self.agent_card.supported_interfaces[0].url + ) finally: Path(temp_path).unlink() @@ -389,7 +399,7 @@ async def test_validate_agent_card_no_url(self): tags=["test"], ) ], - url="", # Empty URL to trigger validation error + supported_interfaces=[], # Empty URL to trigger validation error ) with pytest.raises( @@ -406,7 +416,7 @@ async def test_validate_agent_card_invalid_url(self): invalid_card = AgentCard( name="test", - url="invalid-url", + supported_interfaces=[{"protocol_binding": "JSONRPC", "url": "invalid-url"}], description="test", version="1.0", capabilities=AgentCapabilities(), @@ -463,19 +473,17 @@ async def test_ensure_resolved_with_direct_agent_card_with_factory(self): with patch("httpx.AsyncClient") as mock_client_class: mock_client = AsyncMock() mock_client_class.return_value = mock_client - - with patch( - "google.adk.agents.remote_a2a_agent.A2AClientFactory" - ) as mock_factory_class: - mock_a2a_client = Mock() - mock_factory = Mock() - mock_factory.create.return_value = mock_a2a_client - mock_factory_class.return_value = mock_factory - + mock_a2a_client = Mock() + with patch.object( + agent._a2a_client_factory, + "create", + return_value=mock_a2a_client, + ) as mock_create: await agent._ensure_resolved() assert agent._is_resolved is True assert agent._a2a_client == mock_a2a_client + mock_create.assert_called_once_with(agent_card) @pytest.mark.asyncio async def test_ensure_resolved_with_url_source(self): @@ -870,10 +878,10 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task.id = "task-123" mock_a2a_task.context_id = "context-123" mock_a2a_task.status = Mock(spec=A2ATaskStatus) - mock_a2a_task.status.state = TaskState.completed + mock_a2a_task.status.state = TaskState.TASK_STATE_COMPLETED # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -889,7 +897,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -977,10 +985,10 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task.id = "task-123" mock_a2a_task.context_id = "context-123" mock_a2a_task.status = Mock(spec=A2ATaskStatus) - mock_a2a_task.status.state = TaskState.submitted + mock_a2a_task.status.state = TaskState.TASK_STATE_SUBMITTED # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -996,7 +1004,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1019,12 +1027,12 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): "task_state,event_content", [ pytest.param( - TaskState.submitted, + TaskState.TASK_STATE_SUBMITTED, genai_types.Content(role="model", parts=[]), id="submitted_empty_parts", ), pytest.param( - TaskState.working, + TaskState.TASK_STATE_WORKING, None, id="working_no_content", ), @@ -1059,7 +1067,7 @@ async def test_handle_a2a_response_with_task_missing_content( mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1074,10 +1082,10 @@ async def test_handle_a2a_response_with_task_working_and_no_update(self): mock_a2a_task.id = "task-123" mock_a2a_task.context_id = "context-123" mock_a2a_task.status = Mock(spec=A2ATaskStatus) - mock_a2a_task.status.state = TaskState.working + mock_a2a_task.status.state = TaskState.TASK_STATE_WORKING # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1093,7 +1101,7 @@ async def test_handle_a2a_response_with_task_working_and_no_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1121,11 +1129,11 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message = Mock(spec=A2AMessage) mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.completed + mock_update.status.state = TaskState.TASK_STATE_COMPLETED mock_update.status.message = mock_a2a_message # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1139,7 +1147,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1167,11 +1175,11 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message = Mock(spec=A2AMessage) mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.working + mock_update.status.state = TaskState.TASK_STATE_WORKING mock_update.status.message = mock_a2a_message # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1185,7 +1193,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1209,11 +1217,11 @@ async def test_handle_a2a_response_with_task_status_update_no_message(self): mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.completed + mock_update.status.state = TaskState.TASK_STATE_COMPLETED mock_update.status.message = None result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result is None @@ -1246,7 +1254,7 @@ async def test_handle_a2a_response_with_artifact_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1273,7 +1281,7 @@ async def test_handle_a2a_response_with_partial_artifact_update(self): mock_update.last_chunk = False result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result is None @@ -1437,10 +1445,10 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task.id = "task-123" mock_a2a_task.context_id = "context-123" mock_a2a_task.status = Mock(spec=A2ATaskStatus) - mock_a2a_task.status.state = TaskState.completed + mock_a2a_task.status.state = TaskState.TASK_STATE_COMPLETED # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1456,7 +1464,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1480,10 +1488,10 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task.id = "task-123" mock_a2a_task.context_id = "context-123" mock_a2a_task.status = Mock(spec=A2ATaskStatus) - mock_a2a_task.status.state = TaskState.submitted + mock_a2a_task.status.state = TaskState.TASK_STATE_SUBMITTED # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1499,7 +1507,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1527,11 +1535,11 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message = Mock(spec=A2AMessage) mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.completed + mock_update.status.state = TaskState.TASK_STATE_COMPLETED mock_update.status.message = mock_a2a_message # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1545,7 +1553,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1573,11 +1581,11 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message = Mock(spec=A2AMessage) mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.working + mock_update.status.state = TaskState.TASK_STATE_WORKING mock_update.status.message = mock_a2a_message # Create a proper Event mock that can handle custom_metadata - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_event = Event( author=self.agent.name, invocation_id=self.mock_context.invocation_id, @@ -1591,7 +1599,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1615,11 +1623,11 @@ async def test_handle_a2a_response_with_task_status_update_no_message(self): mock_update = Mock(spec=TaskStatusUpdateEvent) mock_update.status = Mock(A2ATaskStatus) - mock_update.status.state = TaskState.completed + mock_update.status.state = TaskState.TASK_STATE_COMPLETED mock_update.status.message = None result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result is None @@ -1652,7 +1660,7 @@ async def test_handle_a2a_response_with_artifact_update(self): mock_convert.return_value = mock_event result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1679,7 +1687,7 @@ async def test_handle_a2a_response_with_partial_artifact_update(self): mock_update.last_chunk = False result = await self.agent._handle_a2a_response( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result is None @@ -1764,7 +1772,7 @@ async def test_handle_a2a_response_impl_with_task_and_no_update(self): self.mock_config.a2a_task_converter.return_value = mock_event result = await self.agent._handle_a2a_response_v2( - (mock_a2a_task, None), self.mock_context + (None, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1800,7 +1808,7 @@ async def test_handle_a2a_response_impl_with_task_status_update(self): self.mock_config.a2a_status_update_converter.return_value = mock_event result = await self.agent._handle_a2a_response_v2( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1832,7 +1840,7 @@ async def test_handle_a2a_response_impl_with_task_artifact_update(self): self.mock_config.a2a_artifact_update_converter.return_value = mock_event result = await self.agent._handle_a2a_response_v2( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result == mock_event @@ -1862,7 +1870,7 @@ async def test_handle_a2a_response_impl_update_converter_returns_none(self): self.mock_config.a2a_artifact_update_converter.return_value = None result = await self.agent._handle_a2a_response_v2( - (mock_a2a_task, mock_update), self.mock_context + (mock_update, mock_a2a_task), self.mock_context ) assert result is None @@ -1993,9 +2001,9 @@ async def test_run_async_impl_successful_request(self): ) as mock_construct: # Create proper A2A part mocks from a2a.client import Client as A2AClient - from a2a.types import TextPart + from a2a.types import Part - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_construct.return_value = ( [mock_a2a_part], "context-123", @@ -2003,7 +2011,11 @@ async def test_run_async_impl_successful_request(self): # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock(metadata={}) + from a2a.types import StreamResponse as A2AStreamResponse + from a2a.types import Task as A2ATask + mock_stream_response = A2AStreamResponse() + mock_task = A2ATask() + mock_response = (mock_stream_response, mock_task) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2032,11 +2044,10 @@ async def test_run_async_impl_successful_request(self): with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) + mock_message = A2AMessage() mock_message_class.return_value = mock_message # Add model_dump to mock_response for metadata - mock_response.model_dump.return_value = {"test": "response"} # Execute events = [] @@ -2065,9 +2076,9 @@ async def test_run_async_impl_a2a_client_error(self): self.agent, "_construct_message_parts_from_session" ) as mock_construct: # Create proper A2A part mocks - from a2a.types import TextPart + from a2a.types import Part - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_construct.return_value = ( [mock_a2a_part], "context-123", @@ -2088,7 +2099,7 @@ async def test_run_async_impl_a2a_client_error(self): with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) + mock_message = A2AMessage() mock_message_class.return_value = mock_message events = [] @@ -2132,9 +2143,9 @@ async def test_run_async_impl_with_meta_provider(self): ) as mock_construct: # Create proper A2A part mocks from a2a.client import Client as A2AClient - from a2a.types import TextPart + from a2a.types import Part - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_construct.return_value = ( [mock_a2a_part], "context-123", @@ -2142,7 +2153,11 @@ async def test_run_async_impl_with_meta_provider(self): # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock(metadata={}) + from a2a.types import StreamResponse as A2AStreamResponse + from a2a.types import Task as A2ATask + mock_stream_response = A2AStreamResponse() + mock_task = A2ATask() + mock_response = (mock_stream_response, mock_task) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2170,11 +2185,10 @@ async def test_run_async_impl_with_meta_provider(self): with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) + mock_message = A2AMessage() mock_message_class.return_value = mock_message # Add model_dump to mock_response for metadata - mock_response.model_dump.return_value = {"test": "response"} # Execute events = [] @@ -2185,10 +2199,17 @@ async def test_run_async_impl_with_meta_provider(self): mock_meta_provider.assert_called_once_with( self.mock_context, mock_message ) - mock_a2a_client.send_message.assert_called_once_with( - request=mock_message, - request_metadata=request_metadata, - context=ClientCallContext(state=self.mock_session.state), + mock_a2a_client.send_message.assert_called_once() + send_kwargs = mock_a2a_client.send_message.call_args.kwargs + assert send_kwargs["context"] == ClientCallContext( + state=self.mock_session.state + ) + assert isinstance( + send_kwargs["request"], SendMessageRequest + ) + assert send_kwargs["request"].message == mock_message + assert send_kwargs["request"].metadata["custom_meta"] == ( + "value" ) @@ -2269,9 +2290,9 @@ async def test_run_async_impl_successful_request(self): ) as mock_construct: # Create proper A2A part mocks from a2a.client import Client as A2AClient - from a2a.types import TextPart + from a2a.types import Part - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_construct.return_value = ( [mock_a2a_part], "context-123", @@ -2279,7 +2300,11 @@ async def test_run_async_impl_successful_request(self): # Mock A2A client mock_a2a_client = create_autospec(spec=A2AClient, instance=True) - mock_response = Mock(metadata={}) + from a2a.types import StreamResponse as A2AStreamResponse + from a2a.types import Task as A2ATask + mock_stream_response = A2AStreamResponse() + mock_task = A2ATask() + mock_response = (mock_stream_response, mock_task) mock_send_message = AsyncMock() mock_send_message.__aiter__.return_value = [mock_response] mock_a2a_client.send_message.return_value = mock_send_message @@ -2308,13 +2333,10 @@ async def test_run_async_impl_successful_request(self): with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) + mock_message = A2AMessage() mock_message_class.return_value = mock_message # Add model_dump to mock_response for metadata - mock_response.root.model_dump.return_value = { - "test": "response" - } # Execute events = [] @@ -2343,9 +2365,9 @@ async def test_run_async_impl_a2a_client_error(self): self.agent, "_construct_message_parts_from_session" ) as mock_construct: # Create proper A2A part mocks - from a2a.types import TextPart + from a2a.types import Part - mock_a2a_part = Mock(spec=TextPart) + mock_a2a_part = Part() mock_construct.return_value = ( [mock_a2a_part], "context-123", @@ -2366,7 +2388,7 @@ async def test_run_async_impl_a2a_client_error(self): with patch( "google.adk.agents.remote_a2a_agent.A2AMessage" ) as mock_message_class: - mock_message = Mock(spec=A2AMessage) + mock_message = A2AMessage() mock_message_class.return_value = mock_message events = [] @@ -2487,6 +2509,7 @@ async def test_full_workflow_with_direct_agent_card(self): # Mock session with text event mock_part = Mock() mock_part.text = "Hello world" + mock_part.thought = None mock_content = Mock() mock_content.parts = [mock_part] @@ -2510,20 +2533,12 @@ async def test_full_workflow_with_direct_agent_card(self): ) as mock_convert: mock_convert.return_value = mock_event - with patch( - "google.adk.agents.remote_a2a_agent.convert_genai_part_to_a2a_part" - ) as mock_convert_part: - from a2a.types import TextPart + with patch("httpx.AsyncClient") as mock_httpx_client_class: + mock_httpx_client = AsyncMock() + mock_httpx_client_class.return_value = mock_httpx_client - mock_a2a_part = Mock(spec=TextPart) - mock_convert_part.return_value = mock_a2a_part - - with patch("httpx.AsyncClient") as mock_httpx_client_class: - mock_httpx_client = AsyncMock() - mock_httpx_client_class.return_value = mock_httpx_client - - with patch.object(agent, "_a2a_client") as mock_a2a_client: - mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) + with patch.object(agent, "_a2a_client") as mock_a2a_client: + mock_a2a_message = A2AMessage() mock_a2a_message.context_id = "context-123" mock_a2a_message.metadata = {} mock_response = mock_a2a_message @@ -2552,9 +2567,6 @@ async def test_full_workflow_with_direct_agent_card(self): mock_req_log.return_value = "Mock request log" mock_resp_log.return_value = "Mock response log" - # Add model_dump to mock_response for metadata - mock_response.model_dump.return_value = {"test": "response"} - # Execute events = [] async for event in agent._run_async_impl(mock_context): @@ -2584,6 +2596,7 @@ async def test_full_workflow_with_direct_agent_card_and_factory(self): # Mock session with text event mock_part = Mock() mock_part.text = "Hello world" + mock_part.thought = None mock_content = Mock() mock_content.parts = [mock_part] @@ -2607,20 +2620,12 @@ async def test_full_workflow_with_direct_agent_card_and_factory(self): ) as mock_convert: mock_convert.return_value = mock_event - with patch( - "google.adk.agents.remote_a2a_agent.convert_genai_part_to_a2a_part" - ) as mock_convert_part: - from a2a.types import TextPart - - mock_a2a_part = Mock(spec=TextPart) - mock_convert_part.return_value = mock_a2a_part - - with patch("httpx.AsyncClient") as mock_httpx_client_class: - mock_httpx_client = AsyncMock() - mock_httpx_client_class.return_value = mock_httpx_client + with patch("httpx.AsyncClient") as mock_httpx_client_class: + mock_httpx_client = AsyncMock() + mock_httpx_client_class.return_value = mock_httpx_client - with patch.object(agent, "_a2a_client") as mock_a2a_client: - mock_a2a_message = create_autospec(spec=A2AMessage, instance=True) + with patch.object(agent, "_a2a_client") as mock_a2a_client: + mock_a2a_message = A2AMessage() mock_a2a_message.context_id = "context-123" mock_a2a_message.metadata = {} mock_response = mock_a2a_message @@ -2650,7 +2655,6 @@ async def test_full_workflow_with_direct_agent_card_and_factory(self): mock_resp_log.return_value = "Mock response log" # Add model_dump to mock_response for metadata - mock_response.model_dump.return_value = {"test": "response"} # Execute events = [] diff --git a/tests/unittests/integrations/agent_registry/test_agent_registry.py b/tests/unittests/integrations/agent_registry/test_agent_registry.py index 377fceda82..0755c3000d 100644 --- a/tests/unittests/integrations/agent_registry/test_agent_registry.py +++ b/tests/unittests/integrations/agent_registry/test_agent_registry.py @@ -17,7 +17,7 @@ from unittest.mock import MagicMock from unittest.mock import patch -from a2a.types import TransportProtocol as A2ATransport + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.integrations.agent_registry import _ProtocolType from google.adk.integrations.agent_registry import AgentRegistry @@ -57,7 +57,7 @@ async def test_get_mcp_toolset_adds_destination_id( ), "interfaces": [{ "url": "https://mcp.com", - "protocolBinding": A2ATransport.jsonrpc, + "protocolBinding": "JSONRPC", }], } mock_httpx.return_value.__enter__.return_value.get.return_value = ( @@ -123,7 +123,7 @@ async def test_get_mcp_toolset_handles_missing_destination_id( # "mcpServerId" is intentionally omitted "interfaces": [{ "url": "https://mcp.com", - "protocolBinding": A2ATransport.jsonrpc, + "protocolBinding": "JSONRPC", }], } mock_httpx.return_value.__enter__.return_value.get.return_value = ( @@ -174,7 +174,7 @@ def test_get_connection_uri_mcp_interfaces_top_level(self, registry): ] } uri = registry._get_connection_uri( - resource_details, protocol_binding=A2ATransport.jsonrpc + resource_details, protocol_binding="JSONRPC" ) assert uri == "https://mcp-v1main.com" @@ -184,7 +184,7 @@ def test_get_connection_uri_agent_nested_protocols(self, registry): "type": _ProtocolType.A2A_AGENT, "interfaces": [{ "url": "https://my-agent.com", - "protocolBinding": A2ATransport.jsonrpc, + "protocolBinding": "JSONRPC", }], }] } @@ -204,7 +204,7 @@ def test_get_connection_uri_filtering(self, registry): "type": _ProtocolType.A2A_AGENT, "interfaces": [{ "url": "https://my-agent.com", - "protocolBinding": A2ATransport.http_json, + "protocolBinding": "HTTP_JSON", }], }, ] @@ -217,7 +217,7 @@ def test_get_connection_uri_filtering(self, registry): # Filter by binding uri = registry._get_connection_uri( - resource_details, protocol_binding=A2ATransport.http_json + resource_details, protocol_binding="HTTP_JSON" ) assert uri == "https://my-agent.com" @@ -225,7 +225,7 @@ def test_get_connection_uri_filtering(self, registry): uri = registry._get_connection_uri( resource_details, protocol_type=_ProtocolType.A2A_AGENT, - protocol_binding=A2ATransport.jsonrpc, + protocol_binding="JSONRPC", ) assert uri is None @@ -279,7 +279,7 @@ def test_get_mcp_toolset(self, mock_httpx, registry): "displayName": "TestPrefix", "interfaces": [{ "url": "https://mcp.com", - "protocolBinding": A2ATransport.jsonrpc, + "protocolBinding": "JSONRPC", }], } mock_response.raise_for_status = MagicMock() @@ -305,7 +305,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): "type": _ProtocolType.A2A_AGENT, "interfaces": [{ "url": "https://my-agent.com", - "protocolBinding": A2ATransport.jsonrpc, + "protocolBinding": "JSONRPC", }], }], "skills": [{"id": "s1", "name": "Skill 1", "description": "Desc 1"}], @@ -322,7 +322,7 @@ def test_get_remote_a2a_agent(self, mock_httpx, registry): assert isinstance(agent, RemoteA2aAgent) assert agent.name == "TestAgent" assert agent.description == "Test Desc" - assert agent._agent_card.url == "https://my-agent.com" + assert agent._agent_card.supported_interfaces[0].url == "https://my-agent.com" assert agent._agent_card.version == "1.0" assert len(agent._agent_card.skills) == 1 assert agent._agent_card.skills[0].name == "Skill 1" @@ -339,15 +339,15 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): "description": "CardDesc", "version": "2.0", "url": "https://card-url.com", - "skills": [{ - "id": "s1", - "name": "S1", - "description": "D1", - "tags": ["t1"], - }], - "capabilities": {"streaming": True, "polling": False}, - "defaultInputModes": ["text"], - "defaultOutputModes": ["text"], + "skills": [{ + "id": "s1", + "name": "S1", + "description": "D1", + "tags": ["t1"], + }], + "capabilities": {"streaming": True}, + "defaultInputModes": ["text"], + "defaultOutputModes": ["text"], }, }, } @@ -364,7 +364,7 @@ def test_get_remote_a2a_agent_with_card(self, mock_httpx, registry): assert agent.name == "CardName" assert agent.description == "CardDesc" assert agent._agent_card.version == "2.0" - assert agent._agent_card.url == "https://card-url.com" + assert agent._agent_card.supported_interfaces[0].url == "https://card-url.com" assert agent._agent_card.capabilities.streaming is True assert len(agent._agent_card.skills) == 1 assert agent._agent_card.skills[0].name == "S1"