Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ dev = [

a2a = [
# go/keep-sorted start
"a2a-sdk>=0.3.4,<0.4.0",
"a2a-sdk>=1.0.0a0",
# go/keep-sorted end
]

Expand All @@ -120,7 +120,7 @@ eval = [

test = [
# go/keep-sorted start
"a2a-sdk>=0.3.0,<0.4.0",
"a2a-sdk>=1.0.0a0",
"anthropic>=0.43.0", # For anthropic model tests
"crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+
"kubernetes>=29.0.0", # For GkeCodeExecutor
Expand Down
39 changes: 39 additions & 0 deletions src/google/adk/a2a/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from a2a.types import Role
from a2a.types import TaskState


def _install_task_state_aliases() -> None:
"""Adds pre-1.0 TaskState aliases expected by ADK code and tests."""
alias_by_name = {
"working": "TASK_STATE_WORKING",
"failed": "TASK_STATE_FAILED",
"input_required": "TASK_STATE_INPUT_REQUIRED",
"auth_required": "TASK_STATE_AUTH_REQUIRED",
"completed": "TASK_STATE_COMPLETED",
"submitted": "TASK_STATE_SUBMITTED",
"canceled": "TASK_STATE_CANCELED",
"unknown": "TASK_STATE_UNKNOWN",
}
for alias, canonical in alias_by_name.items():
if not hasattr(TaskState, alias) and hasattr(TaskState, canonical):
setattr(TaskState, alias, getattr(TaskState, canonical))


_install_task_state_aliases()


def _install_role_aliases() -> None:
"""Adds pre-1.0 Role aliases expected by ADK code and tests."""
alias_by_name = {
"user": "ROLE_USER",
"agent": "ROLE_AGENT",
}
for alias, canonical in alias_by_name.items():
if not hasattr(Role, alias) and hasattr(Role, canonical):
setattr(Role, alias, getattr(Role, canonical))


_install_role_aliases()
2 changes: 1 addition & 1 deletion src/google/adk/a2a/agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Optional
from typing import Union

from a2a.client.middleware import ClientCallContext
from a2a.client import ClientCallContext
from a2a.server.events import Event as A2AEvent
from a2a.types import Message as A2AMessage
from pydantic import BaseModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Union

from a2a.client.middleware import ClientCallContext
from a2a.client import ClientCallContext
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import Message as A2AMessage
from google.adk.a2a.agent.config import ParametersConfig
Expand All @@ -39,15 +39,13 @@ async def _before_request(
if params.client_call_context is None:
params.client_call_context = ClientCallContext()

http_kwargs = params.client_call_context.state.get('http_kwargs', {})
headers = http_kwargs.get('headers', {})
headers = params.client_call_context.service_parameters or {}
a2a_extensions = headers.get(HTTP_EXTENSION_HEADER, '').split(',')
a2a_extensions = [ext for ext in a2a_extensions if ext]
if _NEW_A2A_ADK_INTEGRATION_EXTENSION not in a2a_extensions:
a2a_extensions.append(_NEW_A2A_ADK_INTEGRATION_EXTENSION)
headers[HTTP_EXTENSION_HEADER] = ','.join(a2a_extensions)
http_kwargs['headers'] = headers
params.client_call_context.state['http_kwargs'] = http_kwargs
params.client_call_context.service_parameters = headers
return a2a_request, params


Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/a2a/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Union

from a2a.client import ClientEvent as A2AClientEvent
from a2a.client.middleware import ClientCallContext
from a2a.client import ClientCallContext
from a2a.types import Message as A2AMessage

from ...agents.invocation_context import InvocationContext
Expand Down
211 changes: 144 additions & 67 deletions src/google/adk/a2a/converters/event_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
from typing import Optional

from a2a.server.events import Event as A2AEvent
from a2a.types import DataPart
from a2a.types import Message
from a2a.types import Part as A2APart
from a2a.types import Role
from a2a.types import Task
from a2a.types import TaskState
from a2a.types import TaskStatus
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.platform import time as platform_time
from google.adk.platform import uuid as platform_uuid
from google.genai import types as genai_types
from google.protobuf.json_format import MessageToDict
from google.protobuf.timestamp_pb2 import Timestamp

from ...agents.invocation_context import InvocationContext
from ...events.event import Event
Expand Down Expand Up @@ -105,6 +105,83 @@ def _serialize_metadata_value(value: Any) -> str:
return str(value)


def _get_part_metadata_value(part: A2APart, key: str) -> Any:
"""Returns a metadata value from either proto Struct or dict-like metadata."""
metadata = getattr(part, "metadata", None)
if not metadata:
return None
try:
return metadata.get(key)
except AttributeError:
try:
return metadata[key]
except Exception:
return None


def _get_part_data_dict(part: A2APart) -> Dict[str, Any]:
"""Returns a part's data payload as a plain dict when possible."""
data = getattr(part, "data", None)
if data is None:
return {}
if isinstance(data, dict):
return data
get_method = getattr(data, "get", None)
if callable(get_method):
try:
return {
"id": get_method("id"),
"name": get_method("name"),
}
except Exception:
pass
try:
return MessageToDict(data)
except Exception:
return {}


def _coerce_a2a_message(message: Message | Any) -> Message:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid using "any" dataype

"""Returns a proto Message, tolerating older mock/dict-style inputs in tests."""
if (
isinstance(message, Message)
and type(message).__module__ != "unittest.mock"
):
return message

coerced_message = Message()
for field_name in ("message_id", "task_id", "context_id"):
field_value = getattr(message, field_name, None)
if field_value:
setattr(coerced_message, field_name, field_value)

role = getattr(message, "role", None)
if role is not None:
coerced_message.role = role
else:
coerced_message.role = Role.ROLE_AGENT

parts = getattr(message, "parts", None)
if parts:
for part in parts:
if isinstance(part, A2APart):
coerced_message.parts.append(part)

metadata = getattr(message, "metadata", None)
if metadata:
coerced_message.metadata.update(metadata)

return coerced_message


def _create_timestamp() -> Timestamp:
"""Creates a protobuf timestamp from the current platform time."""
now = platform_time.get_time()
seconds = int(now)
nanos = int((now - seconds) * 1_000_000_000)
return Timestamp(seconds=seconds, nanos=nanos)


def _get_context_metadata(
event: Event, invocation_context: InvocationContext
) -> Dict[str, str]:
Expand Down Expand Up @@ -184,19 +261,30 @@ def _process_long_running_tool(a2a_part: A2APart, event: Event) -> None:
a2a_part: The A2A part to potentially mark as long-running.
event: The ADK event containing long-running tool information.
"""
if (
isinstance(a2a_part.root, DataPart)
and event.long_running_tool_ids
and a2a_part.root.metadata
and a2a_part.root.metadata.get(
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
)
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and a2a_part.root.data.get("id") in event.long_running_tool_ids
):
a2a_part.root.metadata[
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
] = True
if not event.long_running_tool_ids or not getattr(a2a_part, "metadata", None):
return
has_data = getattr(a2a_part, "HasField", None)
if callable(has_data):
try:
if not a2a_part.HasField("data"):
return
except Exception:
pass

type_key = _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
part_type = (
_get_part_metadata_value(a2a_part, type_key)
or _get_part_metadata_value(a2a_part, A2A_DATA_PART_METADATA_TYPE_KEY)
or _get_part_metadata_value(a2a_part, "adk_type")
)
if part_type != A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL:
return

data_dict = _get_part_data_dict(a2a_part)
if data_dict.get("id") in event.long_running_tool_ids:
a2a_part.metadata.update({
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY): True
})


def convert_a2a_task_to_event(
Expand Down Expand Up @@ -229,7 +317,7 @@ def convert_a2a_task_to_event(
message = None
if a2a_task.artifacts:
message = Message(
message_id="", role=Role.agent, parts=a2a_task.artifacts[-1].parts
message_id="", role=Role.ROLE_AGENT, parts=a2a_task.artifacts[-1].parts
)
elif (
a2a_task.status
Expand Down Expand Up @@ -321,15 +409,10 @@ def convert_a2a_message_to_event(
continue

# Check for long-running tools
if (
a2a_part.root.metadata
and a2a_part.root.metadata.get(
_get_adk_metadata_key(
A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY
)
)
is True
):
if _get_part_metadata_value(
a2a_part,
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
) is True:
for part in parts:
if part.function_call:
long_running_tool_ids.add(part.function_call.id)
Expand Down Expand Up @@ -372,7 +455,7 @@ def convert_a2a_message_to_event(
def convert_event_to_a2a_message(
event: Event,
invocation_context: InvocationContext | None = None,
role: Role = Role.agent,
role: Role = Role.ROLE_AGENT,
part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
) -> Optional[Message]:
"""Converts an ADK event to an A2A message.
Expand Down Expand Up @@ -446,22 +529,19 @@ def _create_error_status_event(
context_id=context_id,
metadata=event_metadata,
status=TaskStatus(
state=TaskState.failed,
state=TaskState.TASK_STATE_FAILED,
message=Message(
message_id=platform_uuid.new_uuid(),
role=Role.agent,
parts=[TextPart(text=error_message)],
role=Role.ROLE_AGENT,
parts=[A2APart(text=error_message)],
metadata={
_get_adk_metadata_key("error_code"): str(event.error_code)
}
if event.error_code
else {},
),
timestamp=datetime.fromtimestamp(
platform_time.get_time(), tz=timezone.utc
).isoformat(),
),
final=False,
timestamp=_create_timestamp(),
)
)


Expand All @@ -484,48 +564,45 @@ def _create_status_update_event(
Returns:
A TaskStatusUpdateEvent with RUNNING state.
"""
proto_message = _coerce_a2a_message(message)

status = TaskStatus(
state=TaskState.working,
message=message,
timestamp=datetime.fromtimestamp(
platform_time.get_time(), tz=timezone.utc
).isoformat(),
state=TaskState.TASK_STATE_WORKING,
message=proto_message,
timestamp=_create_timestamp(),
)

if any(
part.root.metadata.get(
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
)
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and part.root.metadata.get(
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
)
is True
and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME
for part in message.parts
if part.root.metadata
):
status.state = TaskState.auth_required
elif any(
part.root.metadata.get(
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
)
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and part.root.metadata.get(
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
)
is True
for part in message.parts
if part.root.metadata
):
status.state = TaskState.input_required
def _is_long_running(part: A2APart) -> bool:
val = _get_part_metadata_value(
part,
_get_adk_metadata_key(A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
)
return str(val).lower() == "true" or val is True

for part in message.parts:
part_type = (
_get_part_metadata_value(
part, _get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
)
or _get_part_metadata_value(part, A2A_DATA_PART_METADATA_TYPE_KEY)
or _get_part_metadata_value(part, "adk_type")
)
if (
part_type == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and _is_long_running(part)
):
data_dict = _get_part_data_dict(part)
if data_dict.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME:
status.state = TaskState.TASK_STATE_AUTH_REQUIRED
break
status.state = TaskState.TASK_STATE_INPUT_REQUIRED
break

return TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=status,
metadata=_get_context_metadata(event, invocation_context),
final=False,
)


Expand Down
Loading