diff --git a/cmd/mass/commands/run/command.go b/cmd/mass/commands/run/command.go index 6b7ea40..11845f3 100644 --- a/cmd/mass/commands/run/command.go +++ b/cmd/mass/commands/run/command.go @@ -141,8 +141,10 @@ func run(cmd *cobra.Command, bundle, stateDir, permissions, id string, logCfg *l } if cfg.Session.SystemPrompt != "" { - trans.NotifyTurnStart() - trans.NotifyUserPrompt([]runapi.ContentBlock{runapi.TextBlock(acpruntime.BuildSeedSystemPrompt(cfg.Session.SystemPrompt))}) + // Seed prompt targets the initial session — empty sessionID resolves + // to it inside the Translator. + trans.NotifyTurnStart("") + trans.NotifyUserPrompt("", []runapi.ContentBlock{runapi.TextBlock(acpruntime.BuildSeedSystemPrompt(cfg.Session.SystemPrompt))}) resp, err := mgr.SeedSystemPrompt(ctx) stopReason := "error" if err == nil { @@ -151,7 +153,7 @@ func run(cmd *cobra.Command, bundle, stateDir, permissions, id string, logCfg *l if err != nil { trans.NotifyError(err.Error()) } - trans.NotifyTurnEnd(acp.StopReason(stopReason)) + trans.NotifyTurnEnd("", acp.StopReason(stopReason)) if err != nil { return fmt.Errorf("agent-run: seed system prompt: %w", err) } diff --git a/cmd/massctl/commands/agent/mock_test.go b/cmd/massctl/commands/agent/mock_test.go index 87004c1..4150399 100644 --- a/cmd/massctl/commands/agent/mock_test.go +++ b/cmd/massctl/commands/agent/mock_test.go @@ -91,6 +91,20 @@ func (m *mockAgentRunOps) TaskRetry(context.Context, *pkgariapi.AgentRunTaskRetr return nil, nil } +// Multi-session stubs (no-ops — agent command tests don't exercise them). +func (m *mockAgentRunOps) PromptSession(context.Context, pkgariapi.ObjectKey, string, []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { + return nil, nil +} +func (m *mockAgentRunOps) NewSession(context.Context, *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + return nil, nil +} +func (m *mockAgentRunOps) EndSession(context.Context, pkgariapi.ObjectKey, string) error { + return nil +} +func (m *mockAgentRunOps) ListSessions(context.Context, pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) { + return nil, nil +} + // mock WorkspaceOps (stub — not used in agent tests) type mockWorkspaceOps struct{} diff --git a/cmd/massctl/commands/agentrun/command.go b/cmd/massctl/commands/agentrun/command.go index 62809a0..54dda31 100644 --- a/cmd/massctl/commands/agentrun/command.go +++ b/cmd/massctl/commands/agentrun/command.go @@ -64,5 +64,8 @@ Poll with: massctl ar get -w `, cmd.AddCommand(newTaskCmd(getClient)) cmd.AddCommand(newChatCmd(getClient)) cmd.AddCommand(newDebugCmd()) + cmd.AddCommand(newNewSessionCmd(getClient)) + cmd.AddCommand(newEndSessionCmd(getClient)) + cmd.AddCommand(newListSessionsCmd(getClient)) return cmd } diff --git a/cmd/massctl/commands/agentrun/end_session.go b/cmd/massctl/commands/agentrun/end_session.go new file mode 100644 index 0000000..275915e --- /dev/null +++ b/cmd/massctl/commands/agentrun/end_session.go @@ -0,0 +1,46 @@ +package agentrun + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/zoumo/mass/cmd/massctl/commands/cliutil" + pkgariapi "github.com/zoumo/mass/pkg/ari/api" +) + +// newEndSessionCmd implements “massctl agentrun end-session“. +// +// Releases runtime tracking of a session id. The agent process keeps its +// per-session state until cancelled or process exits — ACP has no +// explicit end-session RPC. Refuses to end an agent's initial session +// (the one created by the agentrun handshake); kill the whole agent +// instead via “stop“. +func newEndSessionCmd(getClient cliutil.ClientFn) *cobra.Command { + var ws, sessionID string + cmd := &cobra.Command{ + Use: "end-session name", + Short: "Release runtime tracking of an ACP session", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client, err := getClient() + if err != nil { + return err + } + defer client.Close() + + if err := client.AgentRuns().EndSession(context.Background(), + pkgariapi.ObjectKey{Workspace: ws, Name: args[0]}, sessionID); err != nil { + return fmt.Errorf("ending session %s on %s/%s: %w", sessionID, ws, args[0], err) + } + fmt.Fprintf(cmd.OutOrStdout(), "session %s ended on %s/%s\n", sessionID, ws, args[0]) + return nil + }, + } + cmd.Flags().StringVarP(&ws, "workspace", "w", "", "Workspace name (required)") + cmd.Flags().StringVar(&sessionID, "session-id", "", "Session id to end (required) — must not be the initial session") + _ = cmd.MarkFlagRequired("workspace") + _ = cmd.MarkFlagRequired("session-id") + return cmd +} diff --git a/cmd/massctl/commands/agentrun/list_sessions.go b/cmd/massctl/commands/agentrun/list_sessions.go new file mode 100644 index 0000000..630f6dd --- /dev/null +++ b/cmd/massctl/commands/agentrun/list_sessions.go @@ -0,0 +1,46 @@ +package agentrun + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/zoumo/mass/cmd/massctl/commands/cliutil" + pkgariapi "github.com/zoumo/mass/pkg/ari/api" +) + +// newListSessionsCmd implements “massctl agentrun list-sessions“. +// +// Prints active session ids on an agent-run, one per line. Useful for +// scripts that drive multi-session pools and need to enumerate or +// clean up sessions. +func newListSessionsCmd(getClient cliutil.ClientFn) *cobra.Command { + var ws string + cmd := &cobra.Command{ + Use: "list-sessions name", + Aliases: []string{"ls-sessions"}, + Short: "List active ACP session ids on an agent-run", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client, err := getClient() + if err != nil { + return err + } + defer client.Close() + + out, err := client.AgentRuns().ListSessions(context.Background(), + pkgariapi.ObjectKey{Workspace: ws, Name: args[0]}) + if err != nil { + return fmt.Errorf("listing sessions on %s/%s: %w", ws, args[0], err) + } + for _, id := range out.SessionIDs { + fmt.Fprintln(cmd.OutOrStdout(), id) + } + return nil + }, + } + cmd.Flags().StringVarP(&ws, "workspace", "w", "", "Workspace name (required)") + _ = cmd.MarkFlagRequired("workspace") + return cmd +} diff --git a/cmd/massctl/commands/agentrun/mock_test.go b/cmd/massctl/commands/agentrun/mock_test.go index 6bd0c7c..fc57f0d 100644 --- a/cmd/massctl/commands/agentrun/mock_test.go +++ b/cmd/massctl/commands/agentrun/mock_test.go @@ -79,6 +79,23 @@ func (m *mockAgentRunOps) TaskRetry(ctx context.Context, params *pkgariapi.Agent return &pkgariapi.AgentTask{}, nil } +// Multi-session stubs — tests in this package don't exercise these yet. +func (m *mockAgentRunOps) PromptSession(ctx context.Context, key pkgariapi.ObjectKey, sessionID string, prompt []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { + return m.Prompt(ctx, key, prompt) +} + +func (m *mockAgentRunOps) NewSession(_ context.Context, _ *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + return &pkgariapi.AgentRunNewSessionResult{SessionID: "mock-session"}, nil +} + +func (m *mockAgentRunOps) EndSession(_ context.Context, _ pkgariapi.ObjectKey, _ string) error { + return nil +} + +func (m *mockAgentRunOps) ListSessions(_ context.Context, _ pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) { + return &pkgariapi.AgentRunListSessionsResult{}, nil +} + // ── mock WorkspaceOps (stub — not used in agentrun tests) ──────────────────── type mockWorkspaceOps struct{} diff --git a/cmd/massctl/commands/agentrun/new_session.go b/cmd/massctl/commands/agentrun/new_session.go new file mode 100644 index 0000000..dc3099d --- /dev/null +++ b/cmd/massctl/commands/agentrun/new_session.go @@ -0,0 +1,60 @@ +package agentrun + +import ( + "context" + "fmt" + + "github.com/spf13/cobra" + + "github.com/zoumo/mass/cmd/massctl/commands/cliutil" + pkgariapi "github.com/zoumo/mass/pkg/ari/api" +) + +// newNewSessionCmd implements “massctl agentrun new-session“. +// +// Opens an additional ACP session on a running agent-run — agent process +// is reused (no fork+exec), but the session has its own cwd and state. +// Returns the sessionId to stdout; pass it via subsequent “prompt +// --session-id“ etc. +func newNewSessionCmd(getClient cliutil.ClientFn) *cobra.Command { + var ws, cwd string + cmd := &cobra.Command{ + Use: "new-session name", + Short: "Open an additional ACP session on a running agent-run", + Long: `Opens an additional ACP session on the running agent process. + +The session is multiplexed onto the existing process — no new fork/exec +— but has its own cwd, model state, and message history. Returns the +agent-issued sessionId on stdout; pass it to subsequent prompt / cancel +/ end-session via --session-id.`, + Example: ` # Open a fresh session scoped to /tmp/case-7 + sid=$(massctl ar new-session worker -w my-ws --cwd /tmp/case-7) + massctl ar prompt worker -w my-ws --session-id "$sid" --text "Fix the bug" + massctl ar end-session worker -w my-ws --session-id "$sid"`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + client, err := getClient() + if err != nil { + return err + } + defer client.Close() + + out, err := client.AgentRuns().NewSession(context.Background(), &pkgariapi.AgentRunNewSessionParams{ + Workspace: ws, + Name: args[0], + Cwd: cwd, + }) + if err != nil { + return fmt.Errorf("opening session for %s/%s: %w", ws, args[0], err) + } + // Print sessionId only (script-friendly). + fmt.Fprintln(cmd.OutOrStdout(), out.SessionID) + return nil + }, + } + cmd.Flags().StringVarP(&ws, "workspace", "w", "", "Workspace name (required)") + cmd.Flags().StringVar(&cwd, "cwd", "", "Session cwd (required) — the working directory the agent uses for this session") + _ = cmd.MarkFlagRequired("workspace") + _ = cmd.MarkFlagRequired("cwd") + return cmd +} diff --git a/cmd/massctl/commands/agentrun/prompt.go b/cmd/massctl/commands/agentrun/prompt.go index d991f1b..7b487f0 100644 --- a/cmd/massctl/commands/agentrun/prompt.go +++ b/cmd/massctl/commands/agentrun/prompt.go @@ -17,9 +17,10 @@ import ( func newPromptCmd(getClient cliutil.ClientFn) *cobra.Command { var ( - ws string - text string - wait bool + ws string + text string + wait bool + sessionID string ) cmd := &cobra.Command{ Use: "prompt name", @@ -37,7 +38,9 @@ func newPromptCmd(getClient cliutil.ClientFn) *cobra.Command { key := pkgariapi.ObjectKey{Workspace: ws, Name: name} if !wait { - result, err := client.AgentRuns().Prompt(ctx, key, []runapi.ContentBlock{runapi.TextBlock(text)}) + // PromptSession with empty sessionID == Prompt — single entry point. + result, err := client.AgentRuns().PromptSession(ctx, key, sessionID, + []runapi.ContentBlock{runapi.TextBlock(text)}) if err != nil { return err } @@ -75,12 +78,20 @@ func newPromptCmd(getClient cliutil.ClientFn) *cobra.Command { // Send prompt (fire-and-forget). if err := runClient.SendPrompt(ctx, &runapi.SessionPromptParams{ - Prompt: []runapi.ContentBlock{runapi.TextBlock(text)}, + SessionID: sessionID, + Prompt: []runapi.ContentBlock{runapi.TextBlock(text)}, }); err != nil { return fmt.Errorf("send_prompt: %w", err) } // Collect agent_message text until turn_end. + // + // When --session-id is set, filter events to that session — without + // this, two concurrent sessions on the same agent will cross-talk + // (agent_message from session B counted as session A's, turn_end + // from B exits early). When sessionID is empty (single-session + // legacy default), every event is accepted; Translator stamps + // initial-session events with the initial session id. var parts []string timeout := time.After(5 * time.Minute) for { @@ -89,6 +100,9 @@ func newPromptCmd(getClient cliutil.ClientFn) *cobra.Command { if !ok { return fmt.Errorf("event stream closed before turn_end") } + if sessionID != "" && ev.SessionID != sessionID { + continue + } if ev.Type == runapi.EventTypeTurnEnd { fmt.Fprintln(cmd.OutOrStdout(), strings.Join(parts, "")) return nil @@ -109,6 +123,8 @@ func newPromptCmd(getClient cliutil.ClientFn) *cobra.Command { cmd.Flags().StringVarP(&ws, "workspace", "w", "", "Workspace name (required)") cmd.Flags().StringVar(&text, "text", "", "Prompt text (required)") cmd.Flags().BoolVar(&wait, "wait", false, "Wait for turn to complete and print agent response") + cmd.Flags().StringVar(&sessionID, "session-id", "", + "Session id to prompt (defaults to the agent's initial session). Use new-session to open one.") _ = cmd.MarkFlagRequired("workspace") _ = cmd.MarkFlagRequired("text") return cmd diff --git a/cmd/massctl/commands/workspace/create/command_test.go b/cmd/massctl/commands/workspace/create/command_test.go index a38b464..e7b649a 100644 --- a/cmd/massctl/commands/workspace/create/command_test.go +++ b/cmd/massctl/commands/workspace/create/command_test.go @@ -43,6 +43,20 @@ func (m *mockAgentRunOps) TaskRetry(context.Context, *pkgariapi.AgentRunTaskRetr return &pkgariapi.AgentTask{}, nil } +// Multi-session stubs (no-ops). +func (m *mockAgentRunOps) PromptSession(context.Context, pkgariapi.ObjectKey, string, []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { + return &pkgariapi.AgentRunPromptResult{}, nil +} +func (m *mockAgentRunOps) NewSession(context.Context, *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + return &pkgariapi.AgentRunNewSessionResult{}, nil +} +func (m *mockAgentRunOps) EndSession(context.Context, pkgariapi.ObjectKey, string) error { + return nil +} +func (m *mockAgentRunOps) ListSessions(context.Context, pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) { + return &pkgariapi.AgentRunListSessionsResult{}, nil +} + // ── mock WorkspaceOps (stub — not used in create tests) ────────────────────── type mockWorkspaceOps struct{} diff --git a/cmd/massctl/commands/workspace/mock_test.go b/cmd/massctl/commands/workspace/mock_test.go index ea0a117..4efaaa3 100644 --- a/cmd/massctl/commands/workspace/mock_test.go +++ b/cmd/massctl/commands/workspace/mock_test.go @@ -59,6 +59,20 @@ func (m *mockAgentRunOps) TaskRetry(context.Context, *pkgariapi.AgentRunTaskRetr return &pkgariapi.AgentTask{}, nil } +// Multi-session stubs (no-ops — workspace command tests don't exercise them). +func (m *mockAgentRunOps) PromptSession(context.Context, pkgariapi.ObjectKey, string, []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { + return &pkgariapi.AgentRunPromptResult{}, nil +} +func (m *mockAgentRunOps) NewSession(context.Context, *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + return &pkgariapi.AgentRunNewSessionResult{}, nil +} +func (m *mockAgentRunOps) EndSession(context.Context, pkgariapi.ObjectKey, string) error { + return nil +} +func (m *mockAgentRunOps) ListSessions(context.Context, pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) { + return &pkgariapi.AgentRunListSessionsResult{}, nil +} + // ── mock SystemOps (stub — not used in workspace tests) ──────────────────────── type mockSystemOps struct{} diff --git a/pkg/agentrun/api/methods.go b/pkg/agentrun/api/methods.go index 8d470f6..a8a9ce1 100644 --- a/pkg/agentrun/api/methods.go +++ b/pkg/agentrun/api/methods.go @@ -9,6 +9,14 @@ const ( MethodSessionSetModel = "session/set_model" MethodRuntimePhase = "runtime/status" MethodRuntimeStop = "runtime/stop" + + // Multi-session RPCs — open / end additional ACP sessions on a long- + // lived agent process so callers can multiplex tasks without + // fork+exec per task. See pkg/agentrun/runtime/acp/runtime.go's + // NewSession / EndSession for the underlying runtime contract. + MethodSessionNew = "session/new" + MethodSessionEnd = "session/end" + MethodSessionList = "session/list" ) // Run notification methods. diff --git a/pkg/agentrun/api/types.go b/pkg/agentrun/api/types.go index 5cac3a3..08364ba 100644 --- a/pkg/agentrun/api/types.go +++ b/pkg/agentrun/api/types.go @@ -14,8 +14,14 @@ import ( // SessionPromptParams is the JSON body for the "session/prompt" method. // Prompt is an array of ACP ContentBlocks supporting text, image, audio, // resource, and resource-link content types. +// +// SessionID is optional: when empty, the agent-run's initial session +// (the one opened during Create's handshake) is used — preserves +// backward compatibility for single-session callers. Multi-session +// callers must pass an explicit SessionID obtained from session/new. type SessionPromptParams struct { - Prompt []ContentBlock `json:"prompt"` + SessionID string `json:"sessionId,omitempty"` + Prompt []ContentBlock `json:"prompt"` } // SessionPromptResult is returned by the "session/prompt" method. @@ -23,6 +29,14 @@ type SessionPromptResult struct { StopReason string `json:"stopReason"` } +// SessionCancelParams is the JSON body for the "session/cancel" method. +// Optional SessionID — same semantics as SessionPromptParams.SessionID. +// Pre-multi-session callers passed no params at all; both empty body and +// missing SessionID work as "cancel the initial session". +type SessionCancelParams struct { + SessionID string `json:"sessionId,omitempty"` +} + // SessionLoadParams is the JSON body for the "session/load" RPC method. // agentd always calls this during recovery for best-effort session restore. // agent-run checks ACP loadSession capability internally and auto-fallbacks. @@ -55,13 +69,57 @@ type RuntimeStatusRecovery struct { } // SessionSetModelParams is the JSON body for "session/set_model". +// Optional SessionID — same semantics as SessionPromptParams.SessionID. type SessionSetModelParams struct { - ModelID string `json:"modelId"` + SessionID string `json:"sessionId,omitempty"` + ModelID string `json:"modelId"` } // SessionSetModelResult is returned by "session/set_model". type SessionSetModelResult struct{} +// SessionNewParams is the JSON body for the "session/new" method — +// opens an additional ACP session on the running agent. +// +// Cwd is required: each session is scoped to its own working directory. +// McpServers is optional per-session MCP overrides layered on the +// bundle-level config. +type SessionNewParams struct { + Cwd string `json:"cwd"` + McpServers []SessionNewMcpServer `json:"mcpServers,omitempty"` +} + +// SessionNewMcpServer mirrors acp.McpServer's wire shape for transport +// over JSON-RPC. Kept minimal — extend as actual cases demand. +type SessionNewMcpServer struct { + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` +} + +// SessionNewResult is returned by "session/new". +type SessionNewResult struct { + SessionID string `json:"sessionId"` +} + +// SessionEndParams is the JSON body for "session/end" — removes runtime +// tracking of a session. The agent process's per-session state remains +// until cancelled or the process exits (ACP has no explicit end-session +// RPC); this method only releases the agent-run's local map entry. +type SessionEndParams struct { + SessionID string `json:"sessionId"` +} + +// SessionEndResult is returned by "session/end". +type SessionEndResult struct{} + +// SessionListResult lists active session IDs on the agent. Used by +// callers (e.g. massctl) to inspect the agent's current sessions. +type SessionListResult struct { + SessionIDs []string `json:"sessionIds"` +} + // RuntimePhaseResult is returned by "runtime/status". type RuntimePhaseResult struct { State apiruntime.State `json:"state"` diff --git a/pkg/agentrun/client/client.go b/pkg/agentrun/client/client.go index b3800bb..57aa562 100644 --- a/pkg/agentrun/client/client.go +++ b/pkg/agentrun/client/client.go @@ -52,8 +52,44 @@ func (c *Client) SendPrompt(ctx context.Context, req *runapi.SessionPromptParams return c.c.CallAsync(ctx, runapi.MethodSessionPrompt, req) } +// Cancel cancels in-flight work on the agent's initial session. +// To cancel a specific session opened via NewSession, use CancelSession. func (c *Client) Cancel(ctx context.Context) error { - return c.c.Call(ctx, runapi.MethodSessionCancel, nil, nil) + return c.c.Call(ctx, runapi.MethodSessionCancel, &runapi.SessionCancelParams{}, nil) +} + +// CancelSession cancels in-flight work on a specific session. +func (c *Client) CancelSession(ctx context.Context, sessionID string) error { + return c.c.Call(ctx, runapi.MethodSessionCancel, + &runapi.SessionCancelParams{SessionID: sessionID}, nil) +} + +// NewSession opens an additional ACP session on the running agent. cwd +// scopes the session's working directory; mcpServers are optional per- +// session MCP overrides. Returns the new session id which callers must +// pass to Prompt / Cancel / SetModel via the params SessionID field. +func (c *Client) NewSession(ctx context.Context, req *runapi.SessionNewParams) (*runapi.SessionNewResult, error) { + var result runapi.SessionNewResult + if err := c.c.Call(ctx, runapi.MethodSessionNew, req, &result); err != nil { + return nil, err + } + return &result, nil +} + +// EndSession releases runtime tracking of a session. The agent process's +// per-session state remains until cancelled or the process exits. +func (c *Client) EndSession(ctx context.Context, sessionID string) error { + return c.c.Call(ctx, runapi.MethodSessionEnd, + &runapi.SessionEndParams{SessionID: sessionID}, nil) +} + +// ListSessions returns the agent's active session ids. +func (c *Client) ListSessions(ctx context.Context) (*runapi.SessionListResult, error) { + var result runapi.SessionListResult + if err := c.c.Call(ctx, runapi.MethodSessionList, nil, &result); err != nil { + return nil, err + } + return &result, nil } func (c *Client) Load(ctx context.Context, req *runapi.SessionLoadParams) error { diff --git a/pkg/agentrun/client/client_test.go b/pkg/agentrun/client/client_test.go index 97b7079..c672c73 100644 --- a/pkg/agentrun/client/client_test.go +++ b/pkg/agentrun/client/client_test.go @@ -28,7 +28,7 @@ type stubRunService struct { func (s *stubRunService) Prompt(_ context.Context, req *runapi.SessionPromptParams) (*runapi.SessionPromptResult, error) { return &s.promptResult, nil } -func (s *stubRunService) Cancel(_ context.Context) error { return nil } +func (s *stubRunService) Cancel(_ context.Context, _ *runapi.SessionCancelParams) error { return nil } func (s *stubRunService) Load(_ context.Context, _ *runapi.SessionLoadParams) error { return nil } @@ -46,6 +46,18 @@ func (s *stubRunService) SetModel(_ context.Context, _ *runapi.SessionSetModelPa } func (s *stubRunService) Stop(_ context.Context) error { return nil } +func (s *stubRunService) NewSession(_ context.Context, req *runapi.SessionNewParams) (*runapi.SessionNewResult, error) { + return &runapi.SessionNewResult{SessionID: "stub-session-" + req.Cwd}, nil +} + +func (s *stubRunService) EndSession(_ context.Context, _ *runapi.SessionEndParams) (*runapi.SessionEndResult, error) { + return &runapi.SessionEndResult{}, nil +} + +func (s *stubRunService) ListSessions(_ context.Context) (*runapi.SessionListResult, error) { + return &runapi.SessionListResult{SessionIDs: []string{}}, nil +} + // startTestServer starts a jsonrpc.Server with Register on a temp socket. func startTestServer(t *testing.T, svc runserver.Handler) string { t.Helper() diff --git a/pkg/agentrun/client/watch_test.go b/pkg/agentrun/client/watch_test.go index dd60585..9500266 100644 --- a/pkg/agentrun/client/watch_test.go +++ b/pkg/agentrun/client/watch_test.go @@ -30,8 +30,8 @@ type replayService struct { func (s *replayService) Prompt(context.Context, *runapi.SessionPromptParams) (*runapi.SessionPromptResult, error) { return &runapi.SessionPromptResult{}, nil } -func (s *replayService) Cancel(context.Context) error { return nil } -func (s *replayService) Load(context.Context, *runapi.SessionLoadParams) error { return nil } +func (s *replayService) Cancel(context.Context, *runapi.SessionCancelParams) error { return nil } +func (s *replayService) Load(context.Context, *runapi.SessionLoadParams) error { return nil } func (s *replayService) SetModel(context.Context, *runapi.SessionSetModelParams) (*runapi.SessionSetModelResult, error) { return &runapi.SessionSetModelResult{}, nil } @@ -41,6 +41,16 @@ func (s *replayService) Status(context.Context) (*runapi.RuntimePhaseResult, err } func (s *replayService) Stop(context.Context) error { return nil } +func (s *replayService) NewSession(context.Context, *runapi.SessionNewParams) (*runapi.SessionNewResult, error) { + return &runapi.SessionNewResult{}, nil +} +func (s *replayService) EndSession(context.Context, *runapi.SessionEndParams) (*runapi.SessionEndResult, error) { + return &runapi.SessionEndResult{}, nil +} +func (s *replayService) ListSessions(context.Context) (*runapi.SessionListResult, error) { + return &runapi.SessionListResult{}, nil +} + func (s *replayService) WatchEvent(ctx context.Context, req *runapi.SessionWatchEventParams, watchID string) (*runapi.SessionWatchEventResult, error) { peer := jsonrpc.PeerFromContext(ctx) diff --git a/pkg/agentrun/runtime/acp/runtime.go b/pkg/agentrun/runtime/acp/runtime.go index 877b4ac..4e35284 100644 --- a/pkg/agentrun/runtime/acp/runtime.go +++ b/pkg/agentrun/runtime/acp/runtime.go @@ -57,19 +57,60 @@ type StateChange struct { type StateChangeHook func(StateChange) // Manager manages the lifecycle of a single ACP agent process. +// sessionState holds per-session protocol metadata. Multiple sessions may +// co-exist on one agent process — ACP's session/new is multi-session by +// design (each session has its own cwd, sessionId, model state). The +// agent process (e.g. claude-agent-acp) keeps sessions in a map keyed by +// sessionId; this struct mirrors the runtime-side view of each. +// +// Multi-session lets callers keep a long-lived agent process and switch +// session per task, instead of fork+kill an agent process per task. +type sessionState struct { + id acp.SessionId + models *acp.SessionModelState + cwd string + // inflight counts active PromptSession turns. EndSession refuses to + // release a session with inflight > 0; the prompt's defer decrements it. + inflight int +} + type Manager struct { cfg apiruntime.Config bundleDir string stateDir string logger *slog.Logger - mu sync.Mutex - cmd *exec.Cmd - processDone chan struct{} - conn *acp.ClientSideConnection - sessionID acp.SessionId - events chan acp.SessionNotification - models *acp.SessionModelState // from session/new response + mu sync.Mutex + cmd *exec.Cmd + processDone chan struct{} + conn *acp.ClientSideConnection + events chan acp.SessionNotification + + // sessions holds all active ACP sessions on this agent process, keyed + // by sessionId. NewSession adds entries; EndSession removes them. + // Empty until Create()'s initial session/new completes. Per-session + // inflight prompt count lives on each sessionState — see EndSession. + sessions map[acp.SessionId]*sessionState + + // activePrompts is the total number of in-flight PromptSession calls + // across all sessions. state.json's Phase field is process-wide, so + // it must reflect "any session running" — not "the last prompt that + // finished". PromptSession bumps this; the deferred decrement + + // writeState then re-reads it under m.mu to decide Running vs Idle. + activePrompts int + + // sessionID retains the *first* session created by Create() so legacy + // callers of Prompt/Cancel/SetModel (which don't pass sessionId) keep + // working without code changes. New callers should use the explicit + // sessionId-taking variants (PromptSession etc.). To be removed once + // all callers migrate. + sessionID acp.SessionId + + // models tracks the *first* session's models for the same backward- + // compat reason as sessionID above. Per-session models live in + // sessions[id].models. + models *acp.SessionModelState + stateChangeHook StateChangeHook eventCountsFn func() map[string]int usageFn func() *apiruntime.UsageInfo @@ -83,6 +124,7 @@ func New(cfg apiruntime.Config, bundleDir, stateDir string, logger *slog.Logger) stateDir: stateDir, logger: logger.With("subsystem", "runtime"), events: make(chan acp.SessionNotification, 1024), + sessions: make(map[acp.SessionId]*sessionState), } } @@ -197,8 +239,17 @@ func (m *Manager) Create(ctx context.Context) error { return fmt.Errorf("runtime: acp session/new: %w", err) } m.mu.Lock() + // Initial session populates both the legacy single-session fields + // (sessionID/models) and the multi-session map. Future sessions opened + // via NewSession only register in the map; sessionID stays pinned to + // the first one for backward compat with Prompt/Cancel/SetModel. m.sessionID = sessionResp.SessionId m.models = sessionResp.Models + m.sessions[sessionResp.SessionId] = &sessionState{ + id: sessionResp.SessionId, + models: sessionResp.Models, + cwd: workDir, + } m.mu.Unlock() m.logger.Info("session created", "sessionID", sessionResp.SessionId) @@ -224,6 +275,7 @@ func (m *Manager) Create(ctx context.Context) error { defer close(processDone) _ = cmd.Wait() m.logger.Info("process exited") + m.clearSessions() _ = m.writeState(func(s *apiruntime.State) { s.MassVersion = m.cfg.MassVersion s.ID = m.cfg.Metadata.Name @@ -279,6 +331,7 @@ func (m *Manager) Kill(ctx context.Context) error { } } + m.clearSessions() return m.writeState(func(s *apiruntime.State) { s.MassVersion = m.cfg.MassVersion s.ID = m.cfg.Metadata.Name @@ -288,6 +341,25 @@ func (m *Manager) Kill(ctx context.Context) error { }, "runtime-stop") } +// clearSessions resets the multi-session bookkeeping after the agent +// process has exited (via Kill or natural exit). Without this, Sessions() +// returns stale IDs, EndSession silently succeeds on a dead session, and +// PromptSession passes its conn != nil check before failing on the dead +// pipe — invariants of the session map are broken across the teardown. +// +// In-flight RPCs hold their own captured conn pointer (read under m.mu at +// entry, used without lock); nil-ing m.conn here only guards *future* +// calls. Existing goroutines will fail on the closed pipe, which is fine. +func (m *Manager) clearSessions() { + m.mu.Lock() + defer m.mu.Unlock() + m.sessions = map[acp.SessionId]*sessionState{} + m.conn = nil + m.sessionID = "" + m.models = nil + m.activePrompts = 0 +} + // Delete removes the agent state directory. The agent must be stopped first. func (m *Manager) Delete() error { s, err := spec.ReadState(m.stateDir) @@ -305,24 +377,167 @@ func (m *Manager) GetState() (apiruntime.State, error) { return spec.ReadState(m.stateDir) } -// Prompt sends a user prompt to the agent and blocks until the agent -// returns a PromptResponse. Session notifications emitted by the agent -// during the turn are forwarded to the Events channel. -// On completion (success or error), LastTurn is persisted to state.json. -func (m *Manager) Prompt(ctx context.Context, prompt []acp.ContentBlock) (acp.PromptResponse, error) { +// NewSession opens an additional ACP session on the running agent process. +// The agent must already be started via Create() (which creates the initial +// session). Returns the new session's ID, which callers must pass to +// PromptSession / CancelSession / EndSession to address this session. +// +// cwd is required — each session is scoped to its own working directory. +// All wire-level callers (Service / ARI / CLI) already reject empty cwd; +// the runtime validates here too for in-process callers and so the +// contract is consistent across layers. +// +// mcpServers (optional) are extra MCP servers scoped to this session, layered +// on top of the agent's bundle-level mcpServers. +func (m *Manager) NewSession(ctx context.Context, cwd string, mcpServers []acp.McpServer) (acp.SessionId, error) { + if cwd == "" { + return "", fmt.Errorf("runtime: new session: cwd is required") + } + m.mu.Lock() conn := m.conn - sessionID := m.sessionID m.mu.Unlock() if conn == nil { + return "", fmt.Errorf("runtime: agent not started") + } + + // Layer session-scoped mcpServers on top of bundle-level config. Callers + // that pass nil/empty get the bundle defaults. + merged := convertMcpServers(m.cfg.Session.McpServers) + merged = append(merged, mcpServers...) + + req := acp.NewSessionRequest{ + Meta: m.cfg.Session.Meta, + Cwd: cwd, + McpServers: merged, + } + resp, err := conn.NewSession(ctx, req) + if err != nil { + return "", fmt.Errorf("runtime: acp session/new: %w", err) + } + + m.mu.Lock() + m.sessions[resp.SessionId] = &sessionState{ + id: resp.SessionId, + models: resp.Models, + cwd: cwd, + } + m.mu.Unlock() + m.logger.Info("session opened", "sessionID", resp.SessionId, "cwd", cwd) + return resp.SessionId, nil +} + +// EndSession releases runtime tracking of a session id. The ACP protocol +// has no explicit "end session" RPC today, so this method only clears the +// runtime's local map entry — the agent process retains its own session +// state until cancelled or the process exits. +// +// EndSession refuses to release a session that has in-flight Prompt work +// (returns ErrSessionBusy). Callers must CancelSession (or wait for prompt +// completion) first; otherwise the prompt completes against a session no +// longer tracked, and the resulting state.json Phase / log lines diverge +// from the sessions map. +// +// Refuses to end the initial session (the one Create() opened) — callers +// that want to fully tear down the agent should call Kill() instead. +func (m *Manager) EndSession(sessionID acp.SessionId) error { + m.mu.Lock() + defer m.mu.Unlock() + if sessionID == m.sessionID { + return fmt.Errorf("runtime: cannot end initial session %q (kill the agent instead)", sessionID) + } + sess, ok := m.sessions[sessionID] + if !ok { + return fmt.Errorf("runtime: session %q not found", sessionID) + } + if sess.inflight > 0 { + return fmt.Errorf("runtime: session %q busy: %d in-flight prompt(s) (cancel first)", sessionID, sess.inflight) + } + delete(m.sessions, sessionID) + m.logger.Info("session ended", "sessionID", sessionID) + return nil +} + +// resolveSessionLocked maps an empty sessionID to the initial session. +// Caller must hold m.mu. +func (m *Manager) resolveSessionLocked(sessionID acp.SessionId) acp.SessionId { + if sessionID == "" { + return m.sessionID + } + return sessionID +} + +// decrementPromptInflight is the deferred cleanup for PromptSession. The +// session may have been deleted by Kill/clearSessions in the window between +// PromptSession's entry and conn.Prompt returning; both decrements guard +// for that — sess.inflight only if the session still exists, and +// activePrompts only when positive so a racing clearSessions (which sets +// activePrompts to 0) doesn't drive the counter negative. +func (m *Manager) decrementPromptInflight(sessionID acp.SessionId) { + m.mu.Lock() + defer m.mu.Unlock() + if s, ok := m.sessions[sessionID]; ok { + s.inflight-- + } + if m.activePrompts > 0 { + m.activePrompts-- + } +} + +// SessionIDs returns a snapshot of currently-active session IDs. Order is +// not stable — caller should sort if a deterministic order is needed. +func (m *Manager) SessionIDs() []acp.SessionId { + m.mu.Lock() + defer m.mu.Unlock() + ids := make([]acp.SessionId, 0, len(m.sessions)) + for id := range m.sessions { + ids = append(ids, id) + } + return ids +} + +// Prompt is a backward-compat shim — equivalent to PromptSession(ctx, "", prompt), +// which resolves empty SessionId to the initial session. +func (m *Manager) Prompt(ctx context.Context, prompt []acp.ContentBlock) (acp.PromptResponse, error) { + return m.PromptSession(ctx, "", prompt) +} + +// PromptSession sends a user prompt to a specific session and blocks until +// the agent returns a PromptResponse. Empty sessionID is resolved to the +// agent's initial session (the one Create() opened) — single resolution +// point for the empty-string convention. +// +// Session notifications emitted by the agent during the turn are forwarded +// to the Events channel. On completion (success or error), state.json is +// updated. +// +// The state.json Phase field is single-valued today, so the "PhaseRunning" +// stamp applies process-wide — when multiple sessions are in-flight +// concurrently, phase is "running" if any session is. Per-session phase +// tracking is a follow-up state-schema change. +func (m *Manager) PromptSession(ctx context.Context, sessionID acp.SessionId, prompt []acp.ContentBlock) (acp.PromptResponse, error) { + m.mu.Lock() + conn := m.conn + sessionID = m.resolveSessionLocked(sessionID) + sess, known := m.sessions[sessionID] + if conn == nil { + m.mu.Unlock() return acp.PromptResponse{}, fmt.Errorf("runtime: agent not started") } - m.logger.Debug("prompt started", "blocks", len(prompt)) + if !known { + m.mu.Unlock() + return acp.PromptResponse{}, fmt.Errorf("runtime: prompt: session %q not found", sessionID) + } + sess.inflight++ + m.activePrompts++ + m.mu.Unlock() - _ = m.writeState(func(s *apiruntime.State) { - s.Phase = apiruntime.PhaseRunning - }, "prompt-started") + defer m.decrementPromptInflight(sessionID) + + m.logger.Debug("prompt started", "sessionID", sessionID, "blocks", len(prompt)) + + _ = m.writeState(m.phaseFromActivePromptsLocked, "prompt-started") resp, err := conn.Prompt(ctx, acp.PromptRequest{ SessionId: sessionID, @@ -334,10 +549,8 @@ func (m *Manager) Prompt(ctx context.Context, prompt []acp.ContentBlock) (acp.Pr if err != nil { reason = "prompt-failed" } - m.logger.Debug("prompt done", "reason", reason) - _ = m.writeState(func(s *apiruntime.State) { - s.Phase = apiruntime.PhaseIdle - }, reason) + m.logger.Debug("prompt done", "sessionID", sessionID, "reason", reason) + _ = m.writeState(m.phaseFromActivePromptsLocked, reason) } if err != nil { @@ -346,17 +559,47 @@ func (m *Manager) Prompt(ctx context.Context, prompt []acp.ContentBlock) (acp.Pr return resp, nil } -// Cancel sends a cancel notification to the agent for the current session. +// phaseFromActivePromptsLocked sets state.Phase based on the current +// activePrompts counter. Used as the writeState apply callback for prompt +// start/end so Phase reflects "any session running" rather than the last +// caller's local view. Runs under m.mu (writeState locks before calling). +// +// Skips the write when m.conn is nil — Kill / clearSessions has taken +// ownership of lifecycle phase (Stopped), and a late-firing PromptSession +// writeState (from a prompt that was in flight when Kill ran) must not +// clobber that with Idle / Running. +func (m *Manager) phaseFromActivePromptsLocked(s *apiruntime.State) { + if m.conn == nil { + return + } + if m.activePrompts > 0 { + s.Phase = apiruntime.PhaseRunning + } else { + s.Phase = apiruntime.PhaseIdle + } +} + +// Cancel is a backward-compat shim — equivalent to CancelSession(ctx, ""). func (m *Manager) Cancel(ctx context.Context) error { + return m.CancelSession(ctx, "") +} + +// CancelSession sends a cancel notification for a specific session. Empty +// sessionID resolves to the initial session. +func (m *Manager) CancelSession(ctx context.Context, sessionID acp.SessionId) error { m.mu.Lock() conn := m.conn - sessionID := m.sessionID + sessionID = m.resolveSessionLocked(sessionID) + _, known := m.sessions[sessionID] m.mu.Unlock() if conn == nil { return fmt.Errorf("runtime: agent not started") } - m.logger.Debug("cancel") + if !known { + return fmt.Errorf("runtime: cancel: session %q not found", sessionID) + } + m.logger.Debug("cancel", "sessionID", sessionID) if err := conn.Cancel(ctx, acp.CancelNotification{SessionId: sessionID}); err != nil { return fmt.Errorf("runtime: cancel: %w", err) @@ -364,17 +607,31 @@ func (m *Manager) Cancel(ctx context.Context) error { return nil } -// SetModel switches the agent to a different model via ACP session/set_model. +// SetModel is a backward-compat shim — equivalent to SetModelSession(ctx, "", modelID). func (m *Manager) SetModel(ctx context.Context, modelID string) error { + return m.SetModelSession(ctx, "", modelID) +} + +// SetModelSession switches a specific session to a different model via ACP +// session/set_model. Empty sessionID resolves to the initial session. +// Updates in-memory per-session models state; for the initial session +// also mirrors to the legacy state.json Session.Models field. Per-session +// models persistence is a follow-up state-schema change. +func (m *Manager) SetModelSession(ctx context.Context, sessionID acp.SessionId, modelID string) error { m.mu.Lock() conn := m.conn - sessionID := m.sessionID + sessionID = m.resolveSessionLocked(sessionID) + _, known := m.sessions[sessionID] + initialID := m.sessionID m.mu.Unlock() if conn == nil { return fmt.Errorf("runtime: agent not started") } - m.logger.Debug("set_model", "modelID", modelID) + if !known { + return fmt.Errorf("runtime: set_model: session %q not found", sessionID) + } + m.logger.Debug("set_model", "sessionID", sessionID, "modelID", modelID) _, err := conn.UnstableSetSessionModel(ctx, acp.UnstableSetSessionModelRequest{ SessionId: sessionID, @@ -384,19 +641,30 @@ func (m *Manager) SetModel(ctx context.Context, modelID string) error { return fmt.Errorf("runtime: set model: %w", err) } - // Update in-memory models state. + // Re-lookup under the second lock — the session could have been ended + // (or the whole map cleared by teardown) between the first unlock and + // here, in which case the in-memory model mutation should silently no-op + // rather than write through an orphan struct. m.mu.Lock() - if m.models != nil { + if sess, ok := m.sessions[sessionID]; ok && sess.models != nil { + sess.models.CurrentModelId = acp.ModelId(modelID) //nolint:gosec // ModelId is string + } + // Legacy single-session models mirror for the initial session only. + if sessionID == initialID && m.models != nil { m.models.CurrentModelId = acp.ModelId(modelID) //nolint:gosec // ModelId is string } m.mu.Unlock() - // Persist updated currentModelId to state.json. - _ = m.writeState(func(s *apiruntime.State) { - if s.Session != nil && s.Session.Models != nil { - s.Session.Models.CurrentModelId = modelID - } - }, "set-model") + // state.json's Session.Models is single-session; only update for the + // initial session to keep legacy semantics. Multi-session persistence + // is a follow-up. + if sessionID == initialID { + _ = m.writeState(func(s *apiruntime.State) { + if s.Session != nil && s.Session.Models != nil { + s.Session.Models.CurrentModelId = modelID + } + }, "set-model") + } return nil } diff --git a/pkg/agentrun/runtime/acp/runtime_internal_test.go b/pkg/agentrun/runtime/acp/runtime_internal_test.go index bb3568d..2b9187e 100644 --- a/pkg/agentrun/runtime/acp/runtime_internal_test.go +++ b/pkg/agentrun/runtime/acp/runtime_internal_test.go @@ -1,8 +1,15 @@ package acp import ( + "context" + "log/slog" + "sort" "strings" "testing" + + acp "github.com/coder/acp-go-sdk" + + apiruntime "github.com/zoumo/mass/pkg/runtime-spec/api" ) func TestBuildSeedSystemPrompt_AppendsGuard(t *testing.T) { @@ -20,3 +27,329 @@ func TestBuildSeedSystemPrompt_AppendsGuard(t *testing.T) { t.Fatalf("expected seed prompt to contain %q, got %q", want, seed) } } + +// newManagerForSessionTest builds a Manager with the sessions map +// initialized, bypassing Create() (which needs a real ACP agent +// process). Lets unit tests exercise the session bookkeeping logic +// (NewSession map insert, EndSession map delete, Sessions snapshot, +// error paths) without spinning up an agent. +// +// The returned Manager has nil conn; methods that go through ACP +// (PromptSession etc.) return "agent not started". +func newManagerForSessionTest(t *testing.T) *Manager { + t.Helper() + return &Manager{ + logger: slog.Default(), + sessions: make(map[acp.SessionId]*sessionState), + } +} + +func TestSessionIDs_EmptyWhenNoneRegistered(t *testing.T) { + m := newManagerForSessionTest(t) + if got := m.SessionIDs(); len(got) != 0 { + t.Fatalf("expected empty sessions, got %v", got) + } +} + +func TestSessionIDs_SnapshotIncludesRegistered(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessions["sess-a"] = &sessionState{id: "sess-a", cwd: "/a"} + m.sessions["sess-b"] = &sessionState{id: "sess-b", cwd: "/b"} + m.sessionID = "sess-a" + + got := m.SessionIDs() + if len(got) != 2 { + t.Fatalf("expected 2 sessions, got %v", got) + } + strs := make([]string, len(got)) + for i, s := range got { + strs[i] = string(s) + } + sort.Strings(strs) + want := []string{"sess-a", "sess-b"} + for i := range want { + if strs[i] != want[i] { + t.Fatalf("session %d: want %q got %q", i, want[i], strs[i]) + } + } +} + +func TestEndSession_RemovesFromMap(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + m.sessions["extra"] = &sessionState{id: "extra", cwd: "/x"} + + if err := m.EndSession("extra"); err != nil { + t.Fatalf("end extra: %v", err) + } + if _, ok := m.sessions["extra"]; ok { + t.Fatalf("extra session not removed from map") + } + if _, ok := m.sessions["initial"]; !ok { + t.Fatalf("initial session unexpectedly removed") + } +} + +func TestEndSession_RefusesInitialSession(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + + err := m.EndSession("initial") + if err == nil { + t.Fatalf("expected EndSession to refuse the initial session") + } + if !strings.Contains(err.Error(), "initial session") { + t.Fatalf("expected error to mention initial session, got: %v", err) + } + if _, ok := m.sessions["initial"]; !ok { + t.Fatalf("initial session removed despite error") + } +} + +func TestEndSession_UnknownSessionErrors(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + + err := m.EndSession("never-existed") + if err == nil { + t.Fatalf("expected EndSession to error on unknown session") + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("expected error to mention 'not found', got: %v", err) + } +} + +func TestPromptSession_RejectsBeforeAgentStarted(t *testing.T) { + // Manager with no conn (Create never ran). PromptSession should + // return a clear error rather than panic on nil conn — protects + // callers that race or skip Create(). + m := newManagerForSessionTest(t) + m.sessions["known"] = &sessionState{id: "known"} + + _, err := m.PromptSession(context.Background(), "known", + []acp.ContentBlock{acp.TextBlock("hi")}) + if err == nil { + t.Fatalf("expected error from PromptSession with nil conn") + } + if !strings.Contains(err.Error(), "agent not started") { + t.Fatalf("expected 'agent not started' error, got: %v", err) + } +} + +func TestCancelSession_RejectsBeforeAgentStarted(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessions["known"] = &sessionState{id: "known"} + + err := m.CancelSession(context.Background(), "known") + if err == nil { + t.Fatalf("expected error from CancelSession with nil conn") + } + if !strings.Contains(err.Error(), "agent not started") { + t.Fatalf("expected 'agent not started' error, got: %v", err) + } +} + +// TestEndSession_RefusesBusySession verifies that EndSession returns a +// busy error when a session has in-flight Prompt work (refcount > 0). +// Without this check, ending a session mid-prompt leaves the prompt +// running against a session no longer in the map — state.json bookkeeping +// drifts silently. +func TestEndSession_RefusesBusySession(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + m.sessions["busy"] = &sessionState{id: "busy", cwd: "/b", inflight: 1} + + err := m.EndSession("busy") + if err == nil { + t.Fatalf("expected EndSession to refuse a busy session") + } + if !strings.Contains(err.Error(), "busy") { + t.Fatalf("expected error to mention 'busy', got: %v", err) + } + if _, ok := m.sessions["busy"]; !ok { + t.Fatalf("busy session removed despite error") + } +} + +// TestEndSession_AllowsAfterRefcountClears verifies that a session can +// be ended once its in-flight refcount drops back to zero — the busy +// check is per-state, not permanent. +func TestEndSession_AllowsAfterRefcountClears(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + m.sessions["s"] = &sessionState{id: "s", cwd: "/s", inflight: 1} + + if err := m.EndSession("s"); err == nil { + t.Fatalf("expected busy error first") + } + + m.sessions["s"].inflight = 0 // prompt completed + + if err := m.EndSession("s"); err != nil { + t.Fatalf("expected EndSession to succeed after refcount cleared: %v", err) + } + if _, ok := m.sessions["s"]; ok { + t.Fatalf("session not removed after EndSession") + } +} + +// TestPromptSession_EmptySessionIDResolvesToInitial verifies the +// single-resolution-point convention: an empty sessionID is resolved to +// m.sessionID inside PromptSession (rather than at the Service or +// Manager.Prompt-shim layer). Reaching the "session not found" branch +// would mean the empty-string convention leaked through unresolved. +func TestPromptSession_EmptySessionIDResolvesToInitial(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + + // nil conn so we get "agent not started" instead of nil-deref — + // the point is to confirm we got past the session-lookup step, + // which would have returned "session %q not found" for the literal + // empty string had resolution not happened. + _, err := m.PromptSession(context.Background(), "", + []acp.ContentBlock{acp.TextBlock("hi")}) + if err == nil { + t.Fatalf("expected error (agent not started)") + } + if strings.Contains(err.Error(), "not found") { + t.Fatalf("empty sessionID was not resolved to initial; got: %v", err) + } + if !strings.Contains(err.Error(), "agent not started") { + t.Fatalf("expected 'agent not started', got: %v", err) + } +} + +// TestPhaseFromActivePrompts verifies that state.Phase reflects the +// process-wide activePrompts counter, not a single caller's local view. +// Before this fix, two concurrent PromptSessions racing on completion +// could leave Phase=Idle while one was still running. +func TestPhaseFromActivePrompts(t *testing.T) { + m := newManagerForSessionTest(t) + // Non-nil conn means the agent is alive — phaseFromActivePromptsLocked + // only writes Phase when m.conn != nil (it cedes lifecycle to Kill + // otherwise; see TestPhaseFromActivePromptsLocked_SkipsAfterKill). + m.conn = &acp.ClientSideConnection{} + + cases := []struct { + name string + active int + want apiruntime.Phase + }{ + {"zero prompts → idle", 0, apiruntime.PhaseIdle}, + {"one prompt → running", 1, apiruntime.PhaseRunning}, + {"many prompts → running", 5, apiruntime.PhaseRunning}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + m.activePrompts = tc.active + var s apiruntime.State + m.phaseFromActivePromptsLocked(&s) + if s.Phase != tc.want { + t.Fatalf("activePrompts=%d: want Phase=%q got %q", tc.active, tc.want, s.Phase) + } + }) + } +} + +// TestPhaseFromActivePromptsLocked_SkipsAfterKill verifies that the apply +// callback leaves state.Phase untouched when m.conn is nil. This is the +// post-Kill state — Kill / clearSessions owns lifecycle phase (Stopped), +// and a late-firing PromptSession writeState (from a prompt that was in +// flight when Kill ran and finally errored on the dead pipe) must NOT +// clobber Stopped with Idle/Running. +func TestPhaseFromActivePromptsLocked_SkipsAfterKill(t *testing.T) { + m := newManagerForSessionTest(t) + m.conn = nil // post-Kill: clearSessions nilled the conn + m.activePrompts = 1 // would normally cause apply to write Running + + s := apiruntime.State{Phase: apiruntime.PhaseStopped} + m.phaseFromActivePromptsLocked(&s) + + if s.Phase != apiruntime.PhaseStopped { + t.Fatalf("Phase should remain Stopped after Kill, got %q", s.Phase) + } +} + +// TestClearSessions verifies that the bookkeeping reset on agent +// teardown (Kill / process exit) drops every field that could leave a +// stale view behind. Without this, Sessions() would return dead IDs and +// PromptSession would pass its conn != nil guard before failing on the +// dead pipe. +func TestClearSessions(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + m.sessions["extra"] = &sessionState{id: "extra"} + m.activePrompts = 2 + m.models = &acp.SessionModelState{} + // conn is left nil — clearing a nil conn is a no-op, which is the + // branch we exercise here. The non-nil-clearing path is exercised + // indirectly via Kill()'s integration with a real process. + + m.clearSessions() + + if len(m.sessions) != 0 { + t.Errorf("sessions not cleared: %v", m.sessions) + } + if m.sessionID != "" { + t.Errorf("sessionID not cleared: %q", m.sessionID) + } + if m.activePrompts != 0 { + t.Errorf("activePrompts not cleared: %d", m.activePrompts) + } + if m.models != nil { + t.Errorf("models not cleared: %v", m.models) + } + if m.conn != nil { + t.Errorf("conn not cleared: %v", m.conn) + } +} + +// TestDecrementPromptInflight_GuardsAgainstNegativeAfterClear pins the +// guard against a PromptSession defer racing with Kill: clearSessions +// resets activePrompts to 0, then the still-in-flight prompt's deferred +// cleanup runs. Without the > 0 check the counter would slip to -1 and +// stay there until next clearSessions; with it, the decrement is a +// no-op and the counter remains coherent. +func TestDecrementPromptInflight_GuardsAgainstNegativeAfterClear(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial"} + + // Simulate the post-clearSessions state: counters reset, session map + // emptied (the in-flight prompt's session is already gone). + m.activePrompts = 0 + delete(m.sessions, "initial") + + // Run the deferred cleanup that PromptSession would have queued. + m.decrementPromptInflight("initial") + + if m.activePrompts != 0 { + t.Errorf("activePrompts should stay at 0 after racing clearSessions, got %d", m.activePrompts) + } +} + +// TestDecrementPromptInflight_NormalPath verifies that the happy-path +// decrement (no race) still works — sess.inflight and m.activePrompts +// both go down by one. +func TestDecrementPromptInflight_NormalPath(t *testing.T) { + m := newManagerForSessionTest(t) + m.sessionID = "initial" + m.sessions["initial"] = &sessionState{id: "initial", inflight: 1} + m.activePrompts = 1 + + m.decrementPromptInflight("initial") + + if got := m.sessions["initial"].inflight; got != 0 { + t.Errorf("session inflight should be 0, got %d", got) + } + if m.activePrompts != 0 { + t.Errorf("activePrompts should be 0, got %d", m.activePrompts) + } +} diff --git a/pkg/agentrun/server/register.go b/pkg/agentrun/server/register.go index 7bb53a7..9d9d121 100644 --- a/pkg/agentrun/server/register.go +++ b/pkg/agentrun/server/register.go @@ -17,7 +17,7 @@ type watchEventWire struct { // These are the methods exposed by agent-run over a Unix socket. type Handler interface { Prompt(ctx context.Context, req *runapi.SessionPromptParams) (*runapi.SessionPromptResult, error) - Cancel(ctx context.Context) error + Cancel(ctx context.Context, req *runapi.SessionCancelParams) error Load(ctx context.Context, req *runapi.SessionLoadParams) error // WatchEvent implements K8s List-Watch style event subscription. // When FromSeq is nil, only live events are streamed. @@ -34,16 +34,26 @@ type Handler interface { SetModel(ctx context.Context, req *runapi.SessionSetModelParams) (*runapi.SessionSetModelResult, error) Status(ctx context.Context) (*runapi.RuntimePhaseResult, error) Stop(ctx context.Context) error + + // Multi-session methods (see runtime/acp Manager.NewSession etc.). + NewSession(ctx context.Context, req *runapi.SessionNewParams) (*runapi.SessionNewResult, error) + EndSession(ctx context.Context, req *runapi.SessionEndParams) (*runapi.SessionEndResult, error) + ListSessions(ctx context.Context) (*runapi.SessionListResult, error) } // Register registers a Handler implementation with the server. func Register(s *jsonrpc.Server, svc Handler) { s.RegisterService("session", &jsonrpc.ServiceDesc{ Methods: map[string]jsonrpc.Method{ - "prompt": jsonrpc.UnaryMethod(svc.Prompt), - "cancel": jsonrpc.NullaryCommand(svc.Cancel), + "prompt": jsonrpc.UnaryMethod(svc.Prompt), + // cancel tolerates absent params (back-compat for callers that + // omit the field — pre-multi-session wire shape). + "cancel": jsonrpc.OptionalUnaryCommand(svc.Cancel), "load": jsonrpc.UnaryCommand(svc.Load), "set_model": jsonrpc.UnaryMethod(svc.SetModel), + "new": jsonrpc.UnaryMethod(svc.NewSession), + "end": jsonrpc.UnaryMethod(svc.EndSession), + "list": jsonrpc.NullaryMethod(svc.ListSessions), }, }) s.RegisterService("runtime", &jsonrpc.ServiceDesc{ diff --git a/pkg/agentrun/server/service.go b/pkg/agentrun/server/service.go index f1d8164..1be6752 100644 --- a/pkg/agentrun/server/service.go +++ b/pkg/agentrun/server/service.go @@ -11,6 +11,15 @@ import ( "github.com/zoumo/mass/pkg/jsonrpc" ) +// resolveSessionID maps the wire-level optional sessionId string to the +// typed acp.SessionId. Empty → resolved to the Manager's initial session +// inside PromptSession / CancelSession / SetModelSession themselves. This +// keeps the empty-string-means-initial convention in exactly one layer +// (the runtime) rather than duplicating it across client / service / runtime. +func resolveSessionID(s string) acp.SessionId { + return acp.SessionId(s) +} + // Service implements Handler. type Service struct { mgr *acpruntime.Manager @@ -27,10 +36,14 @@ func (s *Service) Prompt(ctx context.Context, req *runapi.SessionPromptParams) ( if len(req.Prompt) == 0 { return nil, jsonrpc.ErrInvalidParams("missing prompt") } - s.logger.Debug("prompt", "blocks", len(req.Prompt)) - s.trans.NotifyTurnStart() - s.trans.NotifyUserPrompt(req.Prompt) - resp, err := s.mgr.Prompt(ctx, req.Prompt) + s.logger.Debug("prompt", "sessionId", req.SessionID, "blocks", len(req.Prompt)) + // Empty sessionID is resolved to the initial session inside PromptSession + // (and inside Translator's Notify* methods). + s.trans.NotifyTurnStart(req.SessionID) + s.trans.NotifyUserPrompt(req.SessionID, req.Prompt) + + resp, err := s.mgr.PromptSession(ctx, resolveSessionID(req.SessionID), req.Prompt) + stopReason := "error" if err == nil { stopReason = string(resp.StopReason) @@ -38,20 +51,27 @@ func (s *Service) Prompt(ctx context.Context, req *runapi.SessionPromptParams) ( if err != nil { s.trans.NotifyError(err.Error()) } - s.trans.NotifyTurnEnd(acp.StopReason(stopReason)) - s.logger.Debug("prompt done", "stopReason", stopReason) + s.trans.NotifyTurnEnd(req.SessionID, acp.StopReason(stopReason)) + s.logger.Debug("prompt done", "sessionId", req.SessionID, "stopReason", stopReason) if err != nil { return nil, jsonrpc.ErrInternal(err.Error()) } return &runapi.SessionPromptResult{StopReason: string(resp.StopReason)}, nil } -func (s *Service) Cancel(ctx context.Context) (retErr error) { - s.logger.Debug("cancel") +func (s *Service) Cancel(ctx context.Context, req *runapi.SessionCancelParams) (retErr error) { + sessionID := req.SessionID + s.logger.Debug("cancel", "sessionId", sessionID) defer func() { - s.trans.NotifyOperationAudit("cancel", nil, retErr) + var auditArgs map[string]string + if sessionID != "" { + auditArgs = map[string]string{"sessionId": sessionID} + } + s.trans.NotifyOperationAudit("cancel", auditArgs, retErr) }() - if err := s.mgr.Cancel(ctx); err != nil { + + // Empty sessionID is resolved to the initial session inside CancelSession. + if err := s.mgr.CancelSession(ctx, resolveSessionID(sessionID)); err != nil { retErr = jsonrpc.ErrInternal(err.Error()) return retErr } @@ -214,15 +234,20 @@ func (s *Service) Status(_ context.Context) (*runapi.RuntimePhaseResult, error) } func (s *Service) SetModel(ctx context.Context, req *runapi.SessionSetModelParams) (_ *runapi.SessionSetModelResult, retErr error) { - s.logger.Debug("set_model", "modelID", req.ModelID) + s.logger.Debug("set_model", "sessionId", req.SessionID, "modelID", req.ModelID) defer func() { - s.trans.NotifyOperationAudit("set_model", map[string]string{"modelId": req.ModelID}, retErr) + auditArgs := map[string]string{"modelId": req.ModelID} + if req.SessionID != "" { + auditArgs["sessionId"] = req.SessionID + } + s.trans.NotifyOperationAudit("set_model", auditArgs, retErr) }() if req.ModelID == "" { retErr = jsonrpc.ErrInvalidParams("missing modelId") return nil, retErr } - if err := s.mgr.SetModel(ctx, req.ModelID); err != nil { + // Empty sessionID is resolved to the initial session inside SetModelSession. + if err := s.mgr.SetModelSession(ctx, resolveSessionID(req.SessionID), req.ModelID); err != nil { retErr = jsonrpc.ErrInternal(err.Error()) return nil, retErr } @@ -234,3 +259,68 @@ func (s *Service) Stop(_ context.Context) error { s.trans.NotifyOperationAudit("stop", nil, nil) return nil } + +// NewSession opens an additional ACP session on the running agent. See +// runtime/acp Manager.NewSession for cwd / mcpServers semantics. +func (s *Service) NewSession(ctx context.Context, req *runapi.SessionNewParams) (_ *runapi.SessionNewResult, retErr error) { + s.logger.Debug("session/new", "cwd", req.Cwd, "mcpServers", len(req.McpServers)) + defer func() { + s.trans.NotifyOperationAudit("session/new", map[string]string{"cwd": req.Cwd}, retErr) + }() + if req.Cwd == "" { + retErr = jsonrpc.ErrInvalidParams("missing cwd") + return nil, retErr + } + + mcp := make([]acp.McpServer, 0, len(req.McpServers)) + for _, m := range req.McpServers { + envVars := make([]acp.EnvVariable, 0, len(m.Env)) + for k, v := range m.Env { + envVars = append(envVars, acp.EnvVariable{Name: k, Value: v}) + } + args := m.Args + if args == nil { + args = []string{} + } + mcp = append(mcp, acp.McpServer{Stdio: &acp.McpServerStdio{ + Name: m.Name, + Command: m.Command, + Args: args, + Env: envVars, + }}) + } + + sid, err := s.mgr.NewSession(ctx, req.Cwd, mcp) + if err != nil { + retErr = jsonrpc.ErrInternal(err.Error()) + return nil, retErr + } + return &runapi.SessionNewResult{SessionID: string(sid)}, nil +} + +// EndSession releases runtime tracking of a session id. +func (s *Service) EndSession(_ context.Context, req *runapi.SessionEndParams) (_ *runapi.SessionEndResult, retErr error) { + s.logger.Debug("session/end", "sessionId", req.SessionID) + defer func() { + s.trans.NotifyOperationAudit("session/end", map[string]string{"sessionId": req.SessionID}, retErr) + }() + if req.SessionID == "" { + retErr = jsonrpc.ErrInvalidParams("missing sessionId") + return nil, retErr + } + if err := s.mgr.EndSession(resolveSessionID(req.SessionID)); err != nil { + retErr = jsonrpc.ErrInternal(err.Error()) + return nil, retErr + } + return &runapi.SessionEndResult{}, nil +} + +// ListSessions returns the active session IDs snapshot from the Manager. +func (s *Service) ListSessions(_ context.Context) (*runapi.SessionListResult, error) { + ids := s.mgr.SessionIDs() + out := make([]string, len(ids)) + for i, id := range ids { + out[i] = string(id) + } + return &runapi.SessionListResult{SessionIDs: out}, nil +} diff --git a/pkg/agentrun/server/service_test.go b/pkg/agentrun/server/service_test.go index aa6eed9..3c2505c 100644 --- a/pkg/agentrun/server/service_test.go +++ b/pkg/agentrun/server/service_test.go @@ -146,3 +146,24 @@ func TestService_SetModel_AuditOnValidationFailure(t *testing.T) { assert.False(t, ru.OperationAudit.Success) assert.NotEmpty(t, ru.OperationAudit.Error) } + +// TestService_NewSession_MissingCwdValidation verifies that the service layer +// rejects a missing cwd before reaching the runtime layer. The contract is +// cwd-required across all wire layers (CLI / ARI / agentrun); see runtime/acp +// NewSession's matching guard for the symmetric runtime-side check. +func TestService_NewSession_MissingCwdValidation(t *testing.T) { + svc := newTestService(t) + _, err := svc.NewSession(context.Background(), &runapi.SessionNewParams{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing cwd") +} + +// TestService_EndSession_MissingSessionIDValidation pins the symmetric +// validation for end-session — the runtime can't disambiguate "no sessionId +// supplied" from "empty resolves to initial", so the wire layer rejects. +func TestService_EndSession_MissingSessionIDValidation(t *testing.T) { + svc := newTestService(t) + _, err := svc.EndSession(context.Background(), &runapi.SessionEndParams{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing sessionId") +} diff --git a/pkg/agentrun/server/translator.go b/pkg/agentrun/server/translator.go index 27ccf1d..cf2d5d8 100644 --- a/pkg/agentrun/server/translator.go +++ b/pkg/agentrun/server/translator.go @@ -208,14 +208,19 @@ func (t *Translator) LastSeq() int { // NotifyTurnStart broadcasts a turn_start AgentRunEvent. // The new turnId is assigned atomically inside the broadcast callback, // which runs under mu.Lock. -func (t *Translator) NotifyTurnStart() { +// +// sessionID identifies which ACP session this turn belongs to. Empty +// resolves to the agentrun's initial session (back-compat for single- +// session callers). Watchers filter events by sessionId; mis-stamping +// would cause cross-session leakage on `--wait --session-id`. +func (t *Translator) NotifyTurnStart(sessionID string) { newTurnID := uuid.New().String() - t.logger.Debug("turn_start", "turnID", newTurnID) + t.logger.Debug("turn_start", "turnID", newTurnID, "sessionID", sessionID) t.broadcast(func(seq int, at time.Time) runapi.AgentRunEvent { t.currentTurnId = newTurnID return runapi.AgentRunEvent{ RunID: t.runID, - SessionID: t.sessionID, + SessionID: t.resolveSessionLocked(sessionID), Seq: seq, Time: at, Type: runapi.EventTypeTurnStart, @@ -230,13 +235,16 @@ func (t *Translator) NotifyTurnStart() { // Each ContentBlock becomes a separate event, matching the per-block pattern // used on the response side (agent_message). // This must be called after NotifyTurnStart and before mgr.Prompt. -func (t *Translator) NotifyUserPrompt(blocks []runapi.ContentBlock) { +// +// sessionID identifies which session the prompt is destined for. Empty +// resolves to the agentrun's initial session. +func (t *Translator) NotifyUserPrompt(sessionID string, blocks []runapi.ContentBlock) { for _, block := range blocks { b := block // capture for closure t.broadcast(func(seq int, at time.Time) runapi.AgentRunEvent { return runapi.AgentRunEvent{ RunID: t.runID, - SessionID: t.sessionID, + SessionID: t.resolveSessionLocked(sessionID), Seq: seq, Time: at, Type: runapi.EventTypeUserMessage, @@ -274,13 +282,16 @@ func (t *Translator) NotifyOperationAudit(op string, params map[string]string, e // Closes any open content block first, then emits turn_end. // The current turnId is included in the event and cleared AFTER use so the // turn_end event itself carries the identifier. -func (t *Translator) NotifyTurnEnd(reason acp.StopReason) { - t.logger.Debug("turn_end", "turnID", t.currentTurnId, "reason", reason) +// +// sessionID identifies which session the turn ended on. Empty resolves to +// the agentrun's initial session. +func (t *Translator) NotifyTurnEnd(sessionID string, reason acp.StopReason) { + t.logger.Debug("turn_end", "turnID", t.currentTurnId, "sessionID", sessionID, "reason", reason) t.closeOpenBlock() t.broadcast(func(seq int, at time.Time) runapi.AgentRunEvent { ae := runapi.AgentRunEvent{ RunID: t.runID, - SessionID: t.sessionID, + SessionID: t.resolveSessionLocked(sessionID), Seq: seq, Time: at, Type: runapi.EventTypeTurnEnd, @@ -354,7 +365,10 @@ func (t *Translator) run() { t.closeOpenBlock() } - t.broadcastEvent(ev) + // ACP SessionNotification carries the session id; stamp it on + // the resulting event so watchers can filter by session. Empty + // (unusual — pre-handshake) falls back to the initial session. + t.broadcastEventForSession(string(n.SessionId), ev) t.maybeNotifyMetadata(ev) } } @@ -392,14 +406,25 @@ func (t *Translator) closeOpenBlock() { } } -// broadcastEvent builds and broadcasts an AgentRunEvent. +// broadcastEvent builds and broadcasts an AgentRunEvent stamped with the +// initial session id. Use for process-wide events (runtime_update / error) +// where session attribution doesn't apply. Per-session events should use +// broadcastEventForSession. +// // TurnID is applied to all events except runtime_update when an active turn exists. func (t *Translator) broadcastEvent(ev runapi.Event) { + t.broadcastEventForSession("", ev) +} + +// broadcastEventForSession is broadcastEvent with explicit session attribution. +// Empty sessionID resolves to the initial session — callers passing a non- +// empty id stamp it directly, allowing watchers to filter by sessionId. +func (t *Translator) broadcastEventForSession(sessionID string, ev runapi.Event) { t.broadcast(func(seq int, at time.Time) runapi.AgentRunEvent { eventType := runapi.EventTypeOf(ev) ae := runapi.AgentRunEvent{ RunID: t.runID, - SessionID: t.sessionID, + SessionID: t.resolveSessionLocked(sessionID), Seq: seq, Time: at, Type: eventType, @@ -412,6 +437,16 @@ func (t *Translator) broadcastEvent(ev runapi.Event) { }) } +// resolveSessionLocked maps empty sessionID to t.sessionID (the initial +// session). Must be called from within a broadcast callback (which already +// holds t.mu) or from a method holding t.mu. +func (t *Translator) resolveSessionLocked(sessionID string) string { + if sessionID == "" { + return t.sessionID + } + return sessionID +} + // broadcast is the single fan-out entry point. The build callback runs under // mu.Lock and receives the assigned seq and current timestamp. The lock is held // for the entire log-then-fanout sequence to guarantee: diff --git a/pkg/agentrun/server/translator_test.go b/pkg/agentrun/server/translator_test.go index 76b456a..3f31740 100644 --- a/pkg/agentrun/server/translator_test.go +++ b/pkg/agentrun/server/translator_test.go @@ -258,8 +258,8 @@ func TestNotifyTurnStartAndEnd(t *testing.T) { tr.Start() defer tr.Stop() - tr.NotifyTurnStart() - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnStart("") + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) first := drainEvent(t, ch) second := drainEvent(t, ch) @@ -400,13 +400,13 @@ func TestTurnAwareEvent_TurnIdAssigned(t *testing.T) { tr.Start() defer tr.Stop() - tr.NotifyTurnStart() + tr.NotifyTurnStart("") tsEv := drainEvent(t, ch) txt1Ev := sendAndDrainEvent(t, in, ch, "hello") txt2Ev := sendAndDrainEvent(t, in, ch, "world") - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) blockEndEv := drainEvent(t, ch) // synthetic agent_message{end} teEv := drainEvent(t, ch) // turn_end @@ -433,15 +433,15 @@ func TestTurnAwareEvent_TurnIDChangesPerTurn(t *testing.T) { defer tr.Stop() // Turn 1. - tr.NotifyTurnStart() + tr.NotifyTurnStart("") ts1 := drainEvent(t, ch) sendAndDrainEvent(t, in, ch, "turn1") - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) drainEvent(t, ch) // turn_end drainEvent(t, ch) // synthetic content end // Turn 2 — TurnID must differ. - tr.NotifyTurnStart() + tr.NotifyTurnStart("") ts2 := drainEvent(t, ch) assert.NotEqual(t, ts1.TurnID, ts2.TurnID, "turn 2 must have a different TurnID") } @@ -455,7 +455,7 @@ func TestTurnAwareEvent_StateChangeExcludesTurnFields(t *testing.T) { tr.Start() defer tr.Stop() - tr.NotifyTurnStart() + tr.NotifyTurnStart("") tsEv := drainEvent(t, ch) require.NotEmpty(t, tsEv.TurnID) @@ -477,7 +477,7 @@ func TestTurnAwareEvent_MetadataEventInTurn(t *testing.T) { tr.Start() defer tr.Stop() - tr.NotifyTurnStart() + tr.NotifyTurnStart("") tsEv := drainEvent(t, ch) require.NotEmpty(t, tsEv.TurnID) @@ -634,20 +634,20 @@ func TestTurnAwareEvent_ReplayOrdering(t *testing.T) { defer tr.Stop() // Turn 1: turn_start + 2 text events + synthetic block end + turn_end. - tr.NotifyTurnStart() + tr.NotifyTurnStart("") ts1Ev := drainEvent(t, ch) t1aEv := sendAndDrainEvent(t, in, ch, "t1-a") t1bEv := sendAndDrainEvent(t, in, ch, "t1-b") - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) t1EndEv := drainEvent(t, ch) // synthetic agent_message{end} te1Ev := drainEvent(t, ch) // turn_end turn1 := []runapi.AgentRunEvent{ts1Ev, t1aEv, t1bEv, t1EndEv, te1Ev} // Turn 2: turn_start + 1 text event + synthetic block end + turn_end. - tr.NotifyTurnStart() + tr.NotifyTurnStart("") ts2Ev := drainEvent(t, ch) t2aEv := sendAndDrainEvent(t, in, ch, "t2-a") - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) t2EndEv := drainEvent(t, ch) // synthetic agent_message{end} te2Ev := drainEvent(t, ch) // turn_end turn2 := []runapi.AgentRunEvent{ts2Ev, t2aEv, t2EndEv, te2Ev} @@ -688,11 +688,11 @@ func TestEventCounts_PromptTurn(t *testing.T) { defer tr.Stop() // turn_start - tr.NotifyTurnStart() + tr.NotifyTurnStart("") drainEvent(t, ch) // user_message - tr.NotifyUserPrompt([]runapi.ContentBlock{runapi.TextBlock("hello")}) + tr.NotifyUserPrompt("", []runapi.ContentBlock{runapi.TextBlock("hello")}) drainEvent(t, ch) // 2 text events (AgentMessageChunk) @@ -707,7 +707,7 @@ func TestEventCounts_PromptTurn(t *testing.T) { drainEvent(t, ch) // tool_call // turn_end - tr.NotifyTurnEnd(acp.StopReason("end_turn")) + tr.NotifyTurnEnd("", acp.StopReason("end_turn")) drainEvent(t, ch) // state_change @@ -899,3 +899,63 @@ func TestSessionMetadataHook_AllFourTypes(t *testing.T) { defer mu.Unlock() assert.Equal(t, []string{"runtime_update", "runtime_update", "runtime_update", "runtime_update"}, types) } + +// TestNotifyTurnStart_StampsSessionID verifies that NotifyTurnStart writes +// the caller-supplied sessionID on the emitted event. This is what lets +// CLI clients filter `--wait --session-id X` watches without cross-session +// leakage between concurrent prompts on different sessions. +func TestNotifyTurnStart_StampsSessionID(t *testing.T) { + in := make(chan acp.SessionNotification, 1) + tr := NewTranslator("run-1", in, "", slog.Default()) + tr.SetSessionID("initial-sid") + ch, _, _ := tr.Subscribe() + + tr.NotifyTurnStart("session-abc") + + ev := drainEvent(t, ch) + assert.Equal(t, "session-abc", ev.SessionID, + "caller-supplied sessionID must be stamped on the event") +} + +// TestNotifyTurnStart_EmptySessionIDResolvesToInitial verifies that an +// empty sessionID at the Translator API falls back to the initial session +// id — back-compat for legacy single-session callers that don't pass one. +func TestNotifyTurnStart_EmptySessionIDResolvesToInitial(t *testing.T) { + in := make(chan acp.SessionNotification, 1) + tr := NewTranslator("run-1", in, "", slog.Default()) + tr.SetSessionID("initial-sid") + ch, _, _ := tr.Subscribe() + + tr.NotifyTurnStart("") + + ev := drainEvent(t, ch) + assert.Equal(t, "initial-sid", ev.SessionID, + "empty sessionID must resolve to the initial session") +} + +// TestRun_StampsSessionIDFromNotification verifies that content events +// produced by translating ACP SessionNotifications carry the notification's +// session id (not just the Translator's initial-session cache). Without +// this, agent_message events from a non-initial session would all be +// labelled with the initial id, defeating per-session watch filtering. +func TestRun_StampsSessionIDFromNotification(t *testing.T) { + in := make(chan acp.SessionNotification, 1) + tr := NewTranslator("run-1", in, "", slog.Default()) + tr.SetSessionID("initial-sid") + ch, _, _ := tr.Subscribe() + tr.Start() + defer tr.Stop() + + in <- acp.SessionNotification{ + SessionId: acp.SessionId("session-xyz"), + Update: acp.SessionUpdate{ + AgentMessageChunk: &acp.SessionUpdateAgentMessageChunk{ + Content: acp.TextBlock("hi"), + }, + }, + } + + ev := drainEvent(t, ch) + assert.Equal(t, "session-xyz", ev.SessionID, + "content event should carry the notification's sessionID, not the initial") +} diff --git a/pkg/ari/api/domain.go b/pkg/ari/api/domain.go index e69d74f..909c6a5 100644 --- a/pkg/ari/api/domain.go +++ b/pkg/ari/api/domain.go @@ -72,27 +72,34 @@ type ObjectMeta struct { // ──────────────────────────────────────────────────────────────────────────── // AgentSpec describes how to launch an agent process for this named agent definition. +// +// All fields carry both `json` and `yaml` tags. Without explicit `yaml:` tags, +// gopkg.in/yaml.v3 lowercases the Go field name (e.g. `StartupTimeoutSeconds` +// → `startuptimeoutseconds`), which silently drops camelCase YAML keys that +// users / tooling write to match the JSON tag (`startupTimeoutSeconds`). Each +// camelCase field needs an explicit yaml tag for `massctl agent apply -f` to +// preserve it round-trip. type AgentSpec struct { // Disabled controls whether the agent is prevented from creating new agent runs. // nil or false means not disabled (agent is usable). true means disabled. - Disabled *bool `json:"disabled,omitempty"` + Disabled *bool `json:"disabled,omitempty" yaml:"disabled,omitempty"` // ClientProtocol selects the communication protocol adapter. // Default: "acp". - ClientProtocol apiruntime.ClientProtocol `json:"clientProtocol,omitempty"` + ClientProtocol apiruntime.ClientProtocol `json:"clientProtocol,omitempty" yaml:"clientProtocol,omitempty"` // Command is the agent executable. - Command string `json:"command"` + Command string `json:"command" yaml:"command"` // Args are the command-line arguments passed to Command. - Args []string `json:"args,omitempty"` + Args []string `json:"args,omitempty" yaml:"args,omitempty"` // Env is the list of environment variable overrides applied to the process. - Env []apiruntime.EnvVar `json:"env,omitempty"` + Env []apiruntime.EnvVar `json:"env,omitempty" yaml:"env,omitempty"` // StartupTimeoutSeconds is the maximum time (in seconds) to wait for the // agent-run to reach idle state. Nil means use the daemon default. - StartupTimeoutSeconds *int `json:"startupTimeoutSeconds,omitempty"` + StartupTimeoutSeconds *int `json:"startupTimeoutSeconds,omitempty" yaml:"startupTimeoutSeconds,omitempty"` } // IsDisabled reports whether the agent is disabled. diff --git a/pkg/ari/api/methods.go b/pkg/ari/api/methods.go index e22813c..99a0155 100644 --- a/pkg/ari/api/methods.go +++ b/pkg/ari/api/methods.go @@ -28,6 +28,16 @@ const ( MethodAgentRunTaskGet = "agentrun/task/get" MethodAgentRunTaskList = "agentrun/task/list" MethodAgentRunTaskRetry = "agentrun/task/retry" + + // Multi-session lifecycle for a single agentrun. session/new opens + // an additional ACP session (different cwd, fresh state) without + // fork+exec a new agent process; session/end releases runtime + // tracking; session/list enumerates active sessions. Names parallel + // agentrun/task/* — one resource/verb hierarchy at this layer. See + // pkg/agentrun/runtime/acp Manager.NewSession for runtime contract. + MethodAgentRunNewSession = "agentrun/session/new" + MethodAgentRunEndSession = "agentrun/session/end" + MethodAgentRunListSessions = "agentrun/session/list" ) // ARI agent definition methods. diff --git a/pkg/ari/api/types.go b/pkg/ari/api/types.go index 8aaddeb..2129fd5 100644 --- a/pkg/ari/api/types.go +++ b/pkg/ari/api/types.go @@ -89,6 +89,11 @@ type AgentRunPromptParams struct { // Name is the agent run name (required). Name string `json:"name"` + // SessionID addresses a specific session opened via agentrun/session/new. + // Empty (omitted) routes to the agentrun's initial session — preserves + // pre-multi-session caller behavior. + SessionID string `json:"sessionId,omitempty"` + // Prompt is an array of ACP ContentBlocks (text, image, audio, etc.) (required). Prompt []runapi.ContentBlock `json:"prompt"` } @@ -99,6 +104,46 @@ type AgentRunPromptResult struct { Accepted bool `json:"accepted"` } +// AgentRunNewSessionParams is the request params for agentrun/session/new. +// Opens an additional ACP session on an existing agentrun (no fork+exec). +type AgentRunNewSessionParams struct { + Workspace string `json:"workspace"` + Name string `json:"name"` + // Cwd is required — each session is scoped to its own working directory. + Cwd string `json:"cwd"` + // McpServers is optional per-session MCP overrides. Wire shape mirrors + // runapi.SessionNewMcpServer to avoid re-defining transports. + McpServers []runapi.SessionNewMcpServer `json:"mcpServers,omitempty"` +} + +// AgentRunNewSessionResult is the response for agentrun/session/new. +type AgentRunNewSessionResult struct { + // SessionID is the ACP session id the caller passes to subsequent + // agentrun/prompt (etc.) via the SessionID field. + SessionID string `json:"sessionId"` +} + +// AgentRunEndSessionParams is the request params for agentrun/session/end. +type AgentRunEndSessionParams struct { + Workspace string `json:"workspace"` + Name string `json:"name"` + SessionID string `json:"sessionId"` +} + +// AgentRunEndSessionResult is the response for agentrun/session/end. +type AgentRunEndSessionResult struct{} + +// AgentRunListSessionsParams identifies the agentrun whose sessions to list. +type AgentRunListSessionsParams struct { + Workspace string `json:"workspace"` + Name string `json:"name"` +} + +// AgentRunListSessionsResult enumerates the agentrun's active session ids. +type AgentRunListSessionsResult struct { + SessionIDs []string `json:"sessionIds"` +} + // WorkspaceSendParams is the request params for workspace/send method. // Routes a message from one agent run to another within a workspace. type WorkspaceSendParams struct { diff --git a/pkg/ari/client/client.go b/pkg/ari/client/client.go index 338186e..2ddfae0 100644 --- a/pkg/ari/client/client.go +++ b/pkg/ari/client/client.go @@ -143,9 +143,14 @@ func (c *ariClient) Delete(ctx context.Context, key pkgariapi.ObjectKey, obj pkg type agentRunOps struct{ c *jsonrpc.Client } func (o *agentRunOps) Prompt(ctx context.Context, key pkgariapi.ObjectKey, prompt []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { + return o.PromptSession(ctx, key, "", prompt) +} + +func (o *agentRunOps) PromptSession(ctx context.Context, key pkgariapi.ObjectKey, sessionID string, prompt []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) { req := pkgariapi.AgentRunPromptParams{ Workspace: key.Workspace, Name: key.Name, + SessionID: sessionID, Prompt: prompt, } var result pkgariapi.AgentRunPromptResult @@ -205,6 +210,33 @@ func (o *agentRunOps) TaskRetry(ctx context.Context, params *pkgariapi.AgentRunT return &result, nil } +func (o *agentRunOps) NewSession(ctx context.Context, params *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + var result pkgariapi.AgentRunNewSessionResult + if err := o.c.Call(ctx, pkgariapi.MethodAgentRunNewSession, params, &result); err != nil { + return nil, err + } + return &result, nil +} + +func (o *agentRunOps) EndSession(ctx context.Context, key pkgariapi.ObjectKey, sessionID string) error { + req := pkgariapi.AgentRunEndSessionParams{ + Workspace: key.Workspace, + Name: key.Name, + SessionID: sessionID, + } + var raw json.RawMessage + return o.c.Call(ctx, pkgariapi.MethodAgentRunEndSession, req, &raw) +} + +func (o *agentRunOps) ListSessions(ctx context.Context, key pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) { + req := pkgariapi.AgentRunListSessionsParams{Workspace: key.Workspace, Name: key.Name} + var result pkgariapi.AgentRunListSessionsResult + if err := o.c.Call(ctx, pkgariapi.MethodAgentRunListSessions, req, &result); err != nil { + return nil, err + } + return &result, nil +} + // ──────────────────────────────────────────────────────────────────────────── // WorkspaceOps // ──────────────────────────────────────────────────────────────────────────── diff --git a/pkg/ari/client/interfaces.go b/pkg/ari/client/interfaces.go index fb2e2cb..5fdaa0c 100644 --- a/pkg/ari/client/interfaces.go +++ b/pkg/ari/client/interfaces.go @@ -48,9 +48,13 @@ type Client interface { // AgentRunOps provides non-CRUD operations on agent runs. type AgentRunOps interface { - // Prompt sends a multimodal prompt ([]runapi.ContentBlock) to an agent run. + // Prompt sends a multimodal prompt ([]runapi.ContentBlock) to an agent run's + // initial session. For multi-session, use PromptSession. Prompt(ctx context.Context, key pkgariapi.ObjectKey, prompt []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) + // PromptSession addresses a specific session id (opened via NewSession). + PromptSession(ctx context.Context, key pkgariapi.ObjectKey, sessionID string, prompt []runapi.ContentBlock) (*pkgariapi.AgentRunPromptResult, error) + // Cancel cancels the current turn of an agent run. Cancel(ctx context.Context, key pkgariapi.ObjectKey) error @@ -71,6 +75,16 @@ type AgentRunOps interface { // TaskRetry retries an existing task by bumping its attempt count and re-prompting the agent. TaskRetry(ctx context.Context, params *pkgariapi.AgentRunTaskRetryParams) (*pkgariapi.AgentTask, error) + + // NewSession opens an additional ACP session on the running agent process + // (no fork+exec). Returns the agent-issued sessionId. + NewSession(ctx context.Context, params *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) + + // EndSession releases runtime tracking of a session id. + EndSession(ctx context.Context, key pkgariapi.ObjectKey, sessionID string) error + + // ListSessions enumerates active session ids on the agent. + ListSessions(ctx context.Context, key pkgariapi.ObjectKey) (*pkgariapi.AgentRunListSessionsResult, error) } // WorkspaceOps provides non-CRUD operations on workspaces. diff --git a/pkg/ari/server/server.go b/pkg/ari/server/server.go index 62fb343..6de127e 100644 --- a/pkg/ari/server/server.go +++ b/pkg/ari/server/server.go @@ -516,17 +516,105 @@ func (a *agentRunAdapter) Prompt(ctx context.Context, req *pkgariapi.AgentRunPro } prompt := req.Prompt - if err := client.SendPrompt(ctx, &runapi.SessionPromptParams{Prompt: prompt}); err != nil { + if err := client.SendPrompt(ctx, &runapi.SessionPromptParams{ + SessionID: req.SessionID, + Prompt: prompt, + }); err != nil { a.logger.Warn("agentrun/prompt: prompt delivery failed", "workspace", req.Workspace, "name", req.Name, "error", err) a.recordPromptDeliveryFailure(req.Workspace, req.Name, agent.Status, err, false) } a.logger.Info("agentrun/prompt: dispatched", - "workspace", req.Workspace, "name", req.Name) + "workspace", req.Workspace, "name", req.Name, "sessionId", req.SessionID) return &pkgariapi.AgentRunPromptResult{Accepted: true}, nil } +// NewSession forwards agentrun/session/new to the running agent-run's +// session/new RPC. The agent reuses its process to host the new session +// — no fork+exec. Returns the agent-issued sessionId. +// +// Rejects requests while the daemon is recovering agents — sessions opened +// then would race with recovery's view of the agent process. Doesn't use +// reserveIdleAgent because multi-session permits opening a session while +// another is mid-prompt (Status=Running); only recovery is unsafe. +func (a *agentRunAdapter) NewSession(ctx context.Context, req *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) { + if req.Workspace == "" || req.Name == "" || req.Cwd == "" { + return nil, jsonrpc.ErrInvalidParams("workspace, name, and cwd are required") + } + if err := a.rejectIfRecovering(); err != nil { + return nil, err + } + a.logger.Info("agentrun/session/new", "workspace", req.Workspace, "name", req.Name, "cwd", req.Cwd) + + client, err := a.processes.Connect(ctx, req.Workspace, req.Name) + if err != nil { + return nil, &jsonrpc.RPCError{Code: pkgariapi.CodeRecoveryBlocked, Message: "agent not running"} + } + + out, err := client.NewSession(ctx, &runapi.SessionNewParams{ + Cwd: req.Cwd, + McpServers: req.McpServers, + }) + if err != nil { + return nil, jsonrpc.ErrInternal(err.Error()) + } + a.logger.Info("agentrun/session/new: opened", + "workspace", req.Workspace, "name", req.Name, "sessionId", out.SessionID) + return &pkgariapi.AgentRunNewSessionResult{SessionID: out.SessionID}, nil +} + +// EndSession forwards agentrun/session/end to release runtime tracking. +func (a *agentRunAdapter) EndSession(ctx context.Context, req *pkgariapi.AgentRunEndSessionParams) (*pkgariapi.AgentRunEndSessionResult, error) { + if req.Workspace == "" || req.Name == "" || req.SessionID == "" { + return nil, jsonrpc.ErrInvalidParams("workspace, name, and sessionId are required") + } + if err := a.rejectIfRecovering(); err != nil { + return nil, err + } + a.logger.Info("agentrun/session/end", "workspace", req.Workspace, "name", req.Name, "sessionId", req.SessionID) + + client, err := a.processes.Connect(ctx, req.Workspace, req.Name) + if err != nil { + return nil, &jsonrpc.RPCError{Code: pkgariapi.CodeRecoveryBlocked, Message: "agent not running"} + } + if err := client.EndSession(ctx, req.SessionID); err != nil { + return nil, jsonrpc.ErrInternal(err.Error()) + } + return &pkgariapi.AgentRunEndSessionResult{}, nil +} + +// ListSessions forwards agentrun/session/list to enumerate active sessions. +func (a *agentRunAdapter) ListSessions(ctx context.Context, req *pkgariapi.AgentRunListSessionsParams) (*pkgariapi.AgentRunListSessionsResult, error) { + if req.Workspace == "" || req.Name == "" { + return nil, jsonrpc.ErrInvalidParams("workspace and name are required") + } + if err := a.rejectIfRecovering(); err != nil { + return nil, err + } + + client, err := a.processes.Connect(ctx, req.Workspace, req.Name) + if err != nil { + return nil, &jsonrpc.RPCError{Code: pkgariapi.CodeRecoveryBlocked, Message: "agent not running"} + } + out, err := client.ListSessions(ctx) + if err != nil { + return nil, jsonrpc.ErrInternal(err.Error()) + } + return &pkgariapi.AgentRunListSessionsResult{SessionIDs: out.SessionIDs}, nil +} + +// rejectIfRecovering returns CodeRecoveryBlocked when the daemon is mid- +// recovery. Used by session-lifecycle handlers that can't go through +// reserveIdleAgent (multi-session allows operations while another session +// is Running) but still must serialize against the recovery sweep. +func (a *agentRunAdapter) rejectIfRecovering() *jsonrpc.RPCError { + if a.processes.IsRecovering() { + return &jsonrpc.RPCError{Code: pkgariapi.CodeRecoveryBlocked, Message: "daemon is recovering agents"} + } + return nil +} + // Cancel handles agentrun/cancel. // // Connects to the running agent-run and calls Cancel. diff --git a/pkg/ari/server/service.go b/pkg/ari/server/service.go index 926782d..c6c0f98 100644 --- a/pkg/ari/server/service.go +++ b/pkg/ari/server/service.go @@ -36,6 +36,11 @@ type AgentRunService interface { TaskGet(ctx context.Context, params *pkgariapi.AgentRunTaskGetParams) (*pkgariapi.AgentTask, error) TaskList(ctx context.Context, params *pkgariapi.AgentRunTaskListParams) (*pkgariapi.AgentRunTaskListResult, error) TaskRetry(ctx context.Context, params *pkgariapi.AgentRunTaskRetryParams) (*pkgariapi.AgentTask, error) + + // Multi-session lifecycle. + NewSession(ctx context.Context, req *pkgariapi.AgentRunNewSessionParams) (*pkgariapi.AgentRunNewSessionResult, error) + EndSession(ctx context.Context, req *pkgariapi.AgentRunEndSessionParams) (*pkgariapi.AgentRunEndSessionResult, error) + ListSessions(ctx context.Context, req *pkgariapi.AgentRunListSessionsParams) (*pkgariapi.AgentRunListSessionsResult, error) } // AgentService defines agent definition CRUD methods. @@ -157,6 +162,12 @@ func RegisterAgentRunService(s *jsonrpc.Server, svc AgentRunService) { "task/get": jsonrpc.UnaryMethod(svc.TaskGet), "task/list": jsonrpc.UnaryMethod(svc.TaskList), "task/retry": jsonrpc.UnaryMethod(svc.TaskRetry), + // Multi-session lifecycle. Names parallel agentrun/task/* — + // agentrun/session/new / end / list — so a third-party reader + // sees one consistent resource/verb hierarchy at this layer. + "session/new": jsonrpc.UnaryMethod(svc.NewSession), + "session/end": jsonrpc.UnaryMethod(svc.EndSession), + "session/list": jsonrpc.UnaryMethod(svc.ListSessions), }, }) } diff --git a/pkg/jsonrpc/errors.go b/pkg/jsonrpc/errors.go index 8eb5f39..61ff54d 100644 --- a/pkg/jsonrpc/errors.go +++ b/pkg/jsonrpc/errors.go @@ -2,7 +2,18 @@ // framework built on top of sourcegraph/jsonrpc2. package jsonrpc -import "fmt" +import ( + "errors" + "fmt" +) + +// ErrNoParams is returned by the dispatcher's unmarshal callback when the +// JSON-RPC request has no `params` field. UnaryCommand / UnaryMethod treat +// this as InvalidParams (the typical case — handler expects params). +// OptionalUnaryCommand tolerates it (leaves Req at zero value), letting a +// method evolve from NullaryCommand → optional-params without breaking +// callers that omit the field. +var ErrNoParams = errors.New("missing params") // RPCError is a JSON-RPC 2.0 error with code, message, and optional data. // Method handlers return *RPCError to control the JSON-RPC error response; diff --git a/pkg/jsonrpc/method_helpers.go b/pkg/jsonrpc/method_helpers.go index 17031c2..ee7a4c0 100644 --- a/pkg/jsonrpc/method_helpers.go +++ b/pkg/jsonrpc/method_helpers.go @@ -1,6 +1,9 @@ package jsonrpc -import "context" +import ( + "context" + "errors" +) func UnaryMethod[Req, Res any](fn func(ctx context.Context, req *Req) (*Res, error)) Method { return func(ctx context.Context, unmarshal func(any) error) (any, error) { @@ -33,3 +36,20 @@ func NullaryCommand(fn func(ctx context.Context) error) Method { return nil, fn(ctx) } } + +// OptionalUnaryCommand is like UnaryCommand but tolerates absent params: +// callers that omit the JSON-RPC `params` field reach the handler with +// Req at zero value rather than getting InvalidParams. Use this for +// methods evolving from NullaryCommand to taking optional params — +// pre-existing callers that send no params keep working unchanged. +// +// Malformed-JSON unmarshal errors still surface as InvalidParams. +func OptionalUnaryCommand[Req any](fn func(ctx context.Context, req *Req) error) Method { + return func(ctx context.Context, unmarshal func(any) error) (any, error) { + var req Req + if err := unmarshal(&req); err != nil && !errors.Is(err, ErrNoParams) { + return nil, ErrInvalidParams(err.Error()) + } + return nil, fn(ctx, &req) + } +} diff --git a/pkg/jsonrpc/server.go b/pkg/jsonrpc/server.go index c614613..93c4451 100644 --- a/pkg/jsonrpc/server.go +++ b/pkg/jsonrpc/server.go @@ -150,7 +150,7 @@ func (h *serverHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *js unmarshal := func(dst any) error { if req.Params == nil { - return fmt.Errorf("missing params") + return ErrNoParams } return json.Unmarshal(*req.Params, dst) } diff --git a/pkg/jsonrpc/server_test.go b/pkg/jsonrpc/server_test.go index 0c3c9d1..d14617c 100644 --- a/pkg/jsonrpc/server_test.go +++ b/pkg/jsonrpc/server_test.go @@ -223,6 +223,71 @@ func TestServer_PeerNotify(t *testing.T) { } } +// TestOptionalUnaryCommand_TolerateMissingParams verifies that +// OptionalUnaryCommand tolerates a JSON-RPC request with no params field +// (or null params) by delivering a zero-value Req to the handler instead +// of returning -32602 InvalidParams. This pins the back-compat contract +// for methods evolving from NullaryCommand → optional-params. +func TestOptionalUnaryCommand_TolerateMissingParams(t *testing.T) { + type req struct { + Name string `json:"name,omitempty"` + } + var captured req + called := make(chan struct{}, 1) + + srv := jsonrpc.NewServer(slog.Default()) + srv.RegisterService("svc", &jsonrpc.ServiceDesc{ + Methods: map[string]jsonrpc.Method{ + "cancel": jsonrpc.OptionalUnaryCommand(func(_ context.Context, r *req) error { + captured = *r + called <- struct{}{} + return nil + }), + }, + }) + + addr := startTestServer(t, srv) + client := dialTestClient(t, addr) + + // Pass nil — client should send "params: null" or omit the field. Either + // way OptionalUnaryCommand must run the handler with zero-value Req. + err := client.Call(context.Background(), "svc/cancel", nil, nil) + require.NoError(t, err) + + select { + case <-called: + case <-time.After(2 * time.Second): + t.Fatal("handler not invoked") + } + assert.Equal(t, "", captured.Name, "expected zero-value Req from missing params") +} + +// TestOptionalUnaryCommand_SurfacesUnmarshalError verifies that malformed +// params (not absent params) still surface as InvalidParams — the +// tolerance is scoped to ErrNoParams, not all unmarshal failures. +func TestOptionalUnaryCommand_SurfacesUnmarshalError(t *testing.T) { + type req struct { + Value int `json:"value"` + } + + srv := jsonrpc.NewServer(slog.Default()) + srv.RegisterService("svc", &jsonrpc.ServiceDesc{ + Methods: map[string]jsonrpc.Method{ + "op": jsonrpc.OptionalUnaryCommand(func(_ context.Context, _ *req) error { + return nil + }), + }, + }) + + addr := startTestServer(t, srv) + client := dialTestClient(t, addr) + + // Send malformed params — string where int expected. + err := client.Call(context.Background(), "svc/op", map[string]string{"value": "not-a-number"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "-32602") +} + func TestServer_PeerDisconnect(t *testing.T) { disconnectDetected := make(chan struct{})