diff --git a/README.md b/README.md index 834f1819..c381f410 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,19 @@ $env:QINIU_API_KEY = "your_key_here" go run ./cmd/neocode --workdir /path/to/workspace ``` +运行模式切换(默认 `local`): + +```bash +go run ./cmd/neocode --runtime-mode local +go run ./cmd/neocode --runtime-mode gateway +``` + +说明: + +- `--runtime-mode` 仅影响当前进程,不会回写 `config.yaml` +- `gateway` 模式会通过本地 Gateway(优先 IPC)转发 runtime 请求与事件流 +- 若 Gateway 不可达或握手失败会直接报错退出(Fail Fast),不会自动回退到 `local` + ### 4) 首次使用与常用命令 - `/help`:查看命令帮助 - `/provider`:打开 provider 选择器 @@ -127,6 +140,7 @@ go run ./cmd/neocode --workdir /path/to/workspace - API Key 通过环境变量注入,不写入 `config.yaml` - `--workdir` 只影响当前运行,不会回写到配置文件 +- `--runtime-mode` 默认 `local`,用于灰度切换到 `gateway` 模式 详细配置请参考:[docs/guides/configuration.md](docs/guides/configuration.md) diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index dd3d1fb7..3b4526b9 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -240,12 +240,14 @@ $env:GEMINI_API_KEY = "AI..." 不要把这两层职责混在一起理解。 -## CLI Workdir 覆盖 +## CLI 运行参数覆盖 -工作目录不写入 `config.yaml`,只通过启动参数覆盖: +工作目录与运行模式都不写入 `config.yaml`,只通过启动参数覆盖: ```bash go run ./cmd/neocode --workdir /path/to/workspace +go run ./cmd/neocode --runtime-mode local +go run ./cmd/neocode --runtime-mode gateway ``` 说明: @@ -253,6 +255,9 @@ go run ./cmd/neocode --workdir /path/to/workspace - `--workdir` 只影响本次进程 - 不会回写到 `config.yaml` - 工具根目录与 session 隔离都会使用该工作区 +- `--runtime-mode` 默认为 `local`,可切换为 `gateway` +- `gateway` 模式会通过本地 Gateway(优先 IPC)转发 runtime 请求 +- 连接或握手失败会直接退出(Fail Fast),不会自动回退到 `local` ## 常见错误 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 426e3585..9ced4308 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -2,6 +2,7 @@ package app import ( "context" + "errors" "log" "path/filepath" "strings" @@ -29,14 +30,23 @@ import ( "neo-code/internal/tools/todo" "neo-code/internal/tools/webfetch" "neo-code/internal/tui" + "neo-code/internal/tui/services" ) const utf8CodePage = 65001 +const ( + // RuntimeModeLocal 表示继续使用进程内 runtime 直连模式。 + RuntimeModeLocal = "local" + // RuntimeModeGateway 表示通过 Gateway JSON-RPC 转发 runtime 调用。 + RuntimeModeGateway = "gateway" +) + var ( setConsoleOutputCodePage = platformSetConsoleOutputCodePage setConsoleInputCodePage = platformSetConsoleInputCodePage buildToolManagerFunc = buildToolManager + newRemoteRuntimeAdapter = defaultNewRemoteRuntimeAdapter newTUIWithMemo = tui.NewWithMemo cleanupExpiredSessions = func( ctx context.Context, @@ -49,13 +59,19 @@ var ( // BootstrapOptions 描述应用启动时可注入的运行时选项。 type BootstrapOptions struct { - Workdir string + Workdir string + RuntimeMode string } type memoExtractorScheduler interface { ScheduleWithExtractor(sessionID string, messages []providertypes.Message, extractor memo.Extractor) } +type runtimeWithClose interface { + agentruntime.Runtime + Close() error +} + func newMemoExtractorAdapter( factory agentruntime.ProviderFactory, cm *config.Manager, @@ -114,6 +130,11 @@ func EnsureConsoleUTF8() { // BuildRuntime 构建 CLI 与 TUI 共用的运行时依赖。 func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, error) { + runtimeMode, err := resolveBootstrapRuntimeMode(opts.RuntimeMode) + if err != nil { + return RuntimeBundle{}, err + } + defaultCfg, err := bootstrapDefaultConfig(opts) if err != nil { return RuntimeBundle{}, err @@ -210,14 +231,26 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er memo.NewAutoExtractor(nil, memoSvc, time.Duration(cfg.Memo.ExtractTimeoutSec)*time.Second), )) } + + runtimeImpl := agentruntime.Runtime(runtimeSvc) + closeFns := []func() error{toolsCleanup, sessionStore.Close} + if runtimeMode == RuntimeModeGateway { + remoteRuntime, remoteErr := newRemoteRuntimeAdapter(services.RemoteRuntimeAdapterOptions{}) + if remoteErr != nil { + return RuntimeBundle{}, remoteErr + } + runtimeImpl = remoteRuntime + closeFns = append([]func() error{remoteRuntime.Close}, closeFns...) + } + needCleanup = false - closeBundle := combineRuntimeClosers(toolsCleanup, sessionStore.Close) + closeBundle := combineRuntimeClosers(closeFns...) return RuntimeBundle{ Config: cfg, ConfigManager: manager, - Runtime: runtimeSvc, + Runtime: runtimeImpl, ProviderSelection: providerSelection, MemoService: memoSvc, Close: closeBundle, @@ -266,6 +299,20 @@ func resolveBootstrapWorkdir(workdir string) (string, error) { return agentsession.ResolveExistingDir(workdir) } +// resolveBootstrapRuntimeMode 归一化并校验 runtime 运行模式。 +func resolveBootstrapRuntimeMode(mode string) (string, error) { + normalized := strings.ToLower(strings.TrimSpace(mode)) + if normalized == "" { + return RuntimeModeLocal, nil + } + switch normalized { + case RuntimeModeLocal, RuntimeModeGateway: + return normalized, nil + default: + return "", errors.New("bootstrap: runtime mode must be local or gateway") + } +} + func buildToolRegistry(cfg config.Config) (*tools.Registry, func() error, error) { toolRegistry := tools.NewRegistry() toolRegistry.Register(filesystem.New(cfg.Workdir)) @@ -323,6 +370,15 @@ func buildMCPAgentExposureRules(configs []config.MCPAgentExposureConfig) []mcp.A return rules } +// defaultNewRemoteRuntimeAdapter 构建默认的 Gateway runtime 适配器。 +func defaultNewRemoteRuntimeAdapter(options services.RemoteRuntimeAdapterOptions) (runtimeWithClose, error) { + adapter, err := services.NewRemoteRuntimeAdapter(options) + if err != nil { + return nil, err + } + return adapter, nil +} + func buildToolManager(registry *tools.Registry) (tools.Manager, error) { engine, err := security.NewRecommendedPolicyEngine() if err != nil { diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index f8caf4a9..87ed15ef 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -28,6 +28,7 @@ import ( "neo-code/internal/tools" "neo-code/internal/tools/mcp" "neo-code/internal/tui" + "neo-code/internal/tui/services" ) func TestNewProgram(t *testing.T) { @@ -1439,11 +1440,183 @@ func TestNewMemoExtractorAdapterPropagatesFactoryBuildError(t *testing.T) { } } +func TestResolveBootstrapRuntimeMode(t *testing.T) { + mode, err := resolveBootstrapRuntimeMode("") + if err != nil { + t.Fatalf("resolveBootstrapRuntimeMode() error = %v", err) + } + if mode != RuntimeModeLocal { + t.Fatalf("expected default mode %q, got %q", RuntimeModeLocal, mode) + } + + mode, err = resolveBootstrapRuntimeMode(" GATEWAY ") + if err != nil { + t.Fatalf("resolveBootstrapRuntimeMode() error = %v", err) + } + if mode != RuntimeModeGateway { + t.Fatalf("expected gateway mode %q, got %q", RuntimeModeGateway, mode) + } + + _, err = resolveBootstrapRuntimeMode("invalid") + if err == nil { + t.Fatalf("expected invalid runtime mode error") + } +} + +func TestBuildRuntimeRejectsInvalidRuntimeMode(t *testing.T) { + t.Parallel() + + _, err := BuildRuntime(context.Background(), BootstrapOptions{RuntimeMode: "invalid"}) + if err == nil { + t.Fatalf("expected invalid runtime mode error") + } +} + +func TestDefaultNewRemoteRuntimeAdapterReturnsInitError(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + _, err := defaultNewRemoteRuntimeAdapter(services.RemoteRuntimeAdapterOptions{ + ListenAddress: "ipc://127.0.0.1", + TokenFile: home + "/missing-token.json", + }) + if err == nil { + t.Fatalf("expected defaultNewRemoteRuntimeAdapter to fail when token is missing") + } +} + +func TestBuildRuntimeGatewayModeUsesRemoteAdapter(t *testing.T) { + disableBuiltinProviderAPIKeys(t) + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + originalFactory := newRemoteRuntimeAdapter + t.Cleanup(func() { newRemoteRuntimeAdapter = originalFactory }) + + stubRuntime := &stubRemoteRuntimeForBootstrap{ + events: make(chan agentruntime.RuntimeEvent), + } + newRemoteRuntimeAdapter = func(_ services.RemoteRuntimeAdapterOptions) (runtimeWithClose, error) { + return stubRuntime, nil + } + + bundle, err := BuildRuntime(context.Background(), BootstrapOptions{RuntimeMode: RuntimeModeGateway}) + if err != nil { + t.Fatalf("BuildRuntime() error = %v", err) + } + if bundle.Runtime != stubRuntime { + t.Fatalf("expected gateway runtime adapter to be wired") + } + if bundle.Close == nil { + t.Fatalf("expected non-nil close function") + } + if err := bundle.Close(); err != nil { + t.Fatalf("bundle.Close() error = %v", err) + } + if !stubRuntime.closed { + t.Fatalf("expected remote runtime close to be called") + } +} + +func TestBuildRuntimeGatewayModeFailsFastWhenAdapterInitFails(t *testing.T) { + disableBuiltinProviderAPIKeys(t) + + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + originalFactory := newRemoteRuntimeAdapter + t.Cleanup(func() { newRemoteRuntimeAdapter = originalFactory }) + + newRemoteRuntimeAdapter = func(_ services.RemoteRuntimeAdapterOptions) (runtimeWithClose, error) { + return nil, errors.New("gateway connect failed") + } + + _, err := BuildRuntime(context.Background(), BootstrapOptions{RuntimeMode: RuntimeModeGateway}) + if err == nil { + t.Fatalf("expected gateway mode fail-fast error") + } + if !strings.Contains(err.Error(), "gateway connect failed") { + t.Fatalf("unexpected error: %v", err) + } +} + type stubToolForBootstrap struct { name string content string } +type stubRemoteRuntimeForBootstrap struct { + closed bool + events chan agentruntime.RuntimeEvent +} + +func (s *stubRemoteRuntimeForBootstrap) Submit(context.Context, agentruntime.PrepareInput) error { + return nil +} + +func (s *stubRemoteRuntimeForBootstrap) PrepareUserInput( + context.Context, + agentruntime.PrepareInput, +) (agentruntime.UserInput, error) { + return agentruntime.UserInput{}, nil +} + +func (s *stubRemoteRuntimeForBootstrap) Run(context.Context, agentruntime.UserInput) error { + return nil +} + +func (s *stubRemoteRuntimeForBootstrap) Compact(context.Context, agentruntime.CompactInput) (agentruntime.CompactResult, error) { + return agentruntime.CompactResult{}, nil +} + +func (s *stubRemoteRuntimeForBootstrap) ExecuteSystemTool( + context.Context, + agentruntime.SystemToolInput, +) (tools.ToolResult, error) { + return tools.ToolResult{}, nil +} + +func (s *stubRemoteRuntimeForBootstrap) ResolvePermission(context.Context, agentruntime.PermissionResolutionInput) error { + return nil +} + +func (s *stubRemoteRuntimeForBootstrap) CancelActiveRun() bool { + return false +} + +func (s *stubRemoteRuntimeForBootstrap) Events() <-chan agentruntime.RuntimeEvent { + return s.events +} + +func (s *stubRemoteRuntimeForBootstrap) ListSessions(context.Context) ([]agentsession.Summary, error) { + return nil, nil +} + +func (s *stubRemoteRuntimeForBootstrap) LoadSession(context.Context, string) (agentsession.Session, error) { + return agentsession.Session{}, nil +} + +func (s *stubRemoteRuntimeForBootstrap) ActivateSessionSkill(context.Context, string, string) error { + return nil +} + +func (s *stubRemoteRuntimeForBootstrap) DeactivateSessionSkill(context.Context, string, string) error { + return nil +} + +func (s *stubRemoteRuntimeForBootstrap) ListSessionSkills(context.Context, string) ([]agentruntime.SessionSkillState, error) { + return nil, nil +} + +func (s *stubRemoteRuntimeForBootstrap) Close() error { + s.closed = true + return nil +} + func (s stubToolForBootstrap) Name() string { return s.name } func (s stubToolForBootstrap) Description() string { return "stub" } func (s stubToolForBootstrap) Schema() map[string]any { return map[string]any{"type": "object"} } diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index ccfd7177..20058ef5 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "strings" "sync" @@ -11,6 +12,7 @@ import ( "neo-code/internal/gateway" providertypes "neo-code/internal/provider/types" agentruntime "neo-code/internal/runtime" + agentsession "neo-code/internal/session" ) const bridgeLocalSubjectID = "local_admin" @@ -20,6 +22,10 @@ type runtimeRunCanceler interface { CancelRun(runID string) bool } +type runtimeSessionCreator interface { + CreateSession(ctx context.Context, id string) (agentsession.Session, error) +} + // defaultBuildGatewayRuntimePort 构建网关运行时 RuntimePort 适配器,并返回对应资源清理函数。 func defaultBuildGatewayRuntimePort(ctx context.Context, workdir string) (gateway.RuntimePort, func() error, error) { bundle, err := app.BuildRuntime(ctx, app.BootstrapOptions{Workdir: strings.TrimSpace(workdir)}) @@ -184,19 +190,20 @@ func (b *gatewayRuntimePortBridge) LoadSession(ctx context.Context, input gatewa session, err := b.runtime.LoadSession(ctx, sessionID) if err != nil { if isRuntimeNotFoundError(err) { - return gateway.Session{}, gateway.ErrRuntimeResourceNotFound + creator, ok := b.runtime.(runtimeSessionCreator) + if !ok { + return gateway.Session{}, gateway.ErrRuntimeResourceNotFound + } + created, createErr := creator.CreateSession(ctx, sessionID) + if createErr != nil { + return gateway.Session{}, createErr + } + return convertRuntimeSessionToGatewaySession(created), nil } return gateway.Session{}, err } - return gateway.Session{ - ID: strings.TrimSpace(session.ID), - Title: strings.TrimSpace(session.Title), - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt, - Workdir: strings.TrimSpace(session.Workdir), - Messages: convertSessionMessages(session.Messages), - }, nil + return convertRuntimeSessionToGatewaySession(session), nil } // Close 主动停止桥接事件泵,避免网关关闭后后台协程悬挂。 @@ -346,6 +353,18 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM return converted } +// convertRuntimeSessionToGatewaySession 将 runtime 会话结构映射为 gateway 契约返回值。 +func convertRuntimeSessionToGatewaySession(session agentsession.Session) gateway.Session { + return gateway.Session{ + ID: strings.TrimSpace(session.ID), + Title: strings.TrimSpace(session.Title), + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Workdir: strings.TrimSpace(session.Workdir), + Messages: convertSessionMessages(session.Messages), + } +} + // renderSessionMessageContent 将 provider 多段内容渲染为对外展示的单段文本摘要。 func renderSessionMessageContent(parts []providertypes.ContentPart) string { if len(parts) == 0 { @@ -395,7 +414,7 @@ func isRuntimeNotFoundError(err error) bool { if err == nil { return false } - return strings.Contains(strings.ToLower(err.Error()), "not found") + return errors.Is(err, agentsession.ErrSessionNotFound) || errors.Is(err, os.ErrNotExist) } var _ gateway.RuntimePort = (*gatewayRuntimePortBridge)(nil) diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index b39fafb8..863716ee 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -3,6 +3,7 @@ package cli import ( "context" "errors" + "os" "testing" "time" @@ -28,6 +29,9 @@ type runtimeStub struct { loadID string loadSession agentsession.Session loadErr error + createID string + createSession agentsession.Session + createErr error } const testBridgeSubjectID = bridgeLocalSubjectID @@ -80,6 +84,11 @@ func (s *runtimeStub) LoadSession(_ context.Context, id string) (agentsession.Se return s.loadSession, s.loadErr } +func (s *runtimeStub) CreateSession(_ context.Context, id string) (agentsession.Session, error) { + s.createID = id + return s.createSession, s.createErr +} + func (s *runtimeStub) ActivateSessionSkill(context.Context, string, string) error { return nil } @@ -92,6 +101,50 @@ func (s *runtimeStub) ListSessionSkills(context.Context, string) ([]agentruntime return nil, nil } +type runtimeWithoutCreator struct { + base *runtimeStub +} + +func (r *runtimeWithoutCreator) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + return r.base.Submit(ctx, input) +} +func (r *runtimeWithoutCreator) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + return r.base.PrepareUserInput(ctx, input) +} +func (r *runtimeWithoutCreator) Run(ctx context.Context, input agentruntime.UserInput) error { + return r.base.Run(ctx, input) +} +func (r *runtimeWithoutCreator) Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) { + return r.base.Compact(ctx, input) +} +func (r *runtimeWithoutCreator) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + return r.base.ExecuteSystemTool(ctx, input) +} +func (r *runtimeWithoutCreator) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { + return r.base.ResolvePermission(ctx, input) +} +func (r *runtimeWithoutCreator) CancelActiveRun() bool { + return r.base.CancelActiveRun() +} +func (r *runtimeWithoutCreator) Events() <-chan agentruntime.RuntimeEvent { + return r.base.Events() +} +func (r *runtimeWithoutCreator) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { + return r.base.ListSessions(ctx) +} +func (r *runtimeWithoutCreator) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + return r.base.LoadSession(ctx, id) +} +func (r *runtimeWithoutCreator) ActivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.ActivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCreator) DeactivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + return r.base.DeactivateSessionSkill(ctx, sessionID, skillID) +} +func (r *runtimeWithoutCreator) ListSessionSkills(ctx context.Context, sessionID string) ([]agentruntime.SessionSkillState, error) { + return r.base.ListSessionSkills(ctx, sessionID) +} + func TestNewGatewayRuntimePortBridgeRuntimeUnavailable(t *testing.T) { bridge, err := newGatewayRuntimePortBridge(context.Background(), nil) if err == nil { @@ -294,6 +347,57 @@ func TestGatewayRuntimePortBridgeRuntimeMethods(t *testing.T) { } } +func TestGatewayRuntimePortBridgeLoadSessionNotFoundBranches(t *testing.T) { + t.Parallel() + + base := &runtimeStub{ + loadErr: agentsession.ErrSessionNotFound, + } + bridgeWithoutCreator, err := newGatewayRuntimePortBridge(context.Background(), &runtimeWithoutCreator{base: base}) + if err != nil { + t.Fatalf("new bridge without creator: %v", err) + } + t.Cleanup(func() { _ = bridgeWithoutCreator.Close() }) + + if _, err := bridgeWithoutCreator.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: "s-1", + }); !errors.Is(err, gateway.ErrRuntimeResourceNotFound) { + t.Fatalf("expected ErrRuntimeResourceNotFound, got %v", err) + } + + stub := &runtimeStub{ + loadErr: os.ErrNotExist, + createErr: errors.New("create failed"), + } + bridgeWithCreator, err := newGatewayRuntimePortBridge(context.Background(), stub) + if err != nil { + t.Fatalf("new bridge with creator: %v", err) + } + t.Cleanup(func() { _ = bridgeWithCreator.Close() }) + + if _, err := bridgeWithCreator.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: "s-2", + }); err == nil || err.Error() != "create failed" { + t.Fatalf("expected create failed error, got %v", err) + } +} + +func TestIsRuntimeNotFoundErrorIncludesOSErrNotExist(t *testing.T) { + t.Parallel() + + if !isRuntimeNotFoundError(os.ErrNotExist) { + t.Fatalf("os.ErrNotExist should be treated as runtime not found") + } + if !isRuntimeNotFoundError(agentsession.ErrSessionNotFound) { + t.Fatalf("ErrSessionNotFound should be treated as runtime not found") + } + if isRuntimeNotFoundError(errors.New("session not found")) { + t.Fatalf("plain string not-found error should not be treated as runtime not found") + } +} + func TestGatewayRuntimePortBridgeRuntimeMethodErrors(t *testing.T) { stub := &runtimeStub{ submitErr: errors.New("submit failed"), @@ -330,6 +434,72 @@ func TestGatewayRuntimePortBridgeRuntimeMethodErrors(t *testing.T) { } } +func TestGatewayRuntimePortBridgeLoadSessionUpsertWhenMissing(t *testing.T) { + now := time.Now() + stub := &runtimeStub{ + loadErr: agentsession.ErrSessionNotFound, + createSession: agentsession.Session{ + ID: "session-new", + Title: "New Session", + Workdir: "/tmp/work", + CreatedAt: now, + UpdatedAt: now, + }, + } + bridge, err := newGatewayRuntimePortBridge(context.Background(), stub) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + session, err := bridge.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: " session-new ", + }) + if err != nil { + t.Fatalf("load_session upsert: %v", err) + } + if stub.loadID != "session-new" { + t.Fatalf("load id = %q, want %q", stub.loadID, "session-new") + } + if stub.createID != "session-new" { + t.Fatalf("create id = %q, want %q", stub.createID, "session-new") + } + if session.ID != "session-new" || session.Title != "New Session" || session.Workdir != "/tmp/work" { + t.Fatalf("upsert session = %#v, want created session snapshot", session) + } +} + +func TestGatewayRuntimePortBridgeLoadSessionNoUpsertOnPlainStringNotFoundError(t *testing.T) { + now := time.Now() + stub := &runtimeStub{ + loadErr: errors.New("open sessions/session-new.json: no such file"), + createSession: agentsession.Session{ + ID: "session-new", + Title: "New Session", + Workdir: "/tmp/work", + CreatedAt: now, + UpdatedAt: now, + }, + } + bridge, err := newGatewayRuntimePortBridge(context.Background(), stub) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + _, err = bridge.LoadSession(context.Background(), gateway.LoadSessionInput{ + SubjectID: testBridgeSubjectID, + SessionID: " session-new ", + }) + if err == nil || err.Error() != "open sessions/session-new.json: no such file" { + t.Fatalf("expected original string error passthrough, got %v", err) + } + if stub.createID != "" { + t.Fatalf("create should not be called for plain string error, got createID=%q", stub.createID) + } +} + func TestGatewayRuntimePortBridgeRunEventBridge(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/internal/cli/root.go b/internal/cli/root.go index efbac23d..e3a25a92 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -34,12 +34,13 @@ var ( silentUpdateCheckDone <-chan struct{} ) -// GlobalFlags 描述 CLI 根命令当前支持的全局参数。 +// GlobalFlags 描述根命令共享的全局启动参数。 type GlobalFlags struct { - Workdir string + Workdir string + RuntimeMode string } -// Execute 负责执行 NeoCode 的 CLI 根命令。 +// Execute 执行 NeoCode 根命令入口,并在退出前等待静默更新检查收尾。 func Execute(ctx context.Context) error { app.EnsureConsoleUTF8() _ = ConsumeUpdateNotice() @@ -50,7 +51,7 @@ func Execute(ctx context.Context) error { return err } -// NewRootCommand 创建 NeoCode 的 CLI 根命令。 +// NewRootCommand 构建 NeoCode 的根命令及全局参数绑定。 func NewRootCommand() *cobra.Command { settings := viper.New() flags := &GlobalFlags{} @@ -74,14 +75,24 @@ func NewRootCommand() *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { flags.Workdir = strings.TrimSpace(settings.GetString("workdir")) + flags.RuntimeMode = strings.ToLower(strings.TrimSpace(settings.GetString("runtime-mode"))) + switch flags.RuntimeMode { + case "", app.RuntimeModeLocal: + flags.RuntimeMode = app.RuntimeModeLocal + case app.RuntimeModeGateway: + default: + return fmt.Errorf("invalid --runtime-mode %q, must be local or gateway", flags.RuntimeMode) + } return launchRootProgram(cmd.Context(), app.BootstrapOptions{ - Workdir: flags.Workdir, + Workdir: flags.Workdir, + RuntimeMode: flags.RuntimeMode, }) }, } - - cmd.PersistentFlags().String("workdir", "", "工作目录(覆盖本次运行工作区)") + cmd.PersistentFlags().String("workdir", "", "workdir override for current run") + cmd.PersistentFlags().String("runtime-mode", app.RuntimeModeLocal, "runtime mode (local/gateway)") _ = settings.BindPFlag("workdir", cmd.PersistentFlags().Lookup("workdir")) + _ = settings.BindPFlag("runtime-mode", cmd.PersistentFlags().Lookup("runtime-mode")) cmd.AddCommand( newGatewayCommand(), newURLDispatchCommand(), @@ -91,7 +102,7 @@ func NewRootCommand() *cobra.Command { return cmd } -// defaultRootProgramLauncher 负责在默认根命令路径下启动 TUI。 +// defaultRootProgramLauncher 负责创建并运行 TUI Program,同时保证清理函数被正确执行。 func defaultRootProgramLauncher(ctx context.Context, opts app.BootstrapOptions) (err error) { program, cleanup, err := newRootProgram(ctx, opts) if err != nil { @@ -114,7 +125,7 @@ func defaultRootProgramLauncher(ctx context.Context, opts app.BootstrapOptions) return err } -// defaultGlobalPreload 负责执行启动前预检查,避免在命令执行前遗漏上下文取消信号。 +// defaultGlobalPreload 执行全局预加载钩子;当前仅做上下文取消检查。 func defaultGlobalPreload(ctx context.Context) error { if err := ctx.Err(); err != nil { return err @@ -122,7 +133,7 @@ func defaultGlobalPreload(ctx context.Context) error { return nil } -// defaultSilentUpdateCheck 在后台异步检查新版本并缓存退出后提示文案。 +// defaultSilentUpdateCheck 在后台静默检查是否有新版本,并写入一次性升级提示。 func defaultSilentUpdateCheck(ctx context.Context) { currentVersion := readCurrentVersion() if !version.IsSemverRelease(currentVersion) { @@ -155,12 +166,12 @@ func defaultSilentUpdateCheck(ctx context.Context) { }(parentCtx, currentVersion, done) } -// shouldSkipGlobalPreload 判断当前命令是否应跳过全局预加载逻辑。 +// shouldSkipGlobalPreload 判断当前子命令是否跳过全局预加载。 func shouldSkipGlobalPreload(cmd *cobra.Command) bool { return normalizedCommandName(cmd) == "url-dispatch" } -// shouldSkipSilentUpdateCheck 判断当前命令是否应跳过静默更新检测。 +// shouldSkipSilentUpdateCheck 判断当前子命令是否跳过静默更新检查。 func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { switch normalizedCommandName(cmd) { case "url-dispatch", "update": @@ -170,7 +181,7 @@ func shouldSkipSilentUpdateCheck(cmd *cobra.Command) bool { } } -// sanitizeVersionForTerminal 清洗远端版本字符串,避免 ANSI 控制序列或不可见字符污染终端输出。 +// sanitizeVersionForTerminal 清理版本号中的 ANSI 控制字符与不可打印字符,避免污染终端输出。 func sanitizeVersionForTerminal(version string) string { cleaned := ansiEscapeSequencePattern.ReplaceAllString(version, "") var builder strings.Builder @@ -183,7 +194,7 @@ func sanitizeVersionForTerminal(version string) string { return strings.TrimSpace(builder.String()) } -// normalizedCommandName 返回标准化后的命令名,统一处理空命令与大小写。 +// normalizedCommandName 返回小写且去空白后的命令名,便于统一比较。 func normalizedCommandName(cmd *cobra.Command) string { if cmd == nil { return "" @@ -191,14 +202,14 @@ func normalizedCommandName(cmd *cobra.Command) string { return strings.ToLower(strings.TrimSpace(cmd.Name())) } -// setSilentUpdateCheckDone 保存当前静默检测任务的完成信号通道。 +// setSilentUpdateCheckDone 原子地更新静默检查完成信号通道。 func setSilentUpdateCheckDone(done <-chan struct{}) { silentUpdateCheckMu.Lock() silentUpdateCheckDone = done silentUpdateCheckMu.Unlock() } -// waitSilentUpdateCheckDone 在命令退出阶段等待静默检测短暂收口,降低提示丢失概率。 +// waitSilentUpdateCheckDone 在给定超时时间内等待静默更新检查结束,超时后直接返回。 func waitSilentUpdateCheckDone(timeout time.Duration) { if timeout <= 0 { return diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 540ab369..679e669f 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -65,6 +65,45 @@ func TestNewRootCommandAllowsEmptyWorkdir(t *testing.T) { if captured.Workdir != "" { t.Fatalf("expected empty workdir override, got %q", captured.Workdir) } + if captured.RuntimeMode != app.RuntimeModeLocal { + t.Fatalf("expected default runtime mode %q, got %q", app.RuntimeModeLocal, captured.RuntimeMode) + } +} + +func TestNewRootCommandPassesRuntimeModeFlagToLauncher(t *testing.T) { + originalLauncher := launchRootProgram + t.Cleanup(func() { launchRootProgram = originalLauncher }) + + var captured app.BootstrapOptions + launchRootProgram = func(ctx context.Context, opts app.BootstrapOptions) error { + captured = opts + return nil + } + + cmd := NewRootCommand() + cmd.SetArgs([]string{"--runtime-mode", app.RuntimeModeGateway}) + if err := cmd.ExecuteContext(context.Background()); err != nil { + t.Fatalf("ExecuteContext() error = %v", err) + } + if captured.RuntimeMode != app.RuntimeModeGateway { + t.Fatalf("expected runtime mode %q, got %q", app.RuntimeModeGateway, captured.RuntimeMode) + } +} + +func TestNewRootCommandRejectsInvalidRuntimeMode(t *testing.T) { + originalPreload := runGlobalPreload + t.Cleanup(func() { runGlobalPreload = originalPreload }) + runGlobalPreload = func(context.Context) error { return nil } + + cmd := NewRootCommand() + cmd.SetArgs([]string{"--runtime-mode", "invalid"}) + err := cmd.ExecuteContext(context.Background()) + if err == nil { + t.Fatalf("expected invalid runtime mode error") + } + if !strings.Contains(err.Error(), "invalid --runtime-mode") { + t.Fatalf("unexpected error: %v", err) + } } func TestNewRootCommandReturnsLauncherError(t *testing.T) { diff --git a/internal/gateway/network_server_additional_test.go b/internal/gateway/network_server_additional_test.go index dc6c60d1..dccf31a2 100644 --- a/internal/gateway/network_server_additional_test.go +++ b/internal/gateway/network_server_additional_test.go @@ -293,6 +293,7 @@ func TestNetworkServerSSELimitAndWriteErrorBranches(t *testing.T) { t.Fatalf("dial first ws: %v", err) } defer func() { _ = firstConn.Close() }() + waitForWebSocketConnectionCount(t, server, 1, 2*time.Second) sseResponse, err := http.Get("http://" + listenAddress + "/sse?method=gateway.ping&id=sse-limit") if err != nil { diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 1c878d12..d01f21ec 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -540,6 +540,70 @@ func TestDispatchRPCRequestMetricsUnknownMethodCollapsed(t *testing.T) { } } +func TestDispatchRPCRequestMetricsGrowForTUIMethodSequence(t *testing.T) { + metrics := NewGatewayMetrics() + authState := NewConnectionAuthState() + ctx := WithRequestSource(context.Background(), RequestSourceIPC) + ctx = WithGatewayMetrics(ctx, metrics) + ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) + ctx = WithConnectionAuthState(ctx, authState) + ctx = WithTokenAuthenticator(ctx, staticTokenAuthenticator{token: "token-tui"}) + + authenticate := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-auth-tui"`), + Method: protocol.MethodGatewayAuthenticate, + Params: json.RawMessage(`{"token":"token-tui"}`), + }, &runtimePortCompileStub{}) + if authenticate.Error != nil { + t.Fatalf("authenticate response error: %+v", authenticate.Error) + } + + run := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-run-tui"`), + Method: protocol.MethodGatewayRun, + Params: json.RawMessage(`{"session_id":"session-tui","input_text":"hello"}`), + }, &runtimePortCompileStub{}) + if run.Error != nil { + t.Fatalf("run response error: %+v", run.Error) + } + + compact := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-compact-tui"`), + Method: protocol.MethodGatewayCompact, + Params: json.RawMessage(`{"session_id":"session-tui"}`), + }, &runtimePortCompileStub{}) + if compact.Error != nil { + t.Fatalf("compact response error: %+v", compact.Error) + } + + listSessions := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-list-tui"`), + Method: protocol.MethodGatewayListSessions, + Params: json.RawMessage(`{}`), + }, &runtimePortCompileStub{}) + if listSessions.Error != nil { + t.Fatalf("listSessions response error: %+v", listSessions.Error) + } + + snapshot := metrics.Snapshot()["gateway_requests_total"] + if snapshot["ipc|gateway.authenticate|ok"] == 0 { + t.Fatalf("expected authenticate metric to grow, snapshot=%#v", snapshot) + } + if snapshot["ipc|gateway.run|ok"] == 0 { + t.Fatalf("expected run metric to grow, snapshot=%#v", snapshot) + } + if snapshot["ipc|gateway.compact|ok"] == 0 { + t.Fatalf("expected compact metric to grow, snapshot=%#v", snapshot) + } + if snapshot["ipc|gateway.listsessions|ok"] == 0 { + t.Fatalf("expected listSessions metric to grow, snapshot=%#v", snapshot) + } +} + func TestDispatchRPCRequestMetricsACLDeniedAndFrameErrorLabels(t *testing.T) { metrics := NewGatewayMetrics() denyACL := &ControlPlaneACL{ diff --git a/internal/runtime/create_session_test.go b/internal/runtime/create_session_test.go new file mode 100644 index 00000000..ed6a3d73 --- /dev/null +++ b/internal/runtime/create_session_test.go @@ -0,0 +1,205 @@ +package runtime + +import ( + "context" + "fmt" + "os" + "testing" + + agentsession "neo-code/internal/session" +) + +type createSessionUpsertStore struct { + *memoryStore + missingErr error +} + +func (s *createSessionUpsertStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + s.memoryStore.mu.Lock() + _, exists := s.memoryStore.sessions[id] + s.memoryStore.mu.Unlock() + if !exists { + return agentsession.Session{}, s.missingErr + } + return s.memoryStore.LoadSession(ctx, id) +} + +func TestServiceCreateSessionUpsertWhenMissing(t *testing.T) { + t.Parallel() + + store := &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + created, err := service.CreateSession(context.Background(), "session-upsert") + if err != nil { + t.Fatalf("CreateSession() upsert error = %v", err) + } + if created.ID != "session-upsert" { + t.Fatalf("created session id = %q, want %q", created.ID, "session-upsert") + } + if created.Title != "New Session" { + t.Fatalf("created session title = %q, want %q", created.Title, "New Session") + } + + savesAfterCreate := store.memoryStore.saves + loaded, err := service.CreateSession(context.Background(), "session-upsert") + if err != nil { + t.Fatalf("CreateSession() load existing error = %v", err) + } + if loaded.ID != "session-upsert" { + t.Fatalf("loaded session id = %q, want %q", loaded.ID, "session-upsert") + } + if store.memoryStore.saves != savesAfterCreate { + t.Fatalf("unexpected additional create, saves=%d want %d", store.memoryStore.saves, savesAfterCreate) + } +} + +func TestServiceCreateSessionReturnsOriginalErrorWhenMissingErrorIsNotSentinel(t *testing.T) { + t.Parallel() + + store := &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("dependency not found"), + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + _, err := service.CreateSession(context.Background(), "session-upsert") + if err == nil { + t.Fatalf("CreateSession() expected error when missing error is not sentinel") + } + if err.Error() != "dependency not found" { + t.Fatalf("CreateSession() error = %v, want dependency not found", err) + } + if store.memoryStore.saves != 0 { + t.Fatalf("CreateSession() should not create on non-sentinel error, saves=%d", store.memoryStore.saves) + } +} + +type createSessionDuplicateStore struct { + *createSessionUpsertStore + createErr error + loadHits int + loaded agentsession.Session + loadErr error +} + +func (s *createSessionDuplicateStore) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + s.loadHits++ + if s.loadHits == 1 { + return agentsession.Session{}, s.missingErr + } + if s.loadErr != nil { + return agentsession.Session{}, s.loadErr + } + return s.loaded, nil +} + +func (s *createSessionDuplicateStore) CreateSession(ctx context.Context, input agentsession.CreateSessionInput) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + if s.createErr != nil { + return agentsession.Session{}, s.createErr + } + return s.memoryStore.CreateSession(ctx, input) +} + +func TestServiceCreateSessionBranches(t *testing.T) { + t.Parallel() + + store := &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := service.CreateSession(ctx, "session-canceled"); err == nil { + t.Fatalf("CreateSession() should reject canceled context") + } + if _, err := service.CreateSession(context.Background(), " "); err == nil { + t.Fatalf("CreateSession() should reject empty session id") + } +} + +func TestServiceCreateSessionReturnsWorkdirResolutionError(t *testing.T) { + t.Parallel() + + service := &Service{ + sessionStore: newMemoryStore(), + // 不注入 configManager 会使默认 workdir 为空,触发 resolveWorkdirForSession 错误路径。 + } + if _, err := service.CreateSession(context.Background(), "session-workdir"); err == nil { + t.Fatalf("CreateSession() should fail when default workdir cannot be resolved") + } +} + +func TestServiceCreateSessionDuplicateCreateFallsBackToLoad(t *testing.T) { + t.Parallel() + + store := &createSessionDuplicateStore{ + createSessionUpsertStore: &createSessionUpsertStore{ + memoryStore: newMemoryStore(), + missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), + }, + createErr: fmt.Errorf("sqlite: %w", agentsession.ErrSessionAlreadyExists), + loaded: agentsession.Session{ID: "session-dup", Title: "loaded"}, + } + service := &Service{ + configManager: newRuntimeConfigManager(t), + sessionStore: store, + } + + loaded, err := service.CreateSession(context.Background(), "session-dup") + if err != nil { + t.Fatalf("CreateSession() duplicate fallback error = %v", err) + } + if loaded.ID != "session-dup" || loaded.Title != "loaded" { + t.Fatalf("CreateSession() loaded session = %#v", loaded) + } +} + +func TestCreateSessionErrorPredicates(t *testing.T) { + t.Parallel() + + if isRuntimeSessionNotFoundError(nil) { + t.Fatalf("isRuntimeSessionNotFoundError(nil) should be false") + } + if !isRuntimeSessionNotFoundError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionNotFound)) { + t.Fatalf("wrapped ErrSessionNotFound should be detected") + } + + if isRuntimeSessionAlreadyExistsError(nil) { + t.Fatalf("isRuntimeSessionAlreadyExistsError(nil) should be false") + } + if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionAlreadyExists)) { + t.Fatalf("wrapped ErrSessionAlreadyExists should be detected") + } + if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", os.ErrExist)) { + t.Fatalf("wrapped os.ErrExist should be detected") + } + for _, text := range []string{"already exists", "UNIQUE CONSTRAINT", "duplicate key"} { + if isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) { + t.Fatalf("plain text %q should not be treated as already exists", text) + } + } +} diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 4d0c7f91..99de987b 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -2,6 +2,8 @@ package runtime import ( "context" + "errors" + "os" "fmt" "strings" "sync" @@ -264,6 +266,60 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess return session, nil } +// CreateSession 按给定 id 执行会话创建/加载(Upsert)并返回可用会话头。 +func (s *Service) CreateSession(ctx context.Context, id string) (agentsession.Session, error) { + if err := ctx.Err(); err != nil { + return agentsession.Session{}, err + } + sessionID := strings.TrimSpace(id) + if sessionID == "" { + return agentsession.Session{}, errors.New("runtime: session id is empty") + } + defaultWorkdir := "" + if s.configManager != nil { + defaultWorkdir = strings.TrimSpace(s.configManager.Get().Workdir) + } + sessionWorkdir, err := resolveWorkdirForSession(defaultWorkdir, "", "") + if err != nil { + return agentsession.Session{}, err + } + + existing, err := s.sessionStore.LoadSession(ctx, sessionID) + if err == nil { + return existing, nil + } + if !isRuntimeSessionNotFoundError(err) { + return agentsession.Session{}, err + } + + newSession := agentsession.NewWithWorkdir("New Session", sessionWorkdir) + newSession.ID = sessionID + created, createErr := s.sessionStore.CreateSession(ctx, createSessionInputFromSession(newSession)) + if createErr == nil { + return created, nil + } + if isRuntimeSessionAlreadyExistsError(createErr) { + return s.sessionStore.LoadSession(ctx, sessionID) + } + return agentsession.Session{}, createErr +} + +// isRuntimeSessionNotFoundError 判断错误是否代表会话文件/记录不存在。 +func isRuntimeSessionNotFoundError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, agentsession.ErrSessionNotFound) +} + +// isRuntimeSessionAlreadyExistsError 判断错误是否代表会话已被并发创建。 +func isRuntimeSessionAlreadyExistsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, agentsession.ErrSessionAlreadyExists) || errors.Is(err, os.ErrExist) +} + // SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 func (s *Service) SetAutoCompactThresholdResolver(resolver AutoCompactThresholdResolver) { s.autoCompactThresholdResolver = resolver diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 0da93f71..036e4948 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -137,6 +137,9 @@ func (s *SQLiteStore) CreateSession(ctx context.Context, input CreateSessionInpu if err := ctx.Err(); err != nil { return Session{}, err } + if err := s.ensureStorageDirs(); err != nil { + return Session{}, err + } db, err := s.ensureDB(ctx) if err != nil { return Session{}, err @@ -174,6 +177,9 @@ INSERT INTO sessions ( session.TokenOutputTotal, ) if err != nil { + if isSQLiteSessionUniqueConstraintError(err) { + return Session{}, wrapSessionAlreadyExists(err) + } return Session{}, fmt.Errorf("session: insert session %s: %w", session.ID, err) } if err := tx.Commit(); err != nil { @@ -715,6 +721,21 @@ func (s *SQLiteStore) initialize(ctx context.Context) error { return nil } +// ensureStorageDirs 统一保证 session 存储相关目录存在,避免新会话写入时因父目录缺失失败。 +func (s *SQLiteStore) ensureStorageDirs() error { + dbDir := filepath.Dir(s.dbPath) + if err := os.MkdirAll(dbDir, 0o755); err != nil { + return fmt.Errorf("session: create db dir: %w", err) + } + if err := os.MkdirAll(s.projectDir, 0o755); err != nil { + return fmt.Errorf("session: create project dir: %w", err) + } + if err := os.MkdirAll(s.assetsDir, 0o755); err != nil { + return fmt.Errorf("session: create assets dir: %w", err) + } + return nil +} + // loadAssetMeta 查询附件元数据并解析绝对路径。 func (s *SQLiteStore) loadAssetMeta(ctx context.Context, sessionID string, assetID string) (AssetMeta, string, error) { if err := validateStorageID("session id", sessionID); err != nil { @@ -968,7 +989,7 @@ WHERE id = ? ) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return sqliteSessionRow{}, os.ErrNotExist + return sqliteSessionRow{}, wrapSessionNotFound(sql.ErrNoRows) } return sqliteSessionRow{}, fmt.Errorf("session: query session %s: %w", sessionID, err) } @@ -1096,7 +1117,7 @@ func currentLastSeq(ctx context.Context, tx *sql.Tx, sessionID string) (int, err err := tx.QueryRowContext(ctx, `SELECT last_seq FROM sessions WHERE id = ?`, sessionID).Scan(&lastSeq) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return 0, os.ErrNotExist + return 0, wrapSessionNotFound(sql.ErrNoRows) } return 0, fmt.Errorf("session: query last_seq for %s: %w", sessionID, err) } @@ -1275,11 +1296,27 @@ func expectRowsAffected(result sql.Result, sessionID string) error { return fmt.Errorf("session: inspect rows affected for %s: %w", sessionID, err) } if rowsAffected == 0 { - return os.ErrNotExist + return wrapSessionNotFound(os.ErrNotExist) } return nil } +// wrapSessionNotFound 统一包装会话缺失错误,确保上层可通过 ErrSessionNotFound 做精确判断。 +func wrapSessionNotFound(cause error) error { + if cause == nil { + cause = os.ErrNotExist + } + return fmt.Errorf("%w: %w", ErrSessionNotFound, fmt.Errorf("%w: %w", os.ErrNotExist, cause)) +} + +// wrapSessionAlreadyExists 统一包装会话重复创建错误,确保上层可通过 ErrSessionAlreadyExists 做精确判断。 +func wrapSessionAlreadyExists(cause error) error { + if cause == nil { + cause = os.ErrExist + } + return fmt.Errorf("%w: %w", ErrSessionAlreadyExists, fmt.Errorf("%w: %w", os.ErrExist, cause)) +} + // cloneMessage 深拷贝消息,避免共享底层切片和映射。 // mapSessionAssetInsertError 统一收敛附件元数据插入阶段的缺失会话语义,避免向上泄漏底层 SQLite 错误。 func mapSessionAssetInsertError(assetID string, err error) error { @@ -1298,6 +1335,18 @@ func isSQLiteForeignKeyConstraintError(err error) bool { return false } +// isSQLiteSessionUniqueConstraintError 判断底层错误是否为 SQLite 主键/唯一约束失败。 +func isSQLiteSessionUniqueConstraintError(err error) bool { + var sqliteErr *sqlitedriver.Error + if !errors.As(err, &sqliteErr) { + return false + } + code := sqliteErr.Code() + return code == sqlite3.SQLITE_CONSTRAINT || + code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY || + code == sqlite3.SQLITE_CONSTRAINT_UNIQUE +} + func cloneMessage(message providertypes.Message) providertypes.Message { next := message next.Parts = providertypes.CloneParts(message.Parts) diff --git a/internal/session/sqlite_store_additional_test.go b/internal/session/sqlite_store_additional_test.go index 4d967756..fa1307b6 100644 --- a/internal/session/sqlite_store_additional_test.go +++ b/internal/session/sqlite_store_additional_test.go @@ -159,6 +159,8 @@ func TestExpectRowsAffectedBranches(t *testing.T) { } if err := expectRowsAffected(fakeResult{rows: 0}, "s1"); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected os.ErrNotExist when rows=0, got %v", err) + } else if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound when rows=0, got %v", err) } if err := expectRowsAffected(fakeResult{rows: 1}, "s1"); err != nil { t.Fatalf("expected rows=1 to pass, got %v", err) @@ -227,6 +229,76 @@ func TestNormalizeCreateSessionInputDefaultsGeneratedID(t *testing.T) { } } +func TestSQLiteStoreCreateSessionPropagatesEnsureStorageDirsError(t *testing.T) { + t.Parallel() + + store := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + _, err := store.CreateSession(context.Background(), CreateSessionInput{ID: "s1", Title: "title"}) + if err == nil { + t.Fatalf("expected CreateSession() to fail when db dir cannot be created") + } +} + +func TestSQLiteStoreEnsureStorageDirsErrorBranches(t *testing.T) { + t.Parallel() + + dbDirErrStore := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + if err := dbDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create db dir") { + t.Fatalf("expected create db dir error, got %v", err) + } + + projectDirErrStore := &SQLiteStore{ + projectDir: filepath.Join("/dev/null", "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + if err := projectDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create project dir") { + t.Fatalf("expected create project dir error, got %v", err) + } + + assetsDirErrStore := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join("/dev/null", "assets"), + dbPath: filepath.Join(t.TempDir(), "db.sqlite"), + } + if err := assetsDirErrStore.ensureStorageDirs(); err == nil || !strings.Contains(err.Error(), "create assets dir") { + t.Fatalf("expected create assets dir error, got %v", err) + } +} + +func TestSQLiteStoreInitializePropagatesStorageDirError(t *testing.T) { + t.Parallel() + + store := &SQLiteStore{ + projectDir: filepath.Join(t.TempDir(), "project"), + assetsDir: filepath.Join(t.TempDir(), "assets"), + dbPath: filepath.Join("/dev/null", "db.sqlite"), + } + if err := store.initialize(context.Background()); err == nil { + t.Fatalf("expected initialize() to fail when storage dirs are invalid") + } +} + +func TestWrapSessionNotFoundWithNilCause(t *testing.T) { + t.Parallel() + + err := wrapSessionNotFound(nil) + if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected ErrSessionNotFound, got %v", err) + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) + } +} + func TestResolveUpdatedAtReturnsProvidedValue(t *testing.T) { t.Parallel() diff --git a/internal/session/store.go b/internal/session/store.go index 6fcd7a25..87646a9c 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -2,6 +2,7 @@ package session import ( "context" + "errors" "fmt" "regexp" "strings" @@ -24,6 +25,12 @@ const ( var storageIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,127}$`) +// ErrSessionNotFound 表示会话在存储层不存在,用于 runtime 做精确错误分流。 +var ErrSessionNotFound = errors.New("session: session not found") + +// ErrSessionAlreadyExists 表示会话在存储层已存在,用于 runtime 处理并发创建冲突。 +var ErrSessionAlreadyExists = errors.New("session: session already exists") + // Session 表示单个会话的运行态与持久化聚合模型。 type Session struct { ID string diff --git a/internal/session/store_test.go b/internal/session/store_test.go index e38193b5..e9a76862 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -122,6 +122,28 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { } } +func TestSQLiteStoreCreateSessionDuplicateReturnsSentinel(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := newTestStore(t) + input := CreateSessionInput{ID: "dup_session", Title: "dup"} + if _, err := store.CreateSession(ctx, input); err != nil { + t.Fatalf("first CreateSession() error = %v", err) + } + + _, err := store.CreateSession(ctx, input) + if err == nil { + t.Fatalf("expected duplicate CreateSession() to fail") + } + if !errors.Is(err, ErrSessionAlreadyExists) { + t.Fatalf("expected ErrSessionAlreadyExists, got %v", err) + } + if !errors.Is(err, os.ErrExist) { + t.Fatalf("expected os.ErrExist chain, got %v", err) + } +} + func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) { ctx := context.Background() baseDir, err := os.MkdirTemp("", "session-base-") @@ -282,9 +304,13 @@ func TestSQLiteStoreErrors(t *testing.T) { } if err := store.UpdateSessionState(ctx, UpdateSessionStateInput{SessionID: "missing", Title: "x"}); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected update missing session to return os.ErrNotExist, got %v", err) + } else if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected update missing session to return ErrSessionNotFound, got %v", err) } if _, err := store.LoadSession(ctx, "missing"); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected load missing session to return os.ErrNotExist, got %v", err) + } else if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected load missing session to return ErrSessionNotFound, got %v", err) } } @@ -370,6 +396,8 @@ func TestSQLiteStoreAppendReplaceAndSchemaErrors(t *testing.T) { }, }); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected append missing session to return os.ErrNotExist, got %v", err) + } else if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected append missing session to return ErrSessionNotFound, got %v", err) } session, err := store.CreateSession(ctx, CreateSessionInput{ID: "invalid_message", Title: "invalid"}) @@ -406,6 +434,8 @@ func TestSQLiteStoreAppendReplaceAndSchemaErrors(t *testing.T) { }, }); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected replace transcript missing session to return os.ErrNotExist, got %v", err) + } else if !errors.Is(err, ErrSessionNotFound) { + t.Fatalf("expected replace transcript missing session to return ErrSessionNotFound, got %v", err) } } diff --git a/internal/tui/services/gateway_rpc_client.go b/internal/tui/services/gateway_rpc_client.go new file mode 100644 index 00000000..6236f075 --- /dev/null +++ b/internal/tui/services/gateway_rpc_client.go @@ -0,0 +1,620 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + gatewayauth "neo-code/internal/gateway/auth" + "neo-code/internal/gateway/protocol" + "neo-code/internal/gateway/transport" +) + +const ( + defaultGatewayRPCRequestTimeout = 8 * time.Second + defaultGatewayRPCRetryCount = 1 + defaultGatewayNotificationBuffer = 64 + defaultGatewayNotificationQueue = 256 + defaultGatewayNotificationEnqueueTimeout = 3 * time.Second +) + +// GatewayRPCClientOptions 描述网关 JSON-RPC 客户端的初始化参数。 +type GatewayRPCClientOptions struct { + ListenAddress string + TokenFile string + RequestTimeout time.Duration + RetryCount int + Dial func(address string) (net.Conn, error) + ResolveListenAddress func(override string) (string, error) +} + +// GatewayRPCCallOptions 描述单次 RPC 调用的覆盖参数。 +type GatewayRPCCallOptions struct { + Timeout time.Duration + Retries int +} + +// GatewayRPCError 描述网关返回的结构化 RPC 错误。 +type GatewayRPCError struct { + Method string + Code int + GatewayCode string + Message string +} + +func (e *GatewayRPCError) Error() string { + if e == nil { + return "" + } + if strings.TrimSpace(e.GatewayCode) != "" { + return fmt.Sprintf("gateway rpc %s failed (%s): %s", e.Method, e.GatewayCode, e.Message) + } + return fmt.Sprintf("gateway rpc %s failed: %s", e.Method, e.Message) +} + +type gatewayRPCTransportError struct { + Method string + Err error +} + +func (e *gatewayRPCTransportError) Error() string { + if e == nil { + return "" + } + if strings.TrimSpace(e.Method) == "" { + return fmt.Sprintf("gateway rpc transport error: %v", e.Err) + } + return fmt.Sprintf("gateway rpc %s transport error: %v", e.Method, e.Err) +} + +func (e *gatewayRPCTransportError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +type gatewayRPCNotification struct { + Method string + Params json.RawMessage +} + +type gatewayRPCResponse struct { + ID string + Result json.RawMessage + RPCError *protocol.JSONRPCError + TransportErr error +} + +// GatewayRPCClient 维护与 Gateway 的长连接、请求关联与通知分发。 +type GatewayRPCClient struct { + listenAddress string + token string + requestTimeout time.Duration + retryCount int + dialFn func(address string) (net.Conn, error) + + closeOnce sync.Once + closed chan struct{} + + writeMu sync.Mutex + stateMu sync.Mutex + conn net.Conn + pending map[string]chan gatewayRPCResponse + + notifications chan gatewayRPCNotification + notificationQueue chan gatewayRPCNotification + notificationEnqueueTimeout time.Duration + notificationWG sync.WaitGroup + notificationStart sync.Once + sequence uint64 +} + +// NewGatewayRPCClient 创建网关 RPC 客户端,并在启动时静默读取认证 Token。 +func NewGatewayRPCClient(options GatewayRPCClientOptions) (*GatewayRPCClient, error) { + resolveListenAddressFn := options.ResolveListenAddress + if resolveListenAddressFn == nil { + resolveListenAddressFn = transport.ResolveListenAddress + } + listenAddress, err := resolveListenAddressFn(strings.TrimSpace(options.ListenAddress)) + if err != nil { + return nil, fmt.Errorf("gateway rpc client: resolve listen address: %w", err) + } + + token, err := loadGatewayAuthToken(options.TokenFile) + if err != nil { + return nil, err + } + + requestTimeout := options.RequestTimeout + if requestTimeout <= 0 { + requestTimeout = defaultGatewayRPCRequestTimeout + } + + retryCount := options.RetryCount + if retryCount <= 0 { + retryCount = defaultGatewayRPCRetryCount + } + + dialFn := options.Dial + if dialFn == nil { + dialFn = transport.Dial + } + + return &GatewayRPCClient{ + listenAddress: listenAddress, + token: token, + requestTimeout: requestTimeout, + retryCount: retryCount, + dialFn: dialFn, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer), + notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue), + notificationEnqueueTimeout: defaultGatewayNotificationEnqueueTimeout, + }, nil +} + +// Notifications 返回网关 JSON-RPC 通知流。 +func (c *GatewayRPCClient) Notifications() <-chan gatewayRPCNotification { + return c.notifications +} + +// Authenticate 显式调用 gateway.authenticate,建立连接级认证状态。 +func (c *GatewayRPCClient) Authenticate(ctx context.Context) error { + var frame map[string]any + err := c.CallWithOptions( + ctx, + protocol.MethodGatewayAuthenticate, + protocol.AuthenticateParams{Token: c.token}, + &frame, + GatewayRPCCallOptions{ + Timeout: c.requestTimeout, + Retries: c.retryCount, + }, + ) + if err != nil { + return err + } + return nil +} + +// Call 按默认超时与重试策略发起一次 JSON-RPC 调用。 +func (c *GatewayRPCClient) Call(ctx context.Context, method string, params any, result any) error { + return c.CallWithOptions(ctx, method, params, result, GatewayRPCCallOptions{ + Timeout: c.requestTimeout, + Retries: c.retryCount, + }) +} + +// CallWithOptions 发起一次可覆盖超时与重试策略的 JSON-RPC 调用。 +func (c *GatewayRPCClient) CallWithOptions( + ctx context.Context, + method string, + params any, + result any, + options GatewayRPCCallOptions, +) error { + method = strings.TrimSpace(method) + if method == "" { + return errors.New("gateway rpc client: method is empty") + } + + timeout := options.Timeout + if timeout <= 0 { + timeout = c.requestTimeout + } + retries := options.Retries + if retries < 0 { + retries = c.retryCount + } + + var lastErr error + for attempt := 0; attempt <= retries; attempt++ { + lastErr = c.callOnce(ctx, method, params, result, timeout) + if lastErr == nil { + return nil + } + if !isRetryableGatewayCallError(lastErr) || attempt == retries { + return lastErr + } + c.resetConnection() + } + return lastErr +} + +// Close 关闭客户端连接并结束内部通知流。 +func (c *GatewayRPCClient) Close() error { + var firstErr error + c.closeOnce.Do(func() { + close(c.closed) + firstErr = c.forceCloseWithError(errors.New("gateway rpc client closed")) + c.notificationWG.Wait() + close(c.notifications) + }) + return firstErr +} + +func (c *GatewayRPCClient) callOnce( + ctx context.Context, + method string, + params any, + result any, + timeout time.Duration, +) error { + callCtx := ctx + var cancel context.CancelFunc + if timeout > 0 { + callCtx, cancel = context.WithTimeout(ctx, timeout) + } + if cancel != nil { + defer cancel() + } + if err := callCtx.Err(); err != nil { + return err + } + + conn, err := c.ensureConnected() + if err != nil { + return &gatewayRPCTransportError{Method: method, Err: err} + } + + requestID := fmt.Sprintf("tui-%d", atomic.AddUint64(&c.sequence, 1)) + idRaw, err := marshalJSONRawMessage(requestID) + if err != nil { + return fmt.Errorf("gateway rpc client: encode request id: %w", err) + } + + request := protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: idRaw, + Method: method, + } + if params != nil { + paramsRaw, marshalErr := marshalJSONRawMessage(params) + if marshalErr != nil { + return fmt.Errorf("gateway rpc client: encode request params: %w", marshalErr) + } + request.Params = paramsRaw + } + + responseCh := make(chan gatewayRPCResponse, 1) + if !c.registerPending(requestID, responseCh) { + return &gatewayRPCTransportError{Method: method, Err: errors.New("gateway rpc client is closed")} + } + defer c.unregisterPending(requestID) + + if writeErr := c.writeRequest(conn, request); writeErr != nil { + return &gatewayRPCTransportError{Method: method, Err: writeErr} + } + + select { + case <-c.closed: + return &gatewayRPCTransportError{Method: method, Err: errors.New("gateway rpc client is closed")} + case <-callCtx.Done(): + return callCtx.Err() + case response := <-responseCh: + if response.TransportErr != nil { + return &gatewayRPCTransportError{Method: method, Err: response.TransportErr} + } + if response.RPCError != nil { + return mapGatewayRPCError(method, response.RPCError) + } + if result == nil { + return nil + } + if len(response.Result) == 0 { + return &gatewayRPCTransportError{Method: method, Err: errors.New("gateway rpc response result is empty")} + } + if err := json.Unmarshal(response.Result, result); err != nil { + return fmt.Errorf("gateway rpc client: decode %s response: %w", method, err) + } + return nil + } +} + +func (c *GatewayRPCClient) writeRequest(conn net.Conn, request protocol.JSONRPCRequest) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + encoder := json.NewEncoder(conn) + if err := encoder.Encode(request); err != nil { + c.resetConnection() + return fmt.Errorf("write rpc request failed: %w", err) + } + return nil +} + +func (c *GatewayRPCClient) ensureConnected() (net.Conn, error) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + if c.conn != nil { + return c.conn, nil + } + select { + case <-c.closed: + return nil, errors.New("gateway rpc client is closed") + default: + } + + conn, err := c.dialFn(c.listenAddress) + if err != nil { + return nil, fmt.Errorf("dial gateway %s: %w", c.listenAddress, err) + } + c.conn = conn + c.startNotificationDispatcher() + go c.readLoop(conn) + return conn, nil +} + +func (c *GatewayRPCClient) readLoop(conn net.Conn) { + decoder := json.NewDecoder(conn) + for { + var envelope map[string]json.RawMessage + if err := decoder.Decode(&envelope); err != nil { + _ = c.forceCloseWithError(err) + return + } + + if methodRaw, hasMethod := envelope["method"]; hasMethod { + method := decodeRawJSONString(methodRaw) + if strings.TrimSpace(method) == "" { + continue + } + notification := gatewayRPCNotification{ + Method: method, + } + if paramsRaw, hasParams := envelope["params"]; hasParams { + notification.Params = cloneJSONRawMessage(paramsRaw) + } + if !c.enqueueNotification(notification) { + return + } + continue + } + + if idRaw, hasID := envelope["id"]; hasID { + response, err := decodeGatewayRPCResponse(envelope) + if err != nil { + continue + } + response.ID = normalizeJSONRPCResponseID(idRaw) + c.dispatchResponse(response) + } + } +} + +// startNotificationDispatcher 启动通知转发协程,配合 enqueue 超时保护避免 readLoop 长时间背压阻塞。 +func (c *GatewayRPCClient) startNotificationDispatcher() { + c.notificationStart.Do(func() { + c.notificationWG.Add(1) + go func() { + defer c.notificationWG.Done() + for { + select { + case <-c.closed: + return + case notification, ok := <-c.notificationQueue: + if !ok { + return + } + select { + case <-c.closed: + return + case c.notifications <- notification: + } + } + } + }() + }) +} + +// enqueueNotification 投递通知到内部队列;若背压持续超时则主动断开连接,避免 readLoop 无限阻塞。 +func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) bool { + enqueueTimeout := c.notificationEnqueueTimeout + if enqueueTimeout <= 0 { + enqueueTimeout = defaultGatewayNotificationEnqueueTimeout + } + timer := time.NewTimer(enqueueTimeout) + defer timer.Stop() + + select { + case <-c.closed: + return false + case c.notificationQueue <- notification: + return true + case <-timer.C: + err := fmt.Errorf("gateway rpc client: notification queue blocked for %s", enqueueTimeout) + log.Printf("warning: gateway rpc client force close due to notification backpressure method=%s err=%v", notification.Method, err) + _ = c.forceCloseWithError(err) + return false + } +} + +func (c *GatewayRPCClient) dispatchResponse(response gatewayRPCResponse) { + if strings.TrimSpace(response.ID) == "" { + return + } + c.stateMu.Lock() + ch := c.pending[response.ID] + delete(c.pending, response.ID) + c.stateMu.Unlock() + if ch == nil { + return + } + ch <- response +} + +func (c *GatewayRPCClient) registerPending(requestID string, ch chan gatewayRPCResponse) bool { + c.stateMu.Lock() + defer c.stateMu.Unlock() + select { + case <-c.closed: + return false + default: + } + c.pending[requestID] = ch + return true +} + +func (c *GatewayRPCClient) unregisterPending(requestID string) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + delete(c.pending, requestID) +} + +func (c *GatewayRPCClient) resetConnection() { + c.stateMu.Lock() + conn := c.conn + c.conn = nil + c.stateMu.Unlock() + if conn != nil { + _ = conn.Close() + } +} + +func (c *GatewayRPCClient) forceCloseWithError(cause error) error { + c.stateMu.Lock() + conn := c.conn + c.conn = nil + pending := c.pending + c.pending = make(map[string]chan gatewayRPCResponse) + c.stateMu.Unlock() + + if conn != nil { + _ = conn.Close() + } + + transportErr := cause + if transportErr == nil { + transportErr = errors.New("gateway rpc connection closed") + } + for _, ch := range pending { + ch <- gatewayRPCResponse{TransportErr: transportErr} + } + return nil +} + +func mapGatewayRPCError(method string, rpcError *protocol.JSONRPCError) error { + if rpcError == nil { + return &GatewayRPCError{ + Method: method, + Code: protocol.JSONRPCCodeInternalError, + GatewayCode: protocol.GatewayCodeInternalError, + Message: "gateway returned empty rpc error", + } + } + + message := strings.TrimSpace(rpcError.Message) + if message == "" { + message = "gateway returned empty rpc error message" + } + return &GatewayRPCError{ + Method: method, + Code: rpcError.Code, + GatewayCode: strings.TrimSpace(protocol.GatewayCodeFromJSONRPCError(rpcError)), + Message: message, + } +} + +func decodeGatewayRPCResponse(envelope map[string]json.RawMessage) (gatewayRPCResponse, error) { + raw, err := json.Marshal(envelope) + if err != nil { + return gatewayRPCResponse{}, err + } + var response protocol.JSONRPCResponse + if err := json.Unmarshal(raw, &response); err != nil { + return gatewayRPCResponse{}, err + } + return gatewayRPCResponse{ + Result: cloneJSONRawMessage(response.Result), + RPCError: response.Error, + }, nil +} + +func normalizeJSONRPCResponseID(raw json.RawMessage) string { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" || trimmed == "null" { + return "" + } + var asString string + if err := json.Unmarshal(raw, &asString); err == nil { + return strings.TrimSpace(asString) + } + return trimmed +} + +func decodeRawJSONString(raw json.RawMessage) string { + var out string + if err := json.Unmarshal(raw, &out); err != nil { + return "" + } + return strings.TrimSpace(out) +} + +// marshalJSONRawMessage 将任意值编码为独立的 RawMessage,避免复用外部可变切片。 +func marshalJSONRawMessage(value any) (json.RawMessage, error) { + raw, err := json.Marshal(value) + if err != nil { + return nil, err + } + return cloneJSONRawMessage(raw), nil +} + +// cloneJSONRawMessage 复制 RawMessage 底层字节,避免跨协程共享同一底层数组。 +func cloneJSONRawMessage(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 { + return nil + } + cloned := make([]byte, len(raw)) + copy(cloned, raw) + return json.RawMessage(cloned) +} + +func isRetryableGatewayCallError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + var transportErr *gatewayRPCTransportError + if errors.As(err, &transportErr) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + + var rpcErr *GatewayRPCError + if errors.As(err, &rpcErr) { + return strings.EqualFold(strings.TrimSpace(rpcErr.GatewayCode), protocol.GatewayCodeTimeout) + } + return false +} + +// loadGatewayAuthToken 读取 Gateway 静默认证 Token。 +func loadGatewayAuthToken(tokenFile string) (string, error) { + token, err := gatewayauth.LoadTokenFromFile(strings.TrimSpace(tokenFile)) + if err != nil { + return "", fmt.Errorf("gateway rpc client: load auth token: %w", err) + } + token = strings.TrimSpace(token) + if token == "" { + return "", errors.New("gateway rpc client: auth token is empty") + } + return token, nil +} diff --git a/internal/tui/services/gateway_rpc_client_additional_test.go b/internal/tui/services/gateway_rpc_client_additional_test.go new file mode 100644 index 00000000..10041047 --- /dev/null +++ b/internal/tui/services/gateway_rpc_client_additional_test.go @@ -0,0 +1,617 @@ +package services + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" + + "neo-code/internal/gateway/protocol" +) + +type stubConn struct { + writeErr error + closed bool + mu sync.Mutex +} + +func (s *stubConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (s *stubConn) Write(p []byte) (int, error) { + if s.writeErr != nil { + return 0, s.writeErr + } + return len(p), nil +} +func (s *stubConn) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} +func (s *stubConn) LocalAddr() net.Addr { return &net.UnixAddr{} } +func (s *stubConn) RemoteAddr() net.Addr { return &net.UnixAddr{} } +func (s *stubConn) SetDeadline(_ time.Time) error { return nil } +func (s *stubConn) SetReadDeadline(_ time.Time) error { return nil } +func (s *stubConn) SetWriteDeadline(_ time.Time) error { return nil } + +func TestGatewayRPCErrorAndTransportErrorFormatting(t *testing.T) { + t.Parallel() + + var rpcErr *GatewayRPCError + if rpcErr.Error() != "" { + t.Fatalf("nil GatewayRPCError should render empty string") + } + + errWithCode := (&GatewayRPCError{Method: "gateway.run", GatewayCode: "timeout", Message: "boom"}).Error() + if !strings.Contains(errWithCode, "timeout") { + t.Fatalf("GatewayRPCError with code = %q", errWithCode) + } + + var transportErr *gatewayRPCTransportError + if transportErr.Error() != "" || transportErr.Unwrap() != nil { + t.Fatalf("nil transport error should render empty and unwrap nil") + } + + methodErr := &gatewayRPCTransportError{Method: "gateway.run", Err: errors.New("down")} + if !strings.Contains(methodErr.Error(), "gateway.run") { + t.Fatalf("unexpected transport error text: %q", methodErr.Error()) + } + noMethodErr := (&gatewayRPCTransportError{Err: errors.New("down")}).Error() + if !strings.Contains(noMethodErr, "transport error") { + t.Fatalf("unexpected no-method transport error text: %q", noMethodErr) + } + if !errors.Is(methodErr, methodErr.Err) { + t.Fatalf("transport error should unwrap original cause") + } +} + +func TestGatewayRPCClientHelperFunctions(t *testing.T) { + t.Parallel() + + mapped := mapGatewayRPCError("gateway.ping", nil) + typed, ok := mapped.(*GatewayRPCError) + if !ok || typed.GatewayCode != protocol.GatewayCodeInternalError { + t.Fatalf("mapGatewayRPCError(nil) = %#v", mapped) + } + + emptyMessage := mapGatewayRPCError("gateway.ping", &protocol.JSONRPCError{Code: protocol.JSONRPCCodeInternalError}) + if !strings.Contains(emptyMessage.Error(), "empty rpc error message") { + t.Fatalf("empty message fallback missing: %v", emptyMessage) + } + + if normalizeJSONRPCResponseID(json.RawMessage(`"id-1"`)) != "id-1" { + t.Fatalf("normalize string id mismatch") + } + if normalizeJSONRPCResponseID(json.RawMessage(` 7 `)) != "7" { + t.Fatalf("normalize numeric id mismatch") + } + if normalizeJSONRPCResponseID(json.RawMessage(`null`)) != "" { + t.Fatalf("normalize null id mismatch") + } + if decodeRawJSONString(json.RawMessage(`" m "`)) != "m" { + t.Fatalf("decodeRawJSONString mismatch") + } + if decodeRawJSONString(json.RawMessage(`{`)) != "" { + t.Fatalf("decodeRawJSONString invalid payload should return empty") + } + + raw, err := marshalJSONRawMessage(map[string]any{"ok": true}) + if err != nil || len(raw) == 0 { + t.Fatalf("marshalJSONRawMessage() = (%q, %v)", raw, err) + } + if _, err := marshalJSONRawMessage(func() {}); err == nil { + t.Fatalf("expected marshalJSONRawMessage() error for function input") + } + + if cloneJSONRawMessage(nil) != nil { + t.Fatalf("clone nil should return nil") + } + origin := json.RawMessage(`{"k":"v"}`) + cloned := cloneJSONRawMessage(origin) + origin[0] = '[' + if string(cloned) != `{"k":"v"}` { + t.Fatalf("cloneJSONRawMessage should deep clone, got %q", string(cloned)) + } + + if !isRetryableGatewayCallError(context.DeadlineExceeded) { + t.Fatalf("deadline exceeded should be retryable") + } + if isRetryableGatewayCallError(context.Canceled) { + t.Fatalf("context canceled should not be retryable") + } + if !isRetryableGatewayCallError(&gatewayRPCTransportError{Err: errors.New("x")}) { + t.Fatalf("transport error should be retryable") + } + if !isRetryableGatewayCallError(&GatewayRPCError{GatewayCode: protocol.GatewayCodeTimeout}) { + t.Fatalf("gateway timeout should be retryable") + } + if isRetryableGatewayCallError(errors.New("plain")) { + t.Fatalf("plain error should not be retryable") + } + if isRetryableGatewayCallError(nil) { + t.Fatalf("nil error should not be retryable") + } + + if _, err := decodeGatewayRPCResponse(map[string]json.RawMessage{"id": json.RawMessage(`bad`)}); err == nil { + t.Fatalf("expected decodeGatewayRPCResponse marshal error") + } + if _, err := decodeGatewayRPCResponse(map[string]json.RawMessage{"id": json.RawMessage(`"id"`), "result": json.RawMessage(`{`)}); err == nil { + t.Fatalf("expected decodeGatewayRPCResponse unmarshal error") + } +} + +func TestGatewayRPCClientPendingAndForceCloseBranches(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: map[string]chan gatewayRPCResponse{}, + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + + ch := make(chan gatewayRPCResponse, 1) + if ok := client.registerPending("req-1", ch); !ok { + t.Fatalf("registerPending should succeed") + } + client.dispatchResponse(gatewayRPCResponse{ID: "req-1", Result: json.RawMessage(`{"ok":true}`)}) + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("dispatchResponse did not deliver response") + } + + client.dispatchResponse(gatewayRPCResponse{ID: ""}) + client.dispatchResponse(gatewayRPCResponse{ID: "missing"}) + client.unregisterPending("missing") + + pendingCh := make(chan gatewayRPCResponse, 1) + client.pending["req-2"] = pendingCh + if err := client.forceCloseWithError(nil); err != nil { + t.Fatalf("forceCloseWithError() error = %v", err) + } + select { + case response := <-pendingCh: + if response.TransportErr == nil { + t.Fatalf("expected transport error to be forwarded") + } + case <-time.After(time.Second): + t.Fatalf("pending response channel not notified") + } + + close(client.closed) + if ok := client.registerPending("req-3", make(chan gatewayRPCResponse, 1)); ok { + t.Fatalf("registerPending should fail after client closed") + } + client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}) + + resetConn := &stubConn{} + client.conn = resetConn + client.resetConnection() + if !resetConn.closed { + t.Fatalf("resetConnection should close active connection") + } +} + +func TestLoadGatewayAuthTokenErrorBranches(t *testing.T) { + t.Parallel() + + missingPath := filepath.Join(t.TempDir(), "missing-token.json") + if _, err := loadGatewayAuthToken(missingPath); err == nil { + t.Fatalf("expected load token error for missing file") + } + + path := filepath.Join(t.TempDir(), "auth.json") + err := os.WriteFile(path, []byte(`{"version":1,"token":" ","created_at":"2026-04-20T00:00:00Z","updated_at":"2026-04-20T00:00:00Z"}`), 0o600) + if err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := loadGatewayAuthToken(path); err == nil || !strings.Contains(err.Error(), "auth token is empty") { + if err == nil || !strings.Contains(err.Error(), "token is empty") { + t.Fatalf("expected empty auth token error, got %v", err) + } + } +} + +func TestGatewayRPCClientCallWithClosedClientAndInvalidResult(t *testing.T) { + t.Parallel() + + tokenFile, _ := createTestAuthTokenFile(t) + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + dec := json.NewDecoder(serverConn) + enc := json.NewEncoder(serverConn) + req := readRPCRequestOrFail(t, dec) + response := protocol.JSONRPCResponse{JSONRPC: protocol.JSONRPCVersion, ID: req.ID, Result: json.RawMessage(`1`)} + if encodeErr := enc.Encode(response); encodeErr != nil { + t.Errorf("encode response: %v", encodeErr) + } + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + var out map[string]any + callErr := client.CallWithOptions(context.Background(), protocol.MethodGatewayPing, map[string]any{}, &out, GatewayRPCCallOptions{Timeout: time.Second}) + if callErr == nil || !strings.Contains(callErr.Error(), "decode") { + t.Fatalf("expected decode error, got %v", callErr) + } + + _ = client.Close() + if err := client.CallWithOptions(context.Background(), protocol.MethodGatewayPing, nil, nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected closed client call error") + } +} + +func TestNewGatewayRPCClientConstructorBranches(t *testing.T) { + t.Parallel() + + tokenFile, _ := createTestAuthTokenFile(t) + _, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: tokenFile, + ResolveListenAddress: func(string) (string, error) { + return "", errors.New("resolve failed") + }, + }) + if err == nil || !strings.Contains(err.Error(), "resolve listen address") { + t.Fatalf("expected resolve listen address error, got %v", err) + } + + _, err = NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: filepath.Join(t.TempDir(), "missing.json"), + ResolveListenAddress: func(string) (string, error) { + return "ipc://x", nil + }, + }) + if err == nil || !strings.Contains(err.Error(), "load auth token") { + t.Fatalf("expected load auth token error, got %v", err) + } + + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "x", + TokenFile: tokenFile, + ResolveListenAddress: func(string) (string, error) { + return "ipc://x", nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + if client.requestTimeout != defaultGatewayRPCRequestTimeout || client.retryCount != defaultGatewayRPCRetryCount || client.dialFn == nil { + t.Fatalf("default options not applied: timeout=%v retry=%d dialFnNil=%v", client.requestTimeout, client.retryCount, client.dialFn == nil) + } + _ = client.Close() +} + +func TestGatewayRPCClientCallOnceBranches(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + listenAddress: "x", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := client.callOnce(ctx, "m", nil, nil, time.Second); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled error, got %v", err) + } + + client.dialFn = func(string) (net.Conn, error) { return nil, errors.New("dial failed") } + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err == nil || !strings.Contains(err.Error(), "transport") { + t.Fatalf("expected dial transport error, got %v", err) + } + + conn := &stubConn{} + client.conn = conn + if err := client.callOnce(context.Background(), "m", func() {}, nil, time.Second); err == nil || !strings.Contains(err.Error(), "encode request params") { + t.Fatalf("expected params encode error, got %v", err) + } + + close(client.closed) + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err == nil || !strings.Contains(err.Error(), "closed") { + t.Fatalf("expected closed client error, got %v", err) + } +} + +func TestGatewayRPCClientCallOnceResponseBranches(t *testing.T) { + t.Parallel() + + newClient := func() *GatewayRPCClient { + return &GatewayRPCClient{ + listenAddress: "x", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + conn: &stubConn{}, + } + } + + t.Run("transport error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", TransportErr: errors.New("broken")}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "transport") { + t.Fatalf("expected transport response error, got %v", err) + } + }) + + t.Run("rpc error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", RPCError: &protocol.JSONRPCError{Code: -32000, Message: "bad"}}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "gateway rpc m failed") { + t.Fatalf("expected rpc mapped error, got %v", err) + } + }) + + t.Run("result nil", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1"}) + }() + if err := client.callOnce(context.Background(), "m", nil, nil, time.Second); err != nil { + t.Fatalf("expected nil result path, got %v", err) + } + }) + + t.Run("empty result", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1"}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "result is empty") { + t.Fatalf("expected empty result error, got %v", err) + } + }) + + t.Run("decode error", func(t *testing.T) { + client := newClient() + go func() { + time.Sleep(10 * time.Millisecond) + client.dispatchResponse(gatewayRPCResponse{ID: "tui-1", Result: json.RawMessage(`1`)}) + }() + err := client.callOnce(context.Background(), "m", nil, &map[string]any{}, time.Second) + if err == nil || !strings.Contains(err.Error(), "decode m response") { + t.Fatalf("expected decode error, got %v", err) + } + }) +} + +func TestGatewayRPCClientReadLoopAdditionalBranches(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + client.startNotificationDispatcher() + go client.readLoop(clientConn) + + encoder := json.NewEncoder(serverConn) + _ = encoder.Encode(map[string]any{"method": " "}) + _ = encoder.Encode(map[string]any{"id": json.RawMessage(`\"x\"`), "result": json.RawMessage(`{`)}) + _ = encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"x": 1}}) + _ = serverConn.Close() + + select { + case <-client.notifications: + case <-time.After(2 * time.Second): + t.Fatalf("expected one forwarded notification") + } + + _ = client.Close() +} + +func TestGatewayRPCClientNotificationDispatcherStopsOnCloseSignal(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + client.startNotificationDispatcher() + close(client.closed) + client.notificationWG.Wait() +} + +func TestGatewayRPCClientEnqueueNotificationDoesNotDropUnderQueuePressure(t *testing.T) { + t.Parallel() + + const total = 256 + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + client.startNotificationDispatcher() + t.Cleanup(func() { _ = client.Close() }) + + receivedCh := make(chan struct{}, total) + go func() { + for range client.Notifications() { + receivedCh <- struct{}{} + } + }() + + var enqueueWG sync.WaitGroup + for i := 0; i < total; i++ { + enqueueWG.Add(1) + go func(index int) { + defer enqueueWG.Done() + client.enqueueNotification(gatewayRPCNotification{ + Method: protocol.MethodGatewayEvent, + Params: json.RawMessage(`{"index":` + strconv.Itoa(index) + `}`), + }) + }(i) + } + + waitDone := make(chan struct{}) + go func() { + enqueueWG.Wait() + close(waitDone) + }() + + select { + case <-waitDone: + case <-time.After(5 * time.Second): + t.Fatalf("enqueue notifications timed out under queue pressure") + } + + for i := 0; i < total; i++ { + select { + case <-receivedCh: + case <-time.After(5 * time.Second): + t.Fatalf("expected %d notifications, got %d", total, i) + } + } +} + +func TestGatewayRPCClientReadLoopFailsFastOnNotificationBackpressure(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + }) + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification), + notificationQueue: make(chan gatewayRPCNotification, 1), + notificationEnqueueTimeout: 50 * time.Millisecond, + } + client.startNotificationDispatcher() + t.Cleanup(func() { _ = client.Close() }) + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + client.readLoop(clientConn) + }() + encoder := json.NewEncoder(serverConn) + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 1}}); err != nil { + t.Fatalf("encode first notification: %v", err) + } + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 2}}); err != nil { + t.Fatalf("encode second notification: %v", err) + } + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 3}}); err != nil { + t.Fatalf("encode third notification: %v", err) + } + + select { + case <-readDone: + case <-time.After(time.Second): + t.Fatalf("expected readLoop to fail-fast on sustained notification backpressure") + } +} + +func TestGatewayRPCClientEnqueueNotificationUnblocksOnClose(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification), + notificationQueue: make(chan gatewayRPCNotification, 1), + notificationEnqueueTimeout: time.Second, + } + client.startNotificationDispatcher() + + // 首条通知占满队列,第二条通知会阻塞在 enqueue,关闭客户端后应立即退出。 + client.notificationQueue <- gatewayRPCNotification{Method: protocol.MethodGatewayEvent} + + done := make(chan struct{}) + go func() { + defer close(done) + client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}) + }() + + time.Sleep(20 * time.Millisecond) + _ = client.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("enqueueNotification should unblock when client closes") + } +} + +func TestGatewayRPCClientWriteRequestFailure(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 1), + notificationQueue: make(chan gatewayRPCNotification, 1), + } + conn := &stubConn{writeErr: errors.New("write failed")} + err := client.writeRequest(conn, protocol.JSONRPCRequest{JSONRPC: protocol.JSONRPCVersion, ID: json.RawMessage(`\"id\"`), Method: "m"}) + if err == nil || !strings.Contains(err.Error(), "write rpc request failed") { + t.Fatalf("expected write request error, got %v", err) + } +} + +func TestGatewayRPCClientDecodeResponseSuccessAndRetryableNetError(t *testing.T) { + t.Parallel() + + response, err := decodeGatewayRPCResponse(map[string]json.RawMessage{ + "id": json.RawMessage(`"id"`), + "result": json.RawMessage(`{"ok":true}`), + }) + if err != nil || !bytes.Contains(response.Result, []byte(`ok`)) { + t.Fatalf("decodeGatewayRPCResponse success mismatch: (%#v, %v)", response, err) + } + + netErr := &net.DNSError{IsTimeout: true} + if !isRetryableGatewayCallError(netErr) { + t.Fatalf("net timeout error should be retryable") + } +} diff --git a/internal/tui/services/gateway_rpc_client_test.go b/internal/tui/services/gateway_rpc_client_test.go new file mode 100644 index 00000000..cd9cfa6f --- /dev/null +++ b/internal/tui/services/gateway_rpc_client_test.go @@ -0,0 +1,313 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "net" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "neo-code/internal/gateway" + gatewayauth "neo-code/internal/gateway/auth" + "neo-code/internal/gateway/protocol" +) + +func TestGatewayRPCClientAuthenticateCallAndNotification(t *testing.T) { + tokenFile, token := createTestAuthTokenFile(t) + + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + + request := readRPCRequestOrFail(t, decoder) + if request.Method != protocol.MethodGatewayAuthenticate { + t.Fatalf("authenticate method = %q", request.Method) + } + var params protocol.AuthenticateParams + if err := json.Unmarshal(request.Params, ¶ms); err != nil { + t.Fatalf("decode authenticate params: %v", err) + } + if params.Token != token { + t.Fatalf("authenticate token = %q, want %q", params.Token, token) + } + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionAuthenticate, + }) + + request = readRPCRequestOrFail(t, decoder) + if request.Method != protocol.MethodGatewayPing { + t.Fatalf("call method = %q, want %q", request.Method, protocol.MethodGatewayPing) + } + writeRPCNotificationOrFail(t, encoder, protocol.MethodGatewayEvent, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-1", + RunID: "run-1", + Payload: map[string]any{ + "runtime_event_type": string("agent_chunk"), + "payload": "hello", + }, + }) + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionPing, + SessionID: "session-1", + RunID: "run-1", + }) + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + if err := client.Authenticate(context.Background()); err != nil { + t.Fatalf("Authenticate() error = %v", err) + } + + var frame gateway.MessageFrame + if err := client.Call(context.Background(), protocol.MethodGatewayPing, map[string]any{}, &frame); err != nil { + t.Fatalf("Call() error = %v", err) + } + if frame.Type != gateway.FrameTypeAck || frame.Action != gateway.FrameActionPing { + t.Fatalf("unexpected rpc result frame: %#v", frame) + } + + select { + case notification := <-client.Notifications(): + if notification.Method != protocol.MethodGatewayEvent { + t.Fatalf("notification method = %q, want %q", notification.Method, protocol.MethodGatewayEvent) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for notification") + } +} + +func TestGatewayRPCClientRetriesAfterTransportError(t *testing.T) { + tokenFile, _ := createTestAuthTokenFile(t) + + var dialCount int32 + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + attempt := atomic.AddInt32(&dialCount, 1) + if attempt == 1 { + return nil, errors.New("dial failed once") + } + + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + request := readRPCRequestOrFail(t, decoder) + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionPing, + }) + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + var frame gateway.MessageFrame + err = client.CallWithOptions( + context.Background(), + protocol.MethodGatewayPing, + map[string]any{}, + &frame, + GatewayRPCCallOptions{ + Timeout: 2 * time.Second, + Retries: 1, + }, + ) + if err != nil { + t.Fatalf("CallWithOptions() error = %v", err) + } + if atomic.LoadInt32(&dialCount) != 2 { + t.Fatalf("dial count = %d, want %d", atomic.LoadInt32(&dialCount), 2) + } + if frame.Action != gateway.FrameActionPing { + t.Fatalf("unexpected frame: %#v", frame) + } +} + +func TestGatewayRPCClientUsesDefaultRetryCountWhenOptionIsZero(t *testing.T) { + tokenFile, _ := createTestAuthTokenFile(t) + + var dialCount int32 + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + RetryCount: 0, + Dial: func(_ string) (net.Conn, error) { + attempt := atomic.AddInt32(&dialCount, 1) + if attempt == 1 { + return nil, errors.New("dial failed once") + } + + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + request := readRPCRequestOrFail(t, decoder) + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionPing, + }) + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + if client.retryCount != defaultGatewayRPCRetryCount { + t.Fatalf("retryCount = %d, want %d", client.retryCount, defaultGatewayRPCRetryCount) + } + + var frame gateway.MessageFrame + if err := client.Call(context.Background(), protocol.MethodGatewayPing, map[string]any{}, &frame); err != nil { + t.Fatalf("Call() error = %v", err) + } + if atomic.LoadInt32(&dialCount) != 2 { + t.Fatalf("dial count = %d, want %d", atomic.LoadInt32(&dialCount), 2) + } + if frame.Action != gateway.FrameActionPing { + t.Fatalf("unexpected frame: %#v", frame) + } +} + +func TestGatewayRPCClientCallWithEmptyMethodReturnsError(t *testing.T) { + tokenFile, _ := createTestAuthTokenFile(t) + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + return nil, errors.New("should not dial") + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + err = client.CallWithOptions(context.Background(), " ", nil, nil, GatewayRPCCallOptions{}) + if err == nil || !strings.Contains(err.Error(), "method is empty") { + t.Fatalf("expected method empty error, got %v", err) + } +} + +func TestGatewayRPCClientReadLoopSustainsBackpressureWhenNotificationsAreConsumed(t *testing.T) { + tokenFile, _ := createTestAuthTokenFile(t) + + client, err := NewGatewayRPCClient(GatewayRPCClientOptions{ + ListenAddress: "test://gateway", + TokenFile: tokenFile, + Dial: func(_ string) (net.Conn, error) { + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + + request := readRPCRequestOrFail(t, decoder) + for idx := 0; idx < defaultGatewayNotificationQueue+defaultGatewayNotificationBuffer+128; idx++ { + writeRPCNotificationOrFail(t, encoder, protocol.MethodGatewayEvent, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-1", + RunID: "run-1", + Payload: map[string]any{ + "index": idx, + }, + }) + } + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionPing, + }) + }() + return clientConn, nil + }, + }) + if err != nil { + t.Fatalf("NewGatewayRPCClient() error = %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + go func() { + for range client.Notifications() { + } + }() + + callErr := client.CallWithOptions( + context.Background(), + protocol.MethodGatewayPing, + map[string]any{}, + &gateway.MessageFrame{}, + GatewayRPCCallOptions{Timeout: 2 * time.Second}, + ) + if callErr != nil { + t.Fatalf("CallWithOptions() should succeed when notifications are back-pressured, got %v", callErr) + } +} + +func createTestAuthTokenFile(t *testing.T) (string, string) { + t.Helper() + path := filepath.Join(t.TempDir(), "auth.json") + manager, err := gatewayauth.NewManager(path) + if err != nil { + t.Fatalf("gatewayauth.NewManager() error = %v", err) + } + return path, manager.Token() +} + +func readRPCRequestOrFail(t *testing.T, decoder *json.Decoder) protocol.JSONRPCRequest { + t.Helper() + var request protocol.JSONRPCRequest + if err := decoder.Decode(&request); err != nil { + t.Fatalf("decode rpc request: %v", err) + } + return request +} + +func writeRPCResultOrFail(t *testing.T, encoder *json.Encoder, id json.RawMessage, result any) { + t.Helper() + response, rpcErr := protocol.NewJSONRPCResultResponse(id, result) + if rpcErr != nil { + t.Fatalf("build jsonrpc result: %+v", rpcErr) + } + if err := encoder.Encode(response); err != nil { + t.Fatalf("encode jsonrpc result: %v", err) + } +} + +func writeRPCNotificationOrFail(t *testing.T, encoder *json.Encoder, method string, params any) { + t.Helper() + notification := protocol.NewJSONRPCNotification(method, params) + if err := encoder.Encode(notification); err != nil { + t.Fatalf("encode notification: %v", err) + } +} diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go new file mode 100644 index 00000000..c7091b43 --- /dev/null +++ b/internal/tui/services/gateway_stream_client.go @@ -0,0 +1,416 @@ +package services + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" + "neo-code/internal/runtime/controlplane" + "neo-code/internal/tools" +) + +// GatewayStreamClient 负责消费 gateway.event 通知并恢复为 runtime 事件。 +type GatewayStreamClient struct { + source <-chan gatewayRPCNotification + + closeOnce sync.Once + closeCh chan struct{} + done chan struct{} + events chan agentruntime.RuntimeEvent +} + +// NewGatewayStreamClient 创建并启动网关事件流消费者。 +func NewGatewayStreamClient(source <-chan gatewayRPCNotification) *GatewayStreamClient { + client := &GatewayStreamClient{ + source: source, + closeCh: make(chan struct{}), + done: make(chan struct{}), + events: make(chan agentruntime.RuntimeEvent, 128), + } + go client.run() + return client +} + +// Events 返回恢复后的 runtime 事件流。 +func (c *GatewayStreamClient) Events() <-chan agentruntime.RuntimeEvent { + return c.events +} + +// Close 停止事件消费并释放内部资源。 +func (c *GatewayStreamClient) Close() error { + c.closeOnce.Do(func() { + close(c.closeCh) + <-c.done + }) + return nil +} + +// run 持续读取网关通知并向上游输出 runtime 事件。 +func (c *GatewayStreamClient) run() { + defer close(c.done) + defer close(c.events) + + for { + select { + case <-c.closeCh: + return + case notification, ok := <-c.source: + if !ok { + return + } + if !strings.EqualFold(strings.TrimSpace(notification.Method), protocol.MethodGatewayEvent) { + continue + } + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + select { + case <-c.closeCh: + return + case c.events <- agentruntime.RuntimeEvent{ + Type: agentruntime.EventError, + Timestamp: time.Now().UTC(), + Payload: fmt.Sprintf("gateway stream decode error: %v", err), + }: + } + continue + } + + select { + case <-c.closeCh: + return + case c.events <- event: + } + } + } +} + +// decodeRuntimeEventFromGatewayNotification 将单条 gateway.event 通知还原为 runtime 事件。 +func decodeRuntimeEventFromGatewayNotification(notification gatewayRPCNotification) (agentruntime.RuntimeEvent, error) { + var frame gateway.MessageFrame + if len(notification.Params) == 0 { + return agentruntime.RuntimeEvent{}, fmt.Errorf("gateway.event params is empty") + } + if err := json.Unmarshal(notification.Params, &frame); err != nil { + return agentruntime.RuntimeEvent{}, fmt.Errorf("decode gateway.event frame: %w", err) + } + + envelope, ok := extractRuntimeEnvelope(frame.Payload) + if !ok { + return agentruntime.RuntimeEvent{}, fmt.Errorf("missing runtime event envelope") + } + + eventType := agentruntime.EventType(strings.TrimSpace(streamReadMapString(envelope, "runtime_event_type"))) + if eventType == "" { + return agentruntime.RuntimeEvent{}, fmt.Errorf("missing runtime_event_type") + } + + event := agentruntime.RuntimeEvent{ + Type: eventType, + RunID: strings.TrimSpace(frame.RunID), + SessionID: strings.TrimSpace(frame.SessionID), + Turn: streamReadMapInt(envelope, "turn"), + Phase: streamReadMapString(envelope, "phase"), + Timestamp: streamReadMapTime(envelope, "timestamp"), + PayloadVersion: streamReadMapInt(envelope, "payload_version"), + } + if event.Timestamp.IsZero() { + event.Timestamp = time.Now().UTC() + } + + rawPayload, _ := streamReadMapValue(envelope, "payload") + restoredPayload, err := restoreRuntimePayload(event.Type, rawPayload) + if err != nil { + return agentruntime.RuntimeEvent{}, err + } + event.Payload = restoredPayload + return event, nil +} + +// extractRuntimeEnvelope 从网关事件 payload 中抽取 runtime 事件包裹层。 +func extractRuntimeEnvelope(payload any) (map[string]any, bool) { + switch typed := payload.(type) { + case map[string]any: + if _, exists := streamReadMapValue(typed, "runtime_event_type"); exists { + return typed, true + } + if nested, exists := streamReadMapValue(typed, "payload"); exists { + if nestedMap, ok := nested.(map[string]any); ok { + if _, hasEventType := streamReadMapValue(nestedMap, "runtime_event_type"); hasEventType { + return nestedMap, true + } + } + } + case nil: + return nil, false + } + + raw, err := json.Marshal(payload) + if err != nil { + return nil, false + } + + var asMap map[string]any + if err := json.Unmarshal(raw, &asMap); err != nil { + return nil, false + } + if _, exists := streamReadMapValue(asMap, "runtime_event_type"); exists { + return asMap, true + } + if nested, exists := streamReadMapValue(asMap, "payload"); exists { + if nestedMap, ok := nested.(map[string]any); ok { + if _, hasEventType := streamReadMapValue(nestedMap, "runtime_event_type"); hasEventType { + return nestedMap, true + } + } + } + return nil, false +} + +// restoreRuntimePayload 按事件类型将 payload 恢复为 TUI 可消费的强类型结构。 +func restoreRuntimePayload(eventType agentruntime.EventType, payload any) (any, error) { + switch eventType { + case agentruntime.EventUserMessage, agentruntime.EventAgentDone: + return decodeRuntimePayload[providertypes.Message](payload) + case agentruntime.EventToolStart: + return decodeRuntimePayload[providertypes.ToolCall](payload) + case agentruntime.EventToolResult: + return decodeRuntimePayload[tools.ToolResult](payload) + case agentruntime.EventPermissionRequested: + return decodeRuntimePayload[agentruntime.PermissionRequestPayload](payload) + case agentruntime.EventPermissionResolved: + return decodeRuntimePayload[agentruntime.PermissionResolvedPayload](payload) + case agentruntime.EventCompactApplied: + return decodeRuntimePayload[agentruntime.CompactResult](payload) + case agentruntime.EventCompactError: + return decodeRuntimePayload[agentruntime.CompactErrorPayload](payload) + case agentruntime.EventPhaseChanged: + return decodeRuntimePayload[agentruntime.PhaseChangedPayload](payload) + case agentruntime.EventStopReasonDecided: + return decodeStopReasonPayload(payload) + case agentruntime.EventInputNormalized: + return decodeRuntimePayload[agentruntime.InputNormalizedPayload](payload) + case agentruntime.EventAssetSaved: + return decodeRuntimePayload[agentruntime.AssetSavedPayload](payload) + case agentruntime.EventAssetSaveFailed: + return decodeRuntimePayload[agentruntime.AssetSaveFailedPayload](payload) + case agentruntime.EventTodoUpdated, agentruntime.EventTodoConflict: + return decodeRuntimePayload[agentruntime.TodoEventPayload](payload) + case agentruntime.EventType(RuntimeEventRunContext): + return decodeRuntimePayload[RuntimeRunContextPayload](payload) + case agentruntime.EventType(RuntimeEventToolStatus): + return decodeRuntimePayload[RuntimeToolStatusPayload](payload) + case agentruntime.EventType(RuntimeEventUsage): + return decodeRuntimePayload[RuntimeUsagePayload](payload) + case agentruntime.EventAgentChunk, agentruntime.EventToolChunk, agentruntime.EventError, + agentruntime.EventProviderRetry, agentruntime.EventToolCallThinking: + return decodeStringPayload(payload), nil + default: + return payload, nil + } +} + +// decodeStopReasonPayload 额外约束 stop reason 的枚举类型,避免字符串漂移。 +func decodeStopReasonPayload(payload any) (agentruntime.StopReasonDecidedPayload, error) { + decoded, err := decodeRuntimePayload[agentruntime.StopReasonDecidedPayload](payload) + if err != nil { + return agentruntime.StopReasonDecidedPayload{}, err + } + decoded.Reason = controlplane.StopReason(strings.TrimSpace(string(decoded.Reason))) + return decoded, nil +} + +// decodeStringPayload 兼容字符串类事件的 payload 解码。 +func decodeStringPayload(payload any) string { + switch typed := payload.(type) { + case string: + return typed + case nil: + return "" + default: + return strings.TrimSpace(fmt.Sprintf("%v", typed)) + } +} + +// decodeRuntimePayload 使用 JSON 兜底做泛型反序列化,确保 map/struct 输入都可处理。 +func decodeRuntimePayload[T any](payload any) (T, error) { + var zero T + switch typed := payload.(type) { + case T: + return typed, nil + case *T: + if typed == nil { + return zero, fmt.Errorf("payload is nil") + } + return *typed, nil + } + + raw, err := json.Marshal(payload) + if err != nil { + return zero, fmt.Errorf("encode payload: %w", err) + } + if len(raw) == 0 || string(raw) == "null" { + return zero, fmt.Errorf("payload is empty") + } + + var decoded T + if err := json.Unmarshal(raw, &decoded); err != nil { + return zero, fmt.Errorf("decode payload: %w", err) + } + return decoded, nil +} + +// streamReadMapValue 提供对 snake/camel/大小写的兼容键读取。 +func streamReadMapValue(m map[string]any, key string) (any, bool) { + if len(m) == 0 { + return nil, false + } + + trimmedKey := strings.TrimSpace(key) + if trimmedKey == "" { + return nil, false + } + + if value, ok := m[trimmedKey]; ok { + return value, true + } + if value, ok := m[strings.ToLower(trimmedKey)]; ok { + return value, true + } + if value, ok := m[toSnakeCase(trimmedKey)]; ok { + return value, true + } + if value, ok := m[toLowerCamelCase(trimmedKey)]; ok { + return value, true + } + + target := normalizeMapLookupKey(trimmedKey) + for existingKey, value := range m { + if normalizeMapLookupKey(existingKey) == target { + return value, true + } + } + return nil, false +} + +// streamReadMapString 从动态 map 中读取字符串字段。 +func streamReadMapString(m map[string]any, key string) string { + value, ok := streamReadMapValue(m, key) + if !ok || value == nil { + return "" + } + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + default: + return strings.TrimSpace(fmt.Sprintf("%v", typed)) + } +} + +// streamReadMapInt 从动态 map 中读取整数字段,兼容 number/string。 +func streamReadMapInt(m map[string]any, key string) int { + value, ok := streamReadMapValue(m, key) + if !ok || value == nil { + return 0 + } + switch typed := value.(type) { + case int: + return typed + case int64: + return int(typed) + case int32: + return int(typed) + case float64: + return int(typed) + case float32: + return int(typed) + case json.Number: + number, err := typed.Int64() + if err != nil { + return 0 + } + return int(number) + case string: + number, err := strconv.Atoi(strings.TrimSpace(typed)) + if err != nil { + return 0 + } + return number + default: + return 0 + } +} + +// streamReadMapTime 从动态 map 中读取时间字段,支持 RFC3339Nano 字符串。 +func streamReadMapTime(m map[string]any, key string) time.Time { + value, ok := streamReadMapValue(m, key) + if !ok || value == nil { + return time.Time{} + } + switch typed := value.(type) { + case time.Time: + return typed + case string: + trimmed := strings.TrimSpace(typed) + if trimmed == "" { + return time.Time{} + } + parsed, err := time.Parse(time.RFC3339Nano, trimmed) + if err != nil { + return time.Time{} + } + return parsed + default: + return time.Time{} + } +} + +// normalizeMapLookupKey 将键名归一化后用于宽松匹配。 +func normalizeMapLookupKey(key string) string { + replacer := strings.NewReplacer("_", "", "-", "", " ", "") + return strings.ToLower(replacer.Replace(strings.TrimSpace(key))) +} + +// toSnakeCase 将字符串转为 snake_case,用于键名兼容读取。 +func toSnakeCase(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + + var builder strings.Builder + for index, r := range trimmed { + if r >= 'A' && r <= 'Z' { + if index > 0 { + builder.WriteByte('_') + } + builder.WriteRune(r + ('a' - 'A')) + continue + } + builder.WriteRune(r) + } + return strings.ToLower(builder.String()) +} + +// toLowerCamelCase 将首字母转小写,用于 lowerCamel 键名兼容。 +func toLowerCamelCase(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + runes := []rune(trimmed) + if len(runes) == 0 { + return "" + } + if runes[0] >= 'A' && runes[0] <= 'Z' { + runes[0] = runes[0] + ('a' - 'A') + } + return string(runes) +} diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go new file mode 100644 index 00000000..2d32ce1b --- /dev/null +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -0,0 +1,447 @@ +package services + +import ( + "encoding/json" + "reflect" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" + "neo-code/internal/runtime/controlplane" +) + +type streamInvalidJSONMarshaler struct { + raw []byte +} + +func (m streamInvalidJSONMarshaler) MarshalJSON() ([]byte, error) { + return m.raw, nil +} + +func TestDecodeRuntimeEventFromGatewayNotificationErrorBranches(t *testing.T) { + t.Parallel() + + if _, err := decodeRuntimeEventFromGatewayNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}); err == nil { + t.Fatalf("expected empty params error") + } + + if _, err := decodeRuntimeEventFromGatewayNotification(gatewayRPCNotification{ + Method: protocol.MethodGatewayEvent, + Params: json.RawMessage(`{"payload":{}}`), + }); err == nil { + t.Fatalf("expected missing envelope error") + } + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "payload": map[string]any{}, + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(notification); err == nil { + t.Fatalf("expected missing runtime_event_type error") + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationUsesCurrentTimeWhenTimestampMissing(t *testing.T) { + t.Parallel() + + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventError), + "payload": "boom", + }, + }) + + before := time.Now().UTC().Add(-time.Second) + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + if event.Timestamp.Before(before) { + t.Fatalf("event timestamp should fallback to now, got %v", event.Timestamp) + } +} + +func TestExtractRuntimeEnvelopeFallbackMarshalling(t *testing.T) { + t.Parallel() + + type payloadEnvelope struct { + Payload map[string]any `json:"payload"` + } + envelope, ok := extractRuntimeEnvelope(payloadEnvelope{Payload: map[string]any{ + "RuntimeEventType": string(agentruntime.EventError), + "payload": "x", + }}) + if !ok { + t.Fatalf("expected envelope to be detected") + } + if got := streamReadMapString(envelope, "runtime_event_type"); got != string(agentruntime.EventError) { + t.Fatalf("runtime_event_type = %q", got) + } + + if _, ok := extractRuntimeEnvelope(nil); ok { + t.Fatalf("nil payload should not decode") + } +} + +func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + eventType agentruntime.EventType + payload any + assertFn func(t *testing.T, got any) + }{ + { + name: "user message", + eventType: agentruntime.EventUserMessage, + payload: map[string]any{"Role": string(providertypes.RoleAssistant)}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if _, ok := got.(providertypes.Message); !ok { + t.Fatalf("payload type = %T", got) + } + }, + }, + { + name: "permission request", + eventType: agentruntime.EventPermissionRequested, + payload: map[string]any{"RequestID": "req-1"}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if v, ok := got.(agentruntime.PermissionRequestPayload); !ok || v.RequestID != "req-1" { + t.Fatalf("payload = %#v", got) + } + }, + }, + { + name: "stop reason", + eventType: agentruntime.EventStopReasonDecided, + payload: map[string]any{"reason": " max_rounds "}, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(agentruntime.StopReasonDecidedPayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.Reason != controlplane.StopReason("max_rounds") { + t.Fatalf("reason = %q", value.Reason) + } + }, + }, + { + name: "runtime usage payload", + eventType: agentruntime.EventType(RuntimeEventUsage), + payload: map[string]any{"delta": map[string]any{"inputtokens": 1}}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if _, ok := got.(RuntimeUsagePayload); !ok { + t.Fatalf("payload type = %T", got) + } + }, + }, + { + name: "string payload", + eventType: agentruntime.EventToolChunk, + payload: 42, + assertFn: func(t *testing.T, got any) { + t.Helper() + if got != "42" { + t.Fatalf("payload = %#v", got) + } + }, + }, + { + name: "default passthrough", + eventType: agentruntime.EventType("unknown"), + payload: map[string]any{"k": "v"}, + assertFn: func(t *testing.T, got any) { + t.Helper() + if !reflect.DeepEqual(got, map[string]any{"k": "v"}) { + t.Fatalf("payload = %#v", got) + } + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + got, err := restoreRuntimePayload(tc.eventType, tc.payload) + if err != nil { + t.Fatalf("restoreRuntimePayload() error = %v", err) + } + tc.assertFn(t, got) + }) + } +} + +func TestDecodeRuntimePayloadAndMapHelpers(t *testing.T) { + t.Parallel() + + typed, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](agentruntime.InputNormalizedPayload{TextLength: 1}) + if err != nil || typed.TextLength != 1 { + t.Fatalf("typed decode mismatch, got (%#v, %v)", typed, err) + } + + ptrValue := &agentruntime.InputNormalizedPayload{ImageCount: 3} + decodedPtr, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](ptrValue) + if err != nil || decodedPtr.ImageCount != 3 { + t.Fatalf("pointer decode mismatch, got (%#v, %v)", decodedPtr, err) + } + + var nilPtr *agentruntime.InputNormalizedPayload + if _, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](nilPtr); err == nil { + t.Fatalf("expected nil pointer decode error") + } + if _, err := decodeRuntimePayload[agentruntime.InputNormalizedPayload](nil); err == nil { + t.Fatalf("expected nil payload decode error") + } + + m := map[string]any{ + "runtimeEventType": "agent_chunk", + "turn": json.Number("7"), + "payloadVersion": "5", + "time_stamp": "2026-04-20T12:00:00Z", + } + if value, ok := streamReadMapValue(m, "runtime_event_type"); !ok || value != "agent_chunk" { + t.Fatalf("streamReadMapValue mismatch: (%v, %v)", value, ok) + } + if got := streamReadMapInt(m, "turn"); got != 7 { + t.Fatalf("streamReadMapInt(turn) = %d", got) + } + if got := streamReadMapInt(m, "payload_version"); got != 5 { + t.Fatalf("streamReadMapInt(payload_version) = %d", got) + } + if got := streamReadMapString(m, "runtime_event_type"); got != "agent_chunk" { + t.Fatalf("streamReadMapString = %q", got) + } + if got := streamReadMapTime(m, "time_stamp"); got.IsZero() { + t.Fatalf("streamReadMapTime should parse timestamp") + } + + if normalizeMapLookupKey(" Runtime_Event-Type ") != "runtimeeventtype" { + t.Fatalf("normalizeMapLookupKey mismatch") + } + if toSnakeCase("RuntimeEventType") != "runtime_event_type" { + t.Fatalf("toSnakeCase mismatch") + } + if toLowerCamelCase(" RuntimeEventType ") != "runtimeEventType" { + t.Fatalf("toLowerCamelCase mismatch") + } +} + +func TestGatewayStreamClientRunSkipsNonGatewayEventsAndStopsOnClose(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification, 4) + client := NewGatewayStreamClient(source) + + source <- gatewayRPCNotification{Method: "gateway.ping", Params: json.RawMessage(`{"foo":"bar"}`)} + source <- buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventAgentChunk), + "payload": "ok", + }, + }) + + select { + case event := <-client.Events(): + if event.Type != agentruntime.EventAgentChunk { + t.Fatalf("event.Type = %q", event.Type) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for stream event") + } + + if err := client.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if err := client.Close(); err != nil { + t.Fatalf("Close() second call error = %v", err) + } +} + +func TestGatewayStreamClientRunStopsWhenSourceClosed(t *testing.T) { + t.Parallel() + + source := make(chan gatewayRPCNotification) + client := NewGatewayStreamClient(source) + close(source) + + select { + case _, ok := <-client.Events(): + if ok { + t.Fatalf("events channel should be closed after source channel is closed") + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for events channel close") + } +} + +func TestRestoreRuntimePayloadAdditionalBranches(t *testing.T) { + t.Parallel() + + payloadCases := []struct { + eventType agentruntime.EventType + payload any + }{ + {eventType: agentruntime.EventAgentDone, payload: map[string]any{"Role": string(providertypes.RoleAssistant)}}, + {eventType: agentruntime.EventToolStart, payload: map[string]any{"Name": "bash"}}, + {eventType: agentruntime.EventPermissionResolved, payload: map[string]any{"RequestID": "req-1"}}, + {eventType: agentruntime.EventCompactApplied, payload: map[string]any{"Applied": true}}, + {eventType: agentruntime.EventCompactError, payload: map[string]any{"message": "boom"}}, + {eventType: agentruntime.EventPhaseChanged, payload: map[string]any{"from": "a", "to": "b"}}, + {eventType: agentruntime.EventInputNormalized, payload: map[string]any{"text_length": 3}}, + {eventType: agentruntime.EventAssetSaved, payload: map[string]any{"asset_id": "asset-1"}}, + {eventType: agentruntime.EventAssetSaveFailed, payload: map[string]any{"message": "x"}}, + {eventType: agentruntime.EventTodoUpdated, payload: map[string]any{"action": "replace"}}, + {eventType: agentruntime.EventTodoConflict, payload: map[string]any{"action": "conflict"}}, + {eventType: agentruntime.EventType(RuntimeEventRunContext), payload: map[string]any{"provider": "openai"}}, + {eventType: agentruntime.EventType(RuntimeEventToolStatus), payload: map[string]any{"status": "running"}}, + } + + for _, tc := range payloadCases { + if _, err := restoreRuntimePayload(tc.eventType, tc.payload); err != nil { + t.Fatalf("restoreRuntimePayload(%q) error = %v", tc.eventType, err) + } + } + + if _, err := restoreRuntimePayload(agentruntime.EventStopReasonDecided, map[string]any{"reason": func() {}}); err == nil { + t.Fatalf("stop reason payload should return decode error for non-serializable field") + } +} + +func TestStreamHelperBranches(t *testing.T) { + t.Parallel() + + if decodeStringPayload(nil) != "" { + t.Fatalf("decodeStringPayload(nil) should return empty string") + } + if decodeStringPayload("x") != "x" { + t.Fatalf("decodeStringPayload(string) mismatch") + } + + if _, err := decodeRuntimePayload[agentruntime.PhaseChangedPayload](func() {}); err == nil { + t.Fatalf("decodeRuntimePayload should fail on marshal error") + } + if _, err := decodeRuntimePayload[agentruntime.PhaseChangedPayload](map[string]any{"from": map[string]any{"bad": make(chan int)}}); err == nil { + t.Fatalf("decodeRuntimePayload should fail on invalid nested payload") + } + + if value, ok := streamReadMapValue(map[string]any{"RUNTIMEEVENTTYPE": "v"}, "runtime_event_type"); !ok || value != "v" { + t.Fatalf("streamReadMapValue normalized scan mismatch") + } + if _, ok := streamReadMapValue(nil, "key"); ok { + t.Fatalf("nil map lookup should fail") + } + if _, ok := streamReadMapValue(map[string]any{"k": 1}, " "); ok { + t.Fatalf("blank key lookup should fail") + } + + intCases := map[string]any{ + "i": 1, + "i64": int64(2), + "i32": int32(3), + "f64": float64(4), + "f32": float32(5), + "num": json.Number("6"), + "badnum": json.Number("x"), + "str": "7", + "badstr": "x", + } + if streamReadMapInt(intCases, "i") != 1 || + streamReadMapInt(intCases, "i64") != 2 || + streamReadMapInt(intCases, "i32") != 3 || + streamReadMapInt(intCases, "f64") != 4 || + streamReadMapInt(intCases, "f32") != 5 || + streamReadMapInt(intCases, "num") != 6 || + streamReadMapInt(intCases, "str") != 7 { + t.Fatalf("streamReadMapInt numeric coercion mismatch") + } + if streamReadMapInt(intCases, "badnum") != 0 || streamReadMapInt(intCases, "badstr") != 0 || streamReadMapInt(intCases, "missing") != 0 { + t.Fatalf("streamReadMapInt invalid values should return zero") + } + + now := time.Now().UTC() + timeCases := map[string]any{ + "as_time": now, + "as_text": now.Format(time.RFC3339Nano), + "invalid": "not-time", + "blanktxt": " ", + } + if !streamReadMapTime(timeCases, "as_time").Equal(now) { + t.Fatalf("streamReadMapTime(time.Time) mismatch") + } + if streamReadMapTime(timeCases, "as_text").IsZero() { + t.Fatalf("streamReadMapTime(string) should parse") + } + if !streamReadMapTime(timeCases, "invalid").IsZero() || + !streamReadMapTime(timeCases, "blanktxt").IsZero() || + !streamReadMapTime(timeCases, "missing").IsZero() { + t.Fatalf("streamReadMapTime invalid values should return zero time") + } + + if toSnakeCase("") != "" || toLowerCamelCase("") != "" { + t.Fatalf("empty case conversion should return empty string") + } +} + +func TestGatewayStreamDecodeAndEnvelopeExtraBranches(t *testing.T) { + t.Parallel() + + missingTypeNotification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": " ", + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(missingTypeNotification); err == nil { + t.Fatalf("expected missing runtime_event_type error") + } + + invalidPayloadNotification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventToolResult), + "payload": "not-an-object", + }, + }) + if _, err := decodeRuntimeEventFromGatewayNotification(invalidPayloadNotification); err == nil { + t.Fatalf("expected restore payload decode error") + } + + if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("{")}); ok { + t.Fatalf("expected marshal error path to fail envelope extraction") + } + if _, ok := extractRuntimeEnvelope(streamInvalidJSONMarshaler{raw: []byte("[]")}); ok { + t.Fatalf("expected unmarshal-to-map error path to fail envelope extraction") + } + if envelope, ok := extractRuntimeEnvelope(struct { + RuntimeEventType string `json:"runtime_event_type"` + }{RuntimeEventType: string(agentruntime.EventError)}); !ok || streamReadMapString(envelope, "runtime_event_type") == "" { + t.Fatalf("expected runtime_event_type detection after marshal/unmarshal") + } + + if got := streamReadMapString(map[string]any{"v": 123}, "v"); got != "123" { + t.Fatalf("streamReadMapString default conversion mismatch: %q", got) + } + if got := streamReadMapInt(map[string]any{"v": true}, "v"); got != 0 { + t.Fatalf("streamReadMapInt unsupported type should return 0, got %d", got) + } + if got := streamReadMapTime(map[string]any{"v": 1}, "v"); !got.IsZero() { + t.Fatalf("streamReadMapTime unsupported type should return zero, got %v", got) + } +} diff --git a/internal/tui/services/gateway_stream_client_test.go b/internal/tui/services/gateway_stream_client_test.go new file mode 100644 index 00000000..88e1fb13 --- /dev/null +++ b/internal/tui/services/gateway_stream_client_test.go @@ -0,0 +1,144 @@ +package services + +import ( + "encoding/json" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + agentruntime "neo-code/internal/runtime" + "neo-code/internal/tools" +) + +func TestDecodeRuntimeEventFromGatewayNotificationRestoresStringPayload(t *testing.T) { + timestamp := time.Date(2026, 4, 20, 10, 30, 0, 0, time.UTC) + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-1", + RunID: "run-1", + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventAgentChunk), + "turn": 2, + "phase": "thinking", + "timestamp": timestamp.Format(time.RFC3339Nano), + "payload_version": 1, + "payload": "hello", + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + if event.Type != agentruntime.EventAgentChunk { + t.Fatalf("event.Type = %q, want %q", event.Type, agentruntime.EventAgentChunk) + } + if event.SessionID != "session-1" || event.RunID != "run-1" { + t.Fatalf("unexpected ids: %#v", event) + } + if event.Turn != 2 || event.Phase != "thinking" { + t.Fatalf("unexpected turn/phase: %#v", event) + } + if !event.Timestamp.Equal(timestamp) { + t.Fatalf("event.Timestamp = %v, want %v", event.Timestamp, timestamp) + } + payload, ok := event.Payload.(string) + if !ok || payload != "hello" { + t.Fatalf("event.Payload = %#v, want %q", event.Payload, "hello") + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationRestoresToolResultPayload(t *testing.T) { + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-2", + RunID: "run-2", + Payload: map[string]any{ + "runtime_event_type": string(agentruntime.EventToolResult), + "payload": map[string]any{ + "ToolCallID": "call-1", + "Name": "bash", + "Content": "ok", + "IsError": false, + }, + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + toolResult, ok := event.Payload.(tools.ToolResult) + if !ok { + t.Fatalf("event.Payload type = %T, want tools.ToolResult", event.Payload) + } + if toolResult.ToolCallID != "call-1" || toolResult.Name != "bash" || toolResult.Content != "ok" || toolResult.IsError { + t.Fatalf("unexpected tool result payload: %#v", toolResult) + } +} + +func TestDecodeRuntimeEventFromGatewayNotificationSupportsNestedEnvelope(t *testing.T) { + notification := buildGatewayEventNotification(t, gateway.MessageFrame{ + Type: gateway.FrameTypeEvent, + Action: gateway.FrameActionRun, + SessionID: "session-3", + RunID: "run-3", + Payload: map[string]any{ + "type": "run_progress", + "payload": map[string]any{ + "runtime_event_type": string(agentruntime.EventError), + "payload": "boom", + }, + }, + }) + + event, err := decodeRuntimeEventFromGatewayNotification(notification) + if err != nil { + t.Fatalf("decodeRuntimeEventFromGatewayNotification() error = %v", err) + } + if event.Type != agentruntime.EventError { + t.Fatalf("event.Type = %q, want %q", event.Type, agentruntime.EventError) + } + if payload, ok := event.Payload.(string); !ok || payload != "boom" { + t.Fatalf("event.Payload = %#v, want %q", event.Payload, "boom") + } +} + +func TestGatewayStreamClientEmitsDecodeErrorAsRuntimeErrorEvent(t *testing.T) { + source := make(chan gatewayRPCNotification, 1) + client := NewGatewayStreamClient(source) + t.Cleanup(func() { _ = client.Close() }) + + source <- gatewayRPCNotification{ + Method: protocol.MethodGatewayEvent, + Params: json.RawMessage(`{"bad":`), + } + + select { + case event := <-client.Events(): + if event.Type != agentruntime.EventError { + t.Fatalf("event.Type = %q, want %q", event.Type, agentruntime.EventError) + } + payload, ok := event.Payload.(string) + if !ok || payload == "" { + t.Fatalf("event.Payload = %#v, want non-empty string", event.Payload) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for decode error event") + } +} + +func buildGatewayEventNotification(t *testing.T, frame gateway.MessageFrame) gatewayRPCNotification { + t.Helper() + raw, err := json.Marshal(frame) + if err != nil { + t.Fatalf("marshal frame: %v", err) + } + return gatewayRPCNotification{ + Method: protocol.MethodGatewayEvent, + Params: raw, + } +} diff --git a/internal/tui/services/remote_runtime_adapter.go b/internal/tui/services/remote_runtime_adapter.go new file mode 100644 index 00000000..aec6e361 --- /dev/null +++ b/internal/tui/services/remote_runtime_adapter.go @@ -0,0 +1,642 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +const ( + unsupportedActionInGatewayMode = "unsupported_action_in_gateway_mode" + defaultRemoteRuntimeTimeout = 8 * time.Second +) + +var ( + newGatewayRPCClientFactory = NewGatewayRPCClient + newGatewayStreamClientFactory = NewGatewayStreamClient +) + +// RemoteRuntimeAdapterOptions 描述远程 Runtime 适配器的初始化参数。 +type RemoteRuntimeAdapterOptions struct { + ListenAddress string + TokenFile string + RequestTimeout time.Duration + RetryCount int +} + +type remoteGatewayRPCClient interface { + Authenticate(ctx context.Context) error + CallWithOptions(ctx context.Context, method string, params any, result any, options GatewayRPCCallOptions) error + Notifications() <-chan gatewayRPCNotification + Close() error +} + +type remoteGatewayStreamClient interface { + Events() <-chan agentruntime.RuntimeEvent + Close() error +} + +// RemoteRuntimeAdapter 将 TUI runtime 调用转发到 Gateway JSON-RPC 控制面。 +type RemoteRuntimeAdapter struct { + rpcClient remoteGatewayRPCClient + streamClient remoteGatewayStreamClient + timeout time.Duration + retryCount int + + closeOnce sync.Once + closeCh chan struct{} + done chan struct{} + events chan agentruntime.RuntimeEvent + + activeMu sync.Mutex + activeRunID string + activeSession string + lastCancelSent time.Time +} + +// NewRemoteRuntimeAdapter 创建远程 Runtime 适配器,并在启动阶段执行 fail-fast 认证连通性检查。 +func NewRemoteRuntimeAdapter(options RemoteRuntimeAdapterOptions) (*RemoteRuntimeAdapter, error) { + timeout := options.RequestTimeout + if timeout <= 0 { + timeout = defaultRemoteRuntimeTimeout + } + retryCount := normalizeRemoteRuntimeRetryCount(options.RetryCount) + + rpcClient, err := newGatewayRPCClientFactory(GatewayRPCClientOptions{ + ListenAddress: strings.TrimSpace(options.ListenAddress), + TokenFile: strings.TrimSpace(options.TokenFile), + RequestTimeout: timeout, + RetryCount: retryCount, + }) + if err != nil { + return nil, err + } + + streamClient := newGatewayStreamClientFactory(rpcClient.Notifications()) + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, timeout, retryCount) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if err := adapter.authenticate(ctx); err != nil { + _ = adapter.Close() + return nil, err + } + return adapter, nil +} + +func newRemoteRuntimeAdapterWithClients( + rpcClient remoteGatewayRPCClient, + streamClient remoteGatewayStreamClient, + timeout time.Duration, + retryCount int, +) *RemoteRuntimeAdapter { + retryCount = normalizeRemoteRuntimeRetryCount(retryCount) + adapter := &RemoteRuntimeAdapter{ + rpcClient: rpcClient, + streamClient: streamClient, + timeout: timeout, + retryCount: retryCount, + closeCh: make(chan struct{}), + done: make(chan struct{}), + events: make(chan agentruntime.RuntimeEvent, 128), + } + go adapter.forwardEvents() + return adapter +} + +// Submit 将用户输入提交到网关:先 authenticate,再 bindStream,随后 loadSession,最后 run。 +func (r *RemoteRuntimeAdapter) Submit(ctx context.Context, input agentruntime.PrepareInput) error { + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + sessionID = agentsession.NewID("session") + } + runID := strings.TrimSpace(input.RunID) + if runID == "" { + runID = fmt.Sprintf("run-%d", time.Now().UnixNano()) + } + + if err := r.authenticate(ctx); err != nil { + return err + } + if err := r.bindStream(ctx, sessionID, runID); err != nil { + return err + } + if err := r.preloadSession(ctx, sessionID); err != nil { + return err + } + + params := buildGatewayRunParams(sessionID, runID, input) + frame, err := r.callFrame(ctx, protocol.MethodGatewayRun, params, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: 0, + }) + if err != nil { + return err + } + + ackRunID := strings.TrimSpace(frame.RunID) + if ackRunID == "" { + ackRunID = runID + } + r.setActiveRun(ackRunID, sessionID) + return nil +} + +// PrepareUserInput 在 gateway 模式下提供最小可用输入归一化结果,保持接口兼容。 +func (r *RemoteRuntimeAdapter) PrepareUserInput(ctx context.Context, input agentruntime.PrepareInput) (agentruntime.UserInput, error) { + if err := ctx.Err(); err != nil { + return agentruntime.UserInput{}, err + } + + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + sessionID = agentsession.NewID("session") + } + runID := strings.TrimSpace(input.RunID) + if runID == "" { + runID = fmt.Sprintf("run-%d", time.Now().UnixNano()) + } + + parts := make([]providertypes.ContentPart, 0, 1+len(input.Images)) + if strings.TrimSpace(input.Text) != "" { + parts = append(parts, providertypes.NewTextPart(input.Text)) + } + for _, image := range input.Images { + path := strings.TrimSpace(image.Path) + if path == "" { + continue + } + parts = append(parts, providertypes.NewRemoteImagePart(path)) + } + + return agentruntime.UserInput{ + SessionID: sessionID, + RunID: runID, + Parts: parts, + Workdir: strings.TrimSpace(input.Workdir), + }, nil +} + +// Run 保持 runtime 接口兼容,在 gateway 模式下回落到 Submit 通道。 +func (r *RemoteRuntimeAdapter) Run(ctx context.Context, input agentruntime.UserInput) error { + prepareInput := agentruntime.PrepareInput{ + SessionID: strings.TrimSpace(input.SessionID), + RunID: strings.TrimSpace(input.RunID), + Workdir: strings.TrimSpace(input.Workdir), + Text: renderInputTextFromParts(input.Parts), + Images: renderInputImagesFromParts(input.Parts), + } + return r.Submit(ctx, prepareInput) +} + +// Compact 转发 gateway.compact 请求并映射回 runtime CompactResult。 +func (r *RemoteRuntimeAdapter) Compact(ctx context.Context, input agentruntime.CompactInput) (agentruntime.CompactResult, error) { + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + return agentruntime.CompactResult{}, errors.New("gateway runtime adapter: compact session_id is empty") + } + if err := r.authenticate(ctx); err != nil { + return agentruntime.CompactResult{}, err + } + if err := r.bindStream(ctx, sessionID, strings.TrimSpace(input.RunID)); err != nil { + return agentruntime.CompactResult{}, err + } + + frame, err := r.callFrame(ctx, protocol.MethodGatewayCompact, protocol.CompactParams{ + SessionID: sessionID, + RunID: strings.TrimSpace(input.RunID), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return agentruntime.CompactResult{}, err + } + + gatewayResult, err := decodeFramePayload[gateway.CompactResult](frame.Payload) + if err != nil { + return agentruntime.CompactResult{}, err + } + return agentruntime.CompactResult{ + Applied: gatewayResult.Applied, + BeforeChars: gatewayResult.BeforeChars, + AfterChars: gatewayResult.AfterChars, + SavedRatio: gatewayResult.SavedRatio, + TriggerMode: gatewayResult.TriggerMode, + TranscriptID: gatewayResult.TranscriptID, + TranscriptPath: gatewayResult.TranscriptPath, + }, nil +} + +// ExecuteSystemTool 在 gateway 模式下显式不支持,避免任何本地 fallback。 +func (r *RemoteRuntimeAdapter) ExecuteSystemTool(ctx context.Context, input agentruntime.SystemToolInput) (tools.ToolResult, error) { + _ = ctx + _ = input + return tools.ToolResult{}, errors.New(unsupportedActionInGatewayMode) +} + +// ResolvePermission 转发 gateway.resolvePermission 请求。 +func (r *RemoteRuntimeAdapter) ResolvePermission(ctx context.Context, input agentruntime.PermissionResolutionInput) error { + if err := r.authenticate(ctx); err != nil { + return err + } + _, err := r.callFrame(ctx, protocol.MethodGatewayResolvePermission, protocol.ResolvePermissionParams{ + RequestID: strings.TrimSpace(input.RequestID), + Decision: strings.ToLower(strings.TrimSpace(string(input.Decision))), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + return err +} + +// preloadSession 在 run 之前触发一次 gateway.loadSession,用于会话建档/预热。 +func (r *RemoteRuntimeAdapter) preloadSession(ctx context.Context, sessionID string) error { + _, err := r.callFrame(ctx, protocol.MethodGatewayLoadSession, protocol.LoadSessionParams{ + SessionID: strings.TrimSpace(sessionID), + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + return err +} + +// CancelActiveRun 尝试取消当前活跃 run,并返回是否成功发起取消请求。 +func (r *RemoteRuntimeAdapter) CancelActiveRun() bool { + runID, sessionID := r.activeRun() + if runID == "" { + return false + } + + go func(runID string, sessionID string) { + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + if err := r.authenticate(ctx); err != nil { + return + } + _, _ = r.callFrame(ctx, protocol.MethodGatewayCancel, protocol.CancelParams{ + SessionID: sessionID, + RunID: runID, + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: 0, + }) + }(runID, sessionID) + return true +} + +// Events 返回适配后的 runtime 事件流。 +func (r *RemoteRuntimeAdapter) Events() <-chan agentruntime.RuntimeEvent { + return r.events +} + +// ListSessions 转发 gateway.listSessions,并映射为 runtime 层会话摘要。 +func (r *RemoteRuntimeAdapter) ListSessions(ctx context.Context) ([]agentsession.Summary, error) { + if err := r.authenticate(ctx); err != nil { + return nil, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayListSessions, nil, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return nil, err + } + + payload := struct { + Sessions []gateway.SessionSummary `json:"sessions"` + }{} + if err := decodeIntoValue(frame.Payload, &payload); err != nil { + return nil, err + } + + summaries := make([]agentsession.Summary, 0, len(payload.Sessions)) + for _, item := range payload.Sessions { + summaries = append(summaries, agentsession.Summary{ + ID: strings.TrimSpace(item.ID), + Title: strings.TrimSpace(item.Title), + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + }) + } + return summaries, nil +} + +// LoadSession 转发 gateway.loadSession,并执行最小可用语义映射。 +func (r *RemoteRuntimeAdapter) LoadSession(ctx context.Context, id string) (agentsession.Session, error) { + sessionID := strings.TrimSpace(id) + if sessionID == "" { + return agentsession.Session{}, errors.New("gateway runtime adapter: session id is empty") + } + if err := r.authenticate(ctx); err != nil { + return agentsession.Session{}, err + } + frame, err := r.callFrame(ctx, protocol.MethodGatewayLoadSession, protocol.LoadSessionParams{ + SessionID: sessionID, + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + if err != nil { + return agentsession.Session{}, err + } + + loaded, err := decodeFramePayload[gateway.Session](frame.Payload) + if err != nil { + return agentsession.Session{}, err + } + return mapGatewaySessionToRuntimeSession(loaded), nil +} + +// ActivateSessionSkill 在 gateway 模式下显式不支持。 +func (r *RemoteRuntimeAdapter) ActivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + _ = ctx + _ = sessionID + _ = skillID + return errors.New(unsupportedActionInGatewayMode) +} + +// DeactivateSessionSkill 在 gateway 模式下显式不支持。 +func (r *RemoteRuntimeAdapter) DeactivateSessionSkill(ctx context.Context, sessionID string, skillID string) error { + _ = ctx + _ = sessionID + _ = skillID + return errors.New(unsupportedActionInGatewayMode) +} + +// ListSessionSkills 在 gateway 模式下显式不支持。 +func (r *RemoteRuntimeAdapter) ListSessionSkills(ctx context.Context, sessionID string) ([]agentruntime.SessionSkillState, error) { + _ = ctx + _ = sessionID + return nil, errors.New(unsupportedActionInGatewayMode) +} + +// Close 关闭远程适配器并结束事件桥接。 +func (r *RemoteRuntimeAdapter) Close() error { + var closeErr error + r.closeOnce.Do(func() { + close(r.closeCh) + if r.streamClient != nil { + closeErr = errors.Join(closeErr, r.streamClient.Close()) + } + if r.rpcClient != nil { + closeErr = errors.Join(closeErr, r.rpcClient.Close()) + } + <-r.done + }) + return closeErr +} + +func (r *RemoteRuntimeAdapter) authenticate(ctx context.Context) error { + if r.rpcClient == nil { + return errors.New("gateway runtime adapter: rpc client is nil") + } + return r.rpcClient.Authenticate(ctx) +} + +func (r *RemoteRuntimeAdapter) bindStream(ctx context.Context, sessionID string, runID string) error { + _, err := r.callFrame(ctx, protocol.MethodGatewayBindStream, protocol.BindStreamParams{ + SessionID: strings.TrimSpace(sessionID), + RunID: strings.TrimSpace(runID), + Channel: "all", + }, GatewayRPCCallOptions{ + Timeout: r.timeout, + Retries: r.retryCount, + }) + return err +} + +func (r *RemoteRuntimeAdapter) callFrame( + ctx context.Context, + method string, + params any, + options GatewayRPCCallOptions, +) (gateway.MessageFrame, error) { + if r.rpcClient == nil { + return gateway.MessageFrame{}, errors.New("gateway runtime adapter: rpc client is nil") + } + + var frame gateway.MessageFrame + if err := r.rpcClient.CallWithOptions(ctx, method, params, &frame, options); err != nil { + return gateway.MessageFrame{}, err + } + if frame.Type == gateway.FrameTypeError { + if frame.Error == nil { + return gateway.MessageFrame{}, fmt.Errorf("gateway %s returned error frame without payload", method) + } + return gateway.MessageFrame{}, fmt.Errorf("%s: %s", strings.TrimSpace(frame.Error.Code), strings.TrimSpace(frame.Error.Message)) + } + if frame.Type != gateway.FrameTypeAck { + return gateway.MessageFrame{}, fmt.Errorf("gateway %s returned unexpected frame type %q", method, frame.Type) + } + return frame, nil +} + +func (r *RemoteRuntimeAdapter) forwardEvents() { + defer close(r.done) + defer close(r.events) + + if r.streamClient == nil { + return + } + + source := r.streamClient.Events() + for { + select { + case <-r.closeCh: + return + case event, ok := <-source: + if !ok { + return + } + r.observeEvent(event) + select { + case <-r.closeCh: + return + case r.events <- event: + } + } + } +} + +func (r *RemoteRuntimeAdapter) observeEvent(event agentruntime.RuntimeEvent) { + runID := strings.TrimSpace(event.RunID) + sessionID := strings.TrimSpace(event.SessionID) + if runID != "" || sessionID != "" { + r.setActiveRun(runID, sessionID) + } + + switch event.Type { + case agentruntime.EventAgentDone, agentruntime.EventError, agentruntime.EventRunCanceled, agentruntime.EventStopReasonDecided: + r.clearActiveRun(runID) + } +} + +func (r *RemoteRuntimeAdapter) setActiveRun(runID string, sessionID string) { + r.activeMu.Lock() + defer r.activeMu.Unlock() + if strings.TrimSpace(runID) != "" { + r.activeRunID = strings.TrimSpace(runID) + } + if strings.TrimSpace(sessionID) != "" { + r.activeSession = strings.TrimSpace(sessionID) + } +} + +func (r *RemoteRuntimeAdapter) clearActiveRun(runID string) { + r.activeMu.Lock() + defer r.activeMu.Unlock() + normalizedRunID := strings.TrimSpace(runID) + if normalizedRunID == "" { + return + } + if strings.EqualFold(normalizedRunID, strings.TrimSpace(r.activeRunID)) { + r.activeRunID = "" + } +} + +// normalizeRemoteRuntimeRetryCount 统一归一化重试次数,避免零值关闭默认重试兜底。 +func normalizeRemoteRuntimeRetryCount(retryCount int) int { + if retryCount <= 0 { + return defaultGatewayRPCRetryCount + } + return retryCount +} + +func (r *RemoteRuntimeAdapter) activeRun() (string, string) { + r.activeMu.Lock() + defer r.activeMu.Unlock() + return strings.TrimSpace(r.activeRunID), strings.TrimSpace(r.activeSession) +} + +func buildGatewayRunParams(sessionID string, runID string, input agentruntime.PrepareInput) protocol.RunParams { + parts := make([]protocol.RunInputPart, 0, len(input.Images)) + for _, image := range input.Images { + path := strings.TrimSpace(image.Path) + if path == "" { + continue + } + parts = append(parts, protocol.RunInputPart{ + Type: string(gateway.InputPartTypeImage), + Media: &protocol.RunInputMedia{ + URI: path, + MimeType: strings.TrimSpace(image.MimeType), + }, + }) + } + + return protocol.RunParams{ + SessionID: strings.TrimSpace(sessionID), + RunID: strings.TrimSpace(runID), + InputText: strings.TrimSpace(input.Text), + InputParts: parts, + Workdir: strings.TrimSpace(input.Workdir), + } +} + +func renderInputTextFromParts(parts []providertypes.ContentPart) string { + textParts := make([]string, 0, len(parts)) + for _, part := range parts { + if part.Kind != providertypes.ContentPartText { + continue + } + text := strings.TrimSpace(part.Text) + if text == "" { + continue + } + textParts = append(textParts, text) + } + return strings.Join(textParts, "\n") +} + +func renderInputImagesFromParts(parts []providertypes.ContentPart) []agentruntime.UserImageInput { + images := make([]agentruntime.UserImageInput, 0, len(parts)) + for _, part := range parts { + if part.Kind != providertypes.ContentPartImage || part.Image == nil { + continue + } + path := strings.TrimSpace(part.Image.URL) + if path == "" { + continue + } + mimeType := "" + if part.Image.Asset != nil { + mimeType = strings.TrimSpace(part.Image.Asset.MimeType) + } + images = append(images, agentruntime.UserImageInput{ + Path: path, + MimeType: mimeType, + }) + } + return images +} + +func mapGatewaySessionToRuntimeSession(source gateway.Session) agentsession.Session { + messages := make([]providertypes.Message, 0, len(source.Messages)) + for _, item := range source.Messages { + content := strings.TrimSpace(item.Content) + message := providertypes.Message{ + Role: strings.TrimSpace(item.Role), + ToolCallID: strings.TrimSpace(item.ToolCallID), + IsError: item.IsError, + } + if content != "" { + message.Parts = []providertypes.ContentPart{providertypes.NewTextPart(content)} + } + if len(item.ToolCalls) > 0 { + message.ToolCalls = make([]providertypes.ToolCall, 0, len(item.ToolCalls)) + for _, call := range item.ToolCalls { + message.ToolCalls = append(message.ToolCalls, providertypes.ToolCall{ + ID: strings.TrimSpace(call.ID), + Name: strings.TrimSpace(call.Name), + Arguments: call.Arguments, + }) + } + } + messages = append(messages, message) + } + + return agentsession.Session{ + ID: strings.TrimSpace(source.ID), + Title: strings.TrimSpace(source.Title), + CreatedAt: source.CreatedAt, + UpdatedAt: source.UpdatedAt, + Workdir: strings.TrimSpace(source.Workdir), + Messages: messages, + } +} + +func decodeFramePayload[T any](payload any) (T, error) { + var out T + if err := decodeIntoValue(payload, &out); err != nil { + return out, err + } + return out, nil +} + +func decodeIntoValue(payload any, target any) error { + if target == nil { + return errors.New("decode payload target is nil") + } + raw, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("encode frame payload: %w", err) + } + if err := json.Unmarshal(raw, target); err != nil { + return fmt.Errorf("decode frame payload: %w", err) + } + return nil +} + +var _ agentruntime.Runtime = (*RemoteRuntimeAdapter)(nil) diff --git a/internal/tui/services/remote_runtime_adapter_additional_test.go b/internal/tui/services/remote_runtime_adapter_additional_test.go new file mode 100644 index 00000000..977cff3a --- /dev/null +++ b/internal/tui/services/remote_runtime_adapter_additional_test.go @@ -0,0 +1,504 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "net" + "strings" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" +) + +func TestNewRemoteRuntimeAdapterBranches(t *testing.T) { + t.Parallel() + + originalRPCFactory := newGatewayRPCClientFactory + originalStreamFactory := newGatewayStreamClientFactory + t.Cleanup(func() { + newGatewayRPCClientFactory = originalRPCFactory + newGatewayStreamClientFactory = originalStreamFactory + }) + + newGatewayRPCClientFactory = func(options GatewayRPCClientOptions) (*GatewayRPCClient, error) { + if strings.TrimSpace(options.ListenAddress) == "error" { + return nil, errors.New("build rpc failed") + } + client := &GatewayRPCClient{ + listenAddress: options.ListenAddress, + token: "token", + requestTimeout: time.Second, + retryCount: 1, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, 4), + notificationQueue: make(chan gatewayRPCNotification, 4), + } + client.dialFn = func(_ string) (net.Conn, error) { + if options.ListenAddress == "dial-failed" { + return nil, errors.New("dial failed") + } + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + decoder := json.NewDecoder(serverConn) + encoder := json.NewEncoder(serverConn) + request := readRPCRequestOrFail(t, decoder) + writeRPCResultOrFail(t, encoder, request.ID, gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionAuthenticate, + }) + }() + return clientConn, nil + } + return client, nil + } + newGatewayStreamClientFactory = func(source <-chan gatewayRPCNotification) *GatewayStreamClient { + return NewGatewayStreamClient(source) + } + + if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "error"}); err == nil { + t.Fatalf("expected rpc factory error") + } + if _, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ListenAddress: "dial-failed", RequestTimeout: -1}); err == nil { + t.Fatalf("expected authenticate fail-fast error") + } + + adapter, err := NewRemoteRuntimeAdapter(RemoteRuntimeAdapterOptions{ + ListenAddress: "ok", + RequestTimeout: -1, + RetryCount: 0, + }) + if err != nil { + t.Fatalf("NewRemoteRuntimeAdapter() error = %v", err) + } + if adapter.timeout != defaultRemoteRuntimeTimeout { + t.Fatalf("timeout = %v, want %v", adapter.timeout, defaultRemoteRuntimeTimeout) + } + if adapter.retryCount != defaultGatewayRPCRetryCount { + t.Fatalf("retryCount = %d, want %d", adapter.retryCount, defaultGatewayRPCRetryCount) + } + _ = adapter.Close() +} + +func TestRemoteRuntimeAdapterPrepareUserInputAndRun(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionLoadSession, SessionID: "s-1"}, + protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream, SessionID: "s-1", RunID: "r-1"}, + protocol.MethodGatewayRun: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionRun, SessionID: "s-1", RunID: "r-1"}, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := adapter.PrepareUserInput(ctx, agentruntime.PrepareInput{}); err == nil { + t.Fatalf("expected context cancellation error") + } + + input, err := adapter.PrepareUserInput(context.Background(), agentruntime.PrepareInput{ + SessionID: " ", + RunID: "", + Text: " hello ", + Images: []agentruntime.UserImageInput{ + {Path: " "}, + {Path: " /tmp/a.png ", MimeType: " image/png "}, + }, + Workdir: " /repo ", + }) + if err != nil { + t.Fatalf("PrepareUserInput() error = %v", err) + } + if strings.TrimSpace(input.SessionID) == "" || strings.TrimSpace(input.RunID) == "" { + t.Fatalf("session/run id should be generated") + } + if len(input.Parts) != 2 { + t.Fatalf("parts len = %d, want 2", len(input.Parts)) + } + + if err := adapter.Run(context.Background(), input); err != nil { + t.Fatalf("Run() error = %v", err) + } + methods := rpcClient.snapshotMethods() + if len(methods) != 3 || methods[2] != protocol.MethodGatewayRun { + t.Fatalf("unexpected method chain: %#v", methods) + } +} + +func TestRemoteRuntimeAdapterCompactResolvePermissionAndListSessions(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayBindStream: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionBindStream}, + protocol.MethodGatewayCompact: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionCompact, + Payload: gateway.CompactResult{ + Applied: true, + BeforeChars: 100, + AfterChars: 40, + TriggerMode: "auto", + TranscriptID: "t-1", + }, + }, + protocol.MethodGatewayResolvePermission: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionResolvePermission}, + protocol.MethodGatewayListSessions: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionListSessions, + Payload: map[string]any{ + "sessions": []gateway.SessionSummary{{ID: " s1 ", Title: " t1 "}}, + }, + }, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 2) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{}); err == nil { + t.Fatalf("expected compact empty session id error") + } + + compactResult, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s1", RunID: "r1"}) + if err != nil { + t.Fatalf("Compact() error = %v", err) + } + if !compactResult.Applied || compactResult.BeforeChars != 100 || compactResult.AfterChars != 40 { + t.Fatalf("compact result mismatch: %#v", compactResult) + } + + if err := adapter.ResolvePermission(context.Background(), agentruntime.PermissionResolutionInput{RequestID: " req ", Decision: "APPROVE"}); err != nil { + t.Fatalf("ResolvePermission() error = %v", err) + } + + summaries, err := adapter.ListSessions(context.Background()) + if err != nil { + t.Fatalf("ListSessions() error = %v", err) + } + if len(summaries) != 1 || summaries[0].ID != "s1" || summaries[0].Title != "t1" { + t.Fatalf("summaries mismatch: %#v", summaries) + } +} + +func TestRemoteRuntimeAdapterUnsupportedSkillMethods(t *testing.T) { + t.Parallel() + + adapter := newRemoteRuntimeAdapterWithClients( + &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)}, + &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + if err := adapter.ActivateSessionSkill(context.Background(), "s", "skill"); err == nil { + t.Fatalf("ActivateSessionSkill should be unsupported") + } + if err := adapter.DeactivateSessionSkill(context.Background(), "s", "skill"); err == nil { + t.Fatalf("DeactivateSessionSkill should be unsupported") + } + if _, err := adapter.ListSessionSkills(context.Background(), "s"); err == nil { + t.Fatalf("ListSessionSkills should be unsupported") + } +} + +func TestRemoteRuntimeAdapterCallFrameAndDecodeHelpers(t *testing.T) { + t.Parallel() + + adapter := newRemoteRuntimeAdapterWithClients(nil, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.callFrame(context.Background(), protocol.MethodGatewayPing, nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected nil rpc client error") + } + if err := adapter.authenticate(context.Background()); err == nil { + t.Fatalf("authenticate should fail on nil rpc client") + } + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + "error-without-payload": {Type: gateway.FrameTypeError}, + "error-with-payload": { + Type: gateway.FrameTypeError, + Error: &gateway.FrameError{Code: "bad", Message: "oops"}, + }, + "unexpected-type": {Type: gateway.FrameTypeEvent}, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter.rpcClient = rpcClient + + if _, err := adapter.callFrame(context.Background(), "error-without-payload", nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected error frame without payload") + } + if _, err := adapter.callFrame(context.Background(), "error-with-payload", nil, GatewayRPCCallOptions{}); err == nil || !strings.Contains(err.Error(), "bad") { + t.Fatalf("expected frame error mapping, got %v", err) + } + if _, err := adapter.callFrame(context.Background(), "unexpected-type", nil, GatewayRPCCallOptions{}); err == nil { + t.Fatalf("expected unexpected frame type error") + } + + if err := decodeIntoValue(map[string]any{"a": 1}, nil); err == nil { + t.Fatalf("decodeIntoValue should reject nil target") + } + if err := decodeIntoValue(func() {}, &map[string]any{}); err == nil { + t.Fatalf("decodeIntoValue should fail on marshal error") + } + if err := decodeIntoValue(map[string]any{"value": "x"}, &[]int{}); err == nil { + t.Fatalf("decodeIntoValue should fail on unmarshal mismatch") + } + + decoded, err := decodeFramePayload[gateway.CompactResult](map[string]any{"applied": true}) + if err != nil || !decoded.Applied { + t.Fatalf("decodeFramePayload() = (%#v, %v)", decoded, err) + } +} + +func TestRemoteRuntimeAdapterEventObservationAndActiveRunState(t *testing.T) { + t.Parallel() + + eventCh := make(chan agentruntime.RuntimeEvent, 3) + streamClient := &stubRemoteStreamClient{events: eventCh} + adapter := newRemoteRuntimeAdapterWithClients( + &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)}, + streamClient, + time.Second, + 1, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + eventCh <- agentruntime.RuntimeEvent{Type: agentruntime.EventAgentChunk, RunID: "run-a", SessionID: "session-a"} + eventCh <- agentruntime.RuntimeEvent{Type: agentruntime.EventAgentDone, RunID: "run-a", SessionID: "session-a"} + close(eventCh) + + for i := 0; i < 2; i++ { + select { + case <-adapter.Events(): + case <-time.After(time.Second): + t.Fatalf("timed out waiting for forwarded event") + } + } + + runID, sessionID := adapter.activeRun() + if runID != "" || sessionID != "session-a" { + t.Fatalf("active run state mismatch: run=%q session=%q", runID, sessionID) + } + + adapter.setActiveRun(" run-b ", " session-b ") + adapter.clearActiveRun("other-run") + runID, _ = adapter.activeRun() + if runID != "run-b" { + t.Fatalf("clearActiveRun should keep different run, got %q", runID) + } + adapter.clearActiveRun("run-b") + runID, _ = adapter.activeRun() + if runID != "" { + t.Fatalf("expected cleared run id, got %q", runID) + } + + adapter.setActiveRun("run-c", "session-c") + adapter.observeEvent(agentruntime.RuntimeEvent{Type: agentruntime.EventError}) + runID, sessionID = adapter.activeRun() + if runID != "run-c" || sessionID != "session-c" { + t.Fatalf("event error without run id should not clear active run, got run=%q session=%q", runID, sessionID) + } +} + +func TestNewRemoteRuntimeAdapterWithClientsNormalizesRetryCount(t *testing.T) { + t.Parallel() + + adapter := newRemoteRuntimeAdapterWithClients( + &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)}, + &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, + time.Second, + 0, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + if adapter.retryCount != defaultGatewayRPCRetryCount { + t.Fatalf("retryCount = %d, want %d", adapter.retryCount, defaultGatewayRPCRetryCount) + } +} + +func TestRemoteRuntimeAdapterUsesDefaultRetryWhenOptionsZero(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayListSessions: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionListSessions, + Payload: map[string]any{"sessions": []gateway.SessionSummary{}}, + }, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients( + rpcClient, + &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, + time.Second, + 0, + ) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.ListSessions(context.Background()); err != nil { + t.Fatalf("ListSessions() error = %v", err) + } + options, ok := rpcClient.snapshotOptions()[protocol.MethodGatewayListSessions] + if !ok { + t.Fatalf("expected listSessions call options to be captured") + } + if options.Retries != defaultGatewayRPCRetryCount { + t.Fatalf("listSessions retries = %d, want %d", options.Retries, defaultGatewayRPCRetryCount) + } +} + +func TestRemoteRuntimeAdapterLoadSessionAndCancelErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + callErrs: map[string]error{protocol.MethodGatewayCancel: errors.New("cancel failed")}, + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionLoadSession, Payload: func() {}}, + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if _, err := adapter.LoadSession(context.Background(), " "); err == nil { + t.Fatalf("expected empty id validation error") + } + if _, err := adapter.LoadSession(context.Background(), "session-1"); err == nil { + t.Fatalf("expected payload decode error") + } + + adapter.setActiveRun("run-1", "session-1") + if !adapter.CancelActiveRun() { + t.Fatalf("expected cancel attempt for active run") + } +} + +func TestRemoteRuntimeAdapterSubmitAndCompactErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + callErrs: map[string]error{ + protocol.MethodGatewayBindStream: errors.New("bind failed"), + }, + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if err := adapter.Submit(context.Background(), agentruntime.PrepareInput{}); err == nil || !strings.Contains(err.Error(), "bind failed") { + t.Fatalf("expected bind failed submit error, got %v", err) + } + methods := rpcClient.snapshotMethods() + if len(methods) != 1 || methods[0] != protocol.MethodGatewayBindStream { + t.Fatalf("Submit() should fail after bindStream and before loadSession, methods=%#v", methods) + } + bindParams, ok := rpcClient.snapshotParams()[protocol.MethodGatewayBindStream].(protocol.BindStreamParams) + if !ok || strings.TrimSpace(bindParams.SessionID) == "" { + t.Fatalf("Submit() should generate default session id for bindStream, params=%#v", rpcClient.snapshotParams()[protocol.MethodGatewayBindStream]) + } + + rpcClient.authErr = errors.New("auth failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "auth failed") { + t.Fatalf("expected compact auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs[protocol.MethodGatewayBindStream] = errors.New("bind compact failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "bind compact failed") { + t.Fatalf("expected compact bind error, got %v", err) + } + rpcClient.callErrs[protocol.MethodGatewayBindStream] = nil + rpcClient.callErrs[protocol.MethodGatewayCompact] = errors.New("compact failed") + if _, err := adapter.Compact(context.Background(), agentruntime.CompactInput{SessionID: "s-1"}); err == nil || !strings.Contains(err.Error(), "compact failed") { + t.Fatalf("expected compact rpc error, got %v", err) + } +} + +func TestRemoteRuntimeAdapterListAndLoadSessionErrorPaths(t *testing.T) { + t.Parallel() + + rpcClient := &stubRemoteRPCClient{ + notifications: make(chan gatewayRPCNotification), + } + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)}, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + rpcClient.authErr = errors.New("auth failed") + if _, err := adapter.ListSessions(context.Background()); err == nil || !strings.Contains(err.Error(), "auth failed") { + t.Fatalf("expected list auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs = map[string]error{protocol.MethodGatewayListSessions: errors.New("list failed")} + if _, err := adapter.ListSessions(context.Background()); err == nil || !strings.Contains(err.Error(), "list failed") { + t.Fatalf("expected list rpc error, got %v", err) + } + rpcClient.callErrs = nil + rpcClient.frames = map[string]gateway.MessageFrame{ + protocol.MethodGatewayListSessions: {Type: gateway.FrameTypeAck, Action: gateway.FrameActionListSessions, Payload: func() {}}, + } + if _, err := adapter.ListSessions(context.Background()); err == nil { + t.Fatalf("expected list decode error") + } + + rpcClient.authErr = errors.New("auth load failed") + if _, err := adapter.LoadSession(context.Background(), "s-1"); err == nil || !strings.Contains(err.Error(), "auth load failed") { + t.Fatalf("expected load auth error, got %v", err) + } + rpcClient.authErr = nil + rpcClient.callErrs = map[string]error{protocol.MethodGatewayLoadSession: errors.New("load failed")} + if _, err := adapter.LoadSession(context.Background(), "s-1"); err == nil || !strings.Contains(err.Error(), "load failed") { + t.Fatalf("expected load rpc error, got %v", err) + } +} + +func TestRemoteRuntimeAdapterRenderInputHelpers(t *testing.T) { + t.Parallel() + + text := renderInputTextFromParts([]providertypes.ContentPart{ + providertypes.NewTextPart(" first "), + providertypes.NewRemoteImagePart("/tmp/a.png"), + providertypes.NewTextPart("second"), + providertypes.NewTextPart(" "), + }) + if text != "first\nsecond" { + t.Fatalf("renderInputTextFromParts() = %q", text) + } + + images := renderInputImagesFromParts([]providertypes.ContentPart{ + providertypes.NewTextPart("x"), + providertypes.NewRemoteImagePart(" "), + providertypes.ContentPart{ + Kind: providertypes.ContentPartImage, + Image: &providertypes.ImagePart{ + URL: " /tmp/b.png ", + Asset: &providertypes.AssetRef{MimeType: " image/png "}, + }, + }, + }) + if len(images) != 1 || images[0].Path != "/tmp/b.png" || images[0].MimeType != "image/png" { + t.Fatalf("renderInputImagesFromParts() = %#v", images) + } + + params := buildGatewayRunParams(" s ", " r ", agentruntime.PrepareInput{Text: " hi ", Workdir: " /w ", Images: []agentruntime.UserImageInput{{Path: " /img.png ", MimeType: " image/png "}, {Path: " "}}}) + if params.SessionID != "s" || params.RunID != "r" || params.Workdir != "/w" || params.InputText != "hi" || len(params.InputParts) != 1 { + t.Fatalf("buildGatewayRunParams() = %#v", params) + } +} diff --git a/internal/tui/services/remote_runtime_adapter_test.go b/internal/tui/services/remote_runtime_adapter_test.go new file mode 100644 index 00000000..9729eaae --- /dev/null +++ b/internal/tui/services/remote_runtime_adapter_test.go @@ -0,0 +1,402 @@ +package services + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "neo-code/internal/gateway" + "neo-code/internal/gateway/protocol" + providertypes "neo-code/internal/provider/types" + agentruntime "neo-code/internal/runtime" + agentsession "neo-code/internal/session" + "neo-code/internal/tools" +) + +func TestRemoteRuntimeAdapterSubmitAuthenticatesBindsPreloadsAndRuns(t *testing.T) { + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionLoadSession, + SessionID: "session-1", + }, + protocol.MethodGatewayBindStream: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionBindStream, + SessionID: "session-1", + RunID: "run-1", + }, + protocol.MethodGatewayRun: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionRun, + SessionID: "session-1", + RunID: "run-1", + }, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + err := adapter.Submit(context.Background(), agentruntime.PrepareInput{ + SessionID: "session-1", + RunID: "run-1", + Workdir: "/repo", + Text: " hello ", + Images: []agentruntime.UserImageInput{ + {Path: " /tmp/a.png ", MimeType: " image/png "}, + }, + }) + if err != nil { + t.Fatalf("Submit() error = %v", err) + } + if rpcClient.authCalls != 1 { + t.Fatalf("authenticate call count = %d, want %d", rpcClient.authCalls, 1) + } + + methods := rpcClient.snapshotMethods() + if len(methods) != 3 || + methods[0] != protocol.MethodGatewayBindStream || + methods[1] != protocol.MethodGatewayLoadSession || + methods[2] != protocol.MethodGatewayRun { + t.Fatalf("rpc methods = %#v", methods) + } + loadSessionParams, ok := rpcClient.snapshotParams()[protocol.MethodGatewayLoadSession].(protocol.LoadSessionParams) + if !ok { + t.Fatalf( + "loadSession params type = %T, want protocol.LoadSessionParams", + rpcClient.snapshotParams()[protocol.MethodGatewayLoadSession], + ) + } + if loadSessionParams.SessionID != "session-1" { + t.Fatalf("loadSession session_id = %q, want %q", loadSessionParams.SessionID, "session-1") + } + + params, ok := rpcClient.snapshotParams()[protocol.MethodGatewayRun].(protocol.RunParams) + if !ok { + t.Fatalf("run params type = %T, want protocol.RunParams", rpcClient.snapshotParams()[protocol.MethodGatewayRun]) + } + if params.SessionID != "session-1" || params.RunID != "run-1" || params.Workdir != "/repo" { + t.Fatalf("unexpected run params ids/workdir: %#v", params) + } + if params.InputText != "hello" { + t.Fatalf("run input_text = %q, want %q", params.InputText, "hello") + } + if len(params.InputParts) != 1 || params.InputParts[0].Media == nil { + t.Fatalf("run input_parts = %#v, want one image part", params.InputParts) + } + if params.InputParts[0].Media.URI != "/tmp/a.png" || params.InputParts[0].Media.MimeType != "image/png" { + t.Fatalf("unexpected image part media: %#v", params.InputParts[0].Media) + } +} + +func TestRemoteRuntimeAdapterSubmitFailFastOnAuthenticateError(t *testing.T) { + rpcClient := &stubRemoteRPCClient{ + authErr: errors.New("auth failed"), + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + err := adapter.Submit(context.Background(), agentruntime.PrepareInput{ + SessionID: "session-1", + RunID: "run-1", + Text: "hello", + }) + if err == nil || !strings.Contains(err.Error(), "auth failed") { + t.Fatalf("expected auth failure, got %v", err) + } + if methods := rpcClient.snapshotMethods(); len(methods) != 0 { + t.Fatalf("expected no rpc call after auth failure, got %#v", methods) + } +} + +func TestRemoteRuntimeAdapterSubmitFailFastOnBindStreamError(t *testing.T) { + rpcClient := &stubRemoteRPCClient{ + callErrs: map[string]error{ + protocol.MethodGatewayBindStream: errors.New("stream bind failed"), + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + err := adapter.Submit(context.Background(), agentruntime.PrepareInput{ + SessionID: "session-1", + RunID: "run-1", + Text: "hello", + }) + if err == nil || !strings.Contains(err.Error(), "stream bind failed") { + t.Fatalf("expected bindStream failure, got %v", err) + } + + methods := rpcClient.snapshotMethods() + if len(methods) != 1 || methods[0] != protocol.MethodGatewayBindStream { + t.Fatalf("expected only bindStream call before failure, got %#v", methods) + } +} + +func TestRemoteRuntimeAdapterExecuteSystemToolUnsupported(t *testing.T) { + rpcClient := &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)} + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + _, err := adapter.ExecuteSystemTool(context.Background(), agentruntime.SystemToolInput{ + ToolName: "bash", + }) + if err == nil || err.Error() != unsupportedActionInGatewayMode { + t.Fatalf("expected unsupported_action_in_gateway_mode, got %v", err) + } +} + +func TestRemoteRuntimeAdapterLoadSessionMinimalMapping(t *testing.T) { + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayLoadSession: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionLoadSession, + Payload: gateway.Session{ + ID: "session-9", + Title: "title-9", + Workdir: "/repo", + Messages: []gateway.SessionMessage{ + { + Role: providertypes.RoleAssistant, + Content: "hello", + ToolCallID: "call-1", + ToolCalls: []gateway.ToolCall{ + {ID: "call-1", Name: "bash", Arguments: `{"command":"pwd"}`}, + }, + }, + }, + }, + }, + }, + notifications: make(chan gatewayRPCNotification), + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + session, err := adapter.LoadSession(context.Background(), "session-9") + if err != nil { + t.Fatalf("LoadSession() error = %v", err) + } + if session.ID != "session-9" || session.Title != "title-9" || session.Workdir != "/repo" { + t.Fatalf("unexpected session mapping: %#v", session) + } + if len(session.Messages) != 1 { + t.Fatalf("message count = %d, want %d", len(session.Messages), 1) + } + if text := renderPartsForRemoteAdapterTest(session.Messages[0].Parts); text != "hello" { + t.Fatalf("message parts text = %q, want %q", text, "hello") + } + if len(session.Messages[0].ToolCalls) != 1 || session.Messages[0].ToolCalls[0].Name != "bash" { + t.Fatalf("tool call mapping mismatch: %#v", session.Messages[0].ToolCalls) + } +} + +func TestRemoteRuntimeAdapterCancelActiveRunSendsGatewayCancel(t *testing.T) { + methodCh := make(chan string, 1) + rpcClient := &stubRemoteRPCClient{ + frames: map[string]gateway.MessageFrame{ + protocol.MethodGatewayCancel: { + Type: gateway.FrameTypeAck, + Action: gateway.FrameActionCancel, + }, + }, + notifications: make(chan gatewayRPCNotification), + methodCh: methodCh, + } + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + t.Cleanup(func() { _ = adapter.Close() }) + + if canceled := adapter.CancelActiveRun(); canceled { + t.Fatalf("expected no active run to cancel") + } + + adapter.setActiveRun("run-cancel", "session-cancel") + if canceled := adapter.CancelActiveRun(); !canceled { + t.Fatalf("expected cancel request to be scheduled") + } + + select { + case method := <-methodCh: + if method != protocol.MethodGatewayCancel { + t.Fatalf("cancel method = %q, want %q", method, protocol.MethodGatewayCancel) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for cancel rpc call") + } +} + +func TestRemoteRuntimeAdapterCloseClosesUnderlyingClients(t *testing.T) { + rpcClient := &stubRemoteRPCClient{notifications: make(chan gatewayRPCNotification)} + streamClient := &stubRemoteStreamClient{events: make(chan agentruntime.RuntimeEvent)} + adapter := newRemoteRuntimeAdapterWithClients(rpcClient, streamClient, time.Second, 1) + + if err := adapter.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if !rpcClient.closed { + t.Fatalf("expected rpc client to be closed") + } + if !streamClient.closed { + t.Fatalf("expected stream client to be closed") + } +} + +type stubRemoteRPCClient struct { + mu sync.Mutex + + authCalls int + authErr error + + methods []string + params map[string]any + options map[string]GatewayRPCCallOptions + + callErrs map[string]error + frames map[string]gateway.MessageFrame + methodCh chan string + + notifications chan gatewayRPCNotification + closed bool +} + +func (s *stubRemoteRPCClient) Authenticate(_ context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + s.authCalls++ + return s.authErr +} + +func (s *stubRemoteRPCClient) CallWithOptions( + _ context.Context, + method string, + params any, + result any, + options GatewayRPCCallOptions, +) error { + s.mu.Lock() + s.methods = append(s.methods, method) + if s.params == nil { + s.params = map[string]any{} + } + if s.options == nil { + s.options = map[string]GatewayRPCCallOptions{} + } + s.params[method] = params + s.options[method] = options + callErr := s.callErrs[method] + frame, hasFrame := s.frames[method] + s.mu.Unlock() + + if s.methodCh != nil { + select { + case s.methodCh <- method: + default: + } + } + if callErr != nil { + return callErr + } + if typed, ok := result.(*gateway.MessageFrame); ok { + if !hasFrame { + frame = gateway.MessageFrame{ + Type: gateway.FrameTypeAck, + Action: gateway.FrameAction(method), + } + } + *typed = frame + } + return nil +} + +func (s *stubRemoteRPCClient) Notifications() <-chan gatewayRPCNotification { + return s.notifications +} + +func (s *stubRemoteRPCClient) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if !s.closed { + s.closed = true + if s.notifications != nil { + close(s.notifications) + } + } + return nil +} + +func (s *stubRemoteRPCClient) snapshotMethods() []string { + s.mu.Lock() + defer s.mu.Unlock() + return append([]string(nil), s.methods...) +} + +func (s *stubRemoteRPCClient) snapshotParams() map[string]any { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]any, len(s.params)) + for key, value := range s.params { + cloned[key] = value + } + return cloned +} + +func (s *stubRemoteRPCClient) snapshotOptions() map[string]GatewayRPCCallOptions { + s.mu.Lock() + defer s.mu.Unlock() + cloned := make(map[string]GatewayRPCCallOptions, len(s.options)) + for key, value := range s.options { + cloned[key] = value + } + return cloned +} + +type stubRemoteStreamClient struct { + events <-chan agentruntime.RuntimeEvent + closed bool + mu sync.Mutex +} + +func (s *stubRemoteStreamClient) Events() <-chan agentruntime.RuntimeEvent { + return s.events +} + +func (s *stubRemoteStreamClient) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} + +func renderPartsForRemoteAdapterTest(parts []providertypes.ContentPart) string { + builder := strings.Builder{} + for _, part := range parts { + if part.Kind != providertypes.ContentPartText { + continue + } + if builder.Len() > 0 { + builder.WriteByte('\n') + } + builder.WriteString(part.Text) + } + return builder.String() +} + +var _ remoteGatewayRPCClient = (*stubRemoteRPCClient)(nil) +var _ remoteGatewayStreamClient = (*stubRemoteStreamClient)(nil) +var _ agentruntime.Runtime = (*RemoteRuntimeAdapter)(nil) +var _ = tools.ToolResult{} +var _ = agentsession.Summary{}