From b29a4d901201dfca4ab596ba72bd84e0a055fe16 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Mon, 20 Apr 2026 13:15:10 +0000 Subject: [PATCH] test(coverage): raise patch coverage for gateway runtime migration Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/app/bootstrap_test.go | 23 + internal/cli/gateway_runtime_bridge_test.go | 90 ++++ internal/runtime/create_session_test.go | 111 ++++ .../session/sqlite_store_additional_test.go | 70 +++ .../gateway_rpc_client_additional_test.go | 490 ++++++++++++++++++ .../gateway_stream_client_additional_test.go | 447 ++++++++++++++++ .../remote_runtime_adapter_additional_test.go | 440 ++++++++++++++++ 7 files changed, 1671 insertions(+) create mode 100644 internal/tui/services/gateway_rpc_client_additional_test.go create mode 100644 internal/tui/services/gateway_stream_client_additional_test.go create mode 100644 internal/tui/services/remote_runtime_adapter_additional_test.go diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index a152e906..732be446 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -1418,6 +1418,29 @@ func TestResolveBootstrapRuntimeMode(t *testing.T) { } } +func TestBuildRuntimeRejectsInvalidRuntimeMode(t *testing.T) { + t.Parallel() + + _, err := BuildRuntime(context.Background(), BootstrapOptions{RuntimeMode: "invalid"}) + if err == nil { + t.Fatalf("expected invalid runtime mode error") + } +} + +func TestDefaultNewRemoteRuntimeAdapterReturnsInitError(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + _, err := defaultNewRemoteRuntimeAdapter(services.RemoteRuntimeAdapterOptions{ + ListenAddress: "ipc://127.0.0.1", + TokenFile: home + "/missing-token.json", + }) + if err == nil { + t.Fatalf("expected defaultNewRemoteRuntimeAdapter to fail when token is missing") + } +} + func TestBuildRuntimeGatewayModeUsesRemoteAdapter(t *testing.T) { disableBuiltinProviderAPIKeys(t) diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 53517bcd..2cc572d2 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -3,6 +3,7 @@ package cli import ( "context" "errors" + "os" "testing" "time" @@ -100,6 +101,50 @@ func (s *runtimeStub) ListSessionSkills(context.Context, string) ([]agentruntime return nil, nil } +type runtimeWithoutCreator struct { + base *runtimeStub +} + +func (r *runtimeWithoutCreator) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + return r.base.Submit(ctx, input) +} +func (r *runtimeWithoutCreator) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return r.base.PrepareUserInput(ctx, input) +} +func (r *runtimeWithoutCreator) Run(ctx context.Context, input agentruntime.UserInput) error { + return r.base.Run(ctx, input) +} +func (r *runtimeWithoutCreator) Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) { + return r.base.Compact(ctx, input) +} +func (r *runtimeWithoutCreator) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return r.base.ExecuteSystemTool(ctx, input) +} +func (r *runtimeWithoutCreator) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { + return r.base.ResolvePermission(ctx, input) +} +func (r *runtimeWithoutCreator) CancelActiveRun() bool { + return r.base.CancelActiveRun() +} +func (r *runtimeWithoutCreator) Events() <-chan agentruntime.RuntimeEvent { + return r.base.Events() +} +func (r *runtimeWithoutCreator) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { + return r.base.ListSessions(ctx) +} +func (r *runtimeWithoutCreator) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + return r.base.LoadSession(ctx, id) +} +func (r *runtimeWithoutCreator) ActivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.ActivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCreator) DeactivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.DeactivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCreator) ListSessionSkills(ctx context.Context, sessionID string) ([]agentruntime.SessionSkillState, error) { + return r.base.ListSessionSkills(ctx, sessionID) +} + func TestNewGatewayRuntimePortBridgeRuntimeUnavailable(t *testing.T) { bridge, err := newGatewayRuntimePortBridge(context.Background(), nil) if err == nil { @@ -302,6 +347,51 @@ func TestGatewayRuntimePortBridgeRuntimeMethods(t *testing.T) { } } +func TestGatewayRuntimePortBridgeLoadSessionNotFoundBranches(t *testing.T) { + t.Parallel() + + base := &runtimeStub{ + loadErr: errors.New("session not found"), + } + bridgeWithoutCreator, err := newGatewayRuntimePortBridge(context.Background(), &runtimeWithoutCreator{base: base}) + if err != nil { + t.Fatalf("new bridge without creator: %v", err) + } + t.Cleanup(func() { _ = bridgeWithoutCreator.Close() }) + + if _, err := bridgeWithoutCreator.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: "s-1", + }); !errors.Is(err, gateway.ErrRuntimeResourceNotFound) { + t.Fatalf("expected ErrRuntimeResourceNotFound, got %v", err) + } + + stub := &runtimeStub{ + loadErr: errors.New("file does not exist"), + createErr: errors.New("create failed"), + } + bridgeWithCreator, err := newGatewayRuntimePortBridge(context.Background(), stub) + if err != nil { + t.Fatalf("new bridge with creator: %v", err) + } + t.Cleanup(func() { _ = bridgeWithCreator.Close() }) + + if _, err := bridgeWithCreator.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: "s-2", + }); err == nil || err.Error() != "create failed" { + t.Fatalf("expected create failed error, got %v", err) + } +} + +func TestIsRuntimeNotFoundErrorIncludesOSErrNotExist(t *testing.T) { + t.Parallel() + + if !isRuntimeNotFoundError(os.ErrNotExist) { + t.Fatalf("os.ErrNotExist should be treated as runtime not found") + } +} + func TestGatewayRuntimePortBridgeRuntimeMethodErrors(t *testing.T) { stub := &runtimeStub{ submitErr: errors.New("submit failed"), diff --git a/internal/runtime/create_session_test.go b/internal/runtime/create_session_test.go index f1778a7d..a4d6ef21 100644 --- a/internal/runtime/create_session_test.go +++ b/internal/runtime/create_session_test.go @@ -85,3 +85,114 @@ func TestServiceCreateSessionReturnsOriginalErrorWhenMissingErrorIsNotSentinel(t t.Fatalf("CreateSession() should not create on non-sentinel error, saves=%d", store.memoryStore.saves) } } + +type createSessionDuplicateStore struct { + *createSessionUpsertStore + createErr error + loadHits int + loaded agentsession.Session + loadErr error +} + +func (s *createSessionDuplicateStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + s.loadHits++ + if s.loadHits == 1 { + return agentsession.Session{}, s.missingErr + } + if s.loadErr != nil { + return agentsession.Session{}, s.loadErr + } + return s.loaded, nil +} + +func (s *createSessionDuplicateStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + if s.createErr != nil { + return agentsession.Session{}, s.createErr + } + return s.memoryStore.CreateSession(ctx, input) +} + +func TestServiceCreateSessionBranches(t *testing.T) { + t.Parallel() + + store := &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := service.CreateSession(ctx, "session-canceled"); err == nil { + t.Fatalf("CreateSession() should reject canceled context") + } + if _, err := service.CreateSession(context.Background(), " "); err == nil { + t.Fatalf("CreateSession() should reject empty session id") + } +} + +func TestServiceCreateSessionReturnsWorkdirResolutionError(t *testing.T) { + t.Parallel() + + service := &Service{ + sessionStore: newMemoryStore(), + // 不注入 configManager 会使默认 workdir 为空,触发 resolveWorkdirForSession 错误路径。 + } + if _, err := service.CreateSession(context.Background(), "session-workdir"); err == nil { + t.Fatalf("CreateSession() should fail when default workdir cannot be resolved") + } +} + +func TestServiceCreateSessionDuplicateCreateFallsBackToLoad(t *testing.T) { + t.Parallel() + + store := &createSessionDuplicateStore{ + createSessionUpsertStore: &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), + }, + createErr: fmt.Errorf("unique constraint failed"), + loaded: agentsession.Session{ID: "session-dup", Title: "loaded"}, + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + loaded, err := service.CreateSession(context.Background(), "session-dup") + if err != nil { + t.Fatalf("CreateSession() duplicate fallback error = %v", err) + } + if loaded.ID != "session-dup" || loaded.Title != "loaded" { + t.Fatalf("CreateSession() loaded session = %#v", loaded) + } +} + +func TestCreateSessionErrorPredicates(t *testing.T) { + t.Parallel() + + if isRuntimeSessionNotFoundError(nil) { + t.Fatalf("isRuntimeSessionNotFoundError(nil) should be false") + } + if !isRuntimeSessionNotFoundError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionNotFound)) { + t.Fatalf("wrapped ErrSessionNotFound should be detected") + } + + if isRuntimeSessionAlreadyExistsError(nil) { + t.Fatalf("isRuntimeSessionAlreadyExistsError(nil) should be false") + } + for _, text := range []string{"already exists", "UNIQUE CONSTRAINT", "duplicate key"} { + if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) { + t.Fatalf("expected %q to be treated as already exists", text) + } + } +} diff --git a/internal/session/sqlite_store_additional_test.go b/internal/session/sqlite_store_additional_test.go index f86710dd..05f08a37 100644 --- a/internal/session/sqlite_store_additional_test.go +++ b/internal/session/sqlite_store_additional_test.go @@ -228,6 +228,76 @@ func TestNormalizeCreateSessionInputDefaultsGeneratedID(t *testing.T) { } } +func TestSQLiteStoreCreateSessionPropagatesEnsureStorageDirsError(t *testing.T) { + t.Parallel() + + store := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + _, err := store.CreateSession(context.Background(), CreateSessionInput{ID: "s1", Title: "title"}) + if err == nil { + t.Fatalf("expected CreateSession() to fail when db dir cannot be created") + } +} + +func TestSQLiteStoreEnsureStorageDirsErrorBranches(t *testing.T) { + t.Parallel() + + dbDirErrStore := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + if err := dbDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create db dir") { + t.Fatalf("expected create db dir error, got %v", err) + } + + projectDirErrStore := &SQLiteStore{ + projectDir: filepath.Join("/dev/null", "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + if err := projectDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create project dir") { + t.Fatalf("expected create project dir error, got %v", err) + } + + assetsDirErrStore := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join("/dev/null", "assets"), + dbPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + if err := assetsDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create assets dir") { + t.Fatalf("expected create assets dir error, got %v", err) + } +} + +func TestSQLiteStoreInitializePropagatesStorageDirError(t *testing.T) { + t.Parallel() + + store := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + if err := store.initialize(context.Background()); err == nil { + t.Fatalf("expected initialize() to fail when storage dirs are invalid") + } +} + +func TestWrapSessionNotFoundWithNilCause(t *testing.T) { + t.Parallel() + + err := wrapSessionNotFound(nil) + if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound, got %v", err) + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) + } +} + func TestResolveUpdatedAtReturnsProvidedValue(t *testing.T) { t.Parallel() diff --git a/internal/tui/services/gateway_rpc_client_additional_test.go b/internal/tui/services/gateway_rpc_client_additional_test.go new file mode 100644 index 00000000..c0520fa0 --- /dev/null +++ b/internal/tui/services/gateway_rpc_client_additional_test.go @@ -0,0 +1,490 @@ +package services + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "neo-code/internal/gateway/protocol" +) + +type stubConn struct { + writeErr error + closed bool + mu sync.Mutex +} + +func (s *stubConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (s *stubConn) Write(p []byte) (int, error) { + if s.writeErr != nil { + return 0, s.writeErr + } + return len(p), nil +} +func (s *stubConn) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} +func (s *stubConn) LocalAddr() net.Addr { return &net.UnixAddr{} } +func (s *stubConn) RemoteAddr() net.Addr { return &net.UnixAddr{} } +func (s *stubConn) SetDeadline(_ time.Time) error { return nil } +func (s *stubConn) SetReadDeadline(_ time.Time) error { return nil } +func (s *stubConn) SetWriteDeadline(_ time.Time) error { return nil } + +func TestGatewayRPCErrorAndTransportErrorFormatting(t *testing.T) { + t.Parallel() + + var rpcErr *GatewayRPCError + if rpcErr.Error() != "" { + t.Fatalf("nil GatewayRPCError should render empty string") + } + + errWithCode := (&GatewayRPCError{Method: "gateway.run", GatewayCode: "timeout", Message: "boom"}).Error() + if !strings.Contains(errWithCode, "timeout") { + t.Fatalf("GatewayRPCError with code = %q", errWithCode) + } + + var transportErr *gatewayRPCTransportError + if transportErr.Error() != "" || transportErr.Unwrap() != nil { + t.Fatalf("nil transport error should render empty and unwrap nil") + } + + methodErr := &gatewayRPCTransportError{Method: "gateway.run", Err: errors.New("down")} + if !strings.Contains(methodErr.Error(), "gateway.run") { + t.Fatalf("unexpected transport error text: %q", methodErr.Error()) + } + noMethodErr := (&gatewayRPCTransportError{Err: errors.New("down")}).Error() + if !strings.Contains(noMethodErr, "transport error") { + t.Fatalf("unexpected no-method transport error text: %q", noMethodErr) + } + if !errors.Is(methodErr, methodErr.Err) { + t.Fatalf("transport error should unwrap original cause") + } +} + +func TestGatewayRPCClientHelperFunctions(t *testing.T) { + t.Parallel() + + mapped := mapGatewayRPCError("gateway.ping", nil) + typed, ok := mapped.(*GatewayRPCError) + if !ok || typed.GatewayCode != protocol.GatewayCodeInternalError { + t.Fatalf("mapGatewayRPCError(nil) = %#v", mapped) + } + + emptyMessage := mapGatewayRPCError("gateway.ping", &protocol.JSONRPCError{Code: protocol.JSONRPCCodeInternalError}) + if !strings.Contains(emptyMessage.Error(), "empty rpc error message") { + t.Fatalf("empty message fallback missing: %v", emptyMessage) + } + + if normalizeJSONRPCResponseID(json.RawMessage(`"id-1"`)) != "id-1" { + t.Fatalf("normalize string id mismatch") + } + if normalizeJSONRPCResponseID(json.RawMessage(` 7 `)) != "7" { + t.Fatalf("normalize numeric id mismatch") + } + if normalizeJSONRPCResponseID(json.RawMessage(`null`)) != "" { + t.Fatalf("normalize null id mismatch") + } + if decodeRawJSONString(json.RawMessage(`" m "`)) != "m" { + t.Fatalf("decodeRawJSONString mismatch") + } + if decodeRawJSONString(json.RawMessage(`{`)) != "" { + t.Fatalf("decodeRawJSONString invalid payload should return empty") + } + + raw, err := marshalJSONRawMessage(map[string]any{"ok": true}) + if err != nil || len(raw) == 0 { + t.Fatalf("marshalJSONRawMessage() = (%q, %v)", raw, err) + } + if _, err := marshalJSONRawMessage(func() {}); err == nil { + t.Fatalf("expected marshalJSONRawMessage() error for function input") + } + + if cloneJSONRawMessage(nil) != nil { + t.Fatalf("clone nil should return nil") + } + origin := json.RawMessage(`{"k":"v"}`) + cloned := cloneJSONRawMessage(origin) + origin[0] = '[' + if string(cloned) != `{"k":"v"}` { + t.Fatalf("cloneJSONRawMessage should deep clone, got %q", string(cloned)) + } + + if !isRetryableGatewayCallError(context.DeadlineExceeded) { + t.Fatalf("deadline exceeded should be retryable") + } + if isRetryableGatewayCallError(context.Canceled) { + t.Fatalf("context canceled should not be retryable") + } + if !isRetryableGatewayCallError(&gatewayRPCTransportError{Err: errors.New("x")}) { + t.Fatalf("transport error should be retryable") + } + if !isRetryableGatewayCallError(&GatewayRPCError{GatewayCode: protocol.GatewayCodeTimeout}) { + t.Fatalf("gateway timeout should be retryable") + } + if isRetryableGatewayCallError(errors.New("plain")) { + t.Fatalf("plain error should not be retryable") + } + if isRetryableGatewayCallError(nil) { + t.Fatalf("nil error should not be retryable") + } + + if _, err := decodeGatewayRPCResponse(map[string]json.RawMessage{"id": json.RawMessage(`bad`)}); err == nil { + t.Fatalf("expected decodeGatewayRPCResponse marshal error") + } + if _, err := decodeGatewayRPCResponse(map[string]json.RawMessage{"id": json.RawMessage(`"id"`), "result": json.RawMessage(`{`)}); err == nil { + t.Fatalf("expected decodeGatewayRPCResponse unmarshal error") + } +} + +func TestGatewayRPCClientPendingAndForceCloseBranches(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: map[string]chan gatewayRPCResponse{}, + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + + ch := make(chan gatewayRPCResponse, 1) + if ok := client.registerPending("req-1", ch); !ok { + t.Fatalf("registerPending should succeed") + } + client.dispatchResponse(gatewayRPCResponse{ID: "req-1", Result: json.RawMessage(`{"ok":true}`)}) + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("dispatchResponse did not deliver response") + } + + client.dispatchResponse(gatewayRPCResponse{ID: ""}) + client.dispatchResponse(gatewayRPCResponse{ID: "missing"}) + client.unregisterPending("missing") + + pendingCh := make(chan gatewayRPCResponse, 1) + client.pending["req-2"] = pendingCh + if err := client.forceCloseWithError(nil); err != nil { + t.Fatalf("forceCloseWithError() error = %v", err) + } + select { + case response := <-pendingCh: + if response.TransportErr == nil { + t.Fatalf("expected transport error to be forwarded") + } + case <-time.After(time.Second): + t.Fatalf("pending response channel not notified") + } + + close(client.closed) + if ok := client.registerPending("req-3", make(chan gatewayRPCResponse, 1)); ok { + t.Fatalf("registerPending should fail after client closed") + } + client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}) + + resetConn := &stubConn{} + client.conn = resetConn + client.resetConnection() + if !resetConn.closed { + t.Fatalf("resetConnection should close active connection") + } +} + +func TestLoadGatewayAuthTokenErrorBranches(t *testing.T) { + t.Parallel() + + missingPath := filepath.Join(t.TempDir(), "missing-token.json") + if _, err := loadGatewayAuthToken(missingPath); err == nil { + t.Fatalf("expected load token error for missing file") + } + + path := filepath.Join(t.TempDir(), "auth.json") + err := os.WriteFile(path, []byte(`{"version":1,"token":" ","created_at":"2026-04-20T00:00:00Z","updated_at":"2026-04-20T00:00:00Z"}`), 0o600) + if err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := loadGatewayAuthToken(path); err == nil || !strings.Contains(err.Error(), "auth token is empty") { + if err == nil || !strings.Contains(err.Error(), "token is empty") { + t.Fatalf("expected empty auth token error, got %v", err) + } + } +} + +func TestGatewayRPCClientCallWithClosedClientAndInvalidResult(t *testing.T) { + t.Parallel() + + tokenFile, _ := createTestAuthTokenFile(t) + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + dec := json.NewDecoder(serverConn) + enc := json.NewEncoder(serverConn) + req := readRPCRequestOrFail(t, dec) + response := protocol.JSONRPCResponse{JSONRPC: protocol.JSONRPCVersion, ID: req.ID, Result: json.RawMessage(`1`)} + if encodeErr := enc.Encode(response); encodeErr != nil { + t.Errorf("encode response: %v", encodeErr) + } + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + var out map[string]any + callErr := client.CallWithOptions(context.Background(), protocol.MethodGatewayPing, map[string]any{}, &out, GatewayRPCCallOptions{Timeout: time.Second}) + if callErr == nil || !strings.Contains(callErr.Error(), "decode") { + t.Fatalf("expected decode error, got %v", callErr) + } + + _ = client.Close() + if err := client.CallWithOptions(context.Background(), protocol.MethodGatewayPing, nil, nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected closed client call error") + } +} + +func TestNewGatewayRPCClientConstructorBranches(t *testing.T) { + t.Parallel() + + tokenFile, _ := createTestAuthTokenFile(t) + _, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: tokenFile, + ResolveListenAddress: func(string) (string, error) { + return "", errors.New("resolve failed") + }, + }) + if err == nil || !strings.Contains(err.Error(), "resolve listen address") { + t.Fatalf("expected resolve listen address error, got %v", err) + } + + _, err = NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: filepath.Join(t.TempDir(), "missing.json"), + ResolveListenAddress: func(string) (string, error) { + return "ipc://x", nil + }, + }) + if err == nil || !strings.Contains(err.Error(), "load auth token") { + t.Fatalf("expected load auth token error, got %v", err) + } + + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: tokenFile, + ResolveListenAddress: func(string) (string, error) { + return "ipc://x", nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + if client.requestTimeout != defaultGatewayRPCRequestTimeout || client.retryCount != defaultGatewayRPCRetryCount || client.dialFn == nil { + t.Fatalf("default options not applied: timeout=%v retry=%d dialFnNil=%v", client.requestTimeout, client.retryCount, client.dialFn == nil) + } + _ = client.Close() +} + +func TestGatewayRPCClientCallOnceBranches(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + listenAddress: "x", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := client.callOnce(ctx, "m", nil, nil, time.Second); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled error, got %v", err) + } + + client.dialFn = func(string) (net.Conn, error) { return nil, errors.New("dial failed") } + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err == nil || !strings.Contains(err.Error(), "transport") { + t.Fatalf("expected dial transport error, got %v", err) + } + + conn := &stubConn{} + client.conn = conn + if err := client.callOnce(context.Background(), "m", func() {}, nil, time.Second); err == nil || !strings.Contains(err.Error(), "encode request params") { + t.Fatalf("expected params encode error, got %v", err) + } + + close(client.closed) + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err == nil || !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected closed client error, got %v", err) + } +} + +func TestGatewayRPCClientCallOnceResponseBranches(t *testing.T) { + t.Parallel() + + newClient := func() *GatewayRPCClient { + return &GatewayRPCClient{ + listenAddress: "x", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + conn: &stubConn{}, + } + } + + t.Run("transport error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", TransportErr: errors.New("broken")}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "transport") { + t.Fatalf("expected transport response error, got %v", err) + } + }) + + t.Run("rpc error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", RPCError: &protocol.JSONRPCError{Code: -32000, Message: "bad"}}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "gateway rpc m failed") { + t.Fatalf("expected rpc mapped error, got %v", err) + } + }) + + t.Run("result nil", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1"}) + }() + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err != nil { + t.Fatalf("expected nil result path, got %v", err) + } + }) + + t.Run("empty result", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1"}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "result is empty") { + t.Fatalf("expected empty result error, got %v", err) + } + }) + + t.Run("decode error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", Result: json.RawMessage(`1`)}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "decode m response") { + t.Fatalf("expected decode error, got %v", err) + } + }) +} + +func TestGatewayRPCClientReadLoopAdditionalBranches(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + client.startNotificationDispatcher() + go client.readLoop(clientConn) + + encoder := json.NewEncoder(serverConn) + _ = encoder.Encode(map[string]any{"method": " "}) + _ = encoder.Encode(map[string]any{"id": json.RawMessage(`\"x\"`), "result": json.RawMessage(`{`)}) + _ = encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"x": 1}}) + _ = serverConn.Close() + + select { + case <-client.notifications: + case <-time.After(2 * time.Second): + t.Fatalf("expected one forwarded notification") + } + + _ = client.Close() +} + +func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + client.startNotificationDispatcher() + close(client.notificationQueue) + client.notificationWG.Wait() +} + +func TestGatewayRPCClientWriteRequestFailure(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + conn := &stubConn{writeErr: errors.New("write failed")} + err := client.writeRequest(conn, protocol.JSONRPCRequest{JSONRPC: protocol.JSONRPCVersion, ID: json.RawMessage(`\"id\"`), Method: "m"}) + if err == nil || !strings.Contains(err.Error(), "write rpc request failed") { + t.Fatalf("expected write request error, got %v", err) + } +} + +func TestGatewayRPCClientDecodeResponseSuccessAndRetryableNetError(t *testing.T) { + t.Parallel() + + response, err := decodeGatewayRPCResponse(map[string]json.RawMessage{ + "id": json.RawMessage(`"id"`), + "result": json.RawMessage(`{"ok":true}`), + }) + if err != nil || !bytes.Contains(response.Result, []byte(`ok`)) { + t.Fatalf("decodeGatewayRPCResponse success mismatch: (%#v, %v)", response, err) + } + + netErr := &net.DNSError{IsTimeout: true} + if !isRetryableGatewayCallError(netErr) { + t.Fatalf("net timeout error should be retryable") + } +} diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go new file mode 100644 index 00000000..2d32ce1b --- /dev/null +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -0,0 +1,447 @@ +package services + +import ( + "encoding/json" + "reflect" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" + "neo-code/internal/runtime/controlplane" +) + +type streamInvalidJSONMarshaler struct { + raw []byte +} + +func (m streamInvalidJSONMarshaler) MarshalJSON() ([]byte, error) { + return m.raw, nil +} + +func TestDecodeRuntimeEventFromGatewayNotificationErrorBranches(t *testing.T) { + t.Parallel() + + if _, err := decodeRuntimeEventFromGatewayNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}); err == nil { + t.Fatalf("expected empty params error") + } + + if _, err := decodeRuntimeEventFromGatewayNotification(gatewayRPCNotification{ + Method: protocol.MethodGatewayEvent, + Params: json.RawMessage(`{"payload":{}}`), + }); err == nil { + t.Fatalf("expected missing envelope error") + } + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "payload": map[string]any{}, + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(notification); err == nil { + t.Fatalf("expected missing runtime_event_type error") + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationUsesCurrentTimeWhenTimestampMissing(t *testing.T) { + t.Parallel() + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventError), + "payload": "boom", + }, + }) + + before := time.Now().UTC().Add(-time.Second) + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + if event.Timestamp.Before(before) { + t.Fatalf("event timestamp should fallback to now, got %v", event.Timestamp) + } +} + +func TestExtractRuntimeEnvelopeFallbackMarshalling(t *testing.T) { + t.Parallel() + + type payloadEnvelope struct { + Payload map[string]any `json:"payload"` + } + envelope, ok := extractRuntimeEnvelope(payloadEnvelope{Payload: map[string]any{ + "RuntimeEventType": string(agentruntime.EventError), + "payload": "x", + }}) + if !ok { + t.Fatalf("expected envelope to be detected") + } + if got := streamReadMapString(envelope, "runtime_event_type"); got != string(agentruntime.EventError) { + t.Fatalf("runtime_event_type = %q", got) + } + + if _, ok := extractRuntimeEnvelope(nil); ok { + t.Fatalf("nil payload should not decode") + } +} + +func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + eventType agentruntime.EventType + payload any + assertFn func(t *testing.T, got any) + }{ + { + name: "user message", + eventType: agentruntime.EventUserMessage, + payload: map[string]any{"Role": string(providertypes.RoleAssistant)}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if _, ok := got.(providertypes.Message); !ok { + t.Fatalf("payload type = %T", got) + } + }, + }, + { + name: "permission request", + eventType: agentruntime.EventPermissionRequested, + payload: map[string]any{"RequestID": "req-1"}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if v, ok := got.(agentruntime.PermissionRequestPayload); !ok || v.RequestID != "req-1" { + t.Fatalf("payload = %#v", got) + } + }, + }, + { + name: "stop reason", + eventType: agentruntime.EventStopReasonDecided, + payload: map[string]any{"reason": " max_rounds "}, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(agentruntime.StopReasonDecidedPayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.Reason != controlplane.StopReason("max_rounds") { + t.Fatalf("reason = %q", value.Reason) + } + }, + }, + { + name: "runtime usage payload", + eventType: agentruntime.EventType(RuntimeEventUsage), + payload: map[string]any{"delta": map[string]any{"inputtokens": 1}}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if _, ok := got.(RuntimeUsagePayload); !ok { + t.Fatalf("payload type = %T", got) + } + }, + }, + { + name: "string payload", + eventType: agentruntime.EventToolChunk, + payload: 42, + assertFn: func(t *testing.T, got any) { + t.Helper() + if got != "42" { + t.Fatalf("payload = %#v", got) + } + }, + }, + { + name: "default passthrough", + eventType: agentruntime.EventType("unknown"), + payload: map[string]any{"k": "v"}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if !reflect.DeepEqual(got, map[string]any{"k": "v"}) { + t.Fatalf("payload = %#v", got) + } + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + got, err := restoreRuntimePayload(tc.eventType, tc.payload) + if err != nil { + t.Fatalf("restoreRuntimePayload() error = %v", err) + } + tc.assertFn(t, got) + }) + } +} + +func TestDecodeRuntimePayloadAndMapHelpers(t *testing.T) { + t.Parallel() + + typed, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](agentruntime.InputNormalizedPayload{TextLength: 1}) + if err != nil || typed.TextLength != 1 { + t.Fatalf("typed decode mismatch, got (%#v, %v)", typed, err) + } + + ptrValue := &agentruntime.InputNormalizedPayload{ImageCount: 3} + decodedPtr, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](ptrValue) + if err != nil || decodedPtr.ImageCount != 3 { + t.Fatalf("pointer decode mismatch, got (%#v, %v)", decodedPtr, err) + } + + var nilPtr *agentruntime.InputNormalizedPayload + if _, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](nilPtr); err == nil { + t.Fatalf("expected nil pointer decode error") + } + if _, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](nil); err == nil { + t.Fatalf("expected nil payload decode error") + } + + m := map[string]any{ + "runtimeEventType": "agent_chunk", + "turn": json.Number("7"), + "payloadVersion": "5", + "time_stamp": "2026-04-20T12:00:00Z", + } + if value, ok := streamReadMapValue(m, "runtime_event_type"); !ok || value != "agent_chunk" { + t.Fatalf("streamReadMapValue mismatch: (%v, %v)", value, ok) + } + if got := streamReadMapInt(m, "turn"); got != 7 { + t.Fatalf("streamReadMapInt(turn) = %d", got) + } + if got := streamReadMapInt(m, "payload_version"); got != 5 { + t.Fatalf("streamReadMapInt(payload_version) = %d", got) + } + if got := streamReadMapString(m, "runtime_event_type"); got != "agent_chunk" { + t.Fatalf("streamReadMapString = %q", got) + } + if got := streamReadMapTime(m, "time_stamp"); got.IsZero() { + t.Fatalf("streamReadMapTime should parse timestamp") + } + + if normalizeMapLookupKey(" Runtime_Event-Type ") != "runtimeeventtype" { + t.Fatalf("normalizeMapLookupKey mismatch") + } + if toSnakeCase("RuntimeEventType") != "runtime_event_type" { + t.Fatalf("toSnakeCase mismatch") + } + if toLowerCamelCase(" RuntimeEventType ") != "runtimeEventType" { + t.Fatalf("toLowerCamelCase mismatch") + } +} + +func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification, 4) + client := NewGatewayStreamClient(source) + + source <- gatewayRPCNotification{Method: "gateway.ping", Params: json.RawMessage(`{"foo":"bar"}`)} + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventAgentChunk), + "payload": "ok", + }, + }) + + select { + case event := <-client.Events(): + if event.Type != agentruntime.EventAgentChunk { + t.Fatalf("event.Type = %q", event.Type) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for stream event") + } + + if err := client.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if err := client.Close(); err != nil { + t.Fatalf("Close() second call error = %v", err) + } +} + +func TestGatewayStreamClientRunStopsWhenSourceClosed(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification) + client := NewGatewayStreamClient(source) + close(source) + + select { + case _, ok := <-client.Events(): + if ok { + t.Fatalf("events channel should be closed after source channel is closed") + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for events channel close") + } +} + +func TestRestoreRuntimePayloadAdditionalBranches(t *testing.T) { + t.Parallel() + + payloadCases := []struct { + eventType agentruntime.EventType + payload any + }{ + {eventType: agentruntime.EventAgentDone, payload: map[string]any{"Role": string(providertypes.RoleAssistant)}}, + {eventType: agentruntime.EventToolStart, payload: map[string]any{"Name": "bash"}}, + {eventType: agentruntime.EventPermissionResolved, payload: map[string]any{"RequestID": "req-1"}}, + {eventType: agentruntime.EventCompactApplied, payload: map[string]any{"Applied": true}}, + {eventType: agentruntime.EventCompactError, payload: map[string]any{"message": "boom"}}, + {eventType: agentruntime.EventPhaseChanged, payload: map[string]any{"from": "a", "to": "b"}}, + {eventType: agentruntime.EventInputNormalized, payload: map[string]any{"text_length": 3}}, + {eventType: agentruntime.EventAssetSaved, payload: map[string]any{"asset_id": "asset-1"}}, + {eventType: agentruntime.EventAssetSaveFailed, payload: map[string]any{"message": "x"}}, + {eventType: agentruntime.EventTodoUpdated, payload: map[string]any{"action": "replace"}}, + {eventType: agentruntime.EventTodoConflict, payload: map[string]any{"action": "conflict"}}, + {eventType: agentruntime.EventType(RuntimeEventRunContext), payload: map[string]any{"provider": "openai"}}, + {eventType: agentruntime.EventType(RuntimeEventToolStatus), payload: map[string]any{"status": "running"}}, + } + + for _, tc := range payloadCases { + if _, err := restoreRuntimePayload(tc.eventType, tc.payload); err != nil { + t.Fatalf("restoreRuntimePayload(%q) error = %v", tc.eventType, err) + } + } + + if _, err := restoreRuntimePayload(agentruntime.EventStopReasonDecided, map[string]any{"reason": func() {}}); err == nil { + t.Fatalf("stop reason payload should return decode error for non-serializable field") + } +} + +func TestStreamHelperBranches(t *testing.T) { + t.Parallel() + + if decodeStringPayload(nil) != "" { + t.Fatalf("decodeStringPayload(nil) should return empty string") + } + if decodeStringPayload("x") != "x" { + t.Fatalf("decodeStringPayload(string) mismatch") + } + + if _, err := decodeRuntimePayload[agentruntime.PhaseChangedPayload](func() {}); err == nil { + t.Fatalf("decodeRuntimePayload should fail on marshal error") + } + if _, err := decodeRuntimePayload[agentruntime.PhaseChangedPayload](map[string]any{"from": map[string]any{"bad": make(chan int)}}); err == nil { + t.Fatalf("decodeRuntimePayload should fail on invalid nested payload") + } + + if value, ok := streamReadMapValue(map[string]any{"RUNTIMEEVENTTYPE": "v"}, "runtime_event_type"); !ok || value != "v" { + t.Fatalf("streamReadMapValue normalized scan mismatch") + } + if _, ok := streamReadMapValue(nil, "key"); ok { + t.Fatalf("nil map lookup should fail") + } + if _, ok := streamReadMapValue(map[string]any{"k": 1}, " "); ok { + t.Fatalf("blank key lookup should fail") + } + + intCases := map[string]any{ + "i": 1, + "i64": int64(2), + "i32": int32(3), + "f64": float64(4), + "f32": float32(5), + "num": json.Number("6"), + "badnum": json.Number("x"), + "str": "7", + "badstr": "x", + } + if streamReadMapInt(intCases, "i") != 1 || + streamReadMapInt(intCases, "i64") != 2 || + streamReadMapInt(intCases, "i32") != 3 || + streamReadMapInt(intCases, "f64") != 4 || + streamReadMapInt(intCases, "f32") != 5 || + streamReadMapInt(intCases, "num") != 6 || + streamReadMapInt(intCases, "str") != 7 { + t.Fatalf("streamReadMapInt numeric coercion mismatch") + } + if streamReadMapInt(intCases, "badnum") != 0 || streamReadMapInt(intCases, "badstr") != 0 || streamReadMapInt(intCases, "missing") != 0 { + t.Fatalf("streamReadMapInt invalid values should return zero") + } + + now := time.Now().UTC() + timeCases := map[string]any{ + "as_time": now, + "as_text": now.Format(time.RFC3339Nano), + "invalid": "not-time", + "blanktxt": " ", + } + if !streamReadMapTime(timeCases, "as_time").Equal(now) { + t.Fatalf("streamReadMapTime(time.Time) mismatch") + } + if streamReadMapTime(timeCases, "as_text").IsZero() { + t.Fatalf("streamReadMapTime(string) should parse") + } + if !streamReadMapTime(timeCases, "invalid").IsZero() || + !streamReadMapTime(timeCases, "blanktxt").IsZero() || + !streamReadMapTime(timeCases, "missing").IsZero() { + t.Fatalf("streamReadMapTime invalid values should return zero time") + } + + if toSnakeCase("") != "" || toLowerCamelCase("") != "" { + t.Fatalf("empty case conversion should return empty string") + } +} + +func TestGatewayStreamDecodeAndEnvelopeExtraBranches(t *testing.T) { + t.Parallel() + + missingTypeNotification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": " ", + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(missingTypeNotification); err == nil { + t.Fatalf("expected missing runtime_event_type error") + } + + invalidPayloadNotification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventToolResult), + "payload": "not-an-object", + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(invalidPayloadNotification); err == nil { + t.Fatalf("expected restore payload decode error") + } + + if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("{")}); ok { + t.Fatalf("expected marshal error path to fail envelope extraction") + } + if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("[]")}); ok { + t.Fatalf("expected unmarshal-to-map error path to fail envelope extraction") + } + if envelope, ok := extractRuntimeEnvelope(struct { + RuntimeEventType string `json:"runtime_event_type"` + }{RuntimeEventType: string(agentruntime.EventError)}); !ok || streamReadMapString(envelope, "runtime_event_type") == "" { + t.Fatalf("expected runtime_event_type detection after marshal/unmarshal") + } + + if got := streamReadMapString(map[string]any{"v": 123}, "v"); got != "123" { + t.Fatalf("streamReadMapString default conversion mismatch: %q", got) + } + if got := streamReadMapInt(map[string]any{"v": true}, "v"); got != 0 { + t.Fatalf("streamReadMapInt unsupported type should return 0, got %d", got) + } + if got := streamReadMapTime(map[string]any{"v": 1}, "v"); !got.IsZero() { + t.Fatalf("streamReadMapTime unsupported type should return zero, got %v", got) + } +} diff --git a/internal/tui/services/remote_runtime_adapter_additional_test.go b/internal/tui/services/remote_runtime_adapter_additional_test.go new file mode 100644 index 00000000..bfef2773 --- /dev/null +++ b/internal/tui/services/remote_runtime_adapter_additional_test.go @@ -0,0 +1,440 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "net" + "strings" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" +) + +func TestNewRemoteRuntimeAdapterBranches(t *testing.T) { + t.Parallel() + + originalRPCFactory := newGatewayRPCClientFactory + originalStreamFactory := newGatewayStreamClientFactory + t.Cleanup(func() { + newGatewayRPCClientFactory = originalRPCFactory + newGatewayStreamClientFactory = originalStreamFactory + }) + + newGatewayRPCClientFactory = func(options GatewayRPCClientOptions) (*GatewayRPCClient, error) { + if strings.TrimSpace(options.ListenAddress) == "error" { + return nil, errors.New("build rpc failed") + } + client := &GatewayRPCClient{ + listenAddress: options.ListenAddress, + token: "token", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + client.dialFn = func(_ string) (net.Conn, error) { + if options.ListenAddress == "dial-failed" { + return nil, errors.New("dial failed") + } + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + request := readRPCRequestOrFail(t, decoder) + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionAuthenticate, + }) + }() + return clientConn, nil + } + return client, nil + } + newGatewayStreamClientFactory = func(source <-chan gatewayRPCNotification) *GatewayStreamClient { + return NewGatewayStreamClient(source) + } + + if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "error"}); err == nil { + t.Fatalf("expected rpc factory error") + } + if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "dial-failed", RequestTimeout: -1}); err == nil { + t.Fatalf("expected authenticate fail-fast error") + } + + adapter, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ + ListenAddress: "ok", + RequestTimeout: -1, + }) + if err != nil { + t.Fatalf("NewRemoteRuntimeAdapter() error = %v", err) + } + if adapter.timeout != defaultRemoteRuntimeTimeout { + t.Fatalf("timeout = %v, want %v", adapter.timeout, defaultRemoteRuntimeTimeout) + } + _ = adapter.Close() +} + +func TestRemoteRuntimeAdapterPrepareUserInputAndRun(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionLoadSession, SessionID: "s-1"}, + protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream, SessionID: "s-1", RunID: "r-1"}, + protocol.MethodGatewayRun: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionRun, SessionID: "s-1", RunID: "r-1"}, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := adapter.PrepareUserInput(ctx, agentruntime.PrepareInput{}); err == nil { + t.Fatalf("expected context cancellation error") + } + + input, err := adapter.PrepareUserInput(context.Background(), agentruntime.PrepareInput{ + SessionID: " ", + RunID: "", + Text: " hello ", + Images: []agentruntime.UserImageInput{ + {Path: " "}, + {Path: " /tmp/a.png ", MimeType: " image/png "}, + }, + Workdir: " /repo ", + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if strings.TrimSpace(input.SessionID) == "" || strings.TrimSpace(input.RunID) == "" { + t.Fatalf("session/run id should be generated") + } + if len(input.Parts) != 2 { + t.Fatalf("parts len = %d, want 2", len(input.Parts)) + } + + if err := adapter.Run(context.Background(), input); err != nil { + t.Fatalf("Run() error = %v", err) + } + methods := rpcClient.snapshotMethods() + if len(methods) != 3 || methods[2] != protocol.MethodGatewayRun { + t.Fatalf("unexpected method chain: %#v", methods) + } +} + +func TestRemoteRuntimeAdapterCompactResolvePermissionAndListSessions(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream}, + protocol.MethodGatewayCompact: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionCompact, + Payload: gateway.CompactResult{ + Applied: true, + BeforeChars: 100, + AfterChars: 40, + TriggerMode: "auto", + TranscriptID: "t-1", + }, + }, + protocol.MethodGatewayResolvePermission: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionResolvePermission}, + protocol.MethodGatewayListSessions: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionListSessions, + Payload: map[string]any{ + "sessions": []gateway.SessionSummary{{ID: " s1 ", Title: " t1 "}}, + }, + }, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 2) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{}); err == nil { + t.Fatalf("expected compact empty session id error") + } + + compactResult, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s1", RunID: "r1"}) + if err != nil { + t.Fatalf("Compact() error = %v", err) + } + if !compactResult.Applied || compactResult.BeforeChars != 100 || compactResult.AfterChars != 40 { + t.Fatalf("compact result mismatch: %#v", compactResult) + } + + if err := adapter.ResolvePermission(context.Background(), agentruntime.PermissionResolutionInput{RequestID: " req ", Decision: "APPROVE"}); err != nil { + t.Fatalf("ResolvePermission() error = %v", err) + } + + summaries, err := adapter.ListSessions(context.Background()) + if err != nil { + t.Fatalf("ListSessions() error = %v", err) + } + if len(summaries) != 1 || summaries[0].ID != "s1" || summaries[0].Title != "t1" { + t.Fatalf("summaries mismatch: %#v", summaries) + } +} + +func TestRemoteRuntimeAdapterUnsupportedSkillMethods(t *testing.T) { + t.Parallel() + + adapter := newRemoteRuntimeAdapterWithClients( + &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)}, + &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + if err := adapter.ActivateSessionSkill(context.Background(), "s", "skill"); err == nil { + t.Fatalf("ActivateSessionSkill should be unsupported") + } + if err := adapter.DeactivateSessionSkill(context.Background(), "s", "skill"); err == nil { + t.Fatalf("DeactivateSessionSkill should be unsupported") + } + if _, err := adapter.ListSessionSkills(context.Background(), "s"); err == nil { + t.Fatalf("ListSessionSkills should be unsupported") + } +} + +func TestRemoteRuntimeAdapterCallFrameAndDecodeHelpers(t *testing.T) { + t.Parallel() + + adapter := newRemoteRuntimeAdapterWithClients(nil, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.callFrame(context.Background(), protocol.MethodGatewayPing, nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected nil rpc client error") + } + if err := adapter.authenticate(context.Background()); err == nil { + t.Fatalf("authenticate should fail on nil rpc client") + } + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + "error-without-payload": {Type: gateway.FrameTypeError}, + "error-with-payload": { + Type: gateway.FrameTypeError, + Error: &gateway.FrameError{Code: "bad", Message: "oops"}, + }, + "unexpected-type": {Type: gateway.FrameTypeEvent}, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter.rpcClient = rpcClient + + if _, err := adapter.callFrame(context.Background(), "error-without-payload", nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected error frame without payload") + } + if _, err := adapter.callFrame(context.Background(), "error-with-payload", nil, GatewayRPCCallOptions{}); err == nil || !strings.Contains(err.Error(), "bad") { + t.Fatalf("expected frame error mapping, got %v", err) + } + if _, err := adapter.callFrame(context.Background(), "unexpected-type", nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected unexpected frame type error") + } + + if err := decodeIntoValue(map[string]any{"a": 1}, nil); err == nil { + t.Fatalf("decodeIntoValue should reject nil target") + } + if err := decodeIntoValue(func() {}, &map[string]any{}); err == nil { + t.Fatalf("decodeIntoValue should fail on marshal error") + } + if err := decodeIntoValue(map[string]any{"value": "x"}, &[]int{}); err == nil { + t.Fatalf("decodeIntoValue should fail on unmarshal mismatch") + } + + decoded, err := decodeFramePayload[gateway.CompactResult](map[string]any{"applied": true}) + if err != nil || !decoded.Applied { + t.Fatalf("decodeFramePayload() = (%#v, %v)", decoded, err) + } +} + +func TestRemoteRuntimeAdapterEventObservationAndActiveRunState(t *testing.T) { + t.Parallel() + + eventCh := make(chan agentruntime.RuntimeEvent, 3) + streamClient := &stubRemoteStreamClient{events: eventCh} + adapter := newRemoteRuntimeAdapterWithClients( + &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)}, + streamClient, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + eventCh <- agentruntime.RuntimeEvent{Type: agentruntime.EventAgentChunk, RunID: "run-a", SessionID: "session-a"} + eventCh <- agentruntime.RuntimeEvent{Type: agentruntime.EventAgentDone, RunID: "run-a", SessionID: "session-a"} + close(eventCh) + + for i := 0; i < 2; i++ { + select { + case <-adapter.Events(): + case <-time.After(time.Second): + t.Fatalf("timed out waiting for forwarded event") + } + } + + runID, sessionID := adapter.activeRun() + if runID != "" || sessionID != "session-a" { + t.Fatalf("active run state mismatch: run=%q session=%q", runID, sessionID) + } + + adapter.setActiveRun(" run-b ", " session-b ") + adapter.clearActiveRun("other-run") + runID, _ = adapter.activeRun() + if runID != "run-b" { + t.Fatalf("clearActiveRun should keep different run, got %q", runID) + } + adapter.clearActiveRun("run-b") + runID, _ = adapter.activeRun() + if runID != "" { + t.Fatalf("expected cleared run id, got %q", runID) + } +} + +func TestRemoteRuntimeAdapterLoadSessionAndCancelErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + callErrs: map[string]error{protocol.MethodGatewayCancel: errors.New("cancel failed")}, + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionLoadSession, Payload: func() {}}, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.LoadSession(context.Background(), " "); err == nil { + t.Fatalf("expected empty id validation error") + } + if _, err := adapter.LoadSession(context.Background(), "session-1"); err == nil { + t.Fatalf("expected payload decode error") + } + + adapter.setActiveRun("run-1", "session-1") + if !adapter.CancelActiveRun() { + t.Fatalf("expected cancel attempt for active run") + } +} + +func TestRemoteRuntimeAdapterSubmitAndCompactErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + callErrs: map[string]error{ + protocol.MethodGatewayBindStream: errors.New("bind failed"), + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if err := adapter.Submit(context.Background(), agentruntime.PrepareInput{}); err == nil || !strings.Contains(err.Error(), "bind failed") { + t.Fatalf("expected bind failed submit error, got %v", err) + } + params := rpcClient.snapshotParams()[protocol.MethodGatewayLoadSession].(protocol.LoadSessionParams) + if strings.TrimSpace(params.SessionID) == "" { + t.Fatalf("Submit() should generate default session id") + } + + rpcClient.authErr = errors.New("auth failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "auth failed") { + t.Fatalf("expected compact auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs[protocol.MethodGatewayBindStream] = errors.New("bind compact failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "bind compact failed") { + t.Fatalf("expected compact bind error, got %v", err) + } + rpcClient.callErrs[protocol.MethodGatewayBindStream] = nil + rpcClient.callErrs[protocol.MethodGatewayCompact] = errors.New("compact failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "compact failed") { + t.Fatalf("expected compact rpc error, got %v", err) + } +} + +func TestRemoteRuntimeAdapterListAndLoadSessionErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + rpcClient.authErr = errors.New("auth failed") + if _, err := adapter.ListSessions(context.Background()); err == nil || !strings.Contains(err.Error(), "auth failed") { + t.Fatalf("expected list auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs = map[string]error{protocol.MethodGatewayListSessions: errors.New("list failed")} + if _, err := adapter.ListSessions(context.Background()); err == nil || !strings.Contains(err.Error(), "list failed") { + t.Fatalf("expected list rpc error, got %v", err) + } + rpcClient.callErrs = nil + rpcClient.frames = map[string]gateway.MessageFrame{ + protocol.MethodGatewayListSessions: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionListSessions, Payload: func() {}}, + } + if _, err := adapter.ListSessions(context.Background()); err == nil { + t.Fatalf("expected list decode error") + } + + rpcClient.authErr = errors.New("auth load failed") + if _, err := adapter.LoadSession(context.Background(), "s-1"); err == nil || !strings.Contains(err.Error(), "auth load failed") { + t.Fatalf("expected load auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs = map[string]error{protocol.MethodGatewayLoadSession: errors.New("load failed")} + if _, err := adapter.LoadSession(context.Background(), "s-1"); err == nil || !strings.Contains(err.Error(), "load failed") { + t.Fatalf("expected load rpc error, got %v", err) + } +} + +func TestRemoteRuntimeAdapterRenderInputHelpers(t *testing.T) { + t.Parallel() + + text := renderInputTextFromParts([]providertypes.ContentPart{ + providertypes.NewTextPart(" first "), + providertypes.NewRemoteImagePart("/tmp/a.png"), + providertypes.NewTextPart("second"), + providertypes.NewTextPart(" "), + }) + if text != "first\nsecond" { + t.Fatalf("renderInputTextFromParts() = %q", text) + } + + images := renderInputImagesFromParts([]providertypes.ContentPart{ + providertypes.NewTextPart("x"), + providertypes.NewRemoteImagePart(" "), + providertypes.ContentPart{ + Kind: providertypes.ContentPartImage, + Image: &providertypes.ImagePart{ + URL: " /tmp/b.png ", + Asset: &providertypes.AssetRef{MimeType: " image/png "}, + }, + }, + }) + if len(images) != 1 || images[0].Path != "/tmp/b.png" || images[0].MimeType != "image/png" { + t.Fatalf("renderInputImagesFromParts() = %#v", images) + } + + params := buildGatewayRunParams(" s ", " r ", agentruntime.PrepareInput{Text: " hi ", Workdir: " /w ", Images: []agentruntime.UserImageInput{{Path: " /img.png ", MimeType: " image/png "}, {Path: " "}}}) + if params.SessionID != "s" || params.RunID != "r" || params.Workdir != "/w" || params.InputText != "hi" || len(params.InputParts) != 1 { + t.Fatalf("buildGatewayRunParams() = %#v", params) + } +}