1- from concurrent .futures import ThreadPoolExecutor
1+ import asyncio
2+ import queue
3+ import threading
4+ from collections .abc import AsyncGenerator
25from datetime import datetime , timedelta , timezone
36from pathlib import Path
4- from typing import Literal , Sequence , cast
7+ from typing import Literal , Sequence , cast , overload
58
69from pydantic import AwareDatetime , BaseModel , Field , computed_field
710
811from askui .agent import VisionAgent
9- from askui .chat .api .messages .service import MessageService
12+ from askui .chat .api .messages .service import MessageEvent , MessageService
13+ from askui .chat .api .models import Event
1014from askui .chat .api .utils import generate_time_ordered_id
1115from askui .models .shared .computer_agent_cb_param import OnMessageCbParam
1216from askui .models .shared .computer_agent_message_param import MessageParam
@@ -70,15 +74,33 @@ class RunListResponse(BaseModel):
7074 has_more : bool = False
7175
7276
77+ class RunEvent (Event ):
78+ data : Run
79+ event : Literal [
80+ "run.created" ,
81+ "run.started" ,
82+ "run.completed" ,
83+ "run.failed" ,
84+ "run.cancelled" ,
85+ "run.expired" ,
86+ ]
87+
88+
7389class Runner :
7490 def __init__ (self , run : Run , base_dir : Path ) -> None :
7591 self ._run = run
7692 self ._base_dir = base_dir
7793 self ._runs_dir = base_dir / "runs"
7894 self ._msg_service = MessageService (self ._base_dir )
7995
80- def run_task (self ) -> None :
96+ def run (self , event_queue : queue . Queue [ RunEvent | MessageEvent | None ] ) -> None :
8197 self ._mark_started ()
98+ event_queue .put (
99+ RunEvent (
100+ data = self ._run ,
101+ event = "run.started" ,
102+ )
103+ )
82104 messages : list [MessageParam ] = [
83105 cast ("MessageParam" , msg )
84106 for msg in self ._msg_service .list_ (self ._run .thread_id ).data
@@ -87,27 +109,63 @@ def run_task(self) -> None:
87109 def on_message (
88110 on_message_cb_param : OnMessageCbParam ,
89111 ) -> MessageParam | None :
90- self ._msg_service .create (
112+ message = self ._msg_service .create (
91113 thread_id = self ._run .thread_id ,
92114 message = on_message_cb_param .message ,
93115 )
116+ event_queue .put (
117+ MessageEvent (
118+ data = message ,
119+ event = "message.created" ,
120+ )
121+ )
94122 updated_run = self ._retrieve_run ()
95- if self . _should_abort ( updated_run ) :
123+ if updated_run . status == "cancelling" :
96124 updated_run .cancelled_at = datetime .now (tz = timezone .utc )
97125 self ._update_run_file (updated_run )
126+ event_queue .put (
127+ RunEvent (
128+ data = updated_run ,
129+ event = "run.cancelled" ,
130+ )
131+ )
132+ return None
133+ if updated_run .status == "expired" :
134+ event_queue .put (
135+ RunEvent (
136+ data = updated_run ,
137+ event = "run.expired" ,
138+ )
139+ )
98140 return None
99141 return on_message_cb_param .message
100142
101143 try :
102144 with VisionAgent () as agent :
103145 agent .act (messages , on_message = on_message )
104- self ._run .completed_at = datetime .now (tz = timezone .utc )
105- self ._update_run_file (self ._run )
146+ updated_run = self ._retrieve_run ()
147+ if updated_run .status == "in_progress" :
148+ updated_run .completed_at = datetime .now (tz = timezone .utc )
149+ self ._update_run_file (updated_run )
150+ event_queue .put (
151+ RunEvent (
152+ data = updated_run ,
153+ event = "run.completed" ,
154+ )
155+ )
106156 except Exception as e : # noqa: BLE001
107- self ._run .failed_at = datetime .now (tz = timezone .utc )
108- self ._run .last_error = RunError (message = str (e ), code = "server_error" )
109- self ._update_run_file (self ._run )
110- raise
157+ updated_run = self ._retrieve_run ()
158+ updated_run .failed_at = datetime .now (tz = timezone .utc )
159+ updated_run .last_error = RunError (message = str (e ), code = "server_error" )
160+ self ._update_run_file (updated_run )
161+ event_queue .put (
162+ RunEvent (
163+ data = updated_run ,
164+ event = "run.failed" ,
165+ )
166+ )
167+ finally :
168+ event_queue .put (None )
111169
112170 def _mark_started (self ) -> None :
113171 self ._run .started_at = datetime .now (tz = timezone .utc )
@@ -132,29 +190,56 @@ class RunService:
132190 Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs.
133191 """
134192
135- _executor : ThreadPoolExecutor = ThreadPoolExecutor (max_workers = 4 )
136-
137193 def __init__ (self , base_dir : Path ) -> None :
138194 self ._base_dir = base_dir
139195 self ._runs_dir = base_dir / "runs"
140196
141197 def _run_path (self , thread_id : str , run_id : str ) -> Path :
142198 return self ._runs_dir / f"{ thread_id } __{ run_id } .json"
143199
144- def create (self , thread_id : str , stream : bool ) -> Run :
200+ def _create_run (self , thread_id : str ) -> Run :
145201 run = Run (thread_id = thread_id )
146202 self ._runs_dir .mkdir (parents = True , exist_ok = True )
147203 self ._update_run_file (run )
148- runner = Runner (run , self ._base_dir )
149- # TODO(adi-wan-askui): Run differently depending on `stream` parameter
150- runner .run_task ()
151- # if not stream:
152- # self._start_run_background(run)
153204 return run
154205
155- def _start_run_background (self , run : Run ) -> None :
206+ @overload
207+ def create (self , thread_id : str , stream : Literal [False ]) -> Run : ...
208+
209+ @overload
210+ def create (
211+ self , thread_id : str , stream : Literal [True ]
212+ ) -> AsyncGenerator [RunEvent | MessageEvent , None ]: ...
213+
214+ @overload
215+ def create (
216+ self , thread_id : str , stream : bool
217+ ) -> Run | AsyncGenerator [RunEvent | MessageEvent , None ]: ...
218+
219+ def create (
220+ self , thread_id : str , stream : bool
221+ ) -> Run | AsyncGenerator [RunEvent | MessageEvent , None ]:
222+ run = self ._create_run (thread_id )
223+ event_queue : queue .Queue [RunEvent | MessageEvent | None ] = queue .Queue ()
156224 runner = Runner (run , self ._base_dir )
157- self ._executor .submit (runner .run_task )
225+ thread = threading .Thread (target = runner .run , args = (event_queue ,))
226+ thread .start ()
227+ if stream :
228+
229+ async def event_stream () -> AsyncGenerator [RunEvent | MessageEvent , None ]:
230+ yield RunEvent (
231+ data = run ,
232+ event = "run.created" ,
233+ )
234+ loop = asyncio .get_event_loop ()
235+ while True :
236+ event = await loop .run_in_executor (None , event_queue .get )
237+ if event is None :
238+ break
239+ yield event
240+
241+ return event_stream ()
242+ return run
158243
159244 def _update_run_file (self , run : Run ) -> None :
160245 run_file = self ._run_path (run .thread_id , run .id )
0 commit comments