diff --git a/aai_cli/core/llm.py b/aai_cli/core/llm.py index 19b5747b..a10a8052 100644 --- a/aai_cli/core/llm.py +++ b/aai_cli/core/llm.py @@ -2,14 +2,14 @@ import json from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from aai_cli.core import environments from aai_cli.core.errors import APIError, UsageError if TYPE_CHECKING: from openai import OpenAI - from openai.types.chat import ChatCompletion + from openai.types.chat import ChatCompletion, ChatCompletionMessageParam # The LLM Gateway is OpenAI-compatible, so we talk to it through the OpenAI SDK # pointed at the active environment's gateway base (see _client / code_gen). @@ -76,7 +76,7 @@ def build_messages( system: str | None = None, transcript_id: str | None = None, transcript_text: str | None = None, -) -> list[dict[str, str]]: +) -> list[ChatCompletionMessageParam]: """Assemble the chat `messages` array for a transcript transform or plain prompt. With a `transcript_id`, the gateway injects the transcript server-side, so we @@ -88,7 +88,7 @@ def build_messages( content = f"{prompt}\n\nTranscript:\n{transcript_text}" else: content = prompt - messages: list[dict[str, str]] = [] + messages: list[ChatCompletionMessageParam] = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": content}) @@ -130,7 +130,7 @@ def complete( api_key: str, *, model: str, - messages: list[dict[str, str]], + messages: list[ChatCompletionMessageParam], max_tokens: int = DEFAULT_MAX_TOKENS, transcript_id: str | None = None, extra: dict[str, object] | None = None, @@ -153,7 +153,7 @@ def complete( try: return client.chat.completions.create( model=model, - messages=messages, # type: ignore[arg-type] + messages=messages, max_tokens=max_tokens, extra_body=extra_body or None, ) @@ -182,7 +182,7 @@ def content_of(response: ChatCompletion) -> str: return content or "" -def usage_of(response: ChatCompletion) -> dict[str, Any] | None: +def usage_of(response: ChatCompletion) -> dict[str, object] | None: """Return the token-usage block as a plain dict, if present.""" usage = response.usage if usage is None: diff --git a/tests/test_llm.py b/tests/test_llm.py index e86eab3c..8316c7ca 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -164,7 +164,7 @@ def test_build_messages_transcript_id_uses_tag(): def test_build_messages_inline_text(): msgs = llm.build_messages("summarize", transcript_text="hello world") - assert msgs[0]["content"] == "summarize\n\nTranscript:\nhello world" + assert msgs == [{"role": "user", "content": "summarize\n\nTranscript:\nhello world"}] def test_build_messages_with_system_prompt():