diff --git a/cognite/client/_api/agents/agents.py b/cognite/client/_api/agents/agents.py index 5884cd8974..197915f7f2 100644 --- a/cognite/client/_api/agents/agents.py +++ b/cognite/client/_api/agents/agents.py @@ -1,10 +1,10 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload from cognite.client._api_client import APIClient -from cognite.client.data_classes.agents import Agent, AgentList, AgentUpsert +from cognite.client.data_classes.agents import Agent, AgentList, AgentSession, AgentUpsert from cognite.client.data_classes.agents.chat import ( Action, ActionResult, @@ -423,3 +423,98 @@ async def chat( url_path=self._RESOURCE_PATH + "/chat", json=body, semaphore=self._get_semaphore("write") ) return AgentChatResponse._load(response.json()) + + def create_session( + self, + agent_external_id: str, + actions: Sequence[Action] | None = None, + cursor: str | None = None, + **kwargs: Any, + ) -> AgentSession: + """Create a stateful agent session for multi-turn conversations. + + The returned :class:`~cognite.client.data_classes.agents.AgentSession` stores + ``agent_external_id``, ``actions``, and the current conversation cursor. Each + call to ``await session.chat(...)`` automatically threads the cursor from the + previous response into the next request, so callers do not have to manage + cursor state manually. + + This method is available only on the async client + (:class:`~cognite.client.AsyncCogniteClient`). Sync users who need multi-turn + conversations should call :meth:`chat` directly and manage the cursor themselves. + + Unknown keyword arguments are silently accepted for forward compatibility with + future parameters. + + Args: + agent_external_id (str): External ID of the agent to chat with. + actions (Sequence[Action] | None): Client-side actions available to the agent + during the conversation. Passed unchanged to every ``chat()`` call. + Defaults to ``None`` (no client actions). + cursor (str | None): Resume an existing conversation from this cursor. + Defaults to ``None`` (fresh conversation). + **kwargs (Any): Reserved for future parameters; silently ignored in v1. + + Returns: + AgentSession: A stateful session bound to the given agent. + + Examples: + + Simple multi-turn conversation: + + >>> from cognite.client import AsyncCogniteClient + >>> from cognite.client.data_classes.agents import Message + >>> async def main(): + ... client = AsyncCogniteClient() + ... session = client.agents.create_session(agent_external_id="my_agent") + ... response = await session.chat(Message("Hello")) + ... print(response.text) + ... response = await session.chat(Message("Tell me more")) + ... print(response.text) + + Resume a prior conversation using a saved cursor: + + >>> async def resume(saved_cursor: str): + ... client = AsyncCogniteClient() + ... session = client.agents.create_session( + ... agent_external_id="my_agent", + ... cursor=saved_cursor, + ... ) + ... response = await session.chat(Message("Continue where we left off")) + + With client-side actions: + + >>> from cognite.client.data_classes.agents import ClientToolAction, ClientToolResult + >>> async def with_actions(): + ... client = AsyncCogniteClient() + ... add = ClientToolAction( + ... name="add", + ... description="Add two numbers", + ... parameters={ + ... "type": "object", + ... "properties": { + ... "a": {"type": "number"}, + ... "b": {"type": "number"}, + ... }, + ... "required": ["a", "b"], + ... }, + ... ) + ... session = client.agents.create_session( + ... agent_external_id="my_agent", + ... actions=[add], + ... ) + ... response = await session.chat(Message("What is 42 + 58?")) + ... if response.action_calls: + ... for call in response.action_calls: + ... result = call.arguments["a"] + call.arguments["b"] + ... response = await session.chat( + ... ClientToolResult(action_id=call.action_id, content=str(result)) + ... ) + """ + self._warnings.warn() + return AgentSession( + agents_api=self, + agent_external_id=agent_external_id, + actions=actions, + cursor=cursor, + ) diff --git a/cognite/client/_sync_api/agents/agents.py b/cognite/client/_sync_api/agents/agents.py index b68bc08754..495cd89329 100644 --- a/cognite/client/_sync_api/agents/agents.py +++ b/cognite/client/_sync_api/agents/agents.py @@ -1,6 +1,6 @@ """ =============================================================================== -178ce7222985b04d03174af7b7e0f525 +47d95a8f86f0bf9f4eac8732b06a85cb This file is auto-generated from the Async API modules, - do not edit manually! =============================================================================== """ diff --git a/cognite/client/data_classes/agents/__init__.py b/cognite/client/data_classes/agents/__init__.py index 9531002d5e..58a5a56bdf 100644 --- a/cognite/client/data_classes/agents/__init__.py +++ b/cognite/client/data_classes/agents/__init__.py @@ -45,6 +45,7 @@ UnknownActionCall, UnknownContent, ) +from cognite.client.data_classes.agents.session import AgentSession __all__ = [ "Action", @@ -57,6 +58,7 @@ "AgentMessage", "AgentMessageList", "AgentReasoningItem", + "AgentSession", "AgentTool", "AgentToolList", "AgentToolUpsert", diff --git a/cognite/client/data_classes/agents/session.py b/cognite/client/data_classes/agents/session.py new file mode 100644 index 0000000000..26caaca79a --- /dev/null +++ b/cognite/client/data_classes/agents/session.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from cognite.client.data_classes.agents.chat import ( + Action, + ActionResult, + AgentChatResponse, + Message, +) + +if TYPE_CHECKING: + from cognite.client._api.agents.agents import AgentsAPI + + +class AgentSession: + """Stateful session for multi-turn conversations with a Cognite agent. + + Created via :meth:`cognite.client._api.agents.agents.AgentsAPI.create_session` + on :class:`~cognite.client.AsyncCogniteClient`. Do not instantiate directly. + + The session automatically threads the conversation cursor across successive + :meth:`chat` calls, so callers do not need to track cursor state themselves. + + Each session is bound to one agent and one conversation; create a new session + via ``create_session()`` to start a new conversation (there is no ``reset()`` + method). The session is not safe for concurrent ``await session.chat(...)`` + calls on the same instance — use separate ``AgentSession`` objects for + parallel conversations. + + Args: + agents_api (AgentsAPI): The async agents API used to make chat calls. + agent_external_id (str): External ID of the agent bound to this session. + actions (Sequence[Action] | None): Client-side actions available to the agent. + cursor (str | None): Initial cursor (``None`` for a fresh conversation, or + an existing cursor to resume a prior conversation). + """ + + def __init__( + self, + agents_api: AgentsAPI, + agent_external_id: str, + actions: Sequence[Action] | None, + cursor: str | None, + ) -> None: + self._agents_api = agents_api + self.agent_external_id = agent_external_id + self.actions = actions + self._cursor = cursor + + @property + def cursor(self) -> str | None: + """The current conversation cursor. + + ``None`` until the first successful :meth:`chat` call sets it. After each + successful response the cursor advances; if a response has no cursor the + previous non-null value is retained. If a chat request fails the cursor + is unchanged so the call can be retried. + """ + return self._cursor + + async def chat( + self, + messages: Message | ActionResult | Sequence[Message | ActionResult], + ) -> AgentChatResponse: + """Send messages to the agent and receive a response. + + The cursor from the previous response is threaded automatically into the + outgoing request. On success the session's cursor advances to the response + cursor (or is retained if the response has no cursor). On failure the + cursor is unchanged. + + Args: + messages (Message | ActionResult | Sequence[Message | ActionResult]): One or + more messages and/or action results. Accepts the same types as + :meth:`cognite.client._api.agents.agents.AgentsAPI.chat`. + + Returns: + AgentChatResponse: The agent's response. + """ + response = await self._agents_api.chat( + agent_external_id=self.agent_external_id, + messages=messages, + cursor=self._cursor, + actions=self.actions, + ) + self._cursor = response.cursor or self._cursor + return response diff --git a/tests/tests_unit/test_api/test_agents.py b/tests/tests_unit/test_api/test_agents.py index f7d83cb8b6..a28044d0ce 100644 --- a/tests/tests_unit/test_api/test_agents.py +++ b/tests/tests_unit/test_api/test_agents.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from collections.abc import Iterator from unittest.mock import MagicMock @@ -7,13 +8,24 @@ from pytest_httpx import HTTPXMock from cognite.client import AsyncCogniteClient, CogniteClient -from cognite.client.data_classes.agents import Agent, AgentList, AgentUpsert +from cognite.client.data_classes.agents import ( + Agent, + AgentList, + AgentSession, + AgentUpsert, + ClientToolAction, + ClientToolResult, + Message, + ToolConfirmationCall, + ToolConfirmationResult, +) from cognite.client.data_classes.agents.agent_tools import ( DataModelInfo, InstanceSpaces, QueryKnowledgeGraphAgentToolConfiguration, QueryKnowledgeGraphAgentToolUpsert, ) +from cognite.client.exceptions import CogniteAPIError from tests.utils import get_url, jsgz_load @@ -247,3 +259,345 @@ def test_retrieve_agent_with_labels( assert isinstance(retrieved_agent, Agent) assert retrieved_agent.external_id == "agent_1" assert retrieved_agent.labels == ["published"] + + +class TestAgentSession: + @pytest.fixture + def chat_url(self, async_client: AsyncCogniteClient) -> str: + return get_url(async_client.agents, async_client.agents._RESOURCE_PATH + "/chat") + + @staticmethod + def _text_response(cursor: str | None = None, text: str = "ok") -> dict: + return { + "agentExternalId": "agent_1", + "response": { + "cursor": cursor, + "type": "result", + "messages": [ + { + "role": "agent", + "content": {"type": "text", "text": text}, + } + ], + }, + } + + @staticmethod + def _client_tool_call_response(cursor: str, action_id: str, name: str = "add") -> dict: + return { + "agentExternalId": "agent_1", + "response": { + "cursor": cursor, + "type": "result", + "messages": [ + { + "role": "agent", + "actions": [ + { + "type": "clientTool", + "actionId": action_id, + "clientTool": { + "name": name, + "arguments": json.dumps({"a": 1, "b": 2}), + }, + } + ], + } + ], + }, + } + + @staticmethod + def _tool_confirmation_response(cursor: str, action_id: str) -> dict: + return { + "agentExternalId": "agent_1", + "response": { + "cursor": cursor, + "type": "result", + "messages": [ + { + "role": "agent", + "actions": [ + { + "type": "toolConfirmation", + "actionId": action_id, + "toolConfirmation": { + "content": {"type": "text", "text": "Run this tool?"}, + "toolName": "run_func", + "toolArguments": {"x": 1}, + "toolDescription": "Runs a function", + "toolType": "callFunction", + }, + } + ], + } + ], + }, + } + + # T008 [US2] + def test_create_session_stores_config(self, async_client: AsyncCogniteClient) -> None: + add = ClientToolAction( + name="add", + description="Add", + parameters={"type": "object"}, + ) + session = async_client.agents.create_session( + agent_external_id="my_agent", + actions=[add], + cursor="resume_cursor", + ) + assert isinstance(session, AgentSession) + assert session.agent_external_id == "my_agent" + assert session.cursor == "resume_cursor" + assert session.actions == [add] + + # T008 (companion): defaults when actions/cursor omitted + def test_create_session_defaults(self, async_client: AsyncCogniteClient) -> None: + session = async_client.agents.create_session(agent_external_id="x") + assert session.agent_external_id == "x" + assert session.cursor is None + assert session.actions is None + + # T011 [US2] + def test_create_session_ignores_unknown_kwargs(self, async_client: AsyncCogniteClient) -> None: + session = async_client.agents.create_session( + agent_external_id="x", + bogus=True, + dangerously_skip_user_confirmation=True, + ) + assert isinstance(session, AgentSession) + assert session.agent_external_id == "x" + + # T015 + def test_cursor_is_read_only(self, async_client: AsyncCogniteClient) -> None: + session = async_client.agents.create_session(agent_external_id="x") + with pytest.raises(AttributeError): + session.cursor = "new" # type: ignore[misc] + + # T016 + def test_sync_client_does_not_have_create_session(self, cognite_client: CogniteClient) -> None: + assert not hasattr(cognite_client.agents, "create_session") + + # T020 + def test_session_has_no_reset_method(self, async_client: AsyncCogniteClient) -> None: + session = async_client.agents.create_session(agent_external_id="x") + assert not hasattr(session, "reset") + + # T005 [US1] + async def test_chat_threads_cursor( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c1", text="r1") + ) + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c2", text="r2") + ) + + session = async_client.agents.create_session(agent_external_id="agent_1") + + r1 = await session.chat(Message("hi")) + assert r1.text == "r1" + assert session.cursor == "c1" + + r2 = await session.chat(Message("more")) + assert r2.text == "r2" + assert session.cursor == "c2" + + requests = httpx_mock.get_requests() + assert len(requests) == 2 + body_1 = jsgz_load(requests[0].content) + body_2 = jsgz_load(requests[1].content) + assert "cursor" not in body_1 + assert body_2["cursor"] == "c1" + + # T009 [US2] + async def test_first_request_no_cursor_when_cursor_none( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c1")) + session = async_client.agents.create_session(agent_external_id="agent_1", cursor=None) + await session.chat(Message("hi")) + body = jsgz_load(httpx_mock.get_requests()[0].content) + assert "cursor" not in body + + # T010 [US2] + async def test_first_request_uses_resume_cursor( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c2")) + session = async_client.agents.create_session(agent_external_id="agent_1", cursor="resume_me") + await session.chat(Message("continue")) + body = jsgz_load(httpx_mock.get_requests()[0].content) + assert body["cursor"] == "resume_me" + + # T007 [US1] + async def test_null_response_cursor_retains_prior( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c1")) + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor=None)) + session = async_client.agents.create_session(agent_external_id="agent_1") + await session.chat(Message("one")) + assert session.cursor == "c1" + await session.chat(Message("two")) + assert session.cursor == "c1" + + # T006 [US1] + async def test_cursor_not_advanced_on_failure( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="good_cursor") + ) + httpx_mock.add_response(method="POST", url=chat_url, status_code=500, json={"error": {"message": "boom"}}) + + session = async_client.agents.create_session(agent_external_id="agent_1") + await session.chat(Message("hi")) + assert session.cursor == "good_cursor" + + with pytest.raises(CogniteAPIError): + await session.chat(Message("boom")) + + assert session.cursor == "good_cursor" + + # T012 [US2] + async def test_actions_forwarded_to_every_chat( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c1")) + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c2")) + + add = ClientToolAction( + name="add", + description="Add", + parameters={ + "type": "object", + "properties": {"a": {"type": "number"}, "b": {"type": "number"}}, + "required": ["a", "b"], + }, + ) + session = async_client.agents.create_session(agent_external_id="agent_1", actions=[add]) + + await session.chat(Message("one")) + await session.chat(Message("two")) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + expected_actions = [add.dump(camel_case=True)] + for req in reqs: + body = jsgz_load(req.content) + assert body["actions"] == expected_actions + + # T013 [US3] + async def test_action_result_threads_cursor( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response( + method="POST", + url=chat_url, + status_code=200, + json=self._client_tool_call_response(cursor="conv_1", action_id="a1"), + ) + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="conv_2", text="done") + ) + + session = async_client.agents.create_session(agent_external_id="agent_1") + r = await session.chat(Message("calc")) + assert r.action_calls is not None + assert session.cursor == "conv_1" + + call = r.action_calls[0] + await session.chat(ClientToolResult(action_id=call.action_id, content="3")) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + assert jsgz_load(reqs[1].content)["cursor"] == "conv_1" + + # T014 [US3] + async def test_multiple_action_results_in_one_call( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response(method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="final")) + + session = async_client.agents.create_session(agent_external_id="agent_1", cursor="after_calls") + await session.chat( + [ + ClientToolResult(action_id="a1", content="res1"), + ClientToolResult(action_id="a2", content="res2"), + ] + ) + + body = jsgz_load(httpx_mock.get_requests()[0].content) + assert body["cursor"] == "after_calls" + assert len(body["messages"]) == 2 + + # T019 [US3] + async def test_tool_confirmation_pass_through_and_response( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + httpx_mock.add_response( + method="POST", + url=chat_url, + status_code=200, + json=self._tool_confirmation_response(cursor="confirm_c", action_id="conf_1"), + ) + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="after", text="executed") + ) + + session = async_client.agents.create_session(agent_external_id="agent_1") + r = await session.chat(Message("run")) + + assert r.action_calls is not None + assert len(r.action_calls) == 1 + assert isinstance(r.action_calls[0], ToolConfirmationCall) + assert r.action_calls[0].action_id == "conf_1" + assert len(httpx_mock.get_requests()) == 1 # no auto-confirmation + + await session.chat(ToolConfirmationResult(action_id="conf_1", status="ALLOW")) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + assert jsgz_load(reqs[1].content)["cursor"] == "confirm_c" + + # T021 [US3] + async def test_three_action_rounds_thread_cursor( + self, async_client: AsyncCogniteClient, httpx_mock: HTTPXMock, chat_url: str + ) -> None: + for i in range(3): + httpx_mock.add_response( + method="POST", + url=chat_url, + status_code=200, + json=self._client_tool_call_response(cursor=f"c{i + 1}", action_id=f"call_{i + 1}"), + ) + httpx_mock.add_response( + method="POST", url=chat_url, status_code=200, json=self._text_response(cursor="c_final", text="done") + ) + + session = async_client.agents.create_session(agent_external_id="agent_1") + + r = await session.chat(Message("step_1")) + assert r.action_calls is not None + assert session.cursor == "c1" + + r = await session.chat(ClientToolResult(action_id=r.action_calls[0].action_id, content="ok")) + assert r.action_calls is not None + assert session.cursor == "c2" + + r = await session.chat(ClientToolResult(action_id=r.action_calls[0].action_id, content="ok")) + assert r.action_calls is not None + assert session.cursor == "c3" + + r = await session.chat(ClientToolResult(action_id=r.action_calls[0].action_id, content="ok")) + assert session.cursor == "c_final" + + reqs = httpx_mock.get_requests() + assert len(reqs) == 4 + assert "cursor" not in jsgz_load(reqs[0].content) + assert jsgz_load(reqs[1].content)["cursor"] == "c1" + assert jsgz_load(reqs[2].content)["cursor"] == "c2" + assert jsgz_load(reqs[3].content)["cursor"] == "c3"