From a2f7904c87c2ff9c39f9149fe575ab42e6c12bc7 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 23 Apr 2026 12:10:30 -0400 Subject: [PATCH 1/4] feat(engine): forward channel-session registration to mod3 (ADR-082 Wave 2) Kernel-side HTTP forwarder that makes the kernel the identity authority for mod3 channel participants while keeping mod3 the communication-state owner (voice, queue, device). Implements the decision from ADR-082's layer-separation rule: "CogOS owns identity; Mod3 owns communication." Routes (namespaced under /v1/channel-sessions/): POST /v1/channel-sessions/register mint+forward POST /v1/channel-sessions/{id}/deregister forward GET /v1/channel-sessions list (kernel+mod3) GET /v1/channel-sessions/{id} detail Namespace choice: the existing /v1/sessions/* family serves agent-session state with 3-component hyphen-validated IDs tied to the handoff protocol. Channel-participant registration has an incompatible shape (short UUID IDs, participant_id/participant_type/voice/device fields). Rather than weaken ValidateSessionID (which would cascade into handoff semantics), the new concern takes its own namespace. The channel-provider RFC's guidance to unify on cogos_session_register with a participant_type discriminator remains the target at the MCP tool layer (Wave 3); at the HTTP layer the two surfaces coexist cleanly. Behavior: - Kernel mints a session_id (short UUID) if the caller omits one. - Request is forwarded to Config.Mod3URL (default http://localhost:7860) with a 5s timeout. - Response merges kernel identity record + mod3 channel state. - Mod3 unreachable -> HTTP 502 with clear error body. - Mod3 error responses are preserved and surfaced. Config: new Config.Mod3URL field, overridable via MOD3_URL env var. Tests: serve_sessions_channel_test.go covers ID minting, field passthrough, response merging, mod3-down 502, and all four sibling endpoints via httptest.Server fakes. --- internal/engine/config.go | 21 + internal/engine/serve.go | 35 +- internal/engine/serve_sessions_channel.go | 445 ++++++++++++++ .../engine/serve_sessions_channel_test.go | 566 ++++++++++++++++++ 4 files changed, 1065 insertions(+), 2 deletions(-) create mode 100644 internal/engine/serve_sessions_channel.go create mode 100644 internal/engine/serve_sessions_channel_test.go diff --git a/internal/engine/config.go b/internal/engine/config.go index 5d04648..3f91d45 100644 --- a/internal/engine/config.go +++ b/internal/engine/config.go @@ -77,6 +77,16 @@ type Config struct { // at .cog/run/kernel.log.jsonl. Leave empty for the default. KernelLogPath string + // Mod3URL is the base URL (scheme + host + port) of the mod3 voice service + // that owns per-channel communication state (voice, output device, queue) + // keyed on kernel-issued session IDs. The kernel forwards channel-session + // registration to this URL; mod3 remains the per-channel state owner while + // the kernel retains identity authority (ADR-082 split). + // + // Default: http://localhost:7860. Override via `mod3_url` in kernel.yaml + // (top-level or under v3:) or via the COGOS_MOD3_URL env var. + Mod3URL string + LocalModel string localModelConfigured bool @@ -99,6 +109,7 @@ type kernelConfigSection struct { LocalModel string `yaml:"local_model"` DigestPaths map[string]string `yaml:"digest_paths"` KernelLogPath string `yaml:"kernel_log_path"` + Mod3URL string `yaml:"mod3_url"` } // kernelConfig is the on-disk YAML shape of .cog/config/kernel.yaml. @@ -137,6 +148,7 @@ func LoadConfig(workspaceRoot string, port int) (*Config, error) { ToolCallValidationEnabled: true, LocalModel: defaultOllamaModel, DigestPaths: make(map[string]string), + Mod3URL: "http://localhost:7860", } // Load from file if present. @@ -150,6 +162,12 @@ func LoadConfig(workspaceRoot string, port int) (*Config, error) { } } + // Env override for the mod3 URL. Env wins over file; flags stay flag-only + // (we don't surface `--mod3-url` in CLI; one env var + YAML is enough). + if v := os.Getenv("COGOS_MOD3_URL"); v != "" { + cfg.Mod3URL = v + } + // Flag override. if port != 0 { cfg.Port = port @@ -211,6 +229,9 @@ func applyKernelSection(cfg *Config, s kernelConfigSection) { if s.KernelLogPath != "" { cfg.KernelLogPath = s.KernelLogPath } + if s.Mod3URL != "" { + cfg.Mod3URL = s.Mod3URL + } } // findWorkspaceRoot walks up from dir until it finds a directory containing a diff --git a/internal/engine/serve.go b/internal/engine/serve.go index c2c29d8..f5c5422 100644 --- a/internal/engine/serve.go +++ b/internal/engine/serve.go @@ -19,6 +19,15 @@ // GET /v1/constellation/fovea — current fovea state // GET /v1/constellation/adjacent?uri=… — adjacent nodes by attentional proximity // +// Channel-session forwarder (ADR-082 Wave 2, see serve_sessions_channel.go): +// +// POST /v1/channel-sessions/register — kernel mints session_id +// and forwards to mod3; +// returns merged response +// POST /v1/channel-sessions/{id}/deregister — proxy to mod3, drop record +// GET /v1/channel-sessions — kernel view + mod3 list +// GET /v1/channel-sessions/{id} — single-session detail +// // The chat endpoint routes through the inference Router when one is set, // otherwise returns 501. package engine @@ -52,8 +61,8 @@ type Server struct { process *Process router Router // nil until SetRouter is called srv *http.Server - debug debugStore // captures last request pipeline state - attentionLog *attentionLog // per-server log (avoids global write race) + debug debugStore // captures last request pipeline state + attentionLog *attentionLog // per-server log (avoids global write race) agentController AgentController // nil until SetAgentController is called mcpServer *MCPServer // so SetAgentController can propagate to tools @@ -71,6 +80,18 @@ type Server struct { // these are derived views rebuilt from bus replay at startup. sessionRegistry *SessionRegistry handoffRegistry *HandoffRegistry + + // ADR-082 Wave 2: kernel-owned identity registry for channel-participant + // sessions. The kernel mints session_id, mod3 stores per-channel state + // keyed on the kernel-issued ID. Distinct from sessionRegistry above, + // which enforces strict 3-component hyphen IDs for the agent/handoff + // protocol. See serve_sessions_channel.go for the full rationale. + channelSessionRegistry *ChannelSessionRegistry + + // mod3Client is the HTTP client used to forward channel-session calls + // to mod3. Nil in production (falls back to the package-level + // mod3HTTPClient); tests set this to an httptest-backed client. + mod3Client *http.Client } // NewServer constructs a Server bound to the configured port. @@ -94,6 +115,9 @@ func NewServer(cfg *Config, nucleus *Nucleus, process *Process) *Server { s.sessionRegistry = NewSessionRegistry() s.handoffRegistry = NewHandoffRegistry() + // ADR-082 Wave 2 kernel-owned channel-session identity. + s.channelSessionRegistry = NewChannelSessionRegistry() + mux := http.NewServeMux() mux.HandleFunc("GET /", handleDashboard) mux.HandleFunc("GET /canvas", handleCanvas) @@ -133,6 +157,13 @@ func NewServer(cfg *Config, nucleus *Nucleus, process *Process) *Server { // cleanly with the pre-existing GET /v1/sessions[/{id}] surface. s.registerSessionMgmtRoutes(mux) + // ADR-082 Wave 2: kernel-side channel-session forwarder. The four + // /v1/channel-sessions/* routes mint session_ids, record identity + // locally, and forward to mod3 at cfg.Mod3URL. Namespaced under + // /v1/channel-sessions/* to coexist with the agent-session surface + // above (incompatible session_id formats — see serve_sessions_channel.go). + s.registerChannelSessionRoutes(mux) + // Replay bus_sessions + bus_handoffs into the in-memory registries so // the kernel starts with an accurate derived view. Bus is authoritative // either way; this just warms the read path. diff --git a/internal/engine/serve_sessions_channel.go b/internal/engine/serve_sessions_channel.go new file mode 100644 index 0000000..2ec8f13 --- /dev/null +++ b/internal/engine/serve_sessions_channel.go @@ -0,0 +1,445 @@ +// serve_sessions_channel.go — kernel-side HTTP forwarder for channel-session +// registration (ADR-082, Wave 2 of the mod3-kernel integration). +// +// Design locks (Wave 2 handoff): +// +// 1. Session authority = kernel-owned. The kernel mints `session_id` if the +// caller doesn't supply one. Mod3's SessionRegistry stores per-channel +// state (voice, queue, device) keyed on the kernel-issued ID. If mod3 +// crashes, the kernel knows which sessions existed; mod3 rebuilds on +// re-register. +// 2. MCP transport = HTTP proxy. Kernel reaches mod3 over plain HTTP at +// Config.Mod3URL (default http://localhost:7860), not stdio MCP. +// +// Path namespace choice: these routes live under `/v1/channel-sessions/*`, +// NOT `/v1/sessions/*`. The kernel's existing `/v1/sessions/*` family +// (serve_sessions_mgmt.go) serves agent-session state (3-component hyphen- +// validated session IDs, workspace/role required, tied to the handoff +// protocol). Channel-participant registration has an incompatible shape +// (short UUID session IDs, participant_id/participant_type/voice/device +// fields). Rather than weaken ValidateSessionID (which would cascade into +// handoff claim semantics) we namespace the new concern. The channel-provider +// RFC's guidance to "use the same cogos_session_register primitive with a +// participant_type discriminator" remains aspirational at the MCP tool layer +// (Wave 3); at the HTTP layer the two surfaces coexist cleanly. +// +// Routes owned by this file: +// +// POST /v1/channel-sessions/register — mint+forward to mod3 +// POST /v1/channel-sessions/{id}/deregister — forward deregister +// GET /v1/channel-sessions — list (kernel view + mod3 list) +// GET /v1/channel-sessions/{id} — single-session detail +package engine + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// ─── in-memory registry for kernel-owned channel-session identity ──────────── + +// ChannelSessionRecord is the kernel's identity-authority record for a +// channel-participant session. Distinct from SessionState in sessions.go +// (which tracks agent sessions with strict 3-component hyphen IDs); the +// channel-session concern uses short UUIDs and caller-supplied participant +// metadata that the kernel passes through to mod3. +type ChannelSessionRecord struct { + // SessionID is the kernel-authoritative ID. Either caller-supplied or + // minted by the kernel (uuid short form). + SessionID string `json:"session_id"` + + // ParticipantID / ParticipantType / PreferredVoice / PreferredOutputDevice + // / Priority mirror the mod3 SessionRegisterRequest shape. The kernel + // stores them so a post-crash re-register can replay identity cleanly. + ParticipantID string `json:"participant_id,omitempty"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` + + RegisteredAt time.Time `json:"registered_at"` + LastSeen time.Time `json:"last_seen"` + + // Source records whether the session_id came from the caller or was + // minted. Useful for audit; no functional impact. + IDSource string `json:"id_source,omitempty"` // "caller" | "minted" +} + +// ChannelSessionRegistry is the in-memory map keyed by session_id. It holds +// kernel-owned identity only; per-channel state (assigned_voice, queue, +// device) lives in mod3 and is returned as part of the merged forward +// response. If mod3 crashes the kernel knows which sessions existed and +// callers can re-register to rebuild. +type ChannelSessionRegistry struct { + mu sync.RWMutex + rows map[string]*ChannelSessionRecord +} + +// NewChannelSessionRegistry returns an empty registry. +func NewChannelSessionRegistry() *ChannelSessionRegistry { + return &ChannelSessionRegistry{rows: make(map[string]*ChannelSessionRecord)} +} + +// Put stores (or overwrites) a record by session_id. +func (r *ChannelSessionRegistry) Put(rec ChannelSessionRecord) { + r.mu.Lock() + defer r.mu.Unlock() + cp := rec + r.rows[rec.SessionID] = &cp +} + +// Get returns a copy of the record for id, or (nil, false). +func (r *ChannelSessionRegistry) Get(id string) (*ChannelSessionRecord, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + row, ok := r.rows[id] + if !ok { + return nil, false + } + cp := *row + return &cp, true +} + +// Delete removes the record for id. No-op if absent; returns whether the row +// was present before removal (handy for logging / telemetry). +func (r *ChannelSessionRegistry) Delete(id string) bool { + r.mu.Lock() + defer r.mu.Unlock() + _, ok := r.rows[id] + delete(r.rows, id) + return ok +} + +// Snapshot returns a copy of every record. +func (r *ChannelSessionRegistry) Snapshot() []*ChannelSessionRecord { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]*ChannelSessionRecord, 0, len(r.rows)) + for _, row := range r.rows { + cp := *row + out = append(out, &cp) + } + return out +} + +// Len returns the number of tracked channel sessions. +func (r *ChannelSessionRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.rows) +} + +// ─── route registration ────────────────────────────────────────────────────── + +// defaultMod3ForwardTimeout is the per-request timeout for channel-session +// HTTP forwards. 8s is deliberately shorter than WriteTimeout (300s) so a +// stalled mod3 doesn't block a kernel request the whole way through. +const defaultMod3ForwardTimeout = 8 * time.Second + +// mod3HTTPClient is the shared net/http client for forwards. Tests override +// Server.mod3Client for isolation; when nil, handlers fall back to this. +var mod3HTTPClient = &http.Client{Timeout: defaultMod3ForwardTimeout} + +// registerChannelSessionRoutes attaches the 4 channel-session routes onto mux. +// Called from NewServer. +func (s *Server) registerChannelSessionRoutes(mux *http.ServeMux) { + mux.HandleFunc("POST /v1/channel-sessions/register", s.handleChannelSessionRegister) + mux.HandleFunc("POST /v1/channel-sessions/{id}/deregister", s.handleChannelSessionDeregister) + mux.HandleFunc("GET /v1/channel-sessions", s.handleChannelSessionList) + mux.HandleFunc("GET /v1/channel-sessions/{id}", s.handleChannelSessionGet) +} + +// ─── wire types ────────────────────────────────────────────────────────────── + +// channelSessionRegisterRequest is the kernel-facing request body. Shape +// mirrors mod3's SessionRegisterRequest (see mod3/http_api.py line ~278) plus +// an optional session_id — when omitted, the kernel mints one. +type channelSessionRegisterRequest struct { + SessionID string `json:"session_id,omitempty"` + ParticipantID string `json:"participant_id"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` +} + +// channelSessionResponse is the merged shape returned from the kernel. The +// `kernel` block is the identity record (authoritative owner); `mod3` is the +// live channel state (voice pool, queue, device). Callers get everything in +// one round trip. Unknown fields from mod3 are preserved under `mod3` as a +// raw JSON object. +type channelSessionResponse struct { + Kernel *ChannelSessionRecord `json:"kernel"` + Mod3 json.RawMessage `json:"mod3,omitempty"` +} + +// channelSessionListResponse is the shape of GET /v1/channel-sessions. The +// kernel list is always present; the mod3 block is an opaque pass-through so +// clients don't need to know mod3's schema. +type channelSessionListResponse struct { + Kernel []*ChannelSessionRecord `json:"kernel"` + Mod3 json.RawMessage `json:"mod3,omitempty"` +} + +// ─── POST /v1/channel-sessions/register ────────────────────────────────────── + +func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Request) { + var req channelSessionRegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid_request", "body must be JSON") + return + } + if req.ParticipantID == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "participant_id is required") + return + } + + // Mint when absent — kernel is the session-ID authority (ADR-082). + idSource := "caller" + if req.SessionID == "" { + req.SessionID = mintChannelSessionID() + idSource = "minted" + } + if req.ParticipantType == "" { + req.ParticipantType = "agent" // matches mod3's default + } + if req.PreferredOutputDevice == "" { + req.PreferredOutputDevice = "system-default" + } + + now := time.Now().UTC() + record := ChannelSessionRecord{ + SessionID: req.SessionID, + ParticipantID: req.ParticipantID, + ParticipantType: req.ParticipantType, + PreferredVoice: req.PreferredVoice, + PreferredOutputDevice: req.PreferredOutputDevice, + Priority: req.Priority, + RegisteredAt: now, + LastSeen: now, + IDSource: idSource, + } + + // Forward to mod3 with the kernel-issued session_id. Mod3's body is the + // same shape modulo the optional session_id field (mod3 requires it; we + // always supply one). + forwardBody := map[string]any{ + "session_id": req.SessionID, + "participant_id": req.ParticipantID, + "participant_type": req.ParticipantType, + "preferred_voice": req.PreferredVoice, + "preferred_output_device": req.PreferredOutputDevice, + "priority": req.Priority, + } + body, _ := json.Marshal(forwardBody) + + mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodPost, + "/v1/sessions/register", bytes.NewReader(body)) + if err != nil { + slog.Warn("channel-sessions: forward to mod3 failed", + "session_id", req.SessionID, "err", err) + writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", + fmt.Sprintf("mod3 unreachable: %v", err)) + return + } + + // Preserve non-success status bodies so the caller can surface mod3's + // diagnostic text. Any 4xx/5xx from mod3 becomes the response status + // with the raw body attached — the kernel does NOT write its identity + // record in that case (mod3 rejected the registration). + if status < 200 || status >= 300 { + slog.Warn("channel-sessions: mod3 returned non-2xx", + "session_id", req.SessionID, "status", status) + writeJSONPassThrough(w, status, mod3Resp) + return + } + + // Mod3 accepted — commit the kernel-side identity record. + s.channelSessionRegistry.Put(record) + slog.Info("channel-sessions: registered", + "session_id", req.SessionID, "participant_id", req.ParticipantID, + "id_source", idSource) + + // Return merged response. The `mod3` field is the raw body from mod3 so + // callers get the exact {assigned_voice, voice_conflict, output_device, + // queue_depth, ...} shape mod3 emits. + resp := channelSessionResponse{ + Kernel: &record, + Mod3: mod3Resp, + } + writeJSONResp(w, http.StatusOK, resp) +} + +// ─── POST /v1/channel-sessions/{id}/deregister ─────────────────────────────── + +func (s *Server) handleChannelSessionDeregister(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "session_id required in path") + return + } + + mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodPost, + "/v1/sessions/"+id+"/deregister", nil) + if err != nil { + slog.Warn("channel-sessions: deregister forward failed", + "session_id", id, "err", err) + writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", + fmt.Sprintf("mod3 unreachable: %v", err)) + return + } + + // Kernel drops its identity record whenever mod3 successfully + // acknowledges, including 404 ("never registered" is equivalent to + // "not tracked"; clean slate in kernel matches clean slate in mod3). + if status >= 200 && status < 500 { + s.channelSessionRegistry.Delete(id) + } + + writeJSONPassThrough(w, status, mod3Resp) +} + +// ─── GET /v1/channel-sessions ──────────────────────────────────────────────── + +func (s *Server) handleChannelSessionList(w http.ResponseWriter, r *http.Request) { + mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodGet, + "/v1/sessions", nil) + if err != nil { + slog.Warn("channel-sessions: list forward failed", "err", err) + writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", + fmt.Sprintf("mod3 unreachable: %v", err)) + return + } + if status < 200 || status >= 300 { + writeJSONPassThrough(w, status, mod3Resp) + return + } + resp := channelSessionListResponse{ + Kernel: s.channelSessionRegistry.Snapshot(), + Mod3: mod3Resp, + } + writeJSONResp(w, http.StatusOK, resp) +} + +// ─── GET /v1/channel-sessions/{id} ─────────────────────────────────────────── + +func (s *Server) handleChannelSessionGet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "session_id required in path") + return + } + + mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodGet, + "/v1/sessions/"+id, nil) + if err != nil { + slog.Warn("channel-sessions: get forward failed", + "session_id", id, "err", err) + writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", + fmt.Sprintf("mod3 unreachable: %v", err)) + return + } + if status < 200 || status >= 300 { + writeJSONPassThrough(w, status, mod3Resp) + return + } + + kernelRec, _ := s.channelSessionRegistry.Get(id) + resp := channelSessionResponse{ + Kernel: kernelRec, + Mod3: mod3Resp, + } + writeJSONResp(w, http.StatusOK, resp) +} + +// ─── forwarder + helpers ───────────────────────────────────────────────────── + +// forwardMod3 is the single HTTP egress point to mod3. Returns the raw +// response body, HTTP status, and a transport error (non-nil only when the +// request never got a status back — connection refused, DNS, TLS, timeout). +// Mod3 error bodies (4xx/5xx) come back with err == nil and caller decides +// how to surface them. +func (s *Server) forwardMod3(ctx context.Context, method, path string, body io.Reader) (json.RawMessage, int, error) { + base := strings.TrimRight(s.cfg.Mod3URL, "/") + if base == "" { + return nil, 0, errors.New("Mod3URL not configured") + } + url := base + path + + // Scope a per-request timeout on top of the shared client's. The ambient + // request context may have a longer deadline; we want the forward to + // fail fast if mod3 stalls. + reqCtx, cancel := context.WithTimeout(ctx, defaultMod3ForwardTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, method, url, body) + if err != nil { + return nil, 0, fmt.Errorf("build request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + client := s.mod3Client + if client == nil { + client = mod3HTTPClient + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1 MB cap + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err) + } + if len(raw) == 0 { + // Mod3's deregister can return a JSON object; other endpoints + // always return JSON; if the body is empty we substitute null so + // json.RawMessage doesn't marshal an empty string (invalid JSON). + raw = []byte("null") + } + if !json.Valid(raw) { + // Surface as a synthetic JSON string; don't propagate raw bytes. + wrapped, _ := json.Marshal(map[string]string{"raw": string(raw)}) + return json.RawMessage(wrapped), resp.StatusCode, nil + } + return json.RawMessage(raw), resp.StatusCode, nil +} + +// mintChannelSessionID returns a 12-char lowercase-hex short UUID. Short +// enough to be human-readable in logs, unique enough to avoid collisions +// across a single mod3 instance's lifetime. +func mintChannelSessionID() string { + u := uuid.New() + hex := strings.ReplaceAll(u.String(), "-", "") + return "cs-" + hex[:12] +} + +// writeJSONPassThrough writes the given status + raw JSON body verbatim. Used +// when mod3's response (especially errors) should flow to the caller intact. +func writeJSONPassThrough(w http.ResponseWriter, status int, body json.RawMessage) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if len(body) == 0 { + _, _ = w.Write([]byte("null")) + return + } + _, _ = w.Write(body) +} diff --git a/internal/engine/serve_sessions_channel_test.go b/internal/engine/serve_sessions_channel_test.go new file mode 100644 index 0000000..45761cd --- /dev/null +++ b/internal/engine/serve_sessions_channel_test.go @@ -0,0 +1,566 @@ +// serve_sessions_channel_test.go — end-to-end tests for the kernel-side +// channel-session forwarder. Each test stands up a fake mod3 via +// httptest.NewServer and exercises the kernel handler against it. +package engine + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" +) + +// ─── fakeMod3 — a tiny stub that canonically mirrors mod3's responses ──────── + +type fakeMod3 struct { + t *testing.T + srv *httptest.Server + mu sync.Mutex + captured []capturedRequest + + // Overrides — tests set these to control per-endpoint behavior. + registerHandler http.HandlerFunc + deregisterHandler http.HandlerFunc + listHandler http.HandlerFunc + getHandler http.HandlerFunc +} + +type capturedRequest struct { + Method string + Path string + Body []byte +} + +func newFakeMod3(t *testing.T) *fakeMod3 { + t.Helper() + fm := &fakeMod3{t: t} + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/sessions/register", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.registerHandler != nil { + fm.registerHandler(w, r) + return + } + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + sid, _ := body["session_id"].(string) + resp := map[string]any{ + "session_id": sid, + "participant_id": body["participant_id"], + "assigned_voice": "bm_lewis", + "voice_conflict": false, + "output_device": map[string]any{"name": "system-default", "live": true}, + "queue_depth": 0, + "created": true, + } + writeFakeJSON(w, http.StatusOK, resp) + }) + + mux.HandleFunc("POST /v1/sessions/{id}/deregister", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.deregisterHandler != nil { + fm.deregisterHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "session_id": r.PathValue("id"), + "released_voice": "bm_lewis", + "dropped_jobs": 0, + }) + }) + + mux.HandleFunc("GET /v1/sessions", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.listHandler != nil { + fm.listHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "sessions": []any{}, + "serializer": map[string]any{"policy": "round-robin"}, + "voice_pool": []any{"bm_lewis", "af_bella"}, + "voice_holders": map[string]any{}, + }) + }) + + mux.HandleFunc("GET /v1/sessions/{id}", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.getHandler != nil { + fm.getHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": r.PathValue("id"), + "assigned_voice": "bm_lewis", + "output_device": map[string]any{"name": "system-default"}, + }) + }) + + fm.srv = httptest.NewServer(mux) + t.Cleanup(func() { fm.srv.Close() }) + return fm +} + +func (fm *fakeMod3) capture(r *http.Request) { + body, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewReader(body)) + fm.mu.Lock() + defer fm.mu.Unlock() + fm.captured = append(fm.captured, capturedRequest{ + Method: r.Method, Path: r.URL.Path, Body: body, + }) +} + +func (fm *fakeMod3) lastCaptured() capturedRequest { + fm.mu.Lock() + defer fm.mu.Unlock() + if len(fm.captured) == 0 { + fm.t.Fatalf("expected at least one captured request") + } + return fm.captured[len(fm.captured)-1] +} + +func writeFakeJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +// ─── newChannelServer — a Server wired just enough to exercise the routes ──── + +// newChannelServer builds a Server with only the fields needed by the +// channel-session handlers: cfg.Mod3URL pointing at fake mod3, a fresh +// ChannelSessionRegistry, and mux routes registered. The httptest.Client +// used by the kernel is overridden so every call targets fake mod3 regardless +// of the Mod3URL's scheme (handy for tests that want to inject network +// failures without depending on DNS). +func newChannelServer(t *testing.T, fm *fakeMod3) (*Server, *httptest.Server) { + t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + s := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } + mux := http.NewServeMux() + s.registerChannelSessionRoutes(mux) + front := httptest.NewServer(mux) + t.Cleanup(func() { front.Close() }) + return s, front +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +func TestChannelSessionRegister_MintsIDWhenOmitted(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "participant_id": "cog", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode response: %v", err) + } + if decoded.Kernel == nil { + t.Fatalf("expected kernel block, got nil") + } + if decoded.Kernel.SessionID == "" { + t.Fatal("expected minted session_id, got empty string") + } + if !strings.HasPrefix(decoded.Kernel.SessionID, "cs-") { + t.Fatalf("expected minted ID to carry cs- prefix, got %q", decoded.Kernel.SessionID) + } + if decoded.Kernel.IDSource != "minted" { + t.Fatalf("expected id_source=minted, got %q", decoded.Kernel.IDSource) + } + + // Mod3 should have seen the kernel-minted ID. + cap := fm.lastCaptured() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("unmarshal forwarded body: %v", err) + } + if forwarded["session_id"] != decoded.Kernel.SessionID { + t.Fatalf("expected mod3 to receive minted session_id %q, got %v", + decoded.Kernel.SessionID, forwarded["session_id"]) + } + + // Kernel registry should hold the record. + if _, ok := s.channelSessionRegistry.Get(decoded.Kernel.SessionID); !ok { + t.Fatal("expected kernel registry to retain record after success") + } +} + +func TestChannelSessionRegister_UsesCallerSuppliedID(t *testing.T) { + fm := newFakeMod3(t) + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "vox-42", + "participant_id": "sandy", + "participant_type": "agent", + "preferred_voice": "af_bella", + "preferred_output_device": "AirPods", + "priority": 3, + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if decoded.Kernel.SessionID != "vox-42" { + t.Fatalf("expected caller session_id preserved, got %q", decoded.Kernel.SessionID) + } + if decoded.Kernel.IDSource != "caller" { + t.Fatalf("expected id_source=caller, got %q", decoded.Kernel.IDSource) + } + + // Verify all caller-supplied fields made it through the forward. + cap := fm.lastCaptured() + var forwarded map[string]any + _ = json.Unmarshal(cap.Body, &forwarded) + if forwarded["session_id"] != "vox-42" || + forwarded["participant_id"] != "sandy" || + forwarded["participant_type"] != "agent" || + forwarded["preferred_voice"] != "af_bella" || + forwarded["preferred_output_device"] != "AirPods" { + t.Fatalf("forwarded fields mismatch: %v", forwarded) + } + // Priority is a number in JSON; compare via float64. + if p, _ := forwarded["priority"].(float64); int(p) != 3 { + t.Fatalf("expected priority=3 forwarded, got %v", forwarded["priority"]) + } +} + +func TestChannelSessionRegister_MergesMod3Response(t *testing.T) { + fm := newFakeMod3(t) + fm.registerHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-fixed", + "assigned_voice": "bm_oxford", + "voice_conflict": true, + "queue_depth": 5, + "created": false, + }) + } + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "cs-fixed", + "participant_id": "cog", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + + // The merged response should carry mod3's assigned_voice and + // voice_conflict fields intact. + var decoded map[string]json.RawMessage + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + var mod3 map[string]any + if err := json.Unmarshal(decoded["mod3"], &mod3); err != nil { + t.Fatalf("decode mod3 block: %v", err) + } + if mod3["assigned_voice"] != "bm_oxford" { + t.Fatalf("expected assigned_voice=bm_oxford in merge, got %v", mod3["assigned_voice"]) + } + if mod3["voice_conflict"] != true { + t.Fatalf("expected voice_conflict=true in merge, got %v", mod3["voice_conflict"]) + } +} + +func TestChannelSessionRegister_ReturnsBadGatewayWhenMod3Down(t *testing.T) { + // Build a fakeMod3 and immediately close it so the URL refuses connections. + fm := newFakeMod3(t) + fm.srv.Close() + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{"participant_id": "cog"}) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadGateway { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 502, got %d; body: %s", resp.StatusCode, raw) + } + // Kernel must NOT hold a record for a failed forward. + if s.channelSessionRegistry.Len() != 0 { + t.Fatalf("expected empty kernel registry after 502, got %d rows", + s.channelSessionRegistry.Len()) + } +} + +func TestChannelSessionRegister_PropagatesMod3Error(t *testing.T) { + fm := newFakeMod3(t) + fm.registerHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusBadRequest, map[string]string{ + "error": "participant_type must be agent|user", + }) + } + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "participant_id": "cog", + "participant_type": "alien", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 400 passthrough, got %d; body: %s", resp.StatusCode, raw) + } + var decoded map[string]string + _ = json.NewDecoder(resp.Body).Decode(&decoded) + if !strings.Contains(decoded["error"], "participant_type") { + t.Fatalf("expected mod3 error body preserved, got %v", decoded) + } + if s.channelSessionRegistry.Len() != 0 { + t.Fatal("expected kernel registry empty when mod3 rejected registration") + } +} + +func TestChannelSessionRegister_RequiresParticipantID(t *testing.T) { + fm := newFakeMod3(t) + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{}) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + // No forward should have happened. + fm.mu.Lock() + if len(fm.captured) != 0 { + t.Fatalf("expected no forward to mod3 when participant_id missing, got %d", len(fm.captured)) + } + fm.mu.Unlock() +} + +func TestChannelSessionDeregister_Forwards(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + // Pre-populate the kernel registry so we can verify deletion. + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-to-drop", ParticipantID: "cog", + }) + + req, _ := http.NewRequest(http.MethodPost, + front.URL+"/v1/channel-sessions/cs-to-drop/deregister", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("deregister: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + if _, ok := s.channelSessionRegistry.Get("cs-to-drop"); ok { + t.Fatal("expected kernel registry to drop record after successful deregister") + } + cap := fm.lastCaptured() + if cap.Path != "/v1/sessions/cs-to-drop/deregister" || cap.Method != http.MethodPost { + t.Fatalf("unexpected forward: %+v", cap) + } +} + +func TestChannelSessionDeregister_Returns502WhenMod3Down(t *testing.T) { + fm := newFakeMod3(t) + fm.srv.Close() + s, front := newChannelServer(t, fm) + + s.channelSessionRegistry.Put(ChannelSessionRecord{SessionID: "keeper"}) + + req, _ := http.NewRequest(http.MethodPost, + front.URL+"/v1/channel-sessions/keeper/deregister", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("deregister: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("expected 502, got %d", resp.StatusCode) + } + // On transport failure the kernel keeps the record — nothing authoritative + // changed. The caller can retry. + if _, ok := s.channelSessionRegistry.Get("keeper"); !ok { + t.Fatal("expected kernel to preserve record on transport failure") + } +} + +func TestChannelSessionList_MergesSnapshots(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + // Seed kernel registry so we can see the kernel block in the merged response. + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-seed", ParticipantID: "cog", + }) + + resp, err := http.Get(front.URL + "/v1/channel-sessions") + if err != nil { + t.Fatalf("list: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var decoded channelSessionListResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if len(decoded.Kernel) != 1 || decoded.Kernel[0].SessionID != "cs-seed" { + t.Fatalf("expected kernel snapshot of 1, got %+v", decoded.Kernel) + } + if len(decoded.Mod3) == 0 { + t.Fatal("expected mod3 block populated from fake mod3") + } +} + +func TestChannelSessionGet_MergesKernelAndMod3(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-detail", ParticipantID: "cog", PreferredVoice: "bm_lewis", + }) + + resp, err := http.Get(front.URL + "/v1/channel-sessions/cs-detail") + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if decoded.Kernel == nil || decoded.Kernel.SessionID != "cs-detail" { + t.Fatalf("expected kernel record returned, got %+v", decoded.Kernel) + } + if len(decoded.Mod3) == 0 { + t.Fatal("expected mod3 block populated from fake mod3") + } +} + +func TestChannelSessionGet_Returns404WhenMod3NotFound(t *testing.T) { + fm := newFakeMod3(t) + fm.getHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusNotFound, map[string]string{"error": "not found"}) + } + _, front := newChannelServer(t, fm) + + resp, err := http.Get(front.URL + "/v1/channel-sessions/does-not-exist") + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404 passthrough, got %d", resp.StatusCode) + } +} + +// ─── forwardMod3 direct tests ──────────────────────────────────────────────── + +func TestForwardMod3_ErrorsWhenURLUnset(t *testing.T) { + s := &Server{cfg: &Config{Mod3URL: ""}} + _, _, err := s.forwardMod3(context.Background(), http.MethodGet, "/v1/sessions", nil) + if err == nil { + t.Fatal("expected error when Mod3URL is empty") + } +} + +func TestForwardMod3_TimeoutYieldsTransportError(t *testing.T) { + // A listener that accepts connections but never writes — guarantees the + // 8s per-request timeout fires. Use a short deadline so the test runs + // quickly. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { l.Close() }) + + s := &Server{ + cfg: &Config{Mod3URL: "http://" + l.Addr().String()}, + mod3Client: &http.Client{Timeout: 50 * 1e6}, // 50ms + } + _, _, err = s.forwardMod3(context.Background(), http.MethodGet, "/v1/sessions", nil) + if err == nil { + t.Fatal("expected transport error on stalled server") + } + // Deadline exceeded / i/o timeout / context canceled are all acceptable shapes. + msg := err.Error() + if !strings.Contains(strings.ToLower(msg), "timeout") && + !strings.Contains(strings.ToLower(msg), "deadline") && + !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected timeout-ish error, got %v", err) + } +} + +// ─── minor sanity tests for minting + helpers ──────────────────────────────── + +func TestMintChannelSessionID_ShapeAndUniqueness(t *testing.T) { + seen := map[string]bool{} + for i := 0; i < 100; i++ { + id := mintChannelSessionID() + if !strings.HasPrefix(id, "cs-") { + t.Fatalf("expected cs- prefix, got %q", id) + } + if len(id) != 3+12 { + t.Fatalf("expected 15-char ID (cs- + 12 hex), got len=%d (%q)", len(id), id) + } + if seen[id] { + t.Fatalf("mint collision after %d iterations: %q", i, id) + } + seen[id] = true + } +} From c8e2d8d9a29ca34b7febd5099f3e3ec57ec99280 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 23 Apr 2026 12:37:03 -0400 Subject: [PATCH 2/4] feat(engine): MCP proxy for mod3 tools with audio playback (Wave 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wave 3 of the mod3-kernel integration (ADR-082 + channel-provider RFC). The kernel becomes the MCP front door for mod3 voice tools via an HTTP proxy — supersedes the installed binary's OpenClaw gateway, which read mod3's metric headers but silently discarded the audio/wav payload. Tools registered (mod3_* namespace on the /mcp endpoint): * mod3_speak synthesize text + play audio locally * mod3_stop cancel current/queued speech (optional job_id) * mod3_voices list available voices * mod3_status probe mod3 /health * mod3_register_session proxy to POST /v1/sessions/register * mod3_deregister_session proxy to POST /v1/sessions/{id}/deregister * mod3_list_sessions proxy to GET /v1/sessions All tools accept an optional session_id and thread it through to mod3 — in the request body for synthesize, as a query parameter for stop/voices, in the URL for register/deregister. Absent session_id → the proxy omits the field and mod3 routes to its default session. Playback strategy: Option (A), server-side. Synthesis response bodies are written to a tempfile, then played by afplay (macOS) / aplay (Linux) via a fire-and-forget exec. Callers opt in to blocking via blocking=true, or can skip playback entirely with skip_playback=true (returns the WAV bytes base64-encoded, forward-compatible with Option B session-routed playback once the Wave 4 dashboard WebSocket lands). The player command is injectable (modalityProxy.player) so tests never spawn real audio. Metrics: mod3's X-Mod3-* response headers are parsed into a metrics map on the tool result (job_id, duration_sec, rtf, sample_rate, etc.) with numeric headers coerced to int64/float64 where applicable. Errors: mod3-unreachable returns IsError=true with "mod3 unreachable: …" text so the ledger's tool.result event records the failure (same shape as the serve_sessions_channel.go pattern). Non-2xx responses from mod3 preserve the body text in the error result — callers see mod3's own 422 / 5xx explanation intact. Fixes: the drop-audio-bytes bug observed in the installed Apr-19 binary (mcp__cogos__mod3_speak completes in ~1s, returns metrics, but plays nothing). With this proxy the kernel actually hears what mod3 makes. Timeout: 30s on the HTTP client (vs 8s on the channel-session forwarder); accounts for cold-start model loading and multi-sentence synthesis. Files: internal/engine/mcp_modality_proxy.go (new, 551 lines) internal/engine/mcp_modality_proxy_test.go (new, 686 lines) internal/engine/mcp_server.go (modified, +11) Tests: 20 new unit tests — synthesis success/error/session-threading, stop/voices/status/sessions forwarding, metric extraction, server-side playback via a stub shell-script player (proves the bytes reach the player, guarding against the drop-audio regression), non-blocking spawn. All pass. Full ./... suite green. Out of scope (deferred to later waves): * Wave 4: dashboard participant UI + session-routed playback * session-start hook auto-registration * consolidation with the existing OpenClaw-gateway mod3_speak (coexist until deprecated) --- internal/engine/mcp_modality_proxy.go | 551 +++++++++++++++++ internal/engine/mcp_modality_proxy_test.go | 686 +++++++++++++++++++++ internal/engine/mcp_server.go | 11 + 3 files changed, 1248 insertions(+) create mode 100644 internal/engine/mcp_modality_proxy.go create mode 100644 internal/engine/mcp_modality_proxy_test.go diff --git a/internal/engine/mcp_modality_proxy.go b/internal/engine/mcp_modality_proxy.go new file mode 100644 index 0000000..1ac7768 --- /dev/null +++ b/internal/engine/mcp_modality_proxy.go @@ -0,0 +1,551 @@ +// mcp_modality_proxy.go — kernel-side MCP proxy for mod3 voice tools. +// +// Wave 3 of the mod3-kernel integration (ADR-082 + channel-provider RFC). +// The kernel becomes the MCP front door for mod3; the previous OpenClaw +// gateway pattern in the installed binary read metrics but discarded audio +// bytes. This proxy fixes that: it forwards HTTP calls to mod3, captures the +// audio/wav payload, plays it locally via afplay/aplay (fire-and-forget by +// default), and returns mod3's metric headers (X-Mod3-*) to the MCP caller. +// +// Design locks: +// +// 1. MCP transport = HTTP proxy. Every tool handler here POSTs/GETs against +// cfg.Mod3URL + "/v1/*". Mod3 is NOT an MCP server to the kernel. The +// installed binary's OpenClaw gateway is a separate concern — we are the +// next kernel build and will supersede it when deployed. +// 2. Session authority = kernel-owned. session_id is threaded through every +// proxied call. Callers pass it as an optional field; absent → proxy +// omits it and mod3 routes to its default session. Present → proxy +// includes it in the request body (synthesize) or query string (stop). +// 3. Playback strategy = Option (A), server-side. Kernel receives audio/wav, +// writes to a tempfile, execs `afplay` (macOS) or `aplay` (Linux), +// fire-and-forget. Callers can opt in to blocking with blocking=true. +// Forward-compatible with Option (B) session-routed playback once the +// Wave 4 dashboard WebSocket lands — a future session-router check can +// gate this path when a browser subscriber exists. +// +// Tools registered (prefix `mod3_` to namespace against cog_* kernel tools): +// +// - mod3_speak — synthesize + (optionally) play +// - mod3_stop — cancel current/queued speech +// - mod3_voices — list available voices +// - mod3_status — mod3 /health probe + build info +// - mod3_register_session — proxy to mod3 session register (future) +// - mod3_deregister_session — proxy to mod3 session deregister +// - mod3_list_sessions — proxy to mod3 session list +// +// Note: the session-registry family forwards to mod3's /v1/sessions/* routes +// which are not yet live on every mod3 instance (see openapi.json). They +// return a clean 502 "mod3_unreachable" in that case; once mod3 implements +// the routes (ADR-082 Wave 2 target), these tools become useful. +package engine + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ─── proxy wiring on MCPServer ─────────────────────────────────────────────── + +// modalityProxy holds the HTTP client and playback helper used by the mod3_* +// MCP tools. Fields are exported-by-convention (capitalized where needed for +// tests) so test code can override the HTTP client and the player command. +type modalityProxy struct { + // client is the HTTP client used for all mod3 forwards. Nil falls back + // to defaultMod3ProxyClient. + client *http.Client + + // player is the OS command executed for server-side audio playback. + // Overridable in tests to a stub binary / /usr/bin/true. Empty means + // "autodetect via runtime.GOOS" (afplay on darwin, aplay elsewhere). + player string + + // playerArgs, when non-nil, are passed as additional command args + // before the tempfile path. Useful for tests to pipe the wav through + // a counting script. Nil means no extra args. + playerArgs []string + + // disablePlayback short-circuits the player exec entirely. Tests set + // this when they want to assert "we got the bytes" without spawning a + // real player. Production code leaves it false. + disablePlayback bool +} + +// defaultMod3ProxyTimeout is the per-request timeout for mod3 forwards. 30s +// covers the longest-plausible synthesis on the current Kokoro voice stack +// (~5-10s for multi-sentence input, with headroom for cold starts). +const defaultMod3ProxyTimeout = 30 * time.Second + +// defaultMod3ProxyClient is the shared http.Client used when modalityProxy.client +// is nil. Lazily initialised; safe for concurrent use. +var defaultMod3ProxyClient = &http.Client{Timeout: defaultMod3ProxyTimeout} + +// getModalityProxy returns the MCPServer's modality proxy, lazily creating +// one with sane defaults on first access. Tests can pre-seed m.mod3Proxy with +// their own instance before calling this. +func (m *MCPServer) getModalityProxy() *modalityProxy { + if m.mod3Proxy == nil { + m.mod3Proxy = &modalityProxy{} + } + return m.mod3Proxy +} + +// ─── tool registration ─────────────────────────────────────────────────────── + +// registerMod3Tools installs the 7 mod3_* MCP tools. Called from +// MCPServer.registerTools after the cog_* tools so the tool index stays +// stable at the front. +func (m *MCPServer) registerMod3Tools() { + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_speak", + Description: "Synthesize text to speech via mod3 and play the audio " + + "locally. Required: text. Optional: session_id, voice, speed, " + + "emotion, blocking (wait for playback to finish). Returns mod3 " + + "metrics (job_id, duration_sec, rtf, voice) and a playback_status " + + "flag. Fallback: curl -X POST http://localhost:7860/v1/synthesize " + + "-d '{\"text\":\"...\"}' -o out.wav && afplay out.wav", + }, withToolObserver(m, "mod3_speak", m.toolMod3Speak)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_stop", + Description: "Stop current mod3 speech and/or cancel queued jobs. " + + "Optional: session_id, job_id (cancel one specific job). Empty " + + "cancels current playback and clears the queue. Returns mod3's " + + "barge-in interruption context. Fallback: curl -X POST " + + "http://localhost:7860/v1/stop", + }, withToolObserver(m, "mod3_stop", m.toolMod3Stop)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_voices", + Description: "List available mod3 voices, optionally scoped to a " + + "session. Optional: session_id. Returns the voice catalogue mod3 " + + "exposes (id, name, language, gender metadata per voice). " + + "Fallback: curl http://localhost:7860/v1/voices", + }, withToolObserver(m, "mod3_voices", m.toolMod3Voices)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_status", + Description: "Probe mod3's /health endpoint. Returns the raw health " + + "payload (model_loaded, engine info, queue_depth, etc). 502 if " + + "mod3 is unreachable. Fallback: curl http://localhost:7860/health", + }, withToolObserver(m, "mod3_status", m.toolMod3Status)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_register_session", + Description: "Proxy to mod3's POST /v1/sessions/register. Required: " + + "participant_id. Optional: session_id (caller-supplied; kernel " + + "does NOT mint here — use /v1/channel-sessions/register for that), " + + "participant_type, preferred_voice, preferred_output_device, " + + "priority. Returns mod3's full SessionRegisterResponse (assigned_" + + "voice, voice_conflict, output_device, queue_depth).", + }, withToolObserver(m, "mod3_register_session", m.toolMod3RegisterSession)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_deregister_session", + Description: "Proxy to mod3's POST /v1/sessions/{session_id}/deregister. " + + "Required: session_id. Returns mod3's deregister acknowledgment " + + "(released_voice, dropped_jobs).", + }, withToolObserver(m, "mod3_deregister_session", m.toolMod3DeregisterSession)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_list_sessions", + Description: "Proxy to mod3's GET /v1/sessions. Returns the live " + + "mod3 session roster (sessions, voice_pool, voice_holders, " + + "serializer policy). No kernel-side filtering — mod3 is source " + + "of truth for per-channel state.", + }, withToolObserver(m, "mod3_list_sessions", m.toolMod3ListSessions)) +} + +// ─── input / output types ──────────────────────────────────────────────────── + +type mod3SpeakInput struct { + Text string `json:"text"` + SessionID string `json:"session_id,omitempty"` + Voice string `json:"voice,omitempty"` + Speed float64 `json:"speed,omitempty"` + Emotion float64 `json:"emotion,omitempty"` + // Blocking waits for the spawned player to exit before returning the + // tool result. Default false — fire-and-forget so multi-second audio + // doesn't block the MCP call. + Blocking bool `json:"blocking,omitempty"` + // SkipPlayback returns the wav bytes (base64) without attempting local + // playback. Useful for callers routing audio elsewhere (dashboard WS, + // file write, etc). Default false. + SkipPlayback bool `json:"skip_playback,omitempty"` +} + +type mod3StopInput struct { + SessionID string `json:"session_id,omitempty"` + JobID string `json:"job_id,omitempty"` +} + +type mod3VoicesInput struct { + SessionID string `json:"session_id,omitempty"` +} + +type mod3StatusInput struct{} + +type mod3RegisterSessionInput struct { + SessionID string `json:"session_id,omitempty"` + ParticipantID string `json:"participant_id"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` +} + +type mod3DeregisterSessionInput struct { + SessionID string `json:"session_id"` +} + +type mod3ListSessionsInput struct{} + +// ─── handlers ──────────────────────────────────────────────────────────────── + +func (m *MCPServer) toolMod3Speak(ctx context.Context, req *mcp.CallToolRequest, in mod3SpeakInput) (*mcp.CallToolResult, any, error) { + if strings.TrimSpace(in.Text) == "" { + return textResult("text is required") + } + body := map[string]any{"text": in.Text} + if in.Voice != "" { + body["voice"] = in.Voice + } + if in.Speed > 0 { + body["speed"] = in.Speed + } + if in.Emotion > 0 { + body["emotion"] = in.Emotion + } + if in.SessionID != "" { + // Session threading: mod3 ignores unknown fields on SynthesizeRequest + // today; once multi-session synthesis lands, this is the channel. + body["session_id"] = in.SessionID + } + raw, _ := json.Marshal(body) + + audio, headers, status, err := m.proxyMod3Bytes(ctx, http.MethodPost, + "/v1/synthesize", bytes.NewReader(raw), "application/json") + if err != nil { + return mod3ErrorResult(fmt.Sprintf("mod3 unreachable: %v", err)) + } + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %s", status, truncate(string(audio), 400))) + } + + metrics := extractMod3Metrics(headers) + result := map[string]any{ + "ok": true, + "bytes": len(audio), + "metrics": metrics, + "session_id": in.SessionID, // may be empty; echoed for observability + } + + // If the caller asked for raw bytes (no server-side playback), base64- + // encode and return. Forward-compatible with session-routed playback. + if in.SkipPlayback { + result["audio_base64"] = base64.StdEncoding.EncodeToString(audio) + result["playback_status"] = "skipped" + return marshalResult(result) + } + + p := m.getModalityProxy() + if p.disablePlayback { + result["playback_status"] = "disabled" + return marshalResult(result) + } + + playErr := p.playAudio(audio, in.Blocking) + switch { + case playErr == nil && in.Blocking: + result["playback_status"] = "played" + case playErr == nil: + result["playback_status"] = "spawned" + default: + result["playback_status"] = "error" + result["playback_error"] = playErr.Error() + } + return marshalResult(result) +} + +func (m *MCPServer) toolMod3Stop(ctx context.Context, req *mcp.CallToolRequest, in mod3StopInput) (*mcp.CallToolResult, any, error) { + path := "/v1/stop" + q := url.Values{} + if in.JobID != "" { + q.Set("job_id", in.JobID) + } + if in.SessionID != "" { + q.Set("session_id", in.SessionID) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, path, nil) +} + +func (m *MCPServer) toolMod3Voices(ctx context.Context, req *mcp.CallToolRequest, in mod3VoicesInput) (*mcp.CallToolResult, any, error) { + path := "/v1/voices" + if in.SessionID != "" { + path += "?session_id=" + url.QueryEscape(in.SessionID) + } + return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, path, nil) +} + +func (m *MCPServer) toolMod3Status(ctx context.Context, req *mcp.CallToolRequest, in mod3StatusInput) (*mcp.CallToolResult, any, error) { + return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, "/health", nil) +} + +func (m *MCPServer) toolMod3RegisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3RegisterSessionInput) (*mcp.CallToolResult, any, error) { + if in.ParticipantID == "" { + return textResult("participant_id is required") + } + body := map[string]any{ + "participant_id": in.ParticipantID, + } + if in.SessionID != "" { + body["session_id"] = in.SessionID + } + if in.ParticipantType != "" { + body["participant_type"] = in.ParticipantType + } + if in.PreferredVoice != "" { + body["preferred_voice"] = in.PreferredVoice + } + if in.PreferredOutputDevice != "" { + body["preferred_output_device"] = in.PreferredOutputDevice + } + if in.Priority != 0 { + body["priority"] = in.Priority + } + raw, _ := json.Marshal(body) + return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, "/v1/sessions/register", bytes.NewReader(raw)) +} + +func (m *MCPServer) toolMod3DeregisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3DeregisterSessionInput) (*mcp.CallToolResult, any, error) { + if in.SessionID == "" { + return textResult("session_id is required") + } + return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, + "/v1/sessions/"+url.PathEscape(in.SessionID)+"/deregister", nil) +} + +func (m *MCPServer) toolMod3ListSessions(ctx context.Context, req *mcp.CallToolRequest, in mod3ListSessionsInput) (*mcp.CallToolResult, any, error) { + return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, "/v1/sessions", nil) +} + +// ─── HTTP forwarder primitives ─────────────────────────────────────────────── + +// proxyMod3Bytes issues an HTTP request to mod3 and returns the raw body, +// response headers, HTTP status, and a transport error. Caller owns the body +// bytes; they may be audio/wav (mod3_speak) or JSON (everything else). +func (m *MCPServer) proxyMod3Bytes(ctx context.Context, method, path string, body io.Reader, contentType string) ([]byte, http.Header, int, error) { + if m.cfg == nil { + return nil, nil, 0, errors.New("Mod3URL not configured (cfg nil)") + } + base := strings.TrimRight(m.cfg.Mod3URL, "/") + if base == "" { + return nil, nil, 0, errors.New("Mod3URL not configured") + } + + reqCtx, cancel := context.WithTimeout(ctx, defaultMod3ProxyTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, method, base+path, body) + if err != nil { + return nil, nil, 0, fmt.Errorf("build request: %w", err) + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + // Accept both audio and JSON so a single client path covers synthesize + // (audio/wav) and the rest (application/json). + req.Header.Set("Accept", "audio/wav, application/json") + + client := m.getModalityProxy().client + if client == nil { + client = defaultMod3ProxyClient + } + resp, err := client.Do(req) + if err != nil { + return nil, nil, 0, err + } + defer resp.Body.Close() + + // 16 MB cap — more than enough for multi-minute Kokoro wav at 24kHz + // (~2 MB per 30s), safety net against an upstream that never closes. + raw, err := io.ReadAll(io.LimitReader(resp.Body, 16<<20)) + if err != nil { + return nil, resp.Header, resp.StatusCode, fmt.Errorf("read response body: %w", err) + } + return raw, resp.Header, resp.StatusCode, nil +} + +// proxyMod3JSONAsMCP is a convenience wrapper for tools whose response is +// JSON (everything except mod3_speak). Reads the body, parses it as JSON if +// possible, and returns an mcp.CallToolResult; on non-2xx status returns a +// mod3-error marshalled result so the caller sees the mod3 body intact. +func (m *MCPServer) proxyMod3JSONAsMCP(ctx context.Context, method, path string, body io.Reader) (*mcp.CallToolResult, any, error) { + contentType := "" + if body != nil { + contentType = "application/json" + } + raw, _, status, err := m.proxyMod3Bytes(ctx, method, path, body, contentType) + if err != nil { + return mod3ErrorResult(fmt.Sprintf("mod3 unreachable: %v", err)) + } + // Try to parse as JSON; if parse fails, surface the body as text so the + // caller at least sees what mod3 said. + var parsed any + if len(raw) > 0 { + if jsonErr := json.Unmarshal(raw, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(raw)} + } + } + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %v", status, parsed)) + } + return marshalResult(parsed) +} + +// mod3ErrorResult returns an IsError=true CallToolResult so the observer +// wrapper records the tool invocation as a failure in the ledger. +func mod3ErrorResult(msg string) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: msg}}, + IsError: true, + }, nil, nil +} + +// extractMod3Metrics pulls the X-Mod3-* headers into a metrics map. Numeric +// fields are parsed when possible; unknown headers pass through as strings +// so future mod3 headers surface without code changes. +func extractMod3Metrics(h http.Header) map[string]any { + out := map[string]any{} + for key, values := range h { + lk := strings.ToLower(key) + if !strings.HasPrefix(lk, "x-mod3-") || len(values) == 0 { + continue + } + short := strings.TrimPrefix(lk, "x-mod3-") + v := values[0] + // Try numeric parse for the common metric headers. + if f, err := strconv.ParseFloat(v, 64); err == nil { + // Preserve integers as int for nicer JSON output. + if i, ierr := strconv.ParseInt(v, 10, 64); ierr == nil { + out[short] = i + } else { + out[short] = f + } + continue + } + out[short] = v + } + return out +} + +// ─── playback helper ───────────────────────────────────────────────────────── + +// playAudio writes the wav bytes to a tempfile and spawns the platform's +// default player. When blocking==false the function returns immediately +// after the process starts; a goroutine waits for exit so the tempfile can +// be cleaned up. When blocking==true the function waits for exit and +// surfaces any non-zero return as an error. +// +// In tests, set modalityProxy.player to "/usr/bin/true" (or similar) to +// avoid actually playing audio. +func (p *modalityProxy) playAudio(wav []byte, blocking bool) error { + if p.disablePlayback { + return nil + } + f, err := os.CreateTemp("", "mod3-speak-*.wav") + if err != nil { + return fmt.Errorf("tempfile: %w", err) + } + path := f.Name() + if _, err := f.Write(wav); err != nil { + f.Close() + os.Remove(path) + return fmt.Errorf("write tempfile: %w", err) + } + if err := f.Close(); err != nil { + os.Remove(path) + return fmt.Errorf("close tempfile: %w", err) + } + + player := p.player + if player == "" { + player = defaultPlayerCommand() + } + if player == "" { + os.Remove(path) + return fmt.Errorf("no audio player available for GOOS=%s", runtime.GOOS) + } + + args := append([]string{}, p.playerArgs...) + args = append(args, path) + cmd := exec.Command(player, args...) + + if err := cmd.Start(); err != nil { + os.Remove(path) + return fmt.Errorf("start %s: %w", player, err) + } + + if blocking { + err := cmd.Wait() + _ = os.Remove(path) + if err != nil { + return fmt.Errorf("player %s exited: %w", player, err) + } + return nil + } + // Fire-and-forget: reap the child so the tempfile gets cleaned and the + // process isn't a zombie. Log errors; don't propagate (the MCP call + // already returned successfully). + go func() { + if werr := cmd.Wait(); werr != nil { + slog.Debug("mod3 proxy: player exited non-zero", + "player", player, "path", path, "err", werr) + } + _ = os.Remove(path) + }() + return nil +} + +// defaultPlayerCommand returns the preferred platform player, or "" when +// none is available in PATH. Resolved lazily per call so tests that change +// PATH take effect. +func defaultPlayerCommand() string { + candidates := map[string][]string{ + "darwin": {"afplay"}, + "linux": {"aplay", "paplay"}, + "freebsd": {"aplay"}, + } + for _, name := range candidates[runtime.GOOS] { + if _, err := exec.LookPath(name); err == nil { + return name + } + } + // Final fallback: if neither platform default is present, see if the + // caller has exposed one via PATH under its canonical name. + for _, name := range []string{"afplay", "aplay", "paplay", "ffplay"} { + if _, err := exec.LookPath(name); err == nil { + return name + } + } + return "" +} diff --git a/internal/engine/mcp_modality_proxy_test.go b/internal/engine/mcp_modality_proxy_test.go new file mode 100644 index 0000000..9ad0fd6 --- /dev/null +++ b/internal/engine/mcp_modality_proxy_test.go @@ -0,0 +1,686 @@ +// mcp_modality_proxy_test.go — coverage for the mod3 MCP proxy tools. +// +// Strategy: stand up a fake mod3 via httptest.NewServer, point an MCPServer's +// proxy at it, exercise each tool handler directly (handler function + +// typed input, bypassing the MCP SDK's JSON marshal layer). Playback is +// stubbed via disablePlayback or the injectable player field so tests don't +// spawn afplay. +package engine + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ─── fake mod3 with synthesize + session routes ────────────────────────────── + +type fakeMod3Proxy struct { + t *testing.T + srv *httptest.Server + mu sync.Mutex + captured []capturedProxyRequest + + // Overrides for per-endpoint behavior. + synthesizeHandler http.HandlerFunc + stopHandler http.HandlerFunc + voicesHandler http.HandlerFunc + healthHandler http.HandlerFunc + regHandler http.HandlerFunc + deregHandler http.HandlerFunc + listSessionHandler http.HandlerFunc +} + +type capturedProxyRequest struct { + Method string + Path string + Query string + Body []byte +} + +// synthWav is a tiny valid-ish WAV header + ~1KB of silence. Not a real +// audio file — the proxy doesn't parse it, it just forwards/plays bytes. +var synthWav = func() []byte { + b := make([]byte, 1024) + copy(b, []byte("RIFF\x00\x00\x00\x00WAVEfmt ")) + return b +}() + +func newFakeMod3Proxy(t *testing.T) *fakeMod3Proxy { + t.Helper() + fm := &fakeMod3Proxy{t: t} + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/synthesize", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.synthesizeHandler != nil { + fm.synthesizeHandler(w, r) + return + } + // Emit the full mod3 header surface so the proxy can extract it. + w.Header().Set("X-Mod3-Job-Id", "job-test-0001") + w.Header().Set("X-Mod3-Voice", "bm_lewis") + w.Header().Set("X-Mod3-Duration-Sec", "1.23") + w.Header().Set("X-Mod3-Sample-Rate", "24000") + w.Header().Set("X-Mod3-Rtf", "9.29") + w.Header().Set("X-Mod3-Chunks", "1") + w.Header().Set("Content-Type", "audio/wav") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(synthWav) + }) + + mux.HandleFunc("POST /v1/stop", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.stopHandler != nil { + fm.stopHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "stopped", + "dropped_jobs": 0, + "interrupted": true, + "session_id": r.URL.Query().Get("session_id"), + "job_id_target": r.URL.Query().Get("job_id"), + }) + }) + + mux.HandleFunc("GET /v1/voices", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.voicesHandler != nil { + fm.voicesHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "voices": []map[string]any{ + {"id": "bm_lewis", "language": "en-GB"}, + {"id": "af_bella", "language": "en-US"}, + }, + }) + }) + + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.healthHandler != nil { + fm.healthHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "model_loaded": true, + "engine": "kokoro", + }) + }) + + mux.HandleFunc("POST /v1/sessions/register", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.regHandler != nil { + fm.regHandler(w, r) + return + } + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": body["session_id"], + "participant_id": body["participant_id"], + "assigned_voice": "bm_lewis", + }) + }) + + mux.HandleFunc("POST /v1/sessions/{id}/deregister", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.deregHandler != nil { + fm.deregHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "session_id": r.PathValue("id"), + }) + }) + + mux.HandleFunc("GET /v1/sessions", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.listSessionHandler != nil { + fm.listSessionHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "sessions": []any{}, + "voice_pool": []string{"bm_lewis", "af_bella"}, + }) + }) + + fm.srv = httptest.NewServer(mux) + t.Cleanup(func() { fm.srv.Close() }) + return fm +} + +func (fm *fakeMod3Proxy) capture(r *http.Request) { + body, _ := io.ReadAll(r.Body) + fm.mu.Lock() + defer fm.mu.Unlock() + fm.captured = append(fm.captured, capturedProxyRequest{ + Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Body: body, + }) +} + +func (fm *fakeMod3Proxy) last() capturedProxyRequest { + fm.mu.Lock() + defer fm.mu.Unlock() + if len(fm.captured) == 0 { + fm.t.Fatalf("no captured requests") + } + return fm.captured[len(fm.captured)-1] +} + +// newProxyMCP builds a minimal MCPServer whose proxy points at fm, with +// playback fully disabled so tests don't touch the audio stack. +func newProxyMCP(t *testing.T, fm *fakeMod3Proxy) *MCPServer { + t.Helper() + m := &MCPServer{ + cfg: &Config{Mod3URL: fm.srv.URL}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + return m +} + +// decodeToolText parses the JSON text content of a CallToolResult. +func decodeToolText(t *testing.T, res *mcp.CallToolResult) map[string]any { + t.Helper() + if res == nil || len(res.Content) == 0 { + t.Fatalf("empty result") + } + tc, ok := res.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("expected *mcp.TextContent, got %T", res.Content[0]) + } + var out map[string]any + if err := json.Unmarshal([]byte(tc.Text), &out); err != nil { + t.Fatalf("decode result text: %v (raw=%q)", err, tc.Text) + } + return out +} + +// ─── mod3_speak ────────────────────────────────────────────────────────────── + +func TestMod3Speak_SuccessPath(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "hello world", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError=true: %v", res.Content) + } + + out := decodeToolText(t, res) + if ok, _ := out["ok"].(bool); !ok { + t.Fatalf("expected ok=true, got %v", out["ok"]) + } + if bytes, _ := out["bytes"].(float64); bytes < 100 { + t.Fatalf("expected bytes > 100, got %v", out["bytes"]) + } + + metrics, _ := out["metrics"].(map[string]any) + if metrics == nil { + t.Fatal("expected metrics map") + } + if jobID, _ := metrics["job-id"].(string); jobID != "job-test-0001" { + t.Fatalf("expected job-id=job-test-0001, got %v", metrics["job-id"]) + } + if dur, _ := metrics["duration-sec"].(float64); dur != 1.23 { + t.Fatalf("expected duration-sec=1.23, got %v", metrics["duration-sec"]) + } + // Integer parse path — sample-rate comes as "24000". + if sr, ok := metrics["sample-rate"].(float64); !ok || sr != 24000 { + t.Fatalf("expected sample-rate=24000, got %v (%T)", metrics["sample-rate"], metrics["sample-rate"]) + } + if got, _ := out["playback_status"].(string); got != "disabled" { + t.Fatalf("expected playback_status=disabled, got %v", out["playback_status"]) + } +} + +func TestMod3Speak_ForwardsSessionID(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "hello", + SessionID: "cs-abc123", + Voice: "af_bella", + Speed: 1.1, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cap := fm.last() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + if forwarded["session_id"] != "cs-abc123" { + t.Fatalf("expected session_id=cs-abc123, got %v", forwarded["session_id"]) + } + if forwarded["voice"] != "af_bella" { + t.Fatalf("expected voice=af_bella, got %v", forwarded["voice"]) + } + if speed, _ := forwarded["speed"].(float64); speed != 1.1 { + t.Fatalf("expected speed=1.1, got %v", forwarded["speed"]) + } +} + +func TestMod3Speak_OmitsSessionIDWhenAbsent(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "plain"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cap := fm.last() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + if _, present := forwarded["session_id"]; present { + t.Fatalf("expected no session_id key, got %v", forwarded["session_id"]) + } +} + +func TestMod3Speak_EmptyTextRejects(t *testing.T) { + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: " "}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // textResult — IsError is false because it's a validation message, but + // the content must mention the required field. + if len(res.Content) == 0 { + t.Fatal("expected content") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "required") { + t.Fatalf("expected 'required' in message, got %q", tc.Text) + } +} + +func TestMod3Speak_Mod3DownReturnsCleanError(t *testing.T) { + // Port that refuses connections. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := l.Addr().String() + _ = l.Close() // release the port so dials get ECONNREFUSED + + m := &MCPServer{ + cfg: &Config{Mod3URL: "http://" + addr}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "hi"}) + if err != nil { + t.Fatalf("handler should not return Go error, got %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true when mod3 is unreachable") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(strings.ToLower(tc.Text), "mod3 unreachable") { + t.Fatalf("expected 'mod3 unreachable' message, got %q", tc.Text) + } +} + +func TestMod3Speak_PreservesMod3ErrorBody(t *testing.T) { + fm := newFakeMod3Proxy(t) + fm.synthesizeHandler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"detail":"text must not be empty"}`)) + } + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "bad"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true on mod3 422") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "422") { + t.Fatalf("expected '422' in error text, got %q", tc.Text) + } + if !strings.Contains(tc.Text, "text must not be empty") { + t.Fatalf("expected mod3 body preserved, got %q", tc.Text) + } +} + +func TestMod3Speak_SkipPlaybackReturnsBase64(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "skip", + SkipPlayback: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "skipped" { + t.Fatalf("expected playback_status=skipped, got %v", out["playback_status"]) + } + if b64, _ := out["audio_base64"].(string); b64 == "" { + t.Fatal("expected audio_base64 populated when skip_playback=true") + } +} + +// ─── mod3_stop / voices / status / sessions ────────────────────────────────── + +func TestMod3Stop_ForwardsSessionAndJob(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Stop(context.Background(), nil, mod3StopInput{ + SessionID: "cs-abc", + JobID: "job-xyz", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + + cap := fm.last() + if !strings.Contains(cap.Query, "session_id=cs-abc") { + t.Fatalf("expected session_id in query, got %q", cap.Query) + } + if !strings.Contains(cap.Query, "job_id=job-xyz") { + t.Fatalf("expected job_id in query, got %q", cap.Query) + } + + out := decodeToolText(t, res) + if out["status"] != "stopped" { + t.Fatalf("expected status=stopped, got %v", out["status"]) + } +} + +func TestMod3Voices_ReturnsRawList(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Voices(context.Background(), nil, mod3VoicesInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + out := decodeToolText(t, res) + voices, _ := out["voices"].([]any) + if len(voices) != 2 { + t.Fatalf("expected 2 voices, got %d", len(voices)) + } +} + +func TestMod3Voices_ThreadsSessionID(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, _ = m.toolMod3Voices(context.Background(), nil, mod3VoicesInput{SessionID: "cs-qq"}) + cap := fm.last() + if !strings.Contains(cap.Query, "session_id=cs-qq") { + t.Fatalf("expected session_id in query, got %q", cap.Query) + } +} + +func TestMod3Status_HitsHealth(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Status(context.Background(), nil, mod3StatusInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + out := decodeToolText(t, res) + if out["status"] != "ok" { + t.Fatalf("expected status=ok, got %v", out["status"]) + } + + cap := fm.last() + if cap.Path != "/health" { + t.Fatalf("expected /health, got %q", cap.Path) + } +} + +func TestMod3Status_Mod3DownClean(t *testing.T) { + l, _ := net.Listen("tcp", "127.0.0.1:0") + addr := l.Addr().String() + _ = l.Close() + + m := &MCPServer{cfg: &Config{Mod3URL: "http://" + addr}} + res, _, err := m.toolMod3Status(context.Background(), nil, mod3StatusInput{}) + if err != nil { + t.Fatalf("handler should not Go-error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true") + } +} + +func TestMod3RegisterSession_ForwardsBody(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + SessionID: "cs-regtest", + ParticipantID: "cog", + ParticipantType: "agent", + PreferredVoice: "bm_lewis", + PreferredOutputDevice: "system-default", + Priority: 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + + cap := fm.last() + var body map[string]any + if err := json.Unmarshal(cap.Body, &body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["session_id"] != "cs-regtest" { + t.Fatalf("bad session_id: %v", body["session_id"]) + } + if body["participant_id"] != "cog" { + t.Fatalf("bad participant_id: %v", body["participant_id"]) + } + if body["preferred_voice"] != "bm_lewis" { + t.Fatalf("bad preferred_voice: %v", body["preferred_voice"]) + } +} + +func TestMod3RegisterSession_RejectsWithoutParticipant(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "participant_id") { + t.Fatalf("expected validation message, got %q", tc.Text) + } +} + +func TestMod3DeregisterSession_PathEscape(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3DeregisterSession(context.Background(), nil, mod3DeregisterSessionInput{ + SessionID: "cs-drop", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + cap := fm.last() + if cap.Path != "/v1/sessions/cs-drop/deregister" { + t.Fatalf("expected /v1/sessions/cs-drop/deregister, got %q", cap.Path) + } +} + +func TestMod3DeregisterSession_RequiresID(t *testing.T) { + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3DeregisterSession(context.Background(), nil, mod3DeregisterSessionInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "session_id") { + t.Fatalf("expected session_id message, got %q", tc.Text) + } +} + +func TestMod3ListSessions_ReturnsRaw(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3ListSessions(context.Background(), nil, mod3ListSessionsInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + out := decodeToolText(t, res) + vp, _ := out["voice_pool"].([]any) + if len(vp) != 2 { + t.Fatalf("expected 2 voices in pool, got %d", len(vp)) + } +} + +// ─── metric extraction unit test ───────────────────────────────────────────── + +func TestExtractMod3Metrics_TypesByValue(t *testing.T) { + h := http.Header{} + h.Set("X-Mod3-Job-Id", "abc123") + h.Set("X-Mod3-Duration-Sec", "2.50") + h.Set("X-Mod3-Sample-Rate", "24000") + h.Set("X-Mod3-Rtf", "8.1") + h.Set("Content-Type", "audio/wav") // should be skipped + + out := extractMod3Metrics(h) + if len(out) != 4 { + t.Fatalf("expected 4 metrics, got %d: %v", len(out), out) + } + if got, _ := out["job-id"].(string); got != "abc123" { + t.Fatalf("job-id: got %v", out["job-id"]) + } + if got, ok := out["duration-sec"].(float64); !ok || got != 2.5 { + t.Fatalf("duration-sec: got %v (%T)", out["duration-sec"], out["duration-sec"]) + } + if got, ok := out["sample-rate"].(int64); !ok || got != 24000 { + t.Fatalf("sample-rate should parse to int64, got %v (%T)", + out["sample-rate"], out["sample-rate"]) + } + if _, present := out["content-type"]; present { + t.Fatal("content-type should be skipped") + } +} + +// ─── playback injection test ───────────────────────────────────────────────── + +// TestPlayAudio_StubPlayer — validate the playback plumbing by injecting a +// small shell-script player that records the file it received. Proves the +// audio bytes reach the player (the bug the installed binary has today is +// that they don't). Only runs on OSes where a shell is available. +func TestPlayAudio_StubPlayer(t *testing.T) { + dir := t.TempDir() + recPath := filepath.Join(dir, "received.log") + stubPath := filepath.Join(dir, "stub-player.sh") + + // Player writes its last-arg path and the first 4 bytes of the wav to + // received.log; proves (a) the player got a path, (b) the file exists + // at that path with our bytes. + stubBody := `#!/bin/sh +path="$1" +hdr=$(dd if="$path" bs=4 count=1 2>/dev/null) +printf 'path=%s hdr=%s\n' "$path" "$hdr" > "` + recPath + `" +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + + p := &modalityProxy{player: stubPath} + if err := p.playAudio(synthWav, true); err != nil { + t.Fatalf("playAudio: %v", err) + } + + got, err := os.ReadFile(recPath) + if err != nil { + t.Fatalf("read record: %v", err) + } + line := string(got) + if !strings.Contains(line, "hdr=RIFF") { + t.Fatalf("player did not see RIFF header; got %q", line) + } + if !strings.Contains(line, "path=") { + t.Fatalf("player did not get a path arg; got %q", line) + } +} + +// TestPlayAudio_NonBlockingSpawn — fire-and-forget. Use a sleep-style stub +// and assert playAudio returns before the stub would finish. Prevents the +// regression where speech synthesis blocks the MCP response on playback. +func TestPlayAudio_NonBlockingSpawn(t *testing.T) { + dir := t.TempDir() + stubPath := filepath.Join(dir, "sleeper.sh") + stubBody := `#!/bin/sh +sleep 5 +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + p := &modalityProxy{player: stubPath} + + done := make(chan struct{}) + go func() { + if err := p.playAudio(synthWav, false); err != nil { + t.Errorf("playAudio: %v", err) + } + close(done) + }() + + select { + case <-done: + // Expected: playAudio returns immediately in non-blocking mode. + case <-time.After(2 * time.Second): + t.Fatal("playAudio(blocking=false) did not return within 2s") + } +} diff --git a/internal/engine/mcp_server.go b/internal/engine/mcp_server.go index ffffcc2..9970294 100644 --- a/internal/engine/mcp_server.go +++ b/internal/engine/mcp_server.go @@ -47,6 +47,11 @@ type MCPServer struct { busSessions *BusSessionManager sessionRegistry *SessionRegistry handoffRegistry *HandoffRegistry + + // mod3Proxy backs the mod3_* MCP tools (Wave 3 of the mod3-kernel + // integration). Lazily initialised — tests can pre-seed it with a + // custom HTTP client + stub player to avoid real network + audio. + mod3Proxy *modalityProxy } // NewMCPServer creates and configures the MCP server with all stage-1 tools. @@ -241,6 +246,12 @@ func (m *MCPServer) registerTools() { // both surfaces coexist by design — same kernel truth, two MCP // doorways (amendment #5 of the Agent P hybrid plan). m.registerSessionTools() + + // Wave 3: mod3 proxy tools (mcp_modality_proxy.go). The kernel becomes + // the MCP front door for mod3 — HTTP-forwards synthesis/stop/voices/ + // status and plays the returned audio/wav locally. Supersedes the + // installed binary's OpenClaw gateway which silently drops audio bytes. + m.registerMod3Tools() } // registerResources registers MCP Resources — read-only addressable data. From dbe69a8854fd8d664c664410e373a2de3ca658c1 Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 23 Apr 2026 14:21:20 -0400 Subject: [PATCH 3/4] refactor(engine): route MCP session tools through kernel endpoint (Wave 3.5) Eliminates the Wave 2 / Wave 3 divergence where the mod3_register_session, mod3_deregister_session, and mod3_list_sessions MCP tools called mod3's /v1/sessions/* surface directly, bypassing Wave 2's kernel-owned session_id minting at /v1/channel-sessions/register. Session-ID authority is now in one place (ADR-082): the kernel's shared RegisterChannelSession / DeregisterChannelSession / ListChannelSessions methods on *Server. Both the HTTP handlers and the MCP tool handlers call through these methods. No self-localhost loop. Approach 2 (refactor shared logic) over Approach 1 (self-HTTP loop): the Wave 2 handler bodies factored cleanly into *Server methods returning typed errors, and wiring an MCPServer field via SetChannelSessionBackend mirrors the existing SetSessionsBackend pattern, so the public surface didn't have to change. Schema alignment with the channel-provider RFC's cogos_session_register primitive: added optional `kinds` (array) and `metadata` (map) fields to both the Wave 2 register endpoint and the MCP tool input. Both flow through to mod3 unchanged (mod3 ignores unknown fields today) and are preserved on the kernel identity record so downstream consumers can filter by capability. - internal/engine/serve_sessions_channel.go: factored handleChannelSession* bodies into RegisterChannelSession / DeregisterChannelSession / ListChannelSessions methods returning a typed channelSessionForwardError; added Kinds / Metadata fields to ChannelSessionRecord and the request wire type; handlers now thin-wrap the shared methods. - internal/engine/mcp_modality_proxy.go: the three session-family MCP tool handlers now call the channelSessionBackend interface on MCPServer; all direct HTTP calls to mod3's /v1/sessions/* removed from this file. - internal/engine/mcp_server.go: added channelSessionBackend field + interface + SetChannelSessionBackend setter. - internal/engine/serve_mcp.go: wires the live Server as the backend. - tests: updated newProxyMCP to wire a live Server so MCP session-family tests exercise the shared code path; added coverage for minting via the MCP tool, kinds/metadata pass-through (both HTTP and MCP), and the no-backend error path. --- internal/engine/mcp_modality_proxy.go | 169 +++++++---- internal/engine/mcp_modality_proxy_test.go | 223 +++++++++++++-- internal/engine/mcp_server.go | 28 ++ internal/engine/serve_mcp.go | 5 + internal/engine/serve_sessions_channel.go | 265 ++++++++++++++---- .../engine/serve_sessions_channel_test.go | 61 ++++ 6 files changed, 615 insertions(+), 136 deletions(-) diff --git a/internal/engine/mcp_modality_proxy.go b/internal/engine/mcp_modality_proxy.go index 1ac7768..1a97e9d 100644 --- a/internal/engine/mcp_modality_proxy.go +++ b/internal/engine/mcp_modality_proxy.go @@ -1,6 +1,7 @@ // mcp_modality_proxy.go — kernel-side MCP proxy for mod3 voice tools. // -// Wave 3 of the mod3-kernel integration (ADR-082 + channel-provider RFC). +// Wave 3 of the mod3-kernel integration (ADR-082 + channel-provider RFC), +// consolidated in Wave 3.5 with Wave 2's session-ID authority. // The kernel becomes the MCP front door for mod3; the previous OpenClaw // gateway pattern in the installed binary read metrics but discarded audio // bytes. This proxy fixes that: it forwards HTTP calls to mod3, captures the @@ -9,14 +10,16 @@ // // Design locks: // -// 1. MCP transport = HTTP proxy. Every tool handler here POSTs/GETs against -// cfg.Mod3URL + "/v1/*". Mod3 is NOT an MCP server to the kernel. The -// installed binary's OpenClaw gateway is a separate concern — we are the -// next kernel build and will supersede it when deployed. -// 2. Session authority = kernel-owned. session_id is threaded through every -// proxied call. Callers pass it as an optional field; absent → proxy -// omits it and mod3 routes to its default session. Present → proxy -// includes it in the request body (synthesize) or query string (stop). +// 1. MCP transport = HTTP proxy. Synthesis/control tool handlers here +// POST/GET against cfg.Mod3URL + "/v1/*". Mod3 is NOT an MCP server to +// the kernel. The installed binary's OpenClaw gateway is a separate +// concern — we are the next kernel build and will supersede it when +// deployed. +// 2. Session authority = kernel-owned (Wave 3.5). The session-family tools +// (register/deregister/list) do NOT call mod3 directly — they call the +// kernel's RegisterChannelSession / DeregisterChannelSession / +// ListChannelSessions methods on the Server, which mint the session_id +// and forward to mod3. Session ID minting happens in exactly one place. // 3. Playback strategy = Option (A), server-side. Kernel receives audio/wav, // writes to a tempfile, execs `afplay` (macOS) or `aplay` (Linux), // fire-and-forget. Callers can opt in to blocking with blocking=true. @@ -26,18 +29,13 @@ // // Tools registered (prefix `mod3_` to namespace against cog_* kernel tools): // -// - mod3_speak — synthesize + (optionally) play -// - mod3_stop — cancel current/queued speech -// - mod3_voices — list available voices -// - mod3_status — mod3 /health probe + build info -// - mod3_register_session — proxy to mod3 session register (future) -// - mod3_deregister_session — proxy to mod3 session deregister -// - mod3_list_sessions — proxy to mod3 session list -// -// Note: the session-registry family forwards to mod3's /v1/sessions/* routes -// which are not yet live on every mod3 instance (see openapi.json). They -// return a clean 502 "mod3_unreachable" in that case; once mod3 implements -// the routes (ADR-082 Wave 2 target), these tools become useful. +// - mod3_speak — synthesize + (optionally) play (direct to mod3) +// - mod3_stop — cancel current/queued speech (direct to mod3) +// - mod3_voices — list available voices (direct to mod3) +// - mod3_status — mod3 /health probe + build info (direct to mod3) +// - mod3_register_session — kernel-minted session registration (via kernel) +// - mod3_deregister_session — session deregister (via kernel) +// - mod3_list_sessions — merged kernel+mod3 session roster (via kernel) package engine import ( @@ -148,27 +146,34 @@ func (m *MCPServer) registerMod3Tools() { mcp.AddTool(m.server, &mcp.Tool{ Name: "mod3_register_session", - Description: "Proxy to mod3's POST /v1/sessions/register. Required: " + - "participant_id. Optional: session_id (caller-supplied; kernel " + - "does NOT mint here — use /v1/channel-sessions/register for that), " + - "participant_type, preferred_voice, preferred_output_device, " + - "priority. Returns mod3's full SessionRegisterResponse (assigned_" + - "voice, voice_conflict, output_device, queue_depth).", + Description: "Register a channel-participant session. Routes through " + + "the kernel's /v1/channel-sessions/register endpoint so " + + "session_id minting stays centralized (ADR-082 Wave 3.5). " + + "Required: participant_id. Optional: session_id (kernel mints " + + "a cs-* short UUID when absent), participant_type " + + "(agent|user|provider), preferred_voice, preferred_output_device, " + + "priority, kinds (e.g. [\"audio\"] per channel-provider RFC), " + + "metadata (opaque pass-through). Returns the merged {kernel, " + + "mod3} block: kernel identity record + mod3's full " + + "SessionRegisterResponse (assigned_voice, voice_conflict, " + + "output_device, queue_depth).", }, withToolObserver(m, "mod3_register_session", m.toolMod3RegisterSession)) mcp.AddTool(m.server, &mcp.Tool{ Name: "mod3_deregister_session", - Description: "Proxy to mod3's POST /v1/sessions/{session_id}/deregister. " + - "Required: session_id. Returns mod3's deregister acknowledgment " + - "(released_voice, dropped_jobs).", + Description: "Deregister a channel-participant session. Routes " + + "through the kernel's /v1/channel-sessions/{id}/deregister " + + "endpoint so the kernel drops its identity record in sync with " + + "mod3. Required: session_id. Returns mod3's deregister " + + "acknowledgment (released_voice, dropped_jobs).", }, withToolObserver(m, "mod3_deregister_session", m.toolMod3DeregisterSession)) mcp.AddTool(m.server, &mcp.Tool{ Name: "mod3_list_sessions", - Description: "Proxy to mod3's GET /v1/sessions. Returns the live " + - "mod3 session roster (sessions, voice_pool, voice_holders, " + - "serializer policy). No kernel-side filtering — mod3 is source " + - "of truth for per-channel state.", + Description: "List channel-participant sessions via the kernel's " + + "/v1/channel-sessions endpoint. Returns a merged {kernel, mod3} " + + "block: kernel identity records + mod3's live per-channel state " + + "(voice_pool, voice_holders, serializer policy).", }, withToolObserver(m, "mod3_list_sessions", m.toolMod3ListSessions)) } @@ -208,6 +213,10 @@ type mod3RegisterSessionInput struct { PreferredVoice string `json:"preferred_voice,omitempty"` PreferredOutputDevice string `json:"preferred_output_device,omitempty"` Priority int `json:"priority,omitempty"` + // Kinds / Metadata are the channel-provider RFC fields that flow + // through to mod3 unchanged. See cogos_session_register primitive. + Kinds []string `json:"kinds,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type mod3DeregisterSessionInput struct { @@ -310,42 +319,90 @@ func (m *MCPServer) toolMod3Status(ctx context.Context, req *mcp.CallToolRequest return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, "/health", nil) } +// toolMod3RegisterSession routes through the kernel's shared +// RegisterChannelSession so session_id minting happens in exactly one place +// (ADR-082 Wave 3.5). The previous Wave 3 implementation called mod3's +// /v1/sessions/register directly, which bypassed Wave 2's kernel-owned +// minting authority; that path is now gone. func (m *MCPServer) toolMod3RegisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3RegisterSessionInput) (*mcp.CallToolResult, any, error) { if in.ParticipantID == "" { return textResult("participant_id is required") } - body := map[string]any{ - "participant_id": in.ParticipantID, - } - if in.SessionID != "" { - body["session_id"] = in.SessionID + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") + } + resp, ferr := m.channelSessionBackend.RegisterChannelSession(ctx, channelSessionRegisterRequest{ + SessionID: in.SessionID, + ParticipantID: in.ParticipantID, + ParticipantType: in.ParticipantType, + PreferredVoice: in.PreferredVoice, + PreferredOutputDevice: in.PreferredOutputDevice, + Priority: in.Priority, + Kinds: in.Kinds, + Metadata: in.Metadata, + }) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) + } + return marshalResult(resp) +} + +func (m *MCPServer) toolMod3DeregisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3DeregisterSessionInput) (*mcp.CallToolResult, any, error) { + if in.SessionID == "" { + return textResult("session_id is required") } - if in.ParticipantType != "" { - body["participant_type"] = in.ParticipantType + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") } - if in.PreferredVoice != "" { - body["preferred_voice"] = in.PreferredVoice + mod3Resp, status, ferr := m.channelSessionBackend.DeregisterChannelSession(ctx, in.SessionID) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) } - if in.PreferredOutputDevice != "" { - body["preferred_output_device"] = in.PreferredOutputDevice + // Parse mod3's JSON body; surface mod3's non-2xx bodies intact as + // tool errors. The HTTP handler passes these through verbatim; the + // MCP tool wraps them so the caller sees the mod3 body text. + var parsed any + if len(mod3Resp) > 0 { + if jsonErr := json.Unmarshal(mod3Resp, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(mod3Resp)} + } } - if in.Priority != 0 { - body["priority"] = in.Priority + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %v", status, parsed)) } - raw, _ := json.Marshal(body) - return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, "/v1/sessions/register", bytes.NewReader(raw)) + return marshalResult(parsed) } -func (m *MCPServer) toolMod3DeregisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3DeregisterSessionInput) (*mcp.CallToolResult, any, error) { - if in.SessionID == "" { - return textResult("session_id is required") +func (m *MCPServer) toolMod3ListSessions(ctx context.Context, req *mcp.CallToolRequest, in mod3ListSessionsInput) (*mcp.CallToolResult, any, error) { + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") } - return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, - "/v1/sessions/"+url.PathEscape(in.SessionID)+"/deregister", nil) + resp, _, ferr := m.channelSessionBackend.ListChannelSessions(ctx) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) + } + return marshalResult(resp) } -func (m *MCPServer) toolMod3ListSessions(ctx context.Context, req *mcp.CallToolRequest, in mod3ListSessionsInput) (*mcp.CallToolResult, any, error) { - return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, "/v1/sessions", nil) +// channelSessionForwardErrorText renders a *channelSessionForwardError into +// the "mod3 unreachable" / "mod3 returned N: body" shape the legacy MCP +// tool paths used, keeping error surfaces stable for callers that previously +// matched on those strings. +func channelSessionForwardErrorText(ferr *channelSessionForwardError) string { + switch ferr.Kind { + case "mod3_unreachable": + return ferr.Message + case "mod3_rejected": + var parsed any + if len(ferr.Mod3Body) > 0 { + if jsonErr := json.Unmarshal(ferr.Mod3Body, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(ferr.Mod3Body)} + } + } + return fmt.Sprintf("mod3 returned %d: %v", ferr.HTTPStatus, parsed) + default: + return ferr.Message + } } // ─── HTTP forwarder primitives ─────────────────────────────────────────────── diff --git a/internal/engine/mcp_modality_proxy_test.go b/internal/engine/mcp_modality_proxy_test.go index 9ad0fd6..43cafef 100644 --- a/internal/engine/mcp_modality_proxy_test.go +++ b/internal/engine/mcp_modality_proxy_test.go @@ -185,16 +185,43 @@ func (fm *fakeMod3Proxy) last() capturedProxyRequest { } // newProxyMCP builds a minimal MCPServer whose proxy points at fm, with -// playback fully disabled so tests don't touch the audio stack. +// playback fully disabled so tests don't touch the audio stack. A live +// Server is wired in as the channel-session backend so the session-family +// MCP tools (register/deregister/list) flow through the kernel's shared +// minting logic — matching production (ADR-082 Wave 3.5). Synthesis / +// control tools continue calling mod3 directly via the proxy. func newProxyMCP(t *testing.T, fm *fakeMod3Proxy) *MCPServer { t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + srv := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } m := &MCPServer{ - cfg: &Config{Mod3URL: fm.srv.URL}, - mod3Proxy: &modalityProxy{disablePlayback: true}, + cfg: cfg, + mod3Proxy: &modalityProxy{disablePlayback: true}, + channelSessionBackend: srv, } return m } +// newProxyMCPWithServer is like newProxyMCP but returns the wired-in Server +// so tests that want to assert on the kernel-side registry state can do so. +func newProxyMCPWithServer(t *testing.T, fm *fakeMod3Proxy) (*MCPServer, *Server) { + t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + srv := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } + m := &MCPServer{ + cfg: cfg, + mod3Proxy: &modalityProxy{disablePlayback: true}, + channelSessionBackend: srv, + } + return m, srv +} + // decodeToolText parses the JSON text content of a CallToolResult. func decodeToolText(t *testing.T, res *mcp.CallToolResult) map[string]any { t.Helper() @@ -487,9 +514,13 @@ func TestMod3Status_Mod3DownClean(t *testing.T) { } } -func TestMod3RegisterSession_ForwardsBody(t *testing.T) { +// TestMod3RegisterSession_RoutesThroughKernel verifies that the MCP tool +// goes through the kernel's shared RegisterChannelSession backend — not +// directly to mod3 — and that the response is the merged {kernel, mod3} +// block produced by that shared code path (ADR-082 Wave 3.5). +func TestMod3RegisterSession_RoutesThroughKernel(t *testing.T) { fm := newFakeMod3Proxy(t) - m := newProxyMCP(t, fm) + m, srv := newProxyMCPWithServer(t, fm) res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ SessionID: "cs-regtest", @@ -503,22 +534,130 @@ func TestMod3RegisterSession_ForwardsBody(t *testing.T) { t.Fatalf("unexpected error: %v", err) } if res.IsError { - t.Fatal("expected success") + t.Fatalf("expected success, got IsError: %v", res.Content) } + // The forward to mod3 must have happened via the kernel's shared + // method — verify the request body mod3 saw carries the caller- + // supplied session_id and participant_id unchanged. cap := fm.last() + if cap.Path != "/v1/sessions/register" { + t.Fatalf("expected mod3 /v1/sessions/register, got %q", cap.Path) + } var body map[string]any if err := json.Unmarshal(cap.Body, &body); err != nil { t.Fatalf("decode: %v", err) } if body["session_id"] != "cs-regtest" { - t.Fatalf("bad session_id: %v", body["session_id"]) + t.Fatalf("bad session_id forwarded: %v", body["session_id"]) } if body["participant_id"] != "cog" { - t.Fatalf("bad participant_id: %v", body["participant_id"]) + t.Fatalf("bad participant_id forwarded: %v", body["participant_id"]) } if body["preferred_voice"] != "bm_lewis" { - t.Fatalf("bad preferred_voice: %v", body["preferred_voice"]) + t.Fatalf("bad preferred_voice forwarded: %v", body["preferred_voice"]) + } + + // The merged {kernel, mod3} shape must land — verify the kernel's + // identity record is present in the response. + out := decodeToolText(t, res) + kernel, ok := out["kernel"].(map[string]any) + if !ok { + t.Fatalf("expected kernel block in response, got %v", out) + } + if kernel["session_id"] != "cs-regtest" { + t.Fatalf("expected kernel.session_id=cs-regtest, got %v", kernel["session_id"]) + } + if kernel["id_source"] != "caller" { + t.Fatalf("expected id_source=caller, got %v", kernel["id_source"]) + } + + // Kernel registry must hold the committed record — proves we went + // through the shared backend and not straight to mod3. + if _, held := srv.channelSessionRegistry.Get("cs-regtest"); !held { + t.Fatal("expected kernel registry to hold record after register") + } +} + +// TestMod3RegisterSession_KernelMintsWhenAbsent exercises the minting +// path — caller omits session_id, kernel mints one, mod3 receives it. +func TestMod3RegisterSession_KernelMintsWhenAbsent(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, srv := newProxyMCPWithServer(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + ParticipantID: "cog", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError: %v", res.Content) + } + + out := decodeToolText(t, res) + kernel, _ := out["kernel"].(map[string]any) + sid, _ := kernel["session_id"].(string) + if !strings.HasPrefix(sid, "cs-") { + t.Fatalf("expected minted cs-* session_id, got %q", sid) + } + if kernel["id_source"] != "minted" { + t.Fatalf("expected id_source=minted, got %v", kernel["id_source"]) + } + + // Mod3 must have seen the kernel-minted ID verbatim. + cap := fm.last() + var body map[string]any + _ = json.Unmarshal(cap.Body, &body) + if body["session_id"] != sid { + t.Fatalf("mod3 got session_id=%v, expected %q", body["session_id"], sid) + } + + if _, held := srv.channelSessionRegistry.Get(sid); !held { + t.Fatalf("expected kernel registry to hold minted record %q", sid) + } +} + +// TestMod3RegisterSession_ForwardsKindsAndMetadata verifies the Wave 3.5 +// schema alignment with the channel-provider RFC — `kinds` and `metadata` +// flow through the kernel's register endpoint and land in mod3's request +// body unchanged. +func TestMod3RegisterSession_ForwardsKindsAndMetadata(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, _ := newProxyMCPWithServer(t, fm) + + _, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + SessionID: "cs-kinds", + ParticipantID: "mod3-provider", + ParticipantType: "provider", + Kinds: []string{"audio"}, + Metadata: map[string]any{ + "provider_id": "mod3-local", + "build": "0.5.0", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cap := fm.last() + var body map[string]any + if err := json.Unmarshal(cap.Body, &body); err != nil { + t.Fatalf("decode: %v", err) + } + kinds, ok := body["kinds"].([]any) + if !ok || len(kinds) != 1 || kinds[0] != "audio" { + t.Fatalf("expected kinds=[\"audio\"] forwarded, got %v", body["kinds"]) + } + md, ok := body["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected metadata object forwarded, got %v (%T)", body["metadata"], body["metadata"]) + } + if md["provider_id"] != "mod3-local" { + t.Fatalf("expected metadata.provider_id=mod3-local, got %v", md["provider_id"]) + } + if body["participant_type"] != "provider" { + t.Fatalf("expected participant_type=provider, got %v", body["participant_type"]) } } @@ -536,9 +675,38 @@ func TestMod3RegisterSession_RejectsWithoutParticipant(t *testing.T) { } } -func TestMod3DeregisterSession_PathEscape(t *testing.T) { +func TestMod3RegisterSession_NoBackendReturnsCleanError(t *testing.T) { + // An MCPServer with no channel-session backend wired in must surface + // a clean "not configured" error rather than a nil deref — important + // because NewMCPServer (used by tests that only care about memory + // tools) doesn't wire the backend. + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + ParticipantID: "cog", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true when backend is nil") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "backend not configured") { + t.Fatalf("expected backend-not-configured message, got %q", tc.Text) + } +} + +// TestMod3DeregisterSession_RoutesThroughKernel verifies the deregister +// tool forwards via the kernel's shared path and the kernel drops its +// identity record on success. +func TestMod3DeregisterSession_RoutesThroughKernel(t *testing.T) { fm := newFakeMod3Proxy(t) - m := newProxyMCP(t, fm) + m, srv := newProxyMCPWithServer(t, fm) + + // Seed the kernel registry so we can see it get dropped. + srv.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-drop", ParticipantID: "cog", + }) res, _, err := m.toolMod3DeregisterSession(context.Background(), nil, mod3DeregisterSessionInput{ SessionID: "cs-drop", @@ -547,11 +715,14 @@ func TestMod3DeregisterSession_PathEscape(t *testing.T) { t.Fatalf("unexpected error: %v", err) } if res.IsError { - t.Fatal("expected success") + t.Fatalf("expected success, got IsError: %v", res.Content) } cap := fm.last() if cap.Path != "/v1/sessions/cs-drop/deregister" { - t.Fatalf("expected /v1/sessions/cs-drop/deregister, got %q", cap.Path) + t.Fatalf("expected /v1/sessions/cs-drop/deregister at mod3, got %q", cap.Path) + } + if _, held := srv.channelSessionRegistry.Get("cs-drop"); held { + t.Fatal("expected kernel registry to drop record after successful deregister") } } @@ -567,21 +738,35 @@ func TestMod3DeregisterSession_RequiresID(t *testing.T) { } } -func TestMod3ListSessions_ReturnsRaw(t *testing.T) { +// TestMod3ListSessions_RoutesThroughKernel verifies list merges the +// kernel snapshot with mod3's per-channel state (the new Wave 3.5 +// merged shape, not mod3's raw payload). +func TestMod3ListSessions_RoutesThroughKernel(t *testing.T) { fm := newFakeMod3Proxy(t) - m := newProxyMCP(t, fm) + m, srv := newProxyMCPWithServer(t, fm) + + srv.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-seed", ParticipantID: "cog", + }) res, _, err := m.toolMod3ListSessions(context.Background(), nil, mod3ListSessionsInput{}) if err != nil { t.Fatalf("unexpected error: %v", err) } if res.IsError { - t.Fatal("expected success") + t.Fatalf("expected success, got IsError: %v", res.Content) } out := decodeToolText(t, res) - vp, _ := out["voice_pool"].([]any) - if len(vp) != 2 { - t.Fatalf("expected 2 voices in pool, got %d", len(vp)) + kernel, ok := out["kernel"].([]any) + if !ok { + t.Fatalf("expected kernel array in merged response, got %v", out) + } + if len(kernel) != 1 { + t.Fatalf("expected 1 kernel record, got %d", len(kernel)) + } + // Mod3 block must be present (from the fake list handler). + if _, present := out["mod3"]; !present { + t.Fatal("expected mod3 block in merged response") } } diff --git a/internal/engine/mcp_server.go b/internal/engine/mcp_server.go index 9970294..d5c33a9 100644 --- a/internal/engine/mcp_server.go +++ b/internal/engine/mcp_server.go @@ -52,6 +52,25 @@ type MCPServer struct { // integration). Lazily initialised — tests can pre-seed it with a // custom HTTP client + stub player to avoid real network + audio. mod3Proxy *modalityProxy + + // channelSessionBackend is the kernel Server whose + // RegisterChannelSession / DeregisterChannelSession / + // ListChannelSessions methods own session-ID authority (ADR-082 + // Wave 2). Wave 3.5 routes the three mod3 session-family MCP tools + // through these shared methods so minting happens in exactly one + // place. Nil when the MCP server is built outside a live kernel — + // the tools return a clean "not configured" error in that case. + channelSessionBackend channelSessionBackend +} + +// channelSessionBackend is the narrow surface the mod3 session-family MCP +// tools need to forward through the kernel's shared minting/forwarding +// logic. Interface rather than concrete *Server so tests can inject a fake +// without building a whole Server. +type channelSessionBackend interface { + RegisterChannelSession(ctx context.Context, req channelSessionRegisterRequest) (*channelSessionResponse, *channelSessionForwardError) + DeregisterChannelSession(ctx context.Context, sessionID string) (json.RawMessage, int, *channelSessionForwardError) + ListChannelSessions(ctx context.Context) (*channelSessionListResponse, int, *channelSessionForwardError) } // NewMCPServer creates and configures the MCP server with all stage-1 tools. @@ -98,6 +117,15 @@ func (m *MCPServer) SetAgentController(ctrl AgentController) { m.agentController = ctrl } +// SetChannelSessionBackend wires the kernel-owned channel-session minting +// logic into the MCP server so the mod3_register_session / _deregister / +// _list tools call through the same shared methods the HTTP surface uses +// (ADR-082 Wave 3.5). Safe to pass nil; the tools surface a clean "not +// configured" error in that case. +func (m *MCPServer) SetChannelSessionBackend(b channelSessionBackend) { + m.channelSessionBackend = b +} + // Handler returns the http.Handler for mounting at /mcp. func (m *MCPServer) Handler() http.Handler { return m.handler diff --git a/internal/engine/serve_mcp.go b/internal/engine/serve_mcp.go index 397398b..148a980 100644 --- a/internal/engine/serve_mcp.go +++ b/internal/engine/serve_mcp.go @@ -17,6 +17,11 @@ func (s *Server) registerMCPRoutes(mux *http.ServeMux) { // these are nil — NewMCPServer (used by tests that only care about // memory tools) doesn't call this, which is fine. mcpSrv.SetSessionsBackend(s.busSessions, s.sessionRegistry, s.handoffRegistry) + // ADR-082 Wave 3.5: route the mod3 session-family MCP tools through + // the kernel's shared channel-session methods so session-ID minting + // happens in exactly one place (this Server). Handlers dispatching + // to mod3 directly was the Wave 3 divergence this removes. + mcpSrv.SetChannelSessionBackend(s) s.mcpServer = mcpSrv h := mcpSrv.Handler() mux.Handle("GET /mcp", h) diff --git a/internal/engine/serve_sessions_channel.go b/internal/engine/serve_sessions_channel.go index 2ec8f13..fadaf96 100644 --- a/internal/engine/serve_sessions_channel.go +++ b/internal/engine/serve_sessions_channel.go @@ -68,6 +68,17 @@ type ChannelSessionRecord struct { PreferredOutputDevice string `json:"preferred_output_device,omitempty"` Priority int `json:"priority,omitempty"` + // Kinds mirrors the channel-provider RFC's `kinds` metadata field + // (e.g. ["audio"] for a mod3-provider registration). Mod3 ignores it + // today; kept on the kernel record for downstream consumers that want + // to filter by capability. + Kinds []string `json:"kinds,omitempty"` + + // Metadata is an opaque pass-through map the RFC describes as + // "provider_id/kinds in the metadata" — preserved verbatim on the + // kernel record and forwarded to mod3 (which ignores unknown fields). + Metadata map[string]any `json:"metadata,omitempty"` + RegisteredAt time.Time `json:"registered_at"` LastSeen time.Time `json:"last_seen"` @@ -165,13 +176,22 @@ func (s *Server) registerChannelSessionRoutes(mux *http.ServeMux) { // channelSessionRegisterRequest is the kernel-facing request body. Shape // mirrors mod3's SessionRegisterRequest (see mod3/http_api.py line ~278) plus // an optional session_id — when omitted, the kernel mints one. +// +// Wave 3.5 schema alignment with the channel-provider RFC's +// `cogos_session_register` primitive: `kinds` (array of adapter kinds the +// registrant participates in, e.g. ["audio"]) and `metadata` (opaque +// pass-through blob — RFC calls for provider_id/kinds to live in metadata) +// are both optional and flow through to mod3 unchanged (mod3 ignores +// unknown fields). type channelSessionRegisterRequest struct { - SessionID string `json:"session_id,omitempty"` - ParticipantID string `json:"participant_id"` - ParticipantType string `json:"participant_type,omitempty"` - PreferredVoice string `json:"preferred_voice,omitempty"` - PreferredOutputDevice string `json:"preferred_output_device,omitempty"` - Priority int `json:"priority,omitempty"` + SessionID string `json:"session_id,omitempty"` + ParticipantID string `json:"participant_id"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` + Kinds []string `json:"kinds,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // channelSessionResponse is the merged shape returned from the kernel. The @@ -192,18 +212,52 @@ type channelSessionListResponse struct { Mod3 json.RawMessage `json:"mod3,omitempty"` } -// ─── POST /v1/channel-sessions/register ────────────────────────────────────── +// ─── shared register/deregister/list logic (Wave 3.5) ──────────────────────── +// +// These methods are the single place session-ID minting, kernel registry +// commits, and mod3 forwarding happen. Both the HTTP handlers below and the +// mod3_register_session / mod3_deregister_session / mod3_list_sessions MCP +// tools (see mcp_modality_proxy.go) call through here so session-ID +// authority stays centralized — nobody reaches mod3's /v1/sessions/* surface +// directly except this one codepath. + +// channelSessionForwardError classifies a failure in RegisterChannelSession / +// DeregisterChannelSession / ListChannelSessions so the caller (HTTP handler +// or MCP tool) can surface the right status/message without re-parsing. +type channelSessionForwardError struct { + // Kind: "invalid_request" | "mod3_unreachable" | "mod3_rejected" + Kind string + // HTTPStatus is the status the caller should emit (400 for + // invalid_request, 502 for mod3_unreachable, mod3's own status for + // mod3_rejected). + HTTPStatus int + // Message is a human-readable description, safe to surface to callers. + Message string + // Mod3Body carries the raw JSON body mod3 returned when Kind == + // "mod3_rejected", so the HTTP handler can pass it through verbatim. + Mod3Body json.RawMessage +} -func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Request) { - var req channelSessionRegisterRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSONError(w, http.StatusBadRequest, "invalid_request", "body must be JSON") - return - } +func (e *channelSessionForwardError) Error() string { return e.Message } + +// RegisterChannelSession is the Wave 2+3.5 shared entry point for channel- +// session registration. It mints a session_id when absent, forwards to mod3, +// and commits the kernel-side identity record on success. +// +// Callers: +// - HTTP: POST /v1/channel-sessions/register (handleChannelSessionRegister) +// - MCP: mod3_register_session tool (toolMod3RegisterSession) +// +// Returning a (resp, nil) pair means the caller should surface the merged +// response with 200 OK. Returning a non-nil *channelSessionForwardError +// tells the caller which status + body shape to surface. +func (s *Server) RegisterChannelSession(ctx context.Context, req channelSessionRegisterRequest) (*channelSessionResponse, *channelSessionForwardError) { if req.ParticipantID == "" { - writeJSONError(w, http.StatusBadRequest, "invalid_request", - "participant_id is required") - return + return nil, &channelSessionForwardError{ + Kind: "invalid_request", + HTTPStatus: http.StatusBadRequest, + Message: "participant_id is required", + } } // Mint when absent — kernel is the session-ID authority (ADR-082). @@ -227,6 +281,8 @@ func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Req PreferredVoice: req.PreferredVoice, PreferredOutputDevice: req.PreferredOutputDevice, Priority: req.Priority, + Kinds: req.Kinds, + Metadata: req.Metadata, RegisteredAt: now, LastSeen: now, IDSource: idSource, @@ -234,7 +290,9 @@ func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Req // Forward to mod3 with the kernel-issued session_id. Mod3's body is the // same shape modulo the optional session_id field (mod3 requires it; we - // always supply one). + // always supply one). `kinds` and `metadata` are RFC-level fields mod3 + // currently ignores; we still forward them so mod3 can start consuming + // them without a kernel change when it's ready. forwardBody := map[string]any{ "session_id": req.SessionID, "participant_id": req.ParticipantID, @@ -243,27 +301,35 @@ func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Req "preferred_output_device": req.PreferredOutputDevice, "priority": req.Priority, } + if len(req.Kinds) > 0 { + forwardBody["kinds"] = req.Kinds + } + if len(req.Metadata) > 0 { + forwardBody["metadata"] = req.Metadata + } body, _ := json.Marshal(forwardBody) - mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodPost, + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodPost, "/v1/sessions/register", bytes.NewReader(body)) if err != nil { slog.Warn("channel-sessions: forward to mod3 failed", "session_id", req.SessionID, "err", err) - writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", - fmt.Sprintf("mod3 unreachable: %v", err)) - return + return nil, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } } - // Preserve non-success status bodies so the caller can surface mod3's - // diagnostic text. Any 4xx/5xx from mod3 becomes the response status - // with the raw body attached — the kernel does NOT write its identity - // record in that case (mod3 rejected the registration). if status < 200 || status >= 300 { slog.Warn("channel-sessions: mod3 returned non-2xx", "session_id", req.SessionID, "status", status) - writeJSONPassThrough(w, status, mod3Resp) - return + return nil, &channelSessionForwardError{ + Kind: "mod3_rejected", + HTTPStatus: status, + Message: fmt.Sprintf("mod3 returned %d", status), + Mod3Body: mod3Resp, + } } // Mod3 accepted — commit the kernel-side identity record. @@ -272,68 +338,145 @@ func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Req "session_id", req.SessionID, "participant_id", req.ParticipantID, "id_source", idSource) - // Return merged response. The `mod3` field is the raw body from mod3 so - // callers get the exact {assigned_voice, voice_conflict, output_device, - // queue_depth, ...} shape mod3 emits. - resp := channelSessionResponse{ - Kernel: &record, - Mod3: mod3Resp, - } - writeJSONResp(w, http.StatusOK, resp) + return &channelSessionResponse{Kernel: &record, Mod3: mod3Resp}, nil } -// ─── POST /v1/channel-sessions/{id}/deregister ─────────────────────────────── - -func (s *Server) handleChannelSessionDeregister(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - if id == "" { - writeJSONError(w, http.StatusBadRequest, "invalid_request", - "session_id required in path") - return +// DeregisterChannelSession is the shared entry point for deregistration. +// Forwards to mod3 and drops the kernel registry row on any non-5xx mod3 +// response (including 404 — "mod3 forgot" is equivalent to "kernel should +// forget too"). On transport failure the kernel keeps the record so the +// caller can retry. +// +// Returns the raw mod3 body + status on success. Callers should surface +// both verbatim (writeJSONPassThrough on the HTTP side; the MCP tool +// wraps it as a JSON result). +func (s *Server) DeregisterChannelSession(ctx context.Context, sessionID string) (json.RawMessage, int, *channelSessionForwardError) { + if sessionID == "" { + return nil, 0, &channelSessionForwardError{ + Kind: "invalid_request", + HTTPStatus: http.StatusBadRequest, + Message: "session_id is required", + } } - mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodPost, - "/v1/sessions/"+id+"/deregister", nil) + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodPost, + "/v1/sessions/"+sessionID+"/deregister", nil) if err != nil { slog.Warn("channel-sessions: deregister forward failed", - "session_id", id, "err", err) - writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", - fmt.Sprintf("mod3 unreachable: %v", err)) - return + "session_id", sessionID, "err", err) + return nil, 0, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } } // Kernel drops its identity record whenever mod3 successfully // acknowledges, including 404 ("never registered" is equivalent to // "not tracked"; clean slate in kernel matches clean slate in mod3). if status >= 200 && status < 500 { - s.channelSessionRegistry.Delete(id) + s.channelSessionRegistry.Delete(sessionID) } - writeJSONPassThrough(w, status, mod3Resp) + return mod3Resp, status, nil } -// ─── GET /v1/channel-sessions ──────────────────────────────────────────────── - -func (s *Server) handleChannelSessionList(w http.ResponseWriter, r *http.Request) { - mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodGet, +// ListChannelSessions is the shared entry point for the merged list query. +// Returns the kernel snapshot plus mod3's raw `GET /v1/sessions` body. +// +// On mod3 transport failure: error (502). On mod3 non-2xx: returns the +// raw mod3 body with its status attached so the caller can surface intact. +func (s *Server) ListChannelSessions(ctx context.Context) (*channelSessionListResponse, int, *channelSessionForwardError) { + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodGet, "/v1/sessions", nil) if err != nil { slog.Warn("channel-sessions: list forward failed", "err", err) - writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", - fmt.Sprintf("mod3 unreachable: %v", err)) - return + return nil, 0, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } } if status < 200 || status >= 300 { - writeJSONPassThrough(w, status, mod3Resp) - return + return nil, status, &channelSessionForwardError{ + Kind: "mod3_rejected", + HTTPStatus: status, + Message: fmt.Sprintf("mod3 returned %d", status), + Mod3Body: mod3Resp, + } } - resp := channelSessionListResponse{ + return &channelSessionListResponse{ Kernel: s.channelSessionRegistry.Snapshot(), Mod3: mod3Resp, + }, http.StatusOK, nil +} + +// ─── POST /v1/channel-sessions/register ────────────────────────────────────── + +func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Request) { + var req channelSessionRegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid_request", "body must be JSON") + return + } + + resp, ferr := s.RegisterChannelSession(r.Context(), req) + if ferr != nil { + switch ferr.Kind { + case "invalid_request": + writeJSONError(w, ferr.HTTPStatus, "invalid_request", ferr.Message) + case "mod3_unreachable": + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + case "mod3_rejected": + writeJSONPassThrough(w, ferr.HTTPStatus, ferr.Mod3Body) + default: + writeJSONError(w, http.StatusInternalServerError, "internal", ferr.Message) + } + return } writeJSONResp(w, http.StatusOK, resp) } +// ─── POST /v1/channel-sessions/{id}/deregister ─────────────────────────────── + +func (s *Server) handleChannelSessionDeregister(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "session_id required in path") + return + } + + mod3Resp, status, ferr := s.DeregisterChannelSession(r.Context(), id) + if ferr != nil { + if ferr.Kind == "mod3_unreachable" { + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + return + } + writeJSONError(w, ferr.HTTPStatus, ferr.Kind, ferr.Message) + return + } + writeJSONPassThrough(w, status, mod3Resp) +} + +// ─── GET /v1/channel-sessions ──────────────────────────────────────────────── + +func (s *Server) handleChannelSessionList(w http.ResponseWriter, r *http.Request) { + resp, status, ferr := s.ListChannelSessions(r.Context()) + if ferr != nil { + switch ferr.Kind { + case "mod3_unreachable": + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + case "mod3_rejected": + writeJSONPassThrough(w, ferr.HTTPStatus, ferr.Mod3Body) + default: + writeJSONError(w, http.StatusInternalServerError, "internal", ferr.Message) + } + return + } + writeJSONResp(w, status, resp) +} + // ─── GET /v1/channel-sessions/{id} ─────────────────────────────────────────── func (s *Server) handleChannelSessionGet(w http.ResponseWriter, r *http.Request) { diff --git a/internal/engine/serve_sessions_channel_test.go b/internal/engine/serve_sessions_channel_test.go index 45761cd..1d9f41e 100644 --- a/internal/engine/serve_sessions_channel_test.go +++ b/internal/engine/serve_sessions_channel_test.go @@ -361,6 +361,67 @@ func TestChannelSessionRegister_PropagatesMod3Error(t *testing.T) { } } +// TestChannelSessionRegister_ForwardsKindsAndMetadata verifies the Wave 3.5 +// schema alignment: the optional `kinds` array and `metadata` object from +// the channel-provider RFC's cogos_session_register primitive flow through +// Wave 2's register endpoint and land in mod3's request body unchanged. +func TestChannelSessionRegister_ForwardsKindsAndMetadata(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "cs-kinds-http", + "participant_id": "mod3-provider", + "participant_type": "provider", + "kinds": []string{"audio"}, + "metadata": map[string]any{ + "provider_id": "mod3-local", + "build": "0.5.0", + }, + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + cap := fm.lastCaptured() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + kinds, ok := forwarded["kinds"].([]any) + if !ok || len(kinds) != 1 || kinds[0] != "audio" { + t.Fatalf("expected mod3 to receive kinds=[\"audio\"], got %v", forwarded["kinds"]) + } + md, ok := forwarded["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected mod3 to receive metadata object, got %v (%T)", + forwarded["metadata"], forwarded["metadata"]) + } + if md["provider_id"] != "mod3-local" { + t.Fatalf("expected metadata.provider_id=mod3-local forwarded, got %v", md["provider_id"]) + } + + // Kernel record must retain the RFC fields so downstream consumers + // can filter by capability even when mod3 ignores them. + rec, ok := s.channelSessionRegistry.Get("cs-kinds-http") + if !ok { + t.Fatal("expected kernel registry to hold record") + } + if len(rec.Kinds) != 1 || rec.Kinds[0] != "audio" { + t.Fatalf("expected record.Kinds=[audio], got %v", rec.Kinds) + } + if rec.Metadata["provider_id"] != "mod3-local" { + t.Fatalf("expected record.Metadata.provider_id=mod3-local, got %v", rec.Metadata) + } +} + func TestChannelSessionRegister_RequiresParticipantID(t *testing.T) { fm := newFakeMod3(t) _, front := newChannelServer(t, fm) From a31396ec5e2c01dbefabef724fa8a95ba9eb4e3c Mon Sep 17 00:00:00 2001 From: Test User Date: Thu, 23 Apr 2026 14:46:29 -0400 Subject: [PATCH 4/4] feat(engine): mod3_speak skips afplay when dashboard subscriber exists MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wave 4.3 kernel side — close the double-play window between server-side afplay and the dashboard's /ws/audio/{session_id} WebSocket (mod3 side, committed separately). Before spawning the platform player on a session-tagged speak, the kernel asks mod3 whether that session has a live subscriber. If yes, skip afplay entirely — mod3 is already pushing the WAV to the dashboard — and report playback_status=routed_ws. modalityProxy grows an injectable subscriberCheck function field. The default implementation is an HTTP GET against {Mod3URL}/v1/sessions/{id}/subscribers with a 1.5s timeout. Transport errors surface as subscriber_check_error in the tool result AND fall through to the normal afplay path so a flaky mod3 never orphans audio — a key safety property carried over from the Wave 3 fire-and-forget playback design. session_id="" always bypasses the check. CLI invocations of mod3_speak (no session_id) keep the exact afplay behavior as before, so this is a purely additive change scoped to kernel-minted sessions. Tests (mcp_modality_proxy_test.go, +237 lines): - TestMod3Speak_NoSessionAlwaysSpawnsPlayer — session_id="" → stub player invoked once, playback_status=played - TestMod3Speak_SessionWithSubscriberSkipsPlayer — subscriberCheck returns true → stub player NOT invoked, playback_status=routed_ws - TestMod3Speak_SessionWithoutSubscriberSpawnsPlayer — subscriberCheck returns false → stub player invoked once, playback_status=played - TestMod3Speak_SubscriberCheckErrorFallsBackToPlayer — transient probe error → stub player invoked once, subscriber_check_error surfaced in result, playback_status=played - TestCheckSessionSubscriber_DefaultImplementationHitsMod3 — default HTTP path parses cs-yes/cs-no responses correctly - TestCheckSessionSubscriber_Mod3UnreachableReturnsError — ECONNREFUSED surfaces as a non-nil error (so toolMod3Speak falls back) Replaced writeStubPlayer's mutex+int32 pair with a closure-over-getter pattern. Cleaner signature, no polling goroutine, fewer moving parts. Full engine test suite passes; golangci-lint clean. Branch: feat/kernel-wave4, stacked on feat/kernel-session-forwarder (the branch that currently carries Wave 3.5 session-id minting and the mcp_modality_proxy.go baseline this commit extends). --- internal/engine/mcp_modality_proxy.go | 64 ++++++ internal/engine/mcp_modality_proxy_test.go | 237 +++++++++++++++++++++ 2 files changed, 301 insertions(+) diff --git a/internal/engine/mcp_modality_proxy.go b/internal/engine/mcp_modality_proxy.go index 1a97e9d..1e00e03 100644 --- a/internal/engine/mcp_modality_proxy.go +++ b/internal/engine/mcp_modality_proxy.go @@ -83,6 +83,14 @@ type modalityProxy struct { // this when they want to assert "we got the bytes" without spawning a // real player. Production code leaves it false. disablePlayback bool + + // subscriberCheck, when non-nil, is consulted before spawning the local + // player in mod3_speak. If it returns (true, nil) the kernel skips + // afplay — mod3's /ws/audio/{session_id} WebSocket is already pushing + // the WAV to a dashboard subscriber (Wave 4.3). Errors and false return + // values fall through to the normal playback path. Nil means "use the + // default HTTP implementation" (GET {Mod3URL}/v1/sessions/{id}/subscribers). + subscriberCheck func(ctx context.Context, sessionID string) (bool, error) } // defaultMod3ProxyTimeout is the per-request timeout for mod3 forwards. 30s @@ -279,6 +287,25 @@ func (m *MCPServer) toolMod3Speak(ctx context.Context, req *mcp.CallToolRequest, return marshalResult(result) } + // Wave 4.3 — if the session has a live dashboard WebSocket subscriber, + // mod3 is already routing the WAV there. Skip the kernel's local player + // so we don't double-play. The check is scoped to sessions that were + // actually named on the speak call; session_id="" always falls through + // to the normal afplay path so CLI invocations keep working. + if in.SessionID != "" { + subscribed, checkErr := m.checkSessionSubscriber(ctx, in.SessionID) + if checkErr != nil { + // Log-worthy but not fatal — fall back to local playback. + slog.Debug("mod3 proxy: subscriber check failed", + "session_id", in.SessionID, "err", checkErr) + result["subscriber_check_error"] = checkErr.Error() + } + if subscribed { + result["playback_status"] = "routed_ws" + return marshalResult(result) + } + } + playErr := p.playAudio(audio, in.Blocking) switch { case playErr == nil && in.Blocking: @@ -292,6 +319,43 @@ func (m *MCPServer) toolMod3Speak(ctx context.Context, req *mcp.CallToolRequest, return marshalResult(result) } +// checkSessionSubscriber asks mod3 whether ``sessionID`` has at least one +// active dashboard WebSocket subscriber for audio playback. Returns +// ``(subscribed, nil)`` on success, ``(false, err)`` on transport failure. +// ``(false, nil)`` — the default when the proxy has no check configured — +// also suppresses the routing path, so legacy callers see the exact same +// afplay behavior as before. +// +// Injectable via modalityProxy.subscriberCheck for tests. The default is a +// GET against mod3's /v1/sessions/{id}/subscribers endpoint with a 1.5s +// timeout inherited from defaultMod3ProxyTimeout. +func (m *MCPServer) checkSessionSubscriber(ctx context.Context, sessionID string) (bool, error) { + p := m.getModalityProxy() + if p.subscriberCheck != nil { + return p.subscriberCheck(ctx, sessionID) + } + // Default implementation — HTTP GET. Scoped to 1.5s so a wedged mod3 + // can't block a speak for more than that; falls back to afplay on timeout. + checkCtx, cancel := context.WithTimeout(ctx, 1500*time.Millisecond) + defer cancel() + raw, _, status, err := m.proxyMod3Bytes(checkCtx, http.MethodGet, + "/v1/sessions/"+url.PathEscape(sessionID)+"/subscribers", nil, "") + if err != nil { + return false, err + } + if status < 200 || status >= 300 { + return false, fmt.Errorf("mod3 returned %d: %s", status, truncate(string(raw), 200)) + } + var body struct { + Subscribed bool `json:"subscribed"` + Count int `json:"count"` + } + if unmarshalErr := json.Unmarshal(raw, &body); unmarshalErr != nil { + return false, fmt.Errorf("decode subscribers response: %w", unmarshalErr) + } + return body.Subscribed, nil +} + func (m *MCPServer) toolMod3Stop(ctx context.Context, req *mcp.CallToolRequest, in mod3StopInput) (*mcp.CallToolResult, any, error) { path := "/v1/stop" q := url.Values{} diff --git a/internal/engine/mcp_modality_proxy_test.go b/internal/engine/mcp_modality_proxy_test.go index 43cafef..d12b702 100644 --- a/internal/engine/mcp_modality_proxy_test.go +++ b/internal/engine/mcp_modality_proxy_test.go @@ -10,6 +10,7 @@ package engine import ( "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -869,3 +870,239 @@ sleep 5 t.Fatal("playAudio(blocking=false) did not return within 2s") } } + +// ─── Wave 4.3 — subscriber-check / afplay skip ─────────────────────────────── + +// TestMod3Speak_NoSessionAlwaysSpawnsPlayer — session_id="" bypasses the +// subscriber check entirely so CLI invocations of mod3_speak still play +// audio through afplay as they always did. +func TestMod3Speak_NoSessionAlwaysSpawnsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{player: stubPath} + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "no session", + Blocking: true, // wait for stub to finish so the test can assert invocation count + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once, got %d", got) + } +} + +// TestMod3Speak_SessionWithSubscriberSkipsPlayer — when the injected +// subscriber-check returns true, the kernel skips afplay entirely and +// returns playback_status=routed_ws. The stub player must NOT be invoked. +func TestMod3Speak_SessionWithSubscriberSkipsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + if sessionID != "cs-with-sub" { + t.Errorf("unexpected session_id=%q", sessionID) + } + return true, nil + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "skip me", + SessionID: "cs-with-sub", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "routed_ws" { + t.Fatalf("expected playback_status=routed_ws, got %v", out["playback_status"]) + } + // Give any stray goroutine a moment to trip the stub — proving non-invocation. + time.Sleep(100 * time.Millisecond) + if got := count(); got != 0 { + t.Fatalf("expected stub player NOT invoked, got %d", got) + } +} + +// TestMod3Speak_SessionWithoutSubscriberSpawnsPlayer — subscriber-check +// returns false: kernel falls back to the normal afplay path and the stub +// player IS invoked. +func TestMod3Speak_SessionWithoutSubscriberSpawnsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + return false, nil + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "no subscriber", + SessionID: "cs-no-sub", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once, got %d", got) + } +} + +// TestMod3Speak_SubscriberCheckErrorFallsBackToPlayer — transient +// check error (mod3 flaky, timeout, etc.) must not orphan the audio. +// The kernel logs the error, records subscriber_check_error in the result, +// and still spawns the player so the user hears the reply. +func TestMod3Speak_SubscriberCheckErrorFallsBackToPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + return false, fmt.Errorf("mod3 probe timed out") + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "check failed", + SessionID: "cs-flaky", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success despite check error, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if checkErr, _ := out["subscriber_check_error"].(string); !strings.Contains(checkErr, "probe timed out") { + t.Fatalf("expected subscriber_check_error to surface, got %v", out["subscriber_check_error"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once on check error, got %d", got) + } +} + +// TestCheckSessionSubscriber_DefaultImplementationHitsMod3 — wire up a +// stand-alone fake HTTP server that answers the +// /v1/sessions/{id}/subscribers probe and verify the default implementation +// (no injected subscriberCheck) parses its response correctly. +func TestCheckSessionSubscriber_DefaultImplementationHitsMod3(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/sessions/cs-yes/subscribers", func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-yes", "subscribed": true, "count": 1, + }) + }) + mux.HandleFunc("GET /v1/sessions/cs-no/subscribers", func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-no", "subscribed": false, "count": 0, + }) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + m := &MCPServer{ + cfg: &Config{Mod3URL: srv.URL}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + + yes, err := m.checkSessionSubscriber(context.Background(), "cs-yes") + if err != nil { + t.Fatalf("cs-yes: %v", err) + } + if !yes { + t.Fatal("cs-yes: expected subscribed=true") + } + + no, err := m.checkSessionSubscriber(context.Background(), "cs-no") + if err != nil { + t.Fatalf("cs-no: %v", err) + } + if no { + t.Fatal("cs-no: expected subscribed=false") + } +} + +// TestCheckSessionSubscriber_Mod3UnreachableReturnsError — transport +// error (connection refused, timeout) must surface as a non-nil error so +// toolMod3Speak records subscriber_check_error and falls back to afplay. +func TestCheckSessionSubscriber_Mod3UnreachableReturnsError(t *testing.T) { + // Bind a port, close it so dials get ECONNREFUSED. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := l.Addr().String() + _ = l.Close() + + m := &MCPServer{ + cfg: &Config{Mod3URL: "http://" + addr}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + subscribed, err := m.checkSessionSubscriber(context.Background(), "cs-anything") + if err == nil { + t.Fatal("expected transport error") + } + if subscribed { + t.Fatal("expected subscribed=false on error") + } +} + +// writeStubPlayer builds a temp shell script that records each invocation +// by appending to a log file. Returns the path of the executable AND a +// getter that returns the current invocation count. The stub exits +// immediately so blocking=true still works in tests. +func writeStubPlayer(t *testing.T) (stubPath string, count func() int) { + t.Helper() + dir := t.TempDir() + logPath := filepath.Join(dir, "invocations.log") + stubPath = filepath.Join(dir, "stub-player.sh") + stubBody := `#!/bin/sh +echo "invoked" >> "` + logPath + `" +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + count = func() int { + data, err := os.ReadFile(logPath) + if err != nil { + return 0 + } + return strings.Count(string(data), "invoked\n") + } + return stubPath, count +}