From 0935ae86007c20d85a6442672b9c82874a3abbf7 Mon Sep 17 00:00:00 2001 From: Zach Bai Date: Wed, 22 Apr 2026 16:01:35 -0700 Subject: [PATCH] Send TaskCompleted message upon CLI exit with no error. --- internal/types/messages.go | 7 +++ internal/worker/worker.go | 27 ++++++++++ internal/worker/worker_test.go | 97 ++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) diff --git a/internal/types/messages.go b/internal/types/messages.go index 3b086f8..cb98d46 100644 --- a/internal/types/messages.go +++ b/internal/types/messages.go @@ -11,6 +11,7 @@ type MessageType string const ( MessageTypeTaskAssignment MessageType = "task_assignment" MessageTypeTaskClaimed MessageType = "task_claimed" + MessageTypeTaskCompleted MessageType = "task_completed" MessageTypeTaskFailed MessageType = "task_failed" MessageTypeTaskRejected MessageType = "task_rejected" MessageTypeHeartbeat MessageType = "heartbeat" @@ -50,6 +51,12 @@ type TaskClaimedMessage struct { WorkerID string `json:"worker_id"` } +// TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit. +type TaskCompletedMessage struct { + TaskID string `json:"task_id"` + Message string `json:"message"` +} + // TaskFailedMessage is sent from worker to server if task launch fails type TaskFailedMessage struct { TaskID string `json:"task_id"` diff --git a/internal/worker/worker.go b/internal/worker/worker.go index fa1d8cb..8214fdc 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -455,6 +455,9 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme } log.Infof(ctx, "Task execution completed successfully: taskID=%s", taskID) + if err := w.sendTaskCompleted(taskID, "Task completed successfully"); err != nil { + log.Errorf(ctx, "Failed to send task completed message: %v", err) + } } func (w *Worker) sendTaskClaimed(taskID string) error { @@ -505,6 +508,30 @@ func (w *Worker) sendTaskRejected(taskID, reason string) error { return w.sendMessage(msgBytes) } +func (w *Worker) sendTaskCompleted(taskID, message string) error { + completedMsg := types.TaskCompletedMessage{ + TaskID: taskID, + Message: message, + } + + data, err := json.Marshal(completedMsg) + if err != nil { + return fmt.Errorf("failed to marshal task completed message: %w", err) + } + + msg := types.WebSocketMessage{ + Type: types.MessageTypeTaskCompleted, + Data: data, + } + + msgBytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal websocket message: %w", err) + } + + return w.sendMessage(msgBytes) +} + func (w *Worker) sendTaskFailed(taskID, message string) error { failedMsg := types.TaskFailedMessage{ TaskID: taskID, diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index 1ba1c86..cce3f81 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -2,6 +2,8 @@ package worker import ( "context" + "encoding/json" + "errors" "testing" "github.com/warpdotdev/oz-agent-worker/internal/types" @@ -21,6 +23,101 @@ func (b *shutdownRecordingBackend) Shutdown(ctx context.Context) { b.shutdownCtxErr = ctx.Err() } +type recordingBackend struct { + err error +} + +func (b *recordingBackend) ExecuteTask(context.Context, *TaskParams) error { + return b.err +} + +func (b *recordingBackend) Shutdown(context.Context) {} + +func TestExecuteTaskReportsTaskCompletedOnSuccess(t *testing.T) { + w := &Worker{ + ctx: context.Background(), + config: Config{}, + sendChan: make(chan []byte, 1), + activeTasks: map[string]context.CancelFunc{"task-1": func() {}}, + backend: &recordingBackend{}, + } + + w.executeTask(context.Background(), &types.TaskAssignmentMessage{ + TaskID: "task-1", + Task: &types.Task{ID: "task-1", Title: "test task"}, + }) + + msg := readWebSocketMessage(t, w.sendChan) + if msg.Type != types.MessageTypeTaskCompleted { + t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskCompleted) + } + + var completed types.TaskCompletedMessage + if err := json.Unmarshal(msg.Data, &completed); err != nil { + t.Fatalf("failed to unmarshal task completed message: %v", err) + } + if completed.TaskID != "task-1" { + t.Errorf("task ID = %q, want %q", completed.TaskID, "task-1") + } + if completed.Message != "Task completed successfully" { + t.Errorf("message = %q, want %q", completed.Message, "Task completed successfully") + } + if _, ok := w.activeTasks["task-1"]; ok { + t.Fatal("task should be removed from active tasks") + } +} + +func TestExecuteTaskReportsTaskFailedOnBackendError(t *testing.T) { + w := &Worker{ + ctx: context.Background(), + config: Config{}, + sendChan: make(chan []byte, 1), + activeTasks: map[string]context.CancelFunc{"task-1": func() {}}, + backend: &recordingBackend{err: errors.New("boom")}, + } + + w.executeTask(context.Background(), &types.TaskAssignmentMessage{ + TaskID: "task-1", + Task: &types.Task{ID: "task-1", Title: "test task"}, + }) + + msg := readWebSocketMessage(t, w.sendChan) + if msg.Type != types.MessageTypeTaskFailed { + t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskFailed) + } + + var failed types.TaskFailedMessage + if err := json.Unmarshal(msg.Data, &failed); err != nil { + t.Fatalf("failed to unmarshal task failed message: %v", err) + } + if failed.TaskID != "task-1" { + t.Errorf("task ID = %q, want %q", failed.TaskID, "task-1") + } + if failed.Message != "Failed to execute task: boom" { + t.Errorf("message = %q, want %q", failed.Message, "Failed to execute task: boom") + } + if _, ok := w.activeTasks["task-1"]; ok { + t.Fatal("task should be removed from active tasks") + } +} + +func readWebSocketMessage(t *testing.T, messages <-chan []byte) types.WebSocketMessage { + t.Helper() + + select { + case msgBytes := <-messages: + var msg types.WebSocketMessage + if err := json.Unmarshal(msgBytes, &msg); err != nil { + t.Fatalf("failed to unmarshal websocket message: %v", err) + } + return msg + default: + t.Fatal("expected websocket message") + } + + return types.WebSocketMessage{} +} + func TestDefaultImageForTask(t *testing.T) { newWorker := func(defaultImage string) *Worker { ctx := context.Background()