From 8ad384a5e90e8d8800978bfc1fa4cbfb8f14b5eb Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 1 Sep 2025 09:58:23 +0200 Subject: [PATCH] fix(chat): run expiring after 10 minutes --- src/askui/chat/api/runs/models.py | 23 ++++++++++++++++++++--- src/askui/chat/api/runs/runner/runner.py | 14 ++++++++------ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/askui/chat/api/runs/models.py b/src/askui/chat/api/runs/models.py index 1e91c73c..4a085c17 100644 --- a/src/askui/chat/api/runs/models.py +++ b/src/askui/chat/api/runs/models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import timedelta from typing import Literal from pydantic import BaseModel, computed_field @@ -64,7 +64,7 @@ def create(cls, thread_id: ThreadId, params: RunCreateParams) -> "Run": id=generate_time_ordered_id("run"), thread_id=thread_id, created_at=now(), - expires_at=datetime.now(tz=timezone.utc) + timedelta(minutes=10), + expires_at=now() + timedelta(minutes=10), **params.model_dump(exclude={"stream"}), ) @@ -77,10 +77,27 @@ def status(self) -> RunStatus: return "failed" if self.completed_at: return "completed" - if self.expires_at and self.expires_at < datetime.now(tz=timezone.utc): + if self.expires_at and self.expires_at < now(): return "expired" if self.tried_cancelling_at: return "cancelling" if self.started_at: return "in_progress" return "queued" + + def start(self) -> None: + self.started_at = now() + self.expires_at = now() + timedelta(minutes=10) + + def ping(self) -> None: + self.expires_at = now() + timedelta(minutes=10) + + def complete(self) -> None: + self.completed_at = now() + + def cancel(self) -> None: + self.cancelled_at = now() + + def fail(self, error: RunError) -> None: + self.failed_at = now() + self.last_error = error diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 57a33d48..565e03e4 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -1,6 +1,5 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime, timezone from typing import TYPE_CHECKING, Literal, Sequence import anthropic @@ -142,6 +141,8 @@ async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: updated_run = self._retrieve() if self._should_abort(updated_run): break + updated_run.ping() + self._run_service.save(updated_run) while event := self._agent_os.poll_event(): if self._should_abort(updated_run): break @@ -287,6 +288,8 @@ async def async_on_message( updated_run = self._retrieve() if self._should_abort(updated_run): return None + updated_run.ping() + self._run_service.save(updated_run) return on_message_cb_param.message on_message = syncify(async_on_message) @@ -390,7 +393,7 @@ async def run( ) updated_run = self._retrieve() if updated_run.status == "in_progress": - updated_run.completed_at = datetime.now(tz=timezone.utc) + updated_run.complete() self._run_service.save(updated_run) await send_stream.send( RunEvent( @@ -405,7 +408,7 @@ async def run( event="thread.run.cancelling", ) ) - updated_run.cancelled_at = datetime.now(tz=timezone.utc) + updated_run.cancel() self._run_service.save(updated_run) await send_stream.send( RunEvent( @@ -424,8 +427,7 @@ async def run( except Exception as e: # noqa: BLE001 logger.exception("Exception in runner") updated_run = self._retrieve() - updated_run.failed_at = datetime.now(tz=timezone.utc) - updated_run.last_error = RunError(message=str(e), code="server_error") + updated_run.fail(RunError(message=str(e), code="server_error")) self._run_service.save(updated_run) await send_stream.send( RunEvent( @@ -440,7 +442,7 @@ async def run( ) def _mark_run_as_started(self) -> None: - self._run.started_at = datetime.now(tz=timezone.utc) + self._run.start() self._run_service.save(self._run) def _should_abort(self, run: Run) -> bool: