Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions tests/server/a2a/converters/test_event_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,10 @@ def test_includes_optional_fields(self):
class TestBuildMessageMetadata:
def test_includes_object_type_and_tag(self):
event = _make_event(text="hi", object_type="chat.completion", tag="my_tag")
meta = _build_message_metadata(event)
meta = _build_message_metadata(event, "eff-1")
assert meta[MESSAGE_METADATA_OBJECT_TYPE_KEY] == "chat.completion"
assert meta[MESSAGE_METADATA_TAG_KEY] == "my_tag"
assert meta[MESSAGE_METADATA_RESPONSE_ID_KEY] == "eff-1"


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -337,12 +338,12 @@ def test_does_nothing_without_long_running_ids(self):
class TestBuildMessage:
def test_returns_none_for_empty_parts(self):
event = _make_event(text="hi")
assert _build_message(event, [], Role.agent) is None
assert _build_message(event, [], Role.agent, "e1") is None

def test_returns_message_with_parts(self):
event = _make_event(text="hi", response_id="resp-1")
parts = [A2APart(root=TextPart(text="hi"))]
msg = _build_message(event, parts, Role.agent)
msg = _build_message(event, parts, Role.agent, "resp-1")
assert msg is not None
assert msg.role == Role.agent
assert msg.message_id == "resp-1"
Expand Down Expand Up @@ -636,7 +637,7 @@ def test_basic_working(self):
)
event = _make_event(text="hi")
ctx = _make_invocation_context()
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1")
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1")
assert result.status.state == TaskState.working

def test_auth_required_for_euc(self):
Expand All @@ -650,7 +651,7 @@ def test_auth_required_for_euc(self):
msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=dp)])
event = _make_event(function_call=FunctionCall(name=REQUEST_EUC_FUNCTION_CALL_NAME, args={}))
ctx = _make_invocation_context()
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1")
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1")
assert result.status.state == TaskState.auth_required

def test_input_required_for_long_running(self):
Expand All @@ -664,7 +665,7 @@ def test_input_required_for_long_running(self):
msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=dp)])
event = _make_event(function_call=FunctionCall(name="other_tool", args={}))
ctx = _make_invocation_context()
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1")
result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1")
assert result.status.state == TaskState.input_required


Expand All @@ -680,15 +681,19 @@ def test_basic(self):
)
event = _make_event(text="hi", response_id="resp-1")
ctx = _make_invocation_context()
result = _create_artifact_update_event(msg, event, ctx, task_id="t1", context_id="ctx1")
assert result.artifact.artifact_id == "resp-1"
result = _create_artifact_update_event(
msg, event, ctx, task_id="t1", context_id="ctx1", effective_id="m1"
)
assert result.artifact.artifact_id == "m1"
assert result.last_chunk is False

def test_last_chunk(self):
msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=TextPart(text="hi"))])
event = _make_event(text="hi")
ctx = _make_invocation_context()
result = _create_artifact_update_event(msg, event, ctx, task_id="t1", context_id="ctx1", last_chunk=True)
result = _create_artifact_update_event(
msg, event, ctx, task_id="t1", context_id="ctx1", last_chunk=True, effective_id="m1"
)
assert result.last_chunk is True
assert result.artifact.artifact_id == ""
assert result.artifact.parts == []
Expand Down
104 changes: 72 additions & 32 deletions trpc_agent_sdk/server/a2a/converters/_event_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,24 +206,25 @@ def _build_context_metadata(event: Event, ctx: InvocationContext) -> Dict[str, A
return metadata


def _build_message_metadata(event: Event) -> Dict[str, Any]:
def _build_message_metadata(event: Event, effective_id: str) -> Dict[str, Any]:
"""Build message/event metadata (object_type, tag, llm_response_id)."""
return {
MESSAGE_METADATA_OBJECT_TYPE_KEY: _infer_message_object_type(event) or "",
MESSAGE_METADATA_TAG_KEY: _infer_message_tag(event),
MESSAGE_METADATA_RESPONSE_ID_KEY: event.response_id or "",
MESSAGE_METADATA_RESPONSE_ID_KEY: effective_id,
}


def _build_event_metadata(event: Event, message: Message, ctx: InvocationContext) -> Dict[str, Any]:
def _build_event_metadata(event: Event, message: Message, ctx: InvocationContext, effective_id: str) -> Dict[str, Any]:
metadata = _build_context_metadata(event, ctx)
msg_meta = _build_message_metadata(event)
msg_meta = _build_message_metadata(event, effective_id)
set_metadata(metadata, MESSAGE_METADATA_OBJECT_TYPE_KEY, msg_meta.get(MESSAGE_METADATA_OBJECT_TYPE_KEY) or "")
set_metadata(metadata, MESSAGE_METADATA_TAG_KEY, msg_meta.get(MESSAGE_METADATA_TAG_KEY) or "")
set_metadata(metadata, MESSAGE_METADATA_RESPONSE_ID_KEY, msg_meta.get(MESSAGE_METADATA_RESPONSE_ID_KEY) or "")
streaming_delta = A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA
if any(
get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) ==
A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA for p in message.parts if p.root.metadata):
get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == streaming_delta for p in message.parts
if p.root.metadata):
set_metadata(metadata, "streaming_tool_call", "true")
return metadata

Expand All @@ -234,27 +235,41 @@ def _mark_long_running_tools(a2a_parts: List[A2APart], event: Event) -> None:
return
for a2a_part in a2a_parts:
root = a2a_part.root
if (isinstance(root, DataPart) and root.metadata and get_metadata(
root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and root.data.get("id") in event.long_running_tool_ids):
set_metadata(root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, True)
if not isinstance(root, DataPart) or not root.metadata:
continue
if get_metadata(root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) != A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL:
continue
if root.data.get("id") not in event.long_running_tool_ids:
continue
set_metadata(root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, True)


def _effective_response_id(event: Event) -> str:
"""Return ``response_id`` when present, otherwise a new UUID.

Callers that need the same id across multiple locations should invoke this
once and pass the result explicitly.
"""
return event.response_id or str(uuid.uuid4())


def _build_message(event: Event, a2a_parts: List[A2APart], role: Role) -> Optional[Message]:
def _build_message(event: Event, a2a_parts: List[A2APart], role: Role, effective_id: str) -> Optional[Message]:
"""Assemble an A2A Message from converted parts, or return None if empty."""
if not a2a_parts:
return None
message_id = event.response_id or str(uuid.uuid4())
message = Message(message_id=message_id, role=role, parts=a2a_parts)
msg_meta = _build_message_metadata(event)
message = Message(message_id=effective_id, role=role, parts=a2a_parts)
msg_meta = _build_message_metadata(event, effective_id)
if msg_meta:
message.metadata = msg_meta
return message


def _is_streaming_delta(a2a_part: A2APart) -> bool:
return (a2a_part.root.metadata is not None and get_metadata(a2a_part.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY)
== A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA)
meta = a2a_part.root.metadata
if meta is None:
return False
t = get_metadata(meta, A2A_DATA_PART_METADATA_TYPE_KEY)
return t == A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA


def _collect_parts(
Expand Down Expand Up @@ -320,7 +335,8 @@ def convert_event_to_a2a_message(
return None

a2a_parts = _collect_parts(event, **rules)
return _build_message(event, a2a_parts, role)
effective_id = _effective_response_id(event)
return _build_message(event, a2a_parts, role, effective_id)


def convert_content_to_a2a_message(
Expand Down Expand Up @@ -425,8 +441,8 @@ def convert_a2a_message_to_event(
if gpart is None:
logger.warning("Failed to convert A2A part, skipping: %s", a2a_part)
continue
if (metadata_is_true(a2a_part.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
and gpart.function_call):
is_lr = metadata_is_true(a2a_part.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
if is_lr and gpart.function_call:
long_running_tool_ids.add(gpart.function_call.id)
parts.append(gpart)
except Exception as ex: # pylint: disable=broad-except
Expand All @@ -436,8 +452,8 @@ def convert_a2a_message_to_event(
if not parts:
logger.warning("No parts could be converted from A2A message %s", a2a_message)

object_type = (get_metadata(msg_meta, MESSAGE_METADATA_OBJECT_TYPE_KEY)
or _infer_a2a_message_object_type(parts, partial=partial) or _default_object_type(partial))
ot = get_metadata(msg_meta, MESSAGE_METADATA_OBJECT_TYPE_KEY)
object_type = ot or _infer_a2a_message_object_type(parts, partial=partial) or _default_object_type(partial)

return Event(
invocation_id=inv_id,
Expand Down Expand Up @@ -590,31 +606,51 @@ def _create_error_status_event(
)


def _a2a_part_requests_euc_auth(part: A2APart) -> bool:
root = part.root
md = root.metadata
if not md:
return False
t = get_metadata(md, A2A_DATA_PART_METADATA_TYPE_KEY)
return all([
t == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL,
metadata_is_true(md, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME,
])


def _a2a_part_is_long_running_function_call(part: A2APart) -> bool:
root = part.root
md = root.metadata
if not md:
return False
t = get_metadata(md, A2A_DATA_PART_METADATA_TYPE_KEY)
return all([
t == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL,
metadata_is_true(md, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY),
])


def _create_status_update_event(
message: Message,
ctx: InvocationContext,
event: Event,
task_id: Optional[str],
context_id: Optional[str],
effective_id: str = "",
) -> TaskStatusUpdateEvent:
status = TaskStatus(state=TaskState.working, message=message, timestamp=_now_iso())

if any(
get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and metadata_is_true(p.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY)
and p.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME for p in message.parts if p.root.metadata):
if any(_a2a_part_requests_euc_auth(p) for p in message.parts):
status.state = TaskState.auth_required
elif any(
get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
and metadata_is_true(p.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) for p in message.parts
if p.root.metadata):
elif any(_a2a_part_is_long_running_function_call(p) for p in message.parts):
status.state = TaskState.input_required

return TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=status,
metadata=_build_event_metadata(event, message, ctx),
metadata=_build_event_metadata(event, message, ctx, effective_id),
final=False,
)

Expand All @@ -626,8 +662,9 @@ def _create_artifact_update_event(
task_id: Optional[str] = None,
context_id: Optional[str] = None,
last_chunk: bool = False,
effective_id: str = "",
) -> TaskArtifactUpdateEvent:
artifact_id = "" if last_chunk else (event.response_id or "")
artifact_id = "" if last_chunk else effective_id
return TaskArtifactUpdateEvent(
task_id=task_id,
context_id=context_id,
Expand All @@ -636,7 +673,7 @@ def _create_artifact_update_event(
parts=[] if last_chunk else message.parts,
),
last_chunk=last_chunk,
metadata=_build_event_metadata(event, message, ctx),
metadata=_build_event_metadata(event, message, ctx, effective_id),
)


Expand Down Expand Up @@ -674,12 +711,14 @@ def _notify(evt: A2AEvent) -> None:

message = convert_event_to_a2a_message(event, invocation_context)
if message:
effective_id = message.message_id
status_event = _create_status_update_event(
message,
invocation_context,
event,
task_id,
context_id,
effective_id=effective_id,
)
_notify(status_event)

Expand All @@ -691,6 +730,7 @@ def _notify(evt: A2AEvent) -> None:
task_id=task_id,
context_id=context_id,
last_chunk=False,
effective_id=effective_id,
)
a2a_events.append(artifact_event)

Expand Down
Loading