Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions internal/types/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type MessageType string
const (
MessageTypeTaskAssignment MessageType = "task_assignment"
MessageTypeTaskClaimed MessageType = "task_claimed"
MessageTypeTaskCompleted MessageType = "task_completed"
MessageTypeTaskFailed MessageType = "task_failed"
MessageTypeTaskRejected MessageType = "task_rejected"
MessageTypeHeartbeat MessageType = "heartbeat"
Expand Down Expand Up @@ -50,6 +51,12 @@ type TaskClaimedMessage struct {
WorkerID string `json:"worker_id"`
}

// TaskCompletedMessage tells the server to end the active run execution after a successful agent process exit.
type TaskCompletedMessage struct {
TaskID string `json:"task_id"`
Message string `json:"message"`
}

// TaskFailedMessage is sent from worker to server if task launch fails
type TaskFailedMessage struct {
TaskID string `json:"task_id"`
Expand Down
27 changes: 27 additions & 0 deletions internal/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ func (w *Worker) executeTask(ctx context.Context, assignment *types.TaskAssignme
}

log.Infof(ctx, "Task execution completed successfully: taskID=%s", taskID)
if err := w.sendTaskCompleted(taskID, "Task completed successfully"); err != nil {
log.Errorf(ctx, "Failed to send task completed message: %v", err)
}
}

func (w *Worker) sendTaskClaimed(taskID string) error {
Expand Down Expand Up @@ -505,6 +508,30 @@ func (w *Worker) sendTaskRejected(taskID, reason string) error {
return w.sendMessage(msgBytes)
}

func (w *Worker) sendTaskCompleted(taskID, message string) error {
completedMsg := types.TaskCompletedMessage{
TaskID: taskID,
Message: message,
}

data, err := json.Marshal(completedMsg)
if err != nil {
return fmt.Errorf("failed to marshal task completed message: %w", err)
}

msg := types.WebSocketMessage{
Type: types.MessageTypeTaskCompleted,
Data: data,
}

msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal websocket message: %w", err)
}

return w.sendMessage(msgBytes)
}

func (w *Worker) sendTaskFailed(taskID, message string) error {
failedMsg := types.TaskFailedMessage{
TaskID: taskID,
Expand Down
97 changes: 97 additions & 0 deletions internal/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package worker

import (
"context"
"encoding/json"
"errors"
"testing"

"github.com/warpdotdev/oz-agent-worker/internal/types"
Expand All @@ -21,6 +23,101 @@ func (b *shutdownRecordingBackend) Shutdown(ctx context.Context) {
b.shutdownCtxErr = ctx.Err()
}

type recordingBackend struct {
err error
}

func (b *recordingBackend) ExecuteTask(context.Context, *TaskParams) error {
return b.err
}

func (b *recordingBackend) Shutdown(context.Context) {}

func TestExecuteTaskReportsTaskCompletedOnSuccess(t *testing.T) {
w := &Worker{
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 1),
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
backend: &recordingBackend{},
}

w.executeTask(context.Background(), &types.TaskAssignmentMessage{
TaskID: "task-1",
Task: &types.Task{ID: "task-1", Title: "test task"},
})

msg := readWebSocketMessage(t, w.sendChan)
if msg.Type != types.MessageTypeTaskCompleted {
t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskCompleted)
}

var completed types.TaskCompletedMessage
if err := json.Unmarshal(msg.Data, &completed); err != nil {
t.Fatalf("failed to unmarshal task completed message: %v", err)
}
if completed.TaskID != "task-1" {
t.Errorf("task ID = %q, want %q", completed.TaskID, "task-1")
}
if completed.Message != "Task completed successfully" {
t.Errorf("message = %q, want %q", completed.Message, "Task completed successfully")
}
if _, ok := w.activeTasks["task-1"]; ok {
t.Fatal("task should be removed from active tasks")
}
}

func TestExecuteTaskReportsTaskFailedOnBackendError(t *testing.T) {
w := &Worker{
ctx: context.Background(),
config: Config{},
sendChan: make(chan []byte, 1),
activeTasks: map[string]context.CancelFunc{"task-1": func() {}},
backend: &recordingBackend{err: errors.New("boom")},
}

w.executeTask(context.Background(), &types.TaskAssignmentMessage{
TaskID: "task-1",
Task: &types.Task{ID: "task-1", Title: "test task"},
})

msg := readWebSocketMessage(t, w.sendChan)
if msg.Type != types.MessageTypeTaskFailed {
t.Fatalf("message type = %q, want %q", msg.Type, types.MessageTypeTaskFailed)
}

var failed types.TaskFailedMessage
if err := json.Unmarshal(msg.Data, &failed); err != nil {
t.Fatalf("failed to unmarshal task failed message: %v", err)
}
if failed.TaskID != "task-1" {
t.Errorf("task ID = %q, want %q", failed.TaskID, "task-1")
}
if failed.Message != "Failed to execute task: boom" {
t.Errorf("message = %q, want %q", failed.Message, "Failed to execute task: boom")
}
if _, ok := w.activeTasks["task-1"]; ok {
t.Fatal("task should be removed from active tasks")
}
}

func readWebSocketMessage(t *testing.T, messages <-chan []byte) types.WebSocketMessage {
t.Helper()

select {
case msgBytes := <-messages:
var msg types.WebSocketMessage
if err := json.Unmarshal(msgBytes, &msg); err != nil {
t.Fatalf("failed to unmarshal websocket message: %v", err)
}
return msg
default:
t.Fatal("expected websocket message")
}

return types.WebSocketMessage{}
}

func TestDefaultImageForTask(t *testing.T) {
newWorker := func(defaultImage string) *Worker {
ctx := context.Background()
Expand Down
Loading