Skip to content

Commit 00cbbd3

Browse files
committed
feat(chat): add streaming support to creating runs
1 parent d6fcf84 commit 00cbbd3

File tree

5 files changed

+169
-27
lines changed

5 files changed

+169
-27
lines changed

src/askui/chat/__main__.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ def write_message(
338338
)
339339
write_message(last_message)
340340
run = run_service.create(thread_id, stream=False)
341-
print(run)
342341
time.sleep(1)
343342
while run := run_service.retrieve(run.id):
344343
new_messages = message_service.list_(
@@ -352,5 +351,32 @@ def write_message(
352351
time.sleep(1)
353352

354353

354+
if act_prompt := st.chat_input("Ask AI (streaming)"):
355+
if act_prompt != "Continue":
356+
last_message = message_service.create(
357+
thread_id=thread_id,
358+
message=MessageParam(
359+
role="user",
360+
content=act_prompt,
361+
),
362+
)
363+
write_message(last_message)
364+
365+
# Use the streaming API
366+
event_stream = run_service.create(thread_id, stream=True)
367+
import asyncio
368+
369+
async def handle_stream() -> None:
370+
last_msg_id = last_message.id if last_message else None
371+
async for event in event_stream:
372+
if event.event == "message.created":
373+
msg = event.data
374+
if msg and (not last_msg_id or msg.id > last_msg_id):
375+
write_message(msg)
376+
last_msg_id = msg.id
377+
378+
# Run the async handler in Streamlit (sync context)
379+
asyncio.run(handle_stream())
380+
355381
# if st.button("Rerun"):
356382
# rerun()

src/askui/chat/api/messages/service.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from datetime import datetime, timezone
22
from pathlib import Path
3+
from typing import Literal
34

45
from pydantic import AwareDatetime, BaseModel, Field
56

7+
from askui.chat.api.models import Event
68
from askui.chat.api.utils import generate_time_ordered_id
79
from askui.models.shared.computer_agent_message_param import MessageParam
810

@@ -18,6 +20,11 @@ class Message(MessageParam):
1820
object: str = "message"
1921

2022

23+
class MessageEvent(Event):
24+
data: Message
25+
event: Literal["message.created"]
26+
27+
2128
class MessageListResponse(BaseModel):
2229
"""Response model for listing messages."""
2330

src/askui/chat/api/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Literal
2+
3+
from pydantic import BaseModel
4+
5+
6+
class Event(BaseModel):
7+
object: Literal["event"] = "event"

src/askui/chat/api/runs/router.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
from typing import Annotated
1+
from collections.abc import AsyncGenerator
2+
from typing import TYPE_CHECKING, Annotated, cast
23

34
from fastapi import APIRouter, Body, HTTPException, Path
5+
from fastapi.responses import StreamingResponse
46
from pydantic import BaseModel
57

8+
if TYPE_CHECKING:
9+
from askui.chat.api.messages.service import MessageEvent
10+
611
from .dependencies import RunServiceDep
7-
from .service import Run, RunListResponse, RunService
12+
from .service import Run, RunEvent, RunListResponse, RunService
813

914

1015
class CreateRunRequest(BaseModel):
@@ -19,11 +24,23 @@ def create_run(
1924
thread_id: Annotated[str, Path(...)],
2025
request: Annotated[CreateRunRequest, Body(...)],
2126
run_service: RunService = RunServiceDep,
22-
) -> Run:
27+
) -> Run | StreamingResponse:
2328
"""
2429
Create a new run for a given thread.
2530
"""
26-
return run_service.create(thread_id, request.stream)
31+
stream = request.stream
32+
run_or_async_generator = run_service.create(thread_id, stream)
33+
if stream:
34+
async_generator = cast(
35+
"AsyncGenerator[RunEvent | MessageEvent, None]", run_or_async_generator
36+
)
37+
38+
async def sse_event_stream() -> AsyncGenerator[str, None]:
39+
async for event in async_generator:
40+
yield f"event: {event.event}\ndata: {event.model_dump_json()}\n\n"
41+
42+
return StreamingResponse(sse_event_stream(), media_type="text/event-stream")
43+
return cast("Run", run_or_async_generator)
2744

2845

2946
@router.get("/{run_id}")

src/askui/chat/api/runs/service.py

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from concurrent.futures import ThreadPoolExecutor
1+
import asyncio
2+
import queue
3+
import threading
4+
from collections.abc import AsyncGenerator
25
from datetime import datetime, timedelta, timezone
36
from pathlib import Path
4-
from typing import Literal, Sequence, cast
7+
from typing import Literal, Sequence, cast, overload
58

69
from pydantic import AwareDatetime, BaseModel, Field, computed_field
710

811
from askui.agent import VisionAgent
9-
from askui.chat.api.messages.service import MessageService
12+
from askui.chat.api.messages.service import MessageEvent, MessageService
13+
from askui.chat.api.models import Event
1014
from askui.chat.api.utils import generate_time_ordered_id
1115
from askui.models.shared.computer_agent_cb_param import OnMessageCbParam
1216
from askui.models.shared.computer_agent_message_param import MessageParam
@@ -70,15 +74,33 @@ class RunListResponse(BaseModel):
7074
has_more: bool = False
7175

7276

77+
class RunEvent(Event):
78+
data: Run
79+
event: Literal[
80+
"run.created",
81+
"run.started",
82+
"run.completed",
83+
"run.failed",
84+
"run.cancelled",
85+
"run.expired",
86+
]
87+
88+
7389
class Runner:
7490
def __init__(self, run: Run, base_dir: Path) -> None:
7591
self._run = run
7692
self._base_dir = base_dir
7793
self._runs_dir = base_dir / "runs"
7894
self._msg_service = MessageService(self._base_dir)
7995

80-
def run_task(self) -> None:
96+
def run(self, event_queue: queue.Queue[RunEvent | MessageEvent | None]) -> None:
8197
self._mark_started()
98+
event_queue.put(
99+
RunEvent(
100+
data=self._run,
101+
event="run.started",
102+
)
103+
)
82104
messages: list[MessageParam] = [
83105
cast("MessageParam", msg)
84106
for msg in self._msg_service.list_(self._run.thread_id).data
@@ -87,27 +109,63 @@ def run_task(self) -> None:
87109
def on_message(
88110
on_message_cb_param: OnMessageCbParam,
89111
) -> MessageParam | None:
90-
self._msg_service.create(
112+
message = self._msg_service.create(
91113
thread_id=self._run.thread_id,
92114
message=on_message_cb_param.message,
93115
)
116+
event_queue.put(
117+
MessageEvent(
118+
data=message,
119+
event="message.created",
120+
)
121+
)
94122
updated_run = self._retrieve_run()
95-
if self._should_abort(updated_run):
123+
if updated_run.status == "cancelling":
96124
updated_run.cancelled_at = datetime.now(tz=timezone.utc)
97125
self._update_run_file(updated_run)
126+
event_queue.put(
127+
RunEvent(
128+
data=updated_run,
129+
event="run.cancelled",
130+
)
131+
)
132+
return None
133+
if updated_run.status == "expired":
134+
event_queue.put(
135+
RunEvent(
136+
data=updated_run,
137+
event="run.expired",
138+
)
139+
)
98140
return None
99141
return on_message_cb_param.message
100142

101143
try:
102144
with VisionAgent() as agent:
103145
agent.act(messages, on_message=on_message)
104-
self._run.completed_at = datetime.now(tz=timezone.utc)
105-
self._update_run_file(self._run)
146+
updated_run = self._retrieve_run()
147+
if updated_run.status == "in_progress":
148+
updated_run.completed_at = datetime.now(tz=timezone.utc)
149+
self._update_run_file(updated_run)
150+
event_queue.put(
151+
RunEvent(
152+
data=updated_run,
153+
event="run.completed",
154+
)
155+
)
106156
except Exception as e: # noqa: BLE001
107-
self._run.failed_at = datetime.now(tz=timezone.utc)
108-
self._run.last_error = RunError(message=str(e), code="server_error")
109-
self._update_run_file(self._run)
110-
raise
157+
updated_run = self._retrieve_run()
158+
updated_run.failed_at = datetime.now(tz=timezone.utc)
159+
updated_run.last_error = RunError(message=str(e), code="server_error")
160+
self._update_run_file(updated_run)
161+
event_queue.put(
162+
RunEvent(
163+
data=updated_run,
164+
event="run.failed",
165+
)
166+
)
167+
finally:
168+
event_queue.put(None)
111169

112170
def _mark_started(self) -> None:
113171
self._run.started_at = datetime.now(tz=timezone.utc)
@@ -132,29 +190,56 @@ class RunService:
132190
Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs.
133191
"""
134192

135-
_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=4)
136-
137193
def __init__(self, base_dir: Path) -> None:
138194
self._base_dir = base_dir
139195
self._runs_dir = base_dir / "runs"
140196

141197
def _run_path(self, thread_id: str, run_id: str) -> Path:
142198
return self._runs_dir / f"{thread_id}__{run_id}.json"
143199

144-
def create(self, thread_id: str, stream: bool) -> Run:
200+
def _create_run(self, thread_id: str) -> Run:
145201
run = Run(thread_id=thread_id)
146202
self._runs_dir.mkdir(parents=True, exist_ok=True)
147203
self._update_run_file(run)
148-
runner = Runner(run, self._base_dir)
149-
# TODO(adi-wan-askui): Run differently depending on `stream` parameter
150-
runner.run_task()
151-
# if not stream:
152-
# self._start_run_background(run)
153204
return run
154205

155-
def _start_run_background(self, run: Run) -> None:
206+
@overload
207+
def create(self, thread_id: str, stream: Literal[False]) -> Run: ...
208+
209+
@overload
210+
def create(
211+
self, thread_id: str, stream: Literal[True]
212+
) -> AsyncGenerator[RunEvent | MessageEvent, None]: ...
213+
214+
@overload
215+
def create(
216+
self, thread_id: str, stream: bool
217+
) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]: ...
218+
219+
def create(
220+
self, thread_id: str, stream: bool
221+
) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]:
222+
run = self._create_run(thread_id)
223+
event_queue: queue.Queue[RunEvent | MessageEvent | None] = queue.Queue()
156224
runner = Runner(run, self._base_dir)
157-
self._executor.submit(runner.run_task)
225+
thread = threading.Thread(target=runner.run, args=(event_queue,))
226+
thread.start()
227+
if stream:
228+
229+
async def event_stream() -> AsyncGenerator[RunEvent | MessageEvent, None]:
230+
yield RunEvent(
231+
data=run,
232+
event="run.created",
233+
)
234+
loop = asyncio.get_event_loop()
235+
while True:
236+
event = await loop.run_in_executor(None, event_queue.get)
237+
if event is None:
238+
break
239+
yield event
240+
241+
return event_stream()
242+
return run
158243

159244
def _update_run_file(self, run: Run) -> None:
160245
run_file = self._run_path(run.thread_id, run.id)

0 commit comments

Comments
 (0)