diff --git a/tests/server/a2a/converters/test_event_converter.py b/tests/server/a2a/converters/test_event_converter.py index a15b38b..571d2e3 100644 --- a/tests/server/a2a/converters/test_event_converter.py +++ b/tests/server/a2a/converters/test_event_converter.py @@ -301,9 +301,10 @@ def test_includes_optional_fields(self): class TestBuildMessageMetadata: def test_includes_object_type_and_tag(self): event = _make_event(text="hi", object_type="chat.completion", tag="my_tag") - meta = _build_message_metadata(event) + meta = _build_message_metadata(event, "eff-1") assert meta[MESSAGE_METADATA_OBJECT_TYPE_KEY] == "chat.completion" assert meta[MESSAGE_METADATA_TAG_KEY] == "my_tag" + assert meta[MESSAGE_METADATA_RESPONSE_ID_KEY] == "eff-1" # --------------------------------------------------------------------------- @@ -337,12 +338,12 @@ def test_does_nothing_without_long_running_ids(self): class TestBuildMessage: def test_returns_none_for_empty_parts(self): event = _make_event(text="hi") - assert _build_message(event, [], Role.agent) is None + assert _build_message(event, [], Role.agent, "e1") is None def test_returns_message_with_parts(self): event = _make_event(text="hi", response_id="resp-1") parts = [A2APart(root=TextPart(text="hi"))] - msg = _build_message(event, parts, Role.agent) + msg = _build_message(event, parts, Role.agent, "resp-1") assert msg is not None assert msg.role == Role.agent assert msg.message_id == "resp-1" @@ -636,7 +637,7 @@ def test_basic_working(self): ) event = _make_event(text="hi") ctx = _make_invocation_context() - result = _create_status_update_event(msg, ctx, event, "t1", "ctx1") + result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1") assert result.status.state == TaskState.working def test_auth_required_for_euc(self): @@ -650,7 +651,7 @@ def test_auth_required_for_euc(self): msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=dp)]) event = _make_event(function_call=FunctionCall(name=REQUEST_EUC_FUNCTION_CALL_NAME, args={})) ctx = _make_invocation_context() - result = _create_status_update_event(msg, ctx, event, "t1", "ctx1") + result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1") assert result.status.state == TaskState.auth_required def test_input_required_for_long_running(self): @@ -664,7 +665,7 @@ def test_input_required_for_long_running(self): msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=dp)]) event = _make_event(function_call=FunctionCall(name="other_tool", args={})) ctx = _make_invocation_context() - result = _create_status_update_event(msg, ctx, event, "t1", "ctx1") + result = _create_status_update_event(msg, ctx, event, "t1", "ctx1", effective_id="m1") assert result.status.state == TaskState.input_required @@ -680,15 +681,19 @@ def test_basic(self): ) event = _make_event(text="hi", response_id="resp-1") ctx = _make_invocation_context() - result = _create_artifact_update_event(msg, event, ctx, task_id="t1", context_id="ctx1") - assert result.artifact.artifact_id == "resp-1" + result = _create_artifact_update_event( + msg, event, ctx, task_id="t1", context_id="ctx1", effective_id="m1" + ) + assert result.artifact.artifact_id == "m1" assert result.last_chunk is False def test_last_chunk(self): msg = Message(message_id="m1", role=Role.agent, parts=[A2APart(root=TextPart(text="hi"))]) event = _make_event(text="hi") ctx = _make_invocation_context() - result = _create_artifact_update_event(msg, event, ctx, task_id="t1", context_id="ctx1", last_chunk=True) + result = _create_artifact_update_event( + msg, event, ctx, task_id="t1", context_id="ctx1", last_chunk=True, effective_id="m1" + ) assert result.last_chunk is True assert result.artifact.artifact_id == "" assert result.artifact.parts == [] diff --git a/trpc_agent_sdk/server/a2a/converters/_event_converter.py b/trpc_agent_sdk/server/a2a/converters/_event_converter.py index 1c45de3..c0b4c06 100644 --- a/trpc_agent_sdk/server/a2a/converters/_event_converter.py +++ b/trpc_agent_sdk/server/a2a/converters/_event_converter.py @@ -206,24 +206,25 @@ def _build_context_metadata(event: Event, ctx: InvocationContext) -> Dict[str, A return metadata -def _build_message_metadata(event: Event) -> Dict[str, Any]: +def _build_message_metadata(event: Event, effective_id: str) -> Dict[str, Any]: """Build message/event metadata (object_type, tag, llm_response_id).""" return { MESSAGE_METADATA_OBJECT_TYPE_KEY: _infer_message_object_type(event) or "", MESSAGE_METADATA_TAG_KEY: _infer_message_tag(event), - MESSAGE_METADATA_RESPONSE_ID_KEY: event.response_id or "", + MESSAGE_METADATA_RESPONSE_ID_KEY: effective_id, } -def _build_event_metadata(event: Event, message: Message, ctx: InvocationContext) -> Dict[str, Any]: +def _build_event_metadata(event: Event, message: Message, ctx: InvocationContext, effective_id: str) -> Dict[str, Any]: metadata = _build_context_metadata(event, ctx) - msg_meta = _build_message_metadata(event) + msg_meta = _build_message_metadata(event, effective_id) set_metadata(metadata, MESSAGE_METADATA_OBJECT_TYPE_KEY, msg_meta.get(MESSAGE_METADATA_OBJECT_TYPE_KEY) or "") set_metadata(metadata, MESSAGE_METADATA_TAG_KEY, msg_meta.get(MESSAGE_METADATA_TAG_KEY) or "") set_metadata(metadata, MESSAGE_METADATA_RESPONSE_ID_KEY, msg_meta.get(MESSAGE_METADATA_RESPONSE_ID_KEY) or "") + streaming_delta = A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA if any( - get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == - A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA for p in message.parts if p.root.metadata): + get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == streaming_delta for p in message.parts + if p.root.metadata): set_metadata(metadata, "streaming_tool_call", "true") return metadata @@ -234,27 +235,41 @@ def _mark_long_running_tools(a2a_parts: List[A2APart], event: Event) -> None: return for a2a_part in a2a_parts: root = a2a_part.root - if (isinstance(root, DataPart) and root.metadata and get_metadata( - root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and root.data.get("id") in event.long_running_tool_ids): - set_metadata(root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, True) + if not isinstance(root, DataPart) or not root.metadata: + continue + if get_metadata(root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) != A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL: + continue + if root.data.get("id") not in event.long_running_tool_ids: + continue + set_metadata(root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY, True) + + +def _effective_response_id(event: Event) -> str: + """Return ``response_id`` when present, otherwise a new UUID. + + Callers that need the same id across multiple locations should invoke this + once and pass the result explicitly. + """ + return event.response_id or str(uuid.uuid4()) -def _build_message(event: Event, a2a_parts: List[A2APart], role: Role) -> Optional[Message]: +def _build_message(event: Event, a2a_parts: List[A2APart], role: Role, effective_id: str) -> Optional[Message]: """Assemble an A2A Message from converted parts, or return None if empty.""" if not a2a_parts: return None - message_id = event.response_id or str(uuid.uuid4()) - message = Message(message_id=message_id, role=role, parts=a2a_parts) - msg_meta = _build_message_metadata(event) + message = Message(message_id=effective_id, role=role, parts=a2a_parts) + msg_meta = _build_message_metadata(event, effective_id) if msg_meta: message.metadata = msg_meta return message def _is_streaming_delta(a2a_part: A2APart) -> bool: - return (a2a_part.root.metadata is not None and get_metadata(a2a_part.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) - == A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA) + meta = a2a_part.root.metadata + if meta is None: + return False + t = get_metadata(meta, A2A_DATA_PART_METADATA_TYPE_KEY) + return t == A2A_DATA_PART_METADATA_TYPE_STREAMING_FUNCTION_CALL_DELTA def _collect_parts( @@ -320,7 +335,8 @@ def convert_event_to_a2a_message( return None a2a_parts = _collect_parts(event, **rules) - return _build_message(event, a2a_parts, role) + effective_id = _effective_response_id(event) + return _build_message(event, a2a_parts, role, effective_id) def convert_content_to_a2a_message( @@ -425,8 +441,8 @@ def convert_a2a_message_to_event( if gpart is None: logger.warning("Failed to convert A2A part, skipping: %s", a2a_part) continue - if (metadata_is_true(a2a_part.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - and gpart.function_call): + is_lr = metadata_is_true(a2a_part.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) + if is_lr and gpart.function_call: long_running_tool_ids.add(gpart.function_call.id) parts.append(gpart) except Exception as ex: # pylint: disable=broad-except @@ -436,8 +452,8 @@ def convert_a2a_message_to_event( if not parts: logger.warning("No parts could be converted from A2A message %s", a2a_message) - object_type = (get_metadata(msg_meta, MESSAGE_METADATA_OBJECT_TYPE_KEY) - or _infer_a2a_message_object_type(parts, partial=partial) or _default_object_type(partial)) + ot = get_metadata(msg_meta, MESSAGE_METADATA_OBJECT_TYPE_KEY) + object_type = ot or _infer_a2a_message_object_type(parts, partial=partial) or _default_object_type(partial) return Event( invocation_id=inv_id, @@ -590,31 +606,51 @@ def _create_error_status_event( ) +def _a2a_part_requests_euc_auth(part: A2APart) -> bool: + root = part.root + md = root.metadata + if not md: + return False + t = get_metadata(md, A2A_DATA_PART_METADATA_TYPE_KEY) + return all([ + t == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + metadata_is_true(md, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY), + root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME, + ]) + + +def _a2a_part_is_long_running_function_call(part: A2APart) -> bool: + root = part.root + md = root.metadata + if not md: + return False + t = get_metadata(md, A2A_DATA_PART_METADATA_TYPE_KEY) + return all([ + t == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL, + metadata_is_true(md, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY), + ]) + + def _create_status_update_event( message: Message, ctx: InvocationContext, event: Event, task_id: Optional[str], context_id: Optional[str], + effective_id: str = "", ) -> TaskStatusUpdateEvent: status = TaskStatus(state=TaskState.working, message=message, timestamp=_now_iso()) - if any( - get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and metadata_is_true(p.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) - and p.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME for p in message.parts if p.root.metadata): + if any(_a2a_part_requests_euc_auth(p) for p in message.parts): status.state = TaskState.auth_required - elif any( - get_metadata(p.root.metadata, A2A_DATA_PART_METADATA_TYPE_KEY) == A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - and metadata_is_true(p.root.metadata, A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY) for p in message.parts - if p.root.metadata): + elif any(_a2a_part_is_long_running_function_call(p) for p in message.parts): status.state = TaskState.input_required return TaskStatusUpdateEvent( task_id=task_id, context_id=context_id, status=status, - metadata=_build_event_metadata(event, message, ctx), + metadata=_build_event_metadata(event, message, ctx, effective_id), final=False, ) @@ -626,8 +662,9 @@ def _create_artifact_update_event( task_id: Optional[str] = None, context_id: Optional[str] = None, last_chunk: bool = False, + effective_id: str = "", ) -> TaskArtifactUpdateEvent: - artifact_id = "" if last_chunk else (event.response_id or "") + artifact_id = "" if last_chunk else effective_id return TaskArtifactUpdateEvent( task_id=task_id, context_id=context_id, @@ -636,7 +673,7 @@ def _create_artifact_update_event( parts=[] if last_chunk else message.parts, ), last_chunk=last_chunk, - metadata=_build_event_metadata(event, message, ctx), + metadata=_build_event_metadata(event, message, ctx, effective_id), ) @@ -674,12 +711,14 @@ def _notify(evt: A2AEvent) -> None: message = convert_event_to_a2a_message(event, invocation_context) if message: + effective_id = message.message_id status_event = _create_status_update_event( message, invocation_context, event, task_id, context_id, + effective_id=effective_id, ) _notify(status_event) @@ -691,6 +730,7 @@ def _notify(evt: A2AEvent) -> None: task_id=task_id, context_id=context_id, last_chunk=False, + effective_id=effective_id, ) a2a_events.append(artifact_event)