Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions aai_cli/code_agent/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatGenerationChunk


def _flatten_content(messages: object) -> None:
Expand All @@ -38,6 +39,53 @@ def _flatten_content(messages: object) -> None:
)


def _hoist_tool_call_ids(chunk: object) -> None:
"""Move each streamed tool-call ``id`` from inside ``function`` up to the tool-call top level.

The AssemblyAI LLM Gateway's *streaming* ``/v1/chat/completions`` nests the tool-call
``id`` under ``function`` — ``{"function": {"id": …, "name": …}}`` — instead of at the
tool-call's top level, which is where the OpenAI streaming spec (and
``langchain_openai``, via ``id=rtc.get("id")``) reads it. Left alone, every streamed
tool call parses with a name and arguments but ``id=None``, so the reply ``ToolMessage``
fails Pydantic validation (``tool_call_id`` must be a string) and the whole turn errors
out. We move the id back up before langchain converts the chunk; the id rides only the
first delta of a call, so later argument-only deltas (no ``function.id``) are left
untouched. (The non-streaming endpoint already places the id correctly, so only the
streaming path needs this.)
"""
if not isinstance(chunk, dict):
return
choices = chunk.get("choices")
if isinstance(choices, list):
for choice in choices:
_hoist_in_choice(choice)


def _hoist_in_choice(choice: object) -> None:
"""Hoist tool-call ids within one streamed choice's delta (helper for ``_hoist_tool_call_ids``)."""
delta = choice.get("delta") if isinstance(choice, dict) else None
tool_calls = delta.get("tool_calls") if isinstance(delta, dict) else None
if isinstance(tool_calls, list):
_hoist_call_list(tool_calls)


def _hoist_call_list(tool_calls: list[object]) -> None:
"""Hoist a misplaced ``function.id`` to the tool-call top level for each call in the list.

Helper for :func:`_hoist_tool_call_ids` — split out so the per-chunk traversal stays
under the complexity bar. A call is rewritten only when it carries an ``id`` nested
under ``function`` (the gateway's misplaced first-delta shape). This stays idempotent
once the gateway is fixed: a correct delta puts the id at the top level and leaves no
``function.id``, so the move never fires.
"""
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
function = tool_call.get("function")
if isinstance(function, dict) and function.get("id") is not None:
tool_call["id"] = function.pop("id")


def build_model(
api_key: str,
*,
Expand All @@ -51,7 +99,8 @@ def build_model(
implements (the same one `aai_cli.core.llm` uses), rather than the OpenAI
Responses API that langchain would otherwise prefer for ``openai:`` models. The
subclass also flattens content-parts arrays the gateway rejects (see
:func:`_flatten_content`).
:func:`_flatten_content`) and repairs misplaced streamed tool-call ids (see
:func:`_hoist_tool_call_ids`).

``max_tokens`` caps the per-reply length (the live voice agent passes a small cap to
keep spoken replies short and fast); ``extra`` passes any additional gateway request
Expand All @@ -63,7 +112,14 @@ def build_model(
from pydantic import SecretStr

class _GatewayChatOpenAI(ChatOpenAI):
"""ChatOpenAI that rewrites list-content messages to plain strings for the gateway."""
"""ChatOpenAI that adapts the gateway's OpenAI-incompatible quirks for langchain.

Two fix-ups, each working around a gateway response/request bug the upstream client
doesn't expect: flatten list-content messages the gateway 500s on (request side, see
:func:`_flatten_content`), and hoist each streamed tool-call ``id`` back to the
tool-call top level where langchain reads it (response side, see
:func:`_hoist_tool_call_ids`).
"""

def _get_request_payload(
self, input_: object, *, stop: list[str] | None = None, **kwargs: object
Expand All @@ -72,6 +128,17 @@ def _get_request_payload(
_flatten_content(payload.get("messages"))
return payload

def _convert_chunk_to_generation_chunk(
self,
chunk: dict,
default_chunk_class: type,
base_generation_info: dict | None,
) -> ChatGenerationChunk | None:
_hoist_tool_call_ids(chunk)
return super()._convert_chunk_to_generation_chunk(
chunk, default_chunk_class, base_generation_info
)

return _GatewayChatOpenAI(
model=model,
base_url=environments.active().llm_gateway_base,
Expand Down
64 changes: 64 additions & 0 deletions tests/test_code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,70 @@ def test_flatten_content_guards() -> None:
assert items == ["raw", 123]


def test_hoist_tool_call_ids_moves_id_out_of_function_only_when_missing() -> None:
# One chunk exercising every branch: each malformed variant is skipped, and only a
# tool call carrying a function-nested id gets hoisted. Hold references to the inner
# dicts so the in-place mutation is asserted with a clean type.
noid_fn: dict[str, object] = {"name": "b"}
hoist_fn: dict[str, object] = {"id": "HOIST", "name": "c", "arguments": ""}
noid_call: dict[str, object] = {"index": 1, "function": noid_fn}
hoist_call: dict[str, object] = {"index": 2, "function": hoist_fn}
tool_calls: list[object] = [
None, # tool_call not a dict -> skipped
{"index": 0, "function": 7}, # function not a dict -> skipped
noid_call, # function has no id -> nothing to hoist
hoist_call, # the real gateway shape -> id hoisted out of function
]
chunk: dict[str, object] = {
"choices": [
None, # choice not a dict -> skipped
{"delta": None}, # delta not a dict -> skipped
{"delta": {"content": "hi"}}, # no tool_calls -> skipped
{"delta": {"tool_calls": 99}}, # tool_calls not a list -> skipped
{"delta": {"tool_calls": tool_calls}},
]
}
model_mod._hoist_tool_call_ids(chunk)
assert "id" not in noid_call # no id invented for a call that never had one
assert noid_fn == {"name": "b"} # left untouched
assert hoist_call["id"] == "HOIST" # hoisted to the top level where langchain reads it
assert "id" not in hoist_fn # and removed from function so it isn't duplicated


def test_hoist_tool_call_ids_guards() -> None:
model_mod._hoist_tool_call_ids(None) # not a dict -> early return, no error
model_mod._hoist_tool_call_ids({"choices": 99}) # choices not a list -> early return


def test_convert_chunk_hoists_streamed_tool_call_id() -> None:
from langchain_core.messages import AIMessageChunk
from langchain_openai import ChatOpenAI

m = model_mod.build_model("sk-test", model="claude-sonnet-4-6")
assert isinstance(m, ChatOpenAI) # narrow to the subclass that overrides the converter
# The gateway streams the tool-call id nested inside `function`; the override must hoist
# it so langchain's converted chunk carries the id (else the reply ToolMessage gets None).
chunk = {
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"tool_calls": [
{"index": 0, "function": {"id": "toolu_X", "name": "get_weather"}}
],
},
"finish_reason": None,
}
]
}
gen = m._convert_chunk_to_generation_chunk(chunk, AIMessageChunk, None)
assert gen is not None
msg = gen.message
assert isinstance(msg, AIMessageChunk)
assert msg.tool_call_chunks[0]["id"] == "toolu_X"


def test_fetch_url_fetches_and_truncates(monkeypatch: pytest.MonkeyPatch) -> None:
import httpx

Expand Down
Loading