diff --git a/aai_cli/code_agent/model.py b/aai_cli/code_agent/model.py index 716af2f..7f87f6b 100644 --- a/aai_cli/code_agent/model.py +++ b/aai_cli/code_agent/model.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel + from langchain_core.outputs import ChatGenerationChunk def _flatten_content(messages: object) -> None: @@ -38,6 +39,53 @@ def _flatten_content(messages: object) -> None: ) +def _hoist_tool_call_ids(chunk: object) -> None: + """Move each streamed tool-call ``id`` from inside ``function`` up to the tool-call top level. + + The AssemblyAI LLM Gateway's *streaming* ``/v1/chat/completions`` nests the tool-call + ``id`` under ``function`` — ``{"function": {"id": …, "name": …}}`` — instead of at the + tool-call's top level, which is where the OpenAI streaming spec (and + ``langchain_openai``, via ``id=rtc.get("id")``) reads it. Left alone, every streamed + tool call parses with a name and arguments but ``id=None``, so the reply ``ToolMessage`` + fails Pydantic validation (``tool_call_id`` must be a string) and the whole turn errors + out. We move the id back up before langchain converts the chunk; the id rides only the + first delta of a call, so later argument-only deltas (no ``function.id``) are left + untouched. (The non-streaming endpoint already places the id correctly, so only the + streaming path needs this.) + """ + if not isinstance(chunk, dict): + return + choices = chunk.get("choices") + if isinstance(choices, list): + for choice in choices: + _hoist_in_choice(choice) + + +def _hoist_in_choice(choice: object) -> None: + """Hoist tool-call ids within one streamed choice's delta (helper for ``_hoist_tool_call_ids``).""" + delta = choice.get("delta") if isinstance(choice, dict) else None + tool_calls = delta.get("tool_calls") if isinstance(delta, dict) else None + if isinstance(tool_calls, list): + _hoist_call_list(tool_calls) + + +def _hoist_call_list(tool_calls: list[object]) -> None: + """Hoist a misplaced ``function.id`` to the tool-call top level for each call in the list. + + Helper for :func:`_hoist_tool_call_ids` — split out so the per-chunk traversal stays + under the complexity bar. A call is rewritten only when it carries an ``id`` nested + under ``function`` (the gateway's misplaced first-delta shape). This stays idempotent + once the gateway is fixed: a correct delta puts the id at the top level and leaves no + ``function.id``, so the move never fires. + """ + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + function = tool_call.get("function") + if isinstance(function, dict) and function.get("id") is not None: + tool_call["id"] = function.pop("id") + + def build_model( api_key: str, *, @@ -51,7 +99,8 @@ def build_model( implements (the same one `aai_cli.core.llm` uses), rather than the OpenAI Responses API that langchain would otherwise prefer for ``openai:`` models. The subclass also flattens content-parts arrays the gateway rejects (see - :func:`_flatten_content`). + :func:`_flatten_content`) and repairs misplaced streamed tool-call ids (see + :func:`_hoist_tool_call_ids`). ``max_tokens`` caps the per-reply length (the live voice agent passes a small cap to keep spoken replies short and fast); ``extra`` passes any additional gateway request @@ -63,7 +112,14 @@ def build_model( from pydantic import SecretStr class _GatewayChatOpenAI(ChatOpenAI): - """ChatOpenAI that rewrites list-content messages to plain strings for the gateway.""" + """ChatOpenAI that adapts the gateway's OpenAI-incompatible quirks for langchain. + + Two fix-ups, each working around a gateway response/request bug the upstream client + doesn't expect: flatten list-content messages the gateway 500s on (request side, see + :func:`_flatten_content`), and hoist each streamed tool-call ``id`` back to the + tool-call top level where langchain reads it (response side, see + :func:`_hoist_tool_call_ids`). + """ def _get_request_payload( self, input_: object, *, stop: list[str] | None = None, **kwargs: object @@ -72,6 +128,17 @@ def _get_request_payload( _flatten_content(payload.get("messages")) return payload + def _convert_chunk_to_generation_chunk( + self, + chunk: dict, + default_chunk_class: type, + base_generation_info: dict | None, + ) -> ChatGenerationChunk | None: + _hoist_tool_call_ids(chunk) + return super()._convert_chunk_to_generation_chunk( + chunk, default_chunk_class, base_generation_info + ) + return _GatewayChatOpenAI( model=model, base_url=environments.active().llm_gateway_base, diff --git a/tests/test_code_agent.py b/tests/test_code_agent.py index 0e5d17c..05d9fcb 100644 --- a/tests/test_code_agent.py +++ b/tests/test_code_agent.py @@ -282,6 +282,70 @@ def test_flatten_content_guards() -> None: assert items == ["raw", 123] +def test_hoist_tool_call_ids_moves_id_out_of_function_only_when_missing() -> None: + # One chunk exercising every branch: each malformed variant is skipped, and only a + # tool call carrying a function-nested id gets hoisted. Hold references to the inner + # dicts so the in-place mutation is asserted with a clean type. + noid_fn: dict[str, object] = {"name": "b"} + hoist_fn: dict[str, object] = {"id": "HOIST", "name": "c", "arguments": ""} + noid_call: dict[str, object] = {"index": 1, "function": noid_fn} + hoist_call: dict[str, object] = {"index": 2, "function": hoist_fn} + tool_calls: list[object] = [ + None, # tool_call not a dict -> skipped + {"index": 0, "function": 7}, # function not a dict -> skipped + noid_call, # function has no id -> nothing to hoist + hoist_call, # the real gateway shape -> id hoisted out of function + ] + chunk: dict[str, object] = { + "choices": [ + None, # choice not a dict -> skipped + {"delta": None}, # delta not a dict -> skipped + {"delta": {"content": "hi"}}, # no tool_calls -> skipped + {"delta": {"tool_calls": 99}}, # tool_calls not a list -> skipped + {"delta": {"tool_calls": tool_calls}}, + ] + } + model_mod._hoist_tool_call_ids(chunk) + assert "id" not in noid_call # no id invented for a call that never had one + assert noid_fn == {"name": "b"} # left untouched + assert hoist_call["id"] == "HOIST" # hoisted to the top level where langchain reads it + assert "id" not in hoist_fn # and removed from function so it isn't duplicated + + +def test_hoist_tool_call_ids_guards() -> None: + model_mod._hoist_tool_call_ids(None) # not a dict -> early return, no error + model_mod._hoist_tool_call_ids({"choices": 99}) # choices not a list -> early return + + +def test_convert_chunk_hoists_streamed_tool_call_id() -> None: + from langchain_core.messages import AIMessageChunk + from langchain_openai import ChatOpenAI + + m = model_mod.build_model("sk-test", model="claude-sonnet-4-6") + assert isinstance(m, ChatOpenAI) # narrow to the subclass that overrides the converter + # The gateway streams the tool-call id nested inside `function`; the override must hoist + # it so langchain's converted chunk carries the id (else the reply ToolMessage gets None). + chunk = { + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [ + {"index": 0, "function": {"id": "toolu_X", "name": "get_weather"}} + ], + }, + "finish_reason": None, + } + ] + } + gen = m._convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None) + assert gen is not None + msg = gen.message + assert isinstance(msg, AIMessageChunk) + assert msg.tool_call_chunks[0]["id"] == "toolu_X" + + def test_fetch_url_fetches_and_truncates(monkeypatch: pytest.MonkeyPatch) -> None: import httpx