diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e916fce..e44085e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,4 +14,18 @@ jobs: - uses: actions/checkout@v4 - uses: oven-sh/setup-bun@v2 - run: bun install --frozen-lockfile + - name: Setup test config + run: | + mkdir -p ~/.agentara + cat > ~/.agentara/config.yaml << 'EOF' + agents: + default: + type: claude-code + model: claude-sonnet-4-6 + tasking: + max_retries: 3 + messaging: + default_channel_id: test-channel + channels: [] + EOF - run: bun test diff --git a/src/community/anthropic/claude-agent-runner.ts b/src/community/anthropic/claude-agent-runner.ts index 9b1f31e..36c9b0f 100644 --- a/src/community/anthropic/claude-agent-runner.ts +++ b/src/community/anthropic/claude-agent-runner.ts @@ -1,5 +1,6 @@ import { config, + createLogger, extractTextContent, type MessageContent, type ToolMessage, @@ -10,6 +11,18 @@ import { type UserMessage, } from "@/shared"; +const logger = createLogger("claude-agent-runner"); + +/** + * Error thrown when the agent runner is aborted. + */ +export class AgentAbortError extends Error { + constructor(message = "Agent execution was aborted") { + super(message); + this.name = "AgentAbortError"; + } +} + /** * The agent runner for Claude Code CLI. */ @@ -22,6 +35,7 @@ export class ClaudeAgentRunner implements AgentRunner { ): AsyncIterableIterator { const sessionId = message.session_id; const isNew = options?.isNewSession ?? false; + const signal = options?.signal; const textContentOfUserMessage = JSON.stringify( extractTextContent(message), ); @@ -43,6 +57,22 @@ export class ClaudeAgentRunner implements AgentRunner { }, stderr: "pipe", }); + + // Handle abort signal + let aborted = false; + const abortHandler = () => { + aborted = true; + logger.info({ session_id: sessionId }, "killing Claude Code process"); + proc.kill(); + }; + if (signal) { + if (signal.aborted) { + proc.kill(); + throw new AgentAbortError(); + } + signal.addEventListener("abort", abortHandler, { once: true }); + } + const decoder = new TextDecoder(); const stderrChunks: Uint8Array[] = []; const stderrPipe = proc.stderr.pipeTo( @@ -54,27 +84,41 @@ export class ClaudeAgentRunner implements AgentRunner { ); let buffer = ""; let stdoutRaw = ""; - for await (const chunk of proc.stdout) { - const decoded = decoder.decode(chunk, { stream: true }); - buffer += decoded; - stdoutRaw += decoded; - const lines = buffer.split("\n"); - buffer = lines.pop()!; - for (const line of lines) { - if (line.trim()) { - const parsed = this._parseStreamLine(line.trim(), sessionId); - if (parsed) { - yield parsed; + try { + for await (const chunk of proc.stdout) { + if (aborted) { + break; + } + const decoded = decoder.decode(chunk, { stream: true }); + buffer += decoded; + stdoutRaw += decoded; + const lines = buffer.split("\n"); + buffer = lines.pop()!; + for (const line of lines) { + if (line.trim()) { + const parsed = this._parseStreamLine(line.trim(), sessionId); + if (parsed) { + yield parsed; + } } } } - } - if (buffer.trim()) { - const parsed = this._parseStreamLine(buffer.trim(), sessionId); - if (parsed) { - yield parsed; + if (!aborted && buffer.trim()) { + const parsed = this._parseStreamLine(buffer.trim(), sessionId); + if (parsed) { + yield parsed; + } } + } finally { + if (signal) { + signal.removeEventListener("abort", abortHandler); + } + } + + if (aborted) { + throw new AgentAbortError(); } + const exitCode = await proc.exited; await stderrPipe; if (exitCode !== 0) { diff --git a/src/community/feishu/messaging/message-channel.ts b/src/community/feishu/messaging/message-channel.ts index 48c92e1..752580c 100644 --- a/src/community/feishu/messaging/message-channel.ts +++ b/src/community/feishu/messaging/message-channel.ts @@ -95,6 +95,7 @@ export class FeishuMessageChannel await this._inboundClient.start({ eventDispatcher: new EventDispatcher({}).register({ "im.message.receive_v1": this._handleMessageReceive, + "im.message.recalled_v1": this._handleMessageRecall, }), }); } @@ -563,6 +564,17 @@ export class FeishuMessageChannel this.emit("message:inbound", userMessage); }; + private _handleMessageRecall = async (data: { + message_id?: string; + chat_id?: string; + recall_time?: string; + recall_type?: string; + }) => { + if (!data.message_id) return; + this._logger.info({ message_id: data.message_id }, "message recalled"); + this.emit("message:recalled", data.message_id, this.id); + }; + private _threadIdToSessionId = new Map(); /** Persist a thread→session mapping to DB and update the in-memory cache. */ diff --git a/src/community/openai/codex-agent-runner.ts b/src/community/openai/codex-agent-runner.ts index dc8ddbd..7cfb9fd 100644 --- a/src/community/openai/codex-agent-runner.ts +++ b/src/community/openai/codex-agent-runner.ts @@ -17,6 +17,16 @@ import { const logger = createLogger("codex-agent-runner"); +/** + * Error thrown when the agent runner is aborted. + */ +export class AgentAbortError extends Error { + constructor(message = "Agent execution was aborted") { + super(message); + this.name = "AgentAbortError"; + } +} + /** * The agent runner for OpenAI Codex CLI. * @@ -33,6 +43,7 @@ export class CodexAgentRunner implements AgentRunner { ): AsyncIterableIterator { const sessionId = message.session_id; const isNew = options?.isNewSession ?? false; + const signal = options?.signal; const resumeId = options.runnerSessionId ?? sessionId; const textContentOfUserMessage = JSON.stringify( extractTextContent(message), @@ -56,6 +67,21 @@ export class CodexAgentRunner implements AgentRunner { stderr: "pipe", }); + // Handle abort signal + let aborted = false; + const abortHandler = () => { + aborted = true; + logger.info({ session_id: sessionId }, "killing Codex CLI process"); + proc.kill(); + }; + if (signal) { + if (signal.aborted) { + proc.kill(); + throw new AgentAbortError(); + } + signal.addEventListener("abort", abortHandler, { once: true }); + } + const decoder = new TextDecoder(); const stderrChunks: Uint8Array[] = []; const stderrPipe = proc.stderr.pipeTo( @@ -68,27 +94,40 @@ export class CodexAgentRunner implements AgentRunner { let buffer = ""; let stdoutRaw = ""; - for await (const chunk of proc.stdout) { - const decoded = decoder.decode(chunk, { stream: true }); - buffer += decoded; - stdoutRaw += decoded; - const lines = buffer.split("\n"); - buffer = lines.pop()!; - for (const line of lines) { - if (line.trim()) { - const messages = this._parseStreamLine(line.trim(), sessionId); - for (const msg of messages) { - yield msg; + try { + for await (const chunk of proc.stdout) { + if (aborted) { + break; + } + const decoded = decoder.decode(chunk, { stream: true }); + buffer += decoded; + stdoutRaw += decoded; + const lines = buffer.split("\n"); + buffer = lines.pop()!; + for (const line of lines) { + if (line.trim()) { + const messages = this._parseStreamLine(line.trim(), sessionId); + for (const msg of messages) { + yield msg; + } } } } - } - if (buffer.trim()) { - const messages = this._parseStreamLine(buffer.trim(), sessionId); - for (const msg of messages) { - yield msg; + if (!aborted && buffer.trim()) { + const messages = this._parseStreamLine(buffer.trim(), sessionId); + for (const msg of messages) { + yield msg; + } } + } finally { + if (signal) { + signal.removeEventListener("abort", abortHandler); + } + } + + if (aborted) { + throw new AgentAbortError(); } const exitCode = await proc.exited; diff --git a/src/kernel/kernel.ts b/src/kernel/kernel.ts index 8453f8a..15b5a80 100644 --- a/src/kernel/kernel.ts +++ b/src/kernel/kernel.ts @@ -98,6 +98,7 @@ class Kernel { ); } this._messageGateway.on("message:inbound", this._handleInboundMessage); + this._messageGateway.on("message:recalled", this._handleMessageRecall); } /** @@ -111,6 +112,14 @@ class Kernel { } private _handleInboundMessage = async (message: UserMessage) => { + const text = extractTextContent(message).trim(); + + // Handle /stop command + if (text === "/stop") { + await this._handleStopCommand(message); + return; + } + const task: InboundMessageTaskPayload = { type: "inbound_message", message, @@ -118,10 +127,46 @@ class Kernel { await this._taskDispatcher.dispatch(message.session_id, task); }; + private _handleStopCommand = async (message: UserMessage) => { + const sessionId = message.session_id; + const runningTaskId = + this._taskDispatcher.getRunningTaskForSession(sessionId); + + if (runningTaskId) { + await this._taskDispatcher.deleteTask(runningTaskId); + await this._messageGateway.replyMessage(message.id, { + role: "assistant", + session_id: sessionId, + content: [{ type: "text", text: "Task stopped." }], + }); + } else { + await this._messageGateway.replyMessage(message.id, { + role: "assistant", + session_id: sessionId, + content: [{ type: "text", text: "No running task found." }], + }); + } + }; + + private _handleMessageRecall = async ( + messageId: string, + channelId: string, + ) => { + const taskId = this._taskDispatcher.getTaskByMessageId(messageId); + if (taskId) { + await this._taskDispatcher.deleteTask(taskId); + this._logger.info( + { message_id: messageId, task_id: taskId, channel_id: channelId }, + "task stopped due to message recall", + ); + } + }; + private _handleInboundMessageTask = async ( taskId: string, sessionId: string, payload: InboundMessageTaskPayload, + signal?: AbortSignal, ) => { const inboundMessage = payload.message; const session = await this._sessionManager.resolveSession(sessionId, { @@ -146,7 +191,7 @@ class Kernel { }, ); contents = []; - const stream = await session.stream(inboundMessage); + const stream = await session.stream(inboundMessage, { signal }); let lastMessage: AssistantMessage | undefined; for await (const message of stream) { if (message.role === "assistant") { @@ -175,6 +220,7 @@ class Kernel { _taskId: string, sessionId: string, payload: ScheduledTaskPayload, + signal?: AbortSignal, ) => { const payload_without_instruction: { instruction?: string } = { ...payload, @@ -201,7 +247,7 @@ ${payload.instruction}`, firstMessage: userMessage, }); delete payload_without_instruction.instruction; - const assistantMessage = await session.run(userMessage); + const assistantMessage = await session.run(userMessage, { signal }); if (extractTextContent(assistantMessage).includes("[SKIPPED]")) { return; } diff --git a/src/kernel/messaging/multi-channel-message-gateway.ts b/src/kernel/messaging/multi-channel-message-gateway.ts index 2ffa0ec..1f82c70 100644 --- a/src/kernel/messaging/multi-channel-message-gateway.ts +++ b/src/kernel/messaging/multi-channel-message-gateway.ts @@ -44,6 +44,9 @@ export class MultiChannelMessageGateway channel.on("message:inbound", (message: UserMessage) => { this._handleInboundMessage(channel.id, message); }); + channel.on("message:recalled", (messageId: string, channelId: string) => { + this.emit("message:recalled", messageId, channelId); + }); this._logger.info(`Registered channel: ${channel.id}`); } diff --git a/src/kernel/sessioning/session.ts b/src/kernel/sessioning/session.ts index e821cf8..233beae 100644 --- a/src/kernel/sessioning/session.ts +++ b/src/kernel/sessioning/session.ts @@ -16,6 +16,16 @@ export interface SessionEventTypes { message: (message: Message) => void; } +/** + * Options for streaming messages from the session. + */ +export interface SessionStreamOptions { + /** + * Abort signal for cancelling the running task. + */ + signal?: AbortSignal; +} + /** * Represent a session context of the agent. */ @@ -41,10 +51,12 @@ export class Session extends EventEmitter { /** * Return a stream of messages from the agent. * @param userMessage - The message to send to the agent. + * @param streamOptions - Optional options for the stream (e.g., abort signal). * @returns The stream of messages from the agent. */ async stream( userMessage: UserMessage, + streamOptions?: SessionStreamOptions, ): Promise< AsyncIterableIterator > { @@ -52,6 +64,7 @@ export class Session extends EventEmitter { const runner = createAgentRunner(this.agentType); const rawStream = runner.stream(userMessage, { ...this.options, + signal: streamOptions?.signal, }); this.options.isNewSession = false; // eslint-disable-next-line @typescript-eslint/no-this-alias @@ -68,10 +81,14 @@ export class Session extends EventEmitter { /** * Send a message to the agent and return the last message. * @param userMessage - The message to send to the agent. + * @param streamOptions - Optional options for the stream (e.g., abort signal). * @returns The last message from the agent. */ - async run(userMessage: UserMessage): Promise { - const stream = await this.stream(userMessage); + async run( + userMessage: UserMessage, + streamOptions?: SessionStreamOptions, + ): Promise { + const stream = await this.stream(userMessage, streamOptions); let lastMessage: AssistantMessage | undefined; for await (const message of stream) { if (message.role === "assistant") { diff --git a/src/kernel/tasking/task-dispatcher.ts b/src/kernel/tasking/task-dispatcher.ts index b162167..529a244 100644 --- a/src/kernel/tasking/task-dispatcher.ts +++ b/src/kernel/tasking/task-dispatcher.ts @@ -1,6 +1,6 @@ import type { Job } from "bunqueue/client"; import { Queue, Worker } from "bunqueue/client"; -import { desc, eq } from "drizzle-orm"; +import { and, desc, eq, inArray } from "drizzle-orm"; import type { DrizzleDB } from "@/data"; import type { @@ -34,6 +34,7 @@ interface TaskJobData { * @param taskId - The bunqueue job ID for this task. * @param sessionId - The session that owns this task. * @param payload - The task payload. + * @param signal - Optional abort signal for cancelling the task. */ export type TaskHandler

= ( // eslint-disable-next-line no-unused-vars @@ -42,6 +43,8 @@ export type TaskHandler

= ( sessionId: string, // eslint-disable-next-line no-unused-vars payload: P, + // eslint-disable-next-line no-unused-vars + signal?: AbortSignal, ) => Promise; /** @@ -88,6 +91,8 @@ export class TaskDispatcher { private _handlers: Map; /** Per-session promise chain for serial execution. */ private _sessionLocks: Map>; + /** Tracks AbortController for currently running tasks. */ + private _runningTasks: Map; private _logger: Logger; constructor(options: TaskDispatcherOptions) { @@ -95,6 +100,7 @@ export class TaskDispatcher { this._db = options.db; this._handlers = new Map(); this._sessionLocks = new Map(); + this._runningTasks = new Map(); this._logger = createLogger("task-dispatcher"); this._queue = new Queue(QUEUE_NAME, { embedded: true, @@ -412,6 +418,7 @@ export class TaskDispatcher { /** * Delete a task by ID. For pending jobs, removes from the queue. + * For running tasks, aborts the handler and kills any spawned subprocesses. * Always deletes the persisted task row. For one-shot scheduled tasks, * also removes the scheduler row. * @param taskId - The task (job) ID to remove. @@ -422,6 +429,15 @@ export class TaskDispatcher { if (!row) { throw new Error(`Task not found: ${taskId}`); } + + // Abort running task if it exists + const controller = this._runningTasks.get(taskId); + if (controller) { + controller.abort(); + this._runningTasks.delete(taskId); + this._logger.info({ task_id: taskId }, "aborted running task"); + } + try { await this._queue.removeAsync(taskId); } catch { @@ -474,6 +490,44 @@ export class TaskDispatcher { return this._db.select().from(scheduledTasks).all() as ScheduledTaskRow[]; } + /** + * Get the currently running task for a session, if any. + * @param sessionId - The session ID to look up. + * @returns The task ID if found, undefined otherwise. + */ + getRunningTaskForSession(sessionId: string): string | undefined { + const row = this._db + .select({ id: tasks.id }) + .from(tasks) + .where(and(eq(tasks.session_id, sessionId), eq(tasks.status, "running"))) + .get(); + return row?.id; + } + + /** + * Get a pending or running task by its inbound message ID. + * @param messageId - The Feishu message ID to look up. + * @returns The task ID if found, undefined otherwise. + */ + getTaskByMessageId(messageId: string): string | undefined { + const rows = this._db + .select({ id: tasks.id, payload: tasks.payload }) + .from(tasks) + .where(inArray(tasks.status, ["pending", "running"])) + .all(); + + for (const row of rows) { + const payload = row.payload as TaskPayload; + if ( + payload.type === "inbound_message" && + payload.message.id === messageId + ) { + return row.id; + } + } + return undefined; + } + /** * Start the worker. Must be called once during app startup. * Re-registers all persisted scheduled tasks with bunqueue. @@ -593,9 +647,14 @@ export class TaskDispatcher { return; } this._updateTaskStatus(job.id, "running"); + + // Create AbortController for this task + const controller = new AbortController(); + this._runningTasks.set(job.id, controller); + try { const taskId = job.id; - await handler(taskId, sessionId, payload); + await handler(taskId, sessionId, payload, controller.signal); await job.updateProgress(100); this._updateTaskStatus(job.id, "completed"); const schedulerId = job.data.scheduler_id; @@ -608,12 +667,23 @@ export class TaskDispatcher { } } } catch (err) { - this._updateTaskStatus(job.id, "failed"); - this._logger.error( - { session_id: sessionId, type: payload.type, err }, - "task failed", - ); - throw err; + // Don't mark as failed if aborted — the task was intentionally cancelled + if (controller.signal.aborted) { + this._updateTaskStatus(job.id, "cancelled"); + this._logger.info( + { session_id: sessionId, type: payload.type }, + "task cancelled", + ); + } else { + this._updateTaskStatus(job.id, "failed"); + this._logger.error( + { session_id: sessionId, type: payload.type, err }, + "task failed", + ); + throw err; + } + } finally { + this._runningTasks.delete(job.id); } }); diff --git a/src/shared/agents/agent-runner.ts b/src/shared/agents/agent-runner.ts index b211434..fafe1ca 100644 --- a/src/shared/agents/agent-runner.ts +++ b/src/shared/agents/agent-runner.ts @@ -25,6 +25,12 @@ export const AgentRunOptions = z.object({ * Runner-specific session/thread id used by some providers for true resume. */ runnerSessionId: z.string().optional(), + + /** + * Abort signal for cancelling the running task. + * When aborted, the agent runner should kill any spawned subprocesses. + */ + signal: z.instanceof(AbortSignal).optional(), }); export interface AgentRunOptions extends z.infer {} diff --git a/src/shared/messaging/message-channel.ts b/src/shared/messaging/message-channel.ts index a4d0427..5da67fe 100644 --- a/src/shared/messaging/message-channel.ts +++ b/src/shared/messaging/message-channel.ts @@ -6,6 +6,8 @@ import type { AssistantMessage, UserMessage } from "./types"; export interface MessageChannelEventTypes { // eslint-disable-next-line no-unused-vars "message:inbound": (message: UserMessage) => void; + // eslint-disable-next-line no-unused-vars + "message:recalled": (messageId: string, channelId: string) => void; } /** Abstract message channel for sending and receiving messages. */ diff --git a/src/shared/messaging/message-gateway.ts b/src/shared/messaging/message-gateway.ts index 055be2b..5e1c7be 100644 --- a/src/shared/messaging/message-gateway.ts +++ b/src/shared/messaging/message-gateway.ts @@ -7,6 +7,8 @@ import type { AssistantMessage, UserMessage } from "./types"; export interface MessageGatewayEventTypes { // eslint-disable-next-line no-unused-vars "message:inbound": (message: UserMessage) => void; + // eslint-disable-next-line no-unused-vars + "message:recalled": (messageId: string, channelId: string) => void; } /**