From 2b8f77ae7961071f3e8b9300299c309c0d4e471c Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 17 Apr 2026 01:14:00 +0000 Subject: [PATCH] test: improve coverage on gateway/config/subagent targets Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/config/gateway_test.go | 15 ++ .../adapters/urlscheme/dispatcher_test.go | 26 ++ internal/gateway/auth/manager_test.go | 10 + internal/gateway/bootstrap_test.go | 89 +++++++ internal/gateway/network_server_test.go | 82 ++++++ internal/gateway/protocol/jsonrpc_test.go | 26 ++ internal/gateway/request_logging_test.go | 19 ++ internal/gateway/rpc_dispatch_test.go | 34 +++ internal/gateway/stream_relay_test.go | 46 ++++ internal/subagent/scheduler_test.go | 249 ++++++++++++++++++ 10 files changed, 596 insertions(+) diff --git a/internal/config/gateway_test.go b/internal/config/gateway_test.go index 9697900a..0a446e5f 100644 --- a/internal/config/gateway_test.go +++ b/internal/config/gateway_test.go @@ -419,3 +419,18 @@ func TestGatewayApplyDefaultsNilReceivers(t *testing.T) { var observabilityCfg *GatewayObservabilityConfig observabilityCfg.ApplyDefaults(GatewayObservabilityConfig{MetricsEnabled: boolPtr(false)}) } + +func TestLoadGatewayConfigReadFileError(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + configPath := filepath.Join(baseDir, configName) + if err := os.Mkdir(configPath, 0o755); err != nil { + t.Fatalf("mkdir config path: %v", err) + } + + _, err := LoadGatewayConfig(context.Background(), baseDir) + if err == nil || !strings.Contains(err.Error(), "read gateway config file") { + t.Fatalf("expected read gateway config file error, got %v", err) + } +} diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index 1e4a47c8..4d300fe1 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -824,6 +824,32 @@ func TestDispatcherAuthenticateBranches(t *testing.T) { t.Fatalf("expected unexpected auth frame error, got %v", err) } }) + + t.Run("auth version mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + requestIDFn: func() string { return "wake-auth-4" }, + } + conn := &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"1.0","id":"wake-auth-4-auth","result":{}}` + "\n"), + } + err := dispatcher.authenticate(context.Background(), conn, "token-1") + if err == nil || !strings.Contains(err.Error(), "jsonrpc version") { + t.Fatalf("expected auth version mismatch error, got %v", err) + } + }) + + t.Run("auth id mismatch", func(t *testing.T) { + dispatcher := &Dispatcher{ + requestIDFn: func() string { return "wake-auth-5" }, + } + conn := &stubDispatchConn{ + readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"other-auth","result":{}}` + "\n"), + } + err := dispatcher.authenticate(context.Background(), conn, "token-1") + if err == nil || !strings.Contains(err.Error(), "id mismatch") { + t.Fatalf("expected auth id mismatch error, got %v", err) + } + }) } func TestDispatcherDispatchWithAuthHandshake(t *testing.T) { diff --git a/internal/gateway/auth/manager_test.go b/internal/gateway/auth/manager_test.go index dfc1443e..057e00b3 100644 --- a/internal/gateway/auth/manager_test.go +++ b/internal/gateway/auth/manager_test.go @@ -224,3 +224,13 @@ func TestDefaultAuthPathAndLoadOrCreateNilManager(t *testing.T) { t.Fatal("expected nil manager loadOrCreate error") } } + +func TestManagerGenerateTokenAndWriteCredentialsErrorBranches(t *testing.T) { + t.Run("write credentials fails on directory path", func(t *testing.T) { + dir := t.TempDir() + err := writeCredentials(dir, Credentials{Version: 1, Token: "token"}) + if err == nil || !strings.Contains(err.Error(), "write auth file") { + t.Fatalf("expected write auth file error, got %v", err) + } + }) +} diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 427ddb0b..d465381e 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -327,3 +327,92 @@ func TestHandleAuthenticateFrameBranches(t *testing.T) { } }) } + +type invalidJSONMarshaler struct{} + +func (invalidJSONMarshaler) MarshalJSON() ([]byte, error) { + return []byte("{"), nil +} + +func TestBootstrapDecodeAdditionalBranches(t *testing.T) { + t.Run("decode authenticate pointer nil", func(t *testing.T) { + _, frameErr := decodeAuthenticatePayload((*protocol.AuthenticateParams)(nil)) + if frameErr == nil || frameErr.Code != ErrorCodeMissingRequiredField.String() { + t.Fatalf("frameErr = %#v, want missing_required_field", frameErr) + } + }) + + t.Run("decode authenticate unmarshal error", func(t *testing.T) { + _, frameErr := decodeAuthenticatePayload(invalidJSONMarshaler{}) + if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("frameErr = %#v, want invalid_frame", frameErr) + } + }) + + t.Run("decode bind stream unmarshal error", func(t *testing.T) { + _, frameErr := decodeBindStreamParams(invalidJSONMarshaler{}) + if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("frameErr = %#v, want invalid_frame", frameErr) + } + }) + + t.Run("decode bind stream pointer nil", func(t *testing.T) { + _, frameErr := decodeBindStreamParams((*protocol.BindStreamParams)(nil)) + if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("frameErr = %#v, want invalid_frame", frameErr) + } + }) + + t.Run("decode bind stream map missing session", func(t *testing.T) { + _, frameErr := decodeBindStreamParams(map[string]any{"run_id": "r-1"}) + if frameErr == nil || frameErr.Code != ErrorCodeMissingRequiredField.String() { + t.Fatalf("frameErr = %#v, want missing_required_field", frameErr) + } + }) + + t.Run("decode bind stream default struct", func(t *testing.T) { + payload := struct { + SessionID string `json:"session_id"` + RunID string `json:"run_id"` + Channel string `json:"channel"` + }{ + SessionID: "s-1", + RunID: "r-1", + Channel: "ipc", + } + params, frameErr := decodeBindStreamParams(payload) + if frameErr != nil { + t.Fatalf("frameErr = %v", frameErr) + } + if params.SessionID != "s-1" || params.Channel != StreamChannelIPC { + t.Fatalf("params = %#v", params) + } + }) + + t.Run("decode wake pointer nil", func(t *testing.T) { + _, err := decodeWakeIntent((*protocol.WakeIntent)(nil)) + if err == nil { + t.Fatal("expected nil pointer wake payload error") + } + }) + + t.Run("decode wake unmarshal error", func(t *testing.T) { + _, err := decodeWakeIntent(invalidJSONMarshaler{}) + if err == nil { + t.Fatal("expected wake decode error") + } + }) + + t.Run("decode wake direct struct", func(t *testing.T) { + intent, err := decodeWakeIntent(protocol.WakeIntent{ + Action: "review", + Params: map[string]string{"path": "README.md"}, + }) + if err != nil { + t.Fatalf("decode wake intent: %v", err) + } + if intent.Action != "review" { + t.Fatalf("action = %q, want review", intent.Action) + } + }) +} diff --git a/internal/gateway/network_server_test.go b/internal/gateway/network_server_test.go index d5205665..ee997d99 100644 --- a/internal/gateway/network_server_test.go +++ b/internal/gateway/network_server_test.go @@ -668,6 +668,88 @@ func TestNetworkServerVersionAndObservabilityAuthHelpers(t *testing.T) { }) } +func TestNetworkServerObservabilityHandlerBranches(t *testing.T) { + t.Run("prometheus handler branches", func(t *testing.T) { + server := &NetworkServer{} + + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodPost, "/metrics", nil) + server.handlePrometheusMetrics(recorder, request) + if recorder.Code != http.StatusMethodNotAllowed { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusMethodNotAllowed) + } + + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodGet, "/metrics", nil) + server.handlePrometheusMetrics(recorder, request) + if recorder.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusServiceUnavailable) + } + + server.metrics = NewGatewayMetrics() + server.authenticator = staticTokenAuthenticator{token: "token-1"} + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodGet, "/metrics", nil) + server.handlePrometheusMetrics(recorder, request) + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + + recorder = httptest.NewRecorder() + request = httptest.NewRequest(http.MethodGet, "/metrics", nil) + request.Header.Set("Authorization", "Bearer token-1") + server.handlePrometheusMetrics(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK) + } + }) + + t.Run("json metrics unauthorized", func(t *testing.T) { + server := &NetworkServer{ + metrics: NewGatewayMetrics(), + authenticator: staticTokenAuthenticator{token: "token-2"}, + allowedOrigins: defaultControlPlaneOrigins(), + } + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/metrics.json", nil) + server.handleJSONMetrics(recorder, request) + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + }) +} + +func TestNetworkServerOriginAndTokenHelpersAdditional(t *testing.T) { + normalized := normalizeControlPlaneOrigins([]string{" HTTP://LOCALHOST ", "", " app://desktop "}) + if len(normalized) != 2 { + t.Fatalf("normalized len = %d, want 2", len(normalized)) + } + if normalized[0] != "http://localhost" || normalized[1] != "app://desktop" { + t.Fatalf("normalized = %#v", normalized) + } + + if originMatchesAllowRule("http://localhost:3000", "") { + t.Fatal("blank allow rule should reject") + } + if !originMatchesAllowRule("http://[::1]:8080", "http://[::1]") { + t.Fatal("ipv6 localhost with port should match") + } + + server := &NetworkServer{allowedOrigins: []string{"http://localhost"}} + if err := server.validateWebSocketOrigin(nil); err == nil { + t.Fatal("nil request should be rejected") + } + request := httptest.NewRequest(http.MethodGet, "/ws", nil) + request.Header.Set("Origin", "http://evil.example") + if err := server.validateWebSocketOrigin(request); err == nil { + t.Fatal("disallowed origin should be rejected") + } + + if token := extractBearerToken("Basic abc"); token != "" { + t.Fatalf("token = %q, want empty", token) + } +} + func TestNetworkServerCloseInterruptsStreams(t *testing.T) { server := newTestNetworkServer(t, NetworkServerOptions{}) testContext, cancel := context.WithCancel(context.Background()) diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 1c1450b5..8a61a2da 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -395,3 +395,29 @@ func TestNewJSONRPCErrorResponseWithNilIDEncodesNull(t *testing.T) { t.Fatalf("encoded response id = %#v, want nil", payload["id"]) } } + +func TestJSONRPCHelpersAdditionalBranches(t *testing.T) { + if got := MapGatewayCodeToJSONRPCCode(GatewayCodeInternalError); got != JSONRPCCodeInternalError { + t.Fatalf("internal mapping = %d, want %d", got, JSONRPCCodeInternalError) + } + if got := MapGatewayCodeToJSONRPCCode("unknown-code"); got != JSONRPCCodeInternalError { + t.Fatalf("default mapping = %d, want %d", got, JSONRPCCodeInternalError) + } + + if _, err := normalizeJSONRPCID(json.RawMessage(`null`)); err == nil { + t.Fatal("expected null id error") + } + if _, err := normalizeJSONRPCID(json.RawMessage(`" "`)); err == nil { + t.Fatal("expected blank string id error") + } + + if _, err := decodeAuthenticateParams(json.RawMessage(`null`)); err == nil { + t.Fatal("expected missing authenticate params error") + } + if _, err := decodeBindStreamParams(json.RawMessage(`null`)); err == nil { + t.Fatal("expected missing bind stream params error") + } + if _, err := decodeWakeIntentParams(json.RawMessage(`null`)); err == nil { + t.Fatal("expected missing wake params error") + } +} diff --git a/internal/gateway/request_logging_test.go b/internal/gateway/request_logging_test.go index 0dfbbccc..0fc40703 100644 --- a/internal/gateway/request_logging_test.go +++ b/internal/gateway/request_logging_test.go @@ -81,3 +81,22 @@ func TestRequestLatencyMS(t *testing.T) { t.Fatal("requestStartTime should not return zero time") } } + +func TestEmitRequestLogUsesContextSource(t *testing.T) { + t.Parallel() + + buffer := &bytes.Buffer{} + logger := log.New(buffer, "", 0) + ctx := WithRequestSource(context.Background(), RequestSourceHTTP) + + emitRequestLog(ctx, logger, RequestLogEntry{ + RequestID: "req-http", + Method: "gateway.ping", + Status: "ok", + }) + + output := buffer.String() + if !strings.Contains(output, `"source":"http"`) { + t.Fatalf("output = %q, want http source", output) + } +} diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 35576cce..4f7d6b7f 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -352,3 +352,37 @@ func TestDispatchRPCRequestMetricsBranches(t *testing.T) { t.Fatalf("expected ok request metric, snapshot=%#v", snapshot["gateway_requests_total"]) } } + +func TestDispatchRPCRequestFrameErrorWithoutPayload(t *testing.T) { + originalHandlers := requestFrameHandlers + requestFrameHandlers = map[FrameAction]requestFrameHandler{ + FrameActionPing: func(_ context.Context, frame MessageFrame) MessageFrame { + return MessageFrame{Type: FrameTypeError, Action: frame.Action, RequestID: frame.RequestID} + }, + } + t.Cleanup(func() { requestFrameHandlers = originalHandlers }) + + response := dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"rpc-noerr-1"`), + Method: protocol.MethodGatewayPing, + Params: json.RawMessage(`{}`), + }, nil) + if response.Error == nil { + t.Fatal("expected rpc error response") + } + if code := protocol.GatewayCodeFromJSONRPCError(response.Error); code != ErrorCodeInternalError.String() { + t.Fatalf("gateway_code = %q, want %q", code, ErrorCodeInternalError.String()) + } +} + +func TestHydrateFrameSessionFromPayloadBranch(t *testing.T) { + frame := hydrateFrameSessionFromConnection(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionPing, + Payload: map[string]any{"session_id": "s-from-payload"}, + }) + if frame.SessionID != "s-from-payload" { + t.Fatalf("session_id = %q, want %q", frame.SessionID, "s-from-payload") + } +} diff --git a/internal/gateway/stream_relay_test.go b/internal/gateway/stream_relay_test.go index 2e980bcf..2b93a66c 100644 --- a/internal/gateway/stream_relay_test.go +++ b/internal/gateway/stream_relay_test.go @@ -485,3 +485,49 @@ func registerConnectionForRelayTest( t.Fatalf("bind connection: %v", bindErr) } } + +func TestStreamRelayAdditionalCoverageBranches(t *testing.T) { + t.Run("snapshot counts skips nil connection entry", func(t *testing.T) { + relay := NewStreamRelay(StreamRelayOptions{}) + relay.mu.Lock() + relay.connections[ConnectionID("nil-entry")] = nil + relay.mu.Unlock() + + snapshot := relay.SnapshotConnectionCounts() + if snapshot[StreamChannelIPC] != 0 || snapshot[StreamChannelWS] != 0 || snapshot[StreamChannelSSE] != 0 { + t.Fatalf("unexpected snapshot: %#v", snapshot) + } + }) + + t.Run("send sync response branches", func(t *testing.T) { + var nilRelay *StreamRelay + if nilRelay.SendJSONRPCResponseSync("cid", protocol.JSONRPCResponse{}) { + t.Fatal("nil relay should return false") + } + + relay := NewStreamRelay(StreamRelayOptions{}) + if relay.SendJSONRPCResponseSync("", protocol.JSONRPCResponse{}) { + t.Fatal("empty connection id should return false") + } + if relay.SendJSONRPCResponseSync("missing", protocol.JSONRPCResponse{}) { + t.Fatal("missing connection should return false") + } + }) + + t.Run("update active connection metrics", func(t *testing.T) { + metrics := NewGatewayMetrics() + relay := NewStreamRelay(StreamRelayOptions{Metrics: metrics}) + relay.mu.Lock() + relay.connections[ConnectionID("ipc")] = &relayConnection{channel: StreamChannelIPC} + relay.connections[ConnectionID("ws")] = &relayConnection{channel: StreamChannelWS} + relay.connections[ConnectionID("sse")] = &relayConnection{channel: StreamChannelSSE} + relay.connections[ConnectionID("nil")] = nil + relay.updateActiveConnectionMetricsLocked() + relay.mu.Unlock() + + entries := metrics.Snapshot()["gateway_connections_active"] + if entries["ipc"] != 1 || entries["ws"] != 1 || entries["sse"] != 1 { + t.Fatalf("unexpected metrics snapshot: %#v", entries) + } + }) +} diff --git a/internal/subagent/scheduler_test.go b/internal/subagent/scheduler_test.go index 82ad411a..01663b1f 100644 --- a/internal/subagent/scheduler_test.go +++ b/internal/subagent/scheduler_test.go @@ -25,6 +25,11 @@ type schedulerStoreWithClaimError struct { claimErrors map[string]error } +type schedulerStoreWithUpdateError struct { + *schedulerStore + updateErrors map[string]error +} + func newSchedulerStore(t *testing.T, items []agentsession.TodoItem) *schedulerStore { t.Helper() session := agentsession.New("scheduler") @@ -93,6 +98,17 @@ func (s *schedulerStoreWithClaimError) ClaimTodo(id string, ownerType string, ow return s.schedulerStore.ClaimTodo(id, ownerType, ownerID, expectedRevision) } +func (s *schedulerStoreWithUpdateError) UpdateTodo(id string, patch agentsession.TodoPatch, expectedRevision int64) error { + s.mu.Lock() + if err, ok := s.updateErrors[id]; ok && err != nil { + delete(s.updateErrors, id) + s.mu.Unlock() + return err + } + s.mu.Unlock() + return s.schedulerStore.UpdateTodo(id, patch, expectedRevision) +} + type scriptedFactory struct { mu sync.Mutex attempts map[string]int @@ -430,6 +446,239 @@ func TestSchedulerRunConcurrencyLimit(t *testing.T) { } } +func TestSchedulerHelperBranchesAdditional(t *testing.T) { + t.Run("new state and poll delay", func(t *testing.T) { + state := newSchedulerState(0) + if cap(state.outcomes) != 1 { + t.Fatalf("outcome buffer = %d, want 1", cap(state.outcomes)) + } + + store := newSchedulerStore(t, []agentsession.TodoItem{{ID: "a", Content: "a"}}) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{PollInterval: 0}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + + now := scheduler.cfg.Clock() + delay := scheduler.nextPollDelay(map[string]agentsession.TodoItem{ + "a": { + ID: "a", + Status: agentsession.TodoStatusBlocked, + NextRetryAt: now.Add(5 * time.Millisecond), + }, + }) + if delay <= 0 { + t.Fatalf("nextPollDelay = %v, want > 0", delay) + } + }) + + t.Run("emit fills timestamp and effective priority", func(t *testing.T) { + store := newSchedulerStore(t, []agentsession.TodoItem{{ID: "a", Content: "a"}}) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + var received SchedulerEvent + scheduler, err := NewScheduler(store, factory, SchedulerConfig{ + Observer: func(event SchedulerEvent) { received = event }, + }) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + + scheduler.emit(SchedulerEvent{Type: SchedulerEventQueued, TaskID: "a"}) + if received.At.IsZero() { + t.Fatal("emit should fill event timestamp") + } + + now := time.Now() + if got := effectivePriority(agentsession.TodoItem{Priority: 1}, time.Time{}, now, time.Second); got != 1 { + t.Fatalf("priority without readySince = %d, want 1", got) + } + }) +} + +func TestSchedulerStateMachineBranches(t *testing.T) { + t.Run("ensure blocked and ready status branches", func(t *testing.T) { + baseStore := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "a", Content: "a", Status: agentsession.TodoStatusPending}, + {ID: "b", Content: "b", Status: agentsession.TodoStatusBlocked, NextRetryAt: time.Now().Add(time.Second)}, + }) + store := &schedulerStoreWithUpdateError{ + schedulerStore: baseStore, + updateErrors: map[string]error{"a": errors.New("update failed")}, + } + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + + if err := scheduler.ensureBlocked(agentsession.TodoItem{ID: "a", Status: agentsession.TodoStatusPending}, "dep"); err == nil { + t.Fatal("expected ensureBlocked update error") + } + item, ok, err := scheduler.ensureReadyStatus(agentsession.TodoItem{ + ID: "b", + Status: agentsession.TodoStatusBlocked, + NextRetryAt: time.Now().Add(time.Second), + }) + if err != nil || ok || item.ID != "b" { + t.Fatalf("ensureReadyStatus blocked future = (%+v, %v, %v)", item, ok, err) + } + }) + + t.Run("apply outcome success and failure branches", func(t *testing.T) { + store := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "ok", Content: "ok", Status: agentsession.TodoStatusInProgress}, + {ID: "fail", Content: "fail", Status: agentsession.TodoStatusInProgress, RetryLimit: 0}, + }) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{MaxRetries: 0}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + + state := newSchedulerState(1) + summary := &ScheduleResult{Retried: map[string]int{}} + if err := scheduler.applyOutcome(taskOutcome{ + id: "ok", + attempt: 1, + result: Result{State: StateSucceeded, Output: Output{Artifacts: []string{"ok.artifact"}}}, + }, state, summary); err != nil { + t.Fatalf("applyOutcome success error = %v", err) + } + if len(summary.Succeeded) != 1 || summary.Succeeded[0] != "ok" { + t.Fatalf("summary succeeded = %#v", summary.Succeeded) + } + + if err := scheduler.applyOutcome(taskOutcome{ + id: "fail", + attempt: 1, + err: errors.New("boom"), + }, state, summary); err != nil { + t.Fatalf("applyOutcome failure error = %v", err) + } + item, _ := store.FindTodo("fail") + if item.Status != agentsession.TodoStatusFailed { + t.Fatalf("fail status = %q, want failed", item.Status) + } + }) + + t.Run("cancel running todos updates status", func(t *testing.T) { + store := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "run-1", Content: "run-1", Status: agentsession.TodoStatusInProgress}, + }) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + + state := newSchedulerState(1) + state.running["run-1"] = runningTask{id: "run-1", attempt: 1} + scheduler.cancelRunningTodos(state, errors.New("stop now")) + + item, _ := store.FindTodo("run-1") + if item.Status != agentsession.TodoStatusCanceled { + t.Fatalf("status = %q, want canceled", item.Status) + } + }) +} + +func TestSchedulerRetryAndClaimBranches(t *testing.T) { + t.Run("handle task failure schedules retry", func(t *testing.T) { + store := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "r1", Content: "retry", Status: agentsession.TodoStatusInProgress, RetryLimit: 2}, + }) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return StepOutput{}, errors.New("boom") + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + state := newSchedulerState(1) + summary := &ScheduleResult{Retried: map[string]int{}} + err = scheduler.handleTaskFailure( + agentsession.TodoItem{ID: "r1", Status: agentsession.TodoStatusInProgress, RetryCount: 0, RetryLimit: 2, Revision: 1}, + taskOutcome{id: "r1", attempt: 1, err: errors.New("boom")}, + state, + summary, + ) + if err != nil { + t.Fatalf("handleTaskFailure() error = %v", err) + } + if summary.Retried["r1"] != 1 { + t.Fatalf("retry count = %d, want 1", summary.Retried["r1"]) + } + }) + + t.Run("start ready tasks claim conflict and claim error", func(t *testing.T) { + base := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "c1", Content: "claim-1", Status: agentsession.TodoStatusPending, Revision: 1}, + {ID: "c2", Content: "claim-2", Status: agentsession.TodoStatusPending, Revision: 1}, + }) + base.claimConflicts["c1"] = 1 + store := &schedulerStoreWithClaimError{ + schedulerStore: base, + claimErrors: map[string]error{"c2": errors.New("claim failed")}, + } + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{MaxConcurrency: 1}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + state := newSchedulerState(1) + _, err = scheduler.startReadyTasks(context.Background(), []agentsession.TodoItem{ + {ID: "c1", Content: "claim-1", Status: agentsession.TodoStatusPending, Revision: 1}, + {ID: "c2", Content: "claim-2", Status: agentsession.TodoStatusPending, Revision: 1}, + }, state) + if err == nil || !strings.Contains(err.Error(), "claim todo") { + t.Fatalf("startReadyTasks error = %v, want claim todo", err) + } + }) + + t.Run("collect ready tasks marks unmet dependency blocked", func(t *testing.T) { + store := newSchedulerStore(t, []agentsession.TodoItem{ + {ID: "a", Content: "a", Status: agentsession.TodoStatusPending}, + {ID: "b", Content: "b", Status: agentsession.TodoStatusPending, Dependencies: []string{"a"}}, + }) + factory := newScriptedFactory(func(ctx context.Context, taskID string, attempt int, input StepInput) (StepOutput, error) { + return successStep(taskID), nil + }) + scheduler, err := NewScheduler(store, factory, SchedulerConfig{}) + if err != nil { + t.Fatalf("NewScheduler() error = %v", err) + } + graph, err := buildTaskGraph(store.ListTodos()) + if err != nil { + t.Fatalf("buildTaskGraph() error = %v", err) + } + state := newSchedulerState(1) + ready, err := scheduler.collectReadyTasks(mapTodosByID(store.ListTodos()), graph, state) + if err != nil { + t.Fatalf("collectReadyTasks() error = %v", err) + } + if len(ready) != 1 || ready[0].ID != "a" { + t.Fatalf("ready = %#v, want only a", ready) + } + item, _ := store.FindTodo("b") + if item.Status != agentsession.TodoStatusBlocked { + t.Fatalf("b status = %q, want blocked", item.Status) + } + }) +} + func TestSchedulerRunRetryAndGiveUp(t *testing.T) { t.Parallel()