diff --git a/internal/engine/config.go b/internal/engine/config.go index 5d04648..3f91d45 100644 --- a/internal/engine/config.go +++ b/internal/engine/config.go @@ -77,6 +77,16 @@ type Config struct { // at .cog/run/kernel.log.jsonl. Leave empty for the default. KernelLogPath string + // Mod3URL is the base URL (scheme + host + port) of the mod3 voice service + // that owns per-channel communication state (voice, output device, queue) + // keyed on kernel-issued session IDs. The kernel forwards channel-session + // registration to this URL; mod3 remains the per-channel state owner while + // the kernel retains identity authority (ADR-082 split). + // + // Default: http://localhost:7860. Override via `mod3_url` in kernel.yaml + // (top-level or under v3:) or via the COGOS_MOD3_URL env var. + Mod3URL string + LocalModel string localModelConfigured bool @@ -99,6 +109,7 @@ type kernelConfigSection struct { LocalModel string `yaml:"local_model"` DigestPaths map[string]string `yaml:"digest_paths"` KernelLogPath string `yaml:"kernel_log_path"` + Mod3URL string `yaml:"mod3_url"` } // kernelConfig is the on-disk YAML shape of .cog/config/kernel.yaml. @@ -137,6 +148,7 @@ func LoadConfig(workspaceRoot string, port int) (*Config, error) { ToolCallValidationEnabled: true, LocalModel: defaultOllamaModel, DigestPaths: make(map[string]string), + Mod3URL: "http://localhost:7860", } // Load from file if present. @@ -150,6 +162,12 @@ func LoadConfig(workspaceRoot string, port int) (*Config, error) { } } + // Env override for the mod3 URL. Env wins over file; flags stay flag-only + // (we don't surface `--mod3-url` in CLI; one env var + YAML is enough). + if v := os.Getenv("COGOS_MOD3_URL"); v != "" { + cfg.Mod3URL = v + } + // Flag override. if port != 0 { cfg.Port = port @@ -211,6 +229,9 @@ func applyKernelSection(cfg *Config, s kernelConfigSection) { if s.KernelLogPath != "" { cfg.KernelLogPath = s.KernelLogPath } + if s.Mod3URL != "" { + cfg.Mod3URL = s.Mod3URL + } } // findWorkspaceRoot walks up from dir until it finds a directory containing a diff --git a/internal/engine/mcp_modality_proxy.go b/internal/engine/mcp_modality_proxy.go new file mode 100644 index 0000000..1e00e03 --- /dev/null +++ b/internal/engine/mcp_modality_proxy.go @@ -0,0 +1,672 @@ +// mcp_modality_proxy.go — kernel-side MCP proxy for mod3 voice tools. +// +// Wave 3 of the mod3-kernel integration (ADR-082 + channel-provider RFC), +// consolidated in Wave 3.5 with Wave 2's session-ID authority. +// The kernel becomes the MCP front door for mod3; the previous OpenClaw +// gateway pattern in the installed binary read metrics but discarded audio +// bytes. This proxy fixes that: it forwards HTTP calls to mod3, captures the +// audio/wav payload, plays it locally via afplay/aplay (fire-and-forget by +// default), and returns mod3's metric headers (X-Mod3-*) to the MCP caller. +// +// Design locks: +// +// 1. MCP transport = HTTP proxy. Synthesis/control tool handlers here +// POST/GET against cfg.Mod3URL + "/v1/*". Mod3 is NOT an MCP server to +// the kernel. The installed binary's OpenClaw gateway is a separate +// concern — we are the next kernel build and will supersede it when +// deployed. +// 2. Session authority = kernel-owned (Wave 3.5). The session-family tools +// (register/deregister/list) do NOT call mod3 directly — they call the +// kernel's RegisterChannelSession / DeregisterChannelSession / +// ListChannelSessions methods on the Server, which mint the session_id +// and forward to mod3. Session ID minting happens in exactly one place. +// 3. Playback strategy = Option (A), server-side. Kernel receives audio/wav, +// writes to a tempfile, execs `afplay` (macOS) or `aplay` (Linux), +// fire-and-forget. Callers can opt in to blocking with blocking=true. +// Forward-compatible with Option (B) session-routed playback once the +// Wave 4 dashboard WebSocket lands — a future session-router check can +// gate this path when a browser subscriber exists. +// +// Tools registered (prefix `mod3_` to namespace against cog_* kernel tools): +// +// - mod3_speak — synthesize + (optionally) play (direct to mod3) +// - mod3_stop — cancel current/queued speech (direct to mod3) +// - mod3_voices — list available voices (direct to mod3) +// - mod3_status — mod3 /health probe + build info (direct to mod3) +// - mod3_register_session — kernel-minted session registration (via kernel) +// - mod3_deregister_session — session deregister (via kernel) +// - mod3_list_sessions — merged kernel+mod3 session roster (via kernel) +package engine + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ─── proxy wiring on MCPServer ─────────────────────────────────────────────── + +// modalityProxy holds the HTTP client and playback helper used by the mod3_* +// MCP tools. Fields are exported-by-convention (capitalized where needed for +// tests) so test code can override the HTTP client and the player command. +type modalityProxy struct { + // client is the HTTP client used for all mod3 forwards. Nil falls back + // to defaultMod3ProxyClient. + client *http.Client + + // player is the OS command executed for server-side audio playback. + // Overridable in tests to a stub binary / /usr/bin/true. Empty means + // "autodetect via runtime.GOOS" (afplay on darwin, aplay elsewhere). + player string + + // playerArgs, when non-nil, are passed as additional command args + // before the tempfile path. Useful for tests to pipe the wav through + // a counting script. Nil means no extra args. + playerArgs []string + + // disablePlayback short-circuits the player exec entirely. Tests set + // this when they want to assert "we got the bytes" without spawning a + // real player. Production code leaves it false. + disablePlayback bool + + // subscriberCheck, when non-nil, is consulted before spawning the local + // player in mod3_speak. If it returns (true, nil) the kernel skips + // afplay — mod3's /ws/audio/{session_id} WebSocket is already pushing + // the WAV to a dashboard subscriber (Wave 4.3). Errors and false return + // values fall through to the normal playback path. Nil means "use the + // default HTTP implementation" (GET {Mod3URL}/v1/sessions/{id}/subscribers). + subscriberCheck func(ctx context.Context, sessionID string) (bool, error) +} + +// defaultMod3ProxyTimeout is the per-request timeout for mod3 forwards. 30s +// covers the longest-plausible synthesis on the current Kokoro voice stack +// (~5-10s for multi-sentence input, with headroom for cold starts). +const defaultMod3ProxyTimeout = 30 * time.Second + +// defaultMod3ProxyClient is the shared http.Client used when modalityProxy.client +// is nil. Lazily initialised; safe for concurrent use. +var defaultMod3ProxyClient = &http.Client{Timeout: defaultMod3ProxyTimeout} + +// getModalityProxy returns the MCPServer's modality proxy, lazily creating +// one with sane defaults on first access. Tests can pre-seed m.mod3Proxy with +// their own instance before calling this. +func (m *MCPServer) getModalityProxy() *modalityProxy { + if m.mod3Proxy == nil { + m.mod3Proxy = &modalityProxy{} + } + return m.mod3Proxy +} + +// ─── tool registration ─────────────────────────────────────────────────────── + +// registerMod3Tools installs the 7 mod3_* MCP tools. Called from +// MCPServer.registerTools after the cog_* tools so the tool index stays +// stable at the front. +func (m *MCPServer) registerMod3Tools() { + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_speak", + Description: "Synthesize text to speech via mod3 and play the audio " + + "locally. Required: text. Optional: session_id, voice, speed, " + + "emotion, blocking (wait for playback to finish). Returns mod3 " + + "metrics (job_id, duration_sec, rtf, voice) and a playback_status " + + "flag. Fallback: curl -X POST http://localhost:7860/v1/synthesize " + + "-d '{\"text\":\"...\"}' -o out.wav && afplay out.wav", + }, withToolObserver(m, "mod3_speak", m.toolMod3Speak)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_stop", + Description: "Stop current mod3 speech and/or cancel queued jobs. " + + "Optional: session_id, job_id (cancel one specific job). Empty " + + "cancels current playback and clears the queue. Returns mod3's " + + "barge-in interruption context. Fallback: curl -X POST " + + "http://localhost:7860/v1/stop", + }, withToolObserver(m, "mod3_stop", m.toolMod3Stop)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_voices", + Description: "List available mod3 voices, optionally scoped to a " + + "session. Optional: session_id. Returns the voice catalogue mod3 " + + "exposes (id, name, language, gender metadata per voice). " + + "Fallback: curl http://localhost:7860/v1/voices", + }, withToolObserver(m, "mod3_voices", m.toolMod3Voices)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_status", + Description: "Probe mod3's /health endpoint. Returns the raw health " + + "payload (model_loaded, engine info, queue_depth, etc). 502 if " + + "mod3 is unreachable. Fallback: curl http://localhost:7860/health", + }, withToolObserver(m, "mod3_status", m.toolMod3Status)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_register_session", + Description: "Register a channel-participant session. Routes through " + + "the kernel's /v1/channel-sessions/register endpoint so " + + "session_id minting stays centralized (ADR-082 Wave 3.5). " + + "Required: participant_id. Optional: session_id (kernel mints " + + "a cs-* short UUID when absent), participant_type " + + "(agent|user|provider), preferred_voice, preferred_output_device, " + + "priority, kinds (e.g. [\"audio\"] per channel-provider RFC), " + + "metadata (opaque pass-through). Returns the merged {kernel, " + + "mod3} block: kernel identity record + mod3's full " + + "SessionRegisterResponse (assigned_voice, voice_conflict, " + + "output_device, queue_depth).", + }, withToolObserver(m, "mod3_register_session", m.toolMod3RegisterSession)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_deregister_session", + Description: "Deregister a channel-participant session. Routes " + + "through the kernel's /v1/channel-sessions/{id}/deregister " + + "endpoint so the kernel drops its identity record in sync with " + + "mod3. Required: session_id. Returns mod3's deregister " + + "acknowledgment (released_voice, dropped_jobs).", + }, withToolObserver(m, "mod3_deregister_session", m.toolMod3DeregisterSession)) + + mcp.AddTool(m.server, &mcp.Tool{ + Name: "mod3_list_sessions", + Description: "List channel-participant sessions via the kernel's " + + "/v1/channel-sessions endpoint. Returns a merged {kernel, mod3} " + + "block: kernel identity records + mod3's live per-channel state " + + "(voice_pool, voice_holders, serializer policy).", + }, withToolObserver(m, "mod3_list_sessions", m.toolMod3ListSessions)) +} + +// ─── input / output types ──────────────────────────────────────────────────── + +type mod3SpeakInput struct { + Text string `json:"text"` + SessionID string `json:"session_id,omitempty"` + Voice string `json:"voice,omitempty"` + Speed float64 `json:"speed,omitempty"` + Emotion float64 `json:"emotion,omitempty"` + // Blocking waits for the spawned player to exit before returning the + // tool result. Default false — fire-and-forget so multi-second audio + // doesn't block the MCP call. + Blocking bool `json:"blocking,omitempty"` + // SkipPlayback returns the wav bytes (base64) without attempting local + // playback. Useful for callers routing audio elsewhere (dashboard WS, + // file write, etc). Default false. + SkipPlayback bool `json:"skip_playback,omitempty"` +} + +type mod3StopInput struct { + SessionID string `json:"session_id,omitempty"` + JobID string `json:"job_id,omitempty"` +} + +type mod3VoicesInput struct { + SessionID string `json:"session_id,omitempty"` +} + +type mod3StatusInput struct{} + +type mod3RegisterSessionInput struct { + SessionID string `json:"session_id,omitempty"` + ParticipantID string `json:"participant_id"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` + // Kinds / Metadata are the channel-provider RFC fields that flow + // through to mod3 unchanged. See cogos_session_register primitive. + Kinds []string `json:"kinds,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type mod3DeregisterSessionInput struct { + SessionID string `json:"session_id"` +} + +type mod3ListSessionsInput struct{} + +// ─── handlers ──────────────────────────────────────────────────────────────── + +func (m *MCPServer) toolMod3Speak(ctx context.Context, req *mcp.CallToolRequest, in mod3SpeakInput) (*mcp.CallToolResult, any, error) { + if strings.TrimSpace(in.Text) == "" { + return textResult("text is required") + } + body := map[string]any{"text": in.Text} + if in.Voice != "" { + body["voice"] = in.Voice + } + if in.Speed > 0 { + body["speed"] = in.Speed + } + if in.Emotion > 0 { + body["emotion"] = in.Emotion + } + if in.SessionID != "" { + // Session threading: mod3 ignores unknown fields on SynthesizeRequest + // today; once multi-session synthesis lands, this is the channel. + body["session_id"] = in.SessionID + } + raw, _ := json.Marshal(body) + + audio, headers, status, err := m.proxyMod3Bytes(ctx, http.MethodPost, + "/v1/synthesize", bytes.NewReader(raw), "application/json") + if err != nil { + return mod3ErrorResult(fmt.Sprintf("mod3 unreachable: %v", err)) + } + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %s", status, truncate(string(audio), 400))) + } + + metrics := extractMod3Metrics(headers) + result := map[string]any{ + "ok": true, + "bytes": len(audio), + "metrics": metrics, + "session_id": in.SessionID, // may be empty; echoed for observability + } + + // If the caller asked for raw bytes (no server-side playback), base64- + // encode and return. Forward-compatible with session-routed playback. + if in.SkipPlayback { + result["audio_base64"] = base64.StdEncoding.EncodeToString(audio) + result["playback_status"] = "skipped" + return marshalResult(result) + } + + p := m.getModalityProxy() + if p.disablePlayback { + result["playback_status"] = "disabled" + return marshalResult(result) + } + + // Wave 4.3 — if the session has a live dashboard WebSocket subscriber, + // mod3 is already routing the WAV there. Skip the kernel's local player + // so we don't double-play. The check is scoped to sessions that were + // actually named on the speak call; session_id="" always falls through + // to the normal afplay path so CLI invocations keep working. + if in.SessionID != "" { + subscribed, checkErr := m.checkSessionSubscriber(ctx, in.SessionID) + if checkErr != nil { + // Log-worthy but not fatal — fall back to local playback. + slog.Debug("mod3 proxy: subscriber check failed", + "session_id", in.SessionID, "err", checkErr) + result["subscriber_check_error"] = checkErr.Error() + } + if subscribed { + result["playback_status"] = "routed_ws" + return marshalResult(result) + } + } + + playErr := p.playAudio(audio, in.Blocking) + switch { + case playErr == nil && in.Blocking: + result["playback_status"] = "played" + case playErr == nil: + result["playback_status"] = "spawned" + default: + result["playback_status"] = "error" + result["playback_error"] = playErr.Error() + } + return marshalResult(result) +} + +// checkSessionSubscriber asks mod3 whether ``sessionID`` has at least one +// active dashboard WebSocket subscriber for audio playback. Returns +// ``(subscribed, nil)`` on success, ``(false, err)`` on transport failure. +// ``(false, nil)`` — the default when the proxy has no check configured — +// also suppresses the routing path, so legacy callers see the exact same +// afplay behavior as before. +// +// Injectable via modalityProxy.subscriberCheck for tests. The default is a +// GET against mod3's /v1/sessions/{id}/subscribers endpoint with a 1.5s +// timeout inherited from defaultMod3ProxyTimeout. +func (m *MCPServer) checkSessionSubscriber(ctx context.Context, sessionID string) (bool, error) { + p := m.getModalityProxy() + if p.subscriberCheck != nil { + return p.subscriberCheck(ctx, sessionID) + } + // Default implementation — HTTP GET. Scoped to 1.5s so a wedged mod3 + // can't block a speak for more than that; falls back to afplay on timeout. + checkCtx, cancel := context.WithTimeout(ctx, 1500*time.Millisecond) + defer cancel() + raw, _, status, err := m.proxyMod3Bytes(checkCtx, http.MethodGet, + "/v1/sessions/"+url.PathEscape(sessionID)+"/subscribers", nil, "") + if err != nil { + return false, err + } + if status < 200 || status >= 300 { + return false, fmt.Errorf("mod3 returned %d: %s", status, truncate(string(raw), 200)) + } + var body struct { + Subscribed bool `json:"subscribed"` + Count int `json:"count"` + } + if unmarshalErr := json.Unmarshal(raw, &body); unmarshalErr != nil { + return false, fmt.Errorf("decode subscribers response: %w", unmarshalErr) + } + return body.Subscribed, nil +} + +func (m *MCPServer) toolMod3Stop(ctx context.Context, req *mcp.CallToolRequest, in mod3StopInput) (*mcp.CallToolResult, any, error) { + path := "/v1/stop" + q := url.Values{} + if in.JobID != "" { + q.Set("job_id", in.JobID) + } + if in.SessionID != "" { + q.Set("session_id", in.SessionID) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + return m.proxyMod3JSONAsMCP(ctx, http.MethodPost, path, nil) +} + +func (m *MCPServer) toolMod3Voices(ctx context.Context, req *mcp.CallToolRequest, in mod3VoicesInput) (*mcp.CallToolResult, any, error) { + path := "/v1/voices" + if in.SessionID != "" { + path += "?session_id=" + url.QueryEscape(in.SessionID) + } + return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, path, nil) +} + +func (m *MCPServer) toolMod3Status(ctx context.Context, req *mcp.CallToolRequest, in mod3StatusInput) (*mcp.CallToolResult, any, error) { + return m.proxyMod3JSONAsMCP(ctx, http.MethodGet, "/health", nil) +} + +// toolMod3RegisterSession routes through the kernel's shared +// RegisterChannelSession so session_id minting happens in exactly one place +// (ADR-082 Wave 3.5). The previous Wave 3 implementation called mod3's +// /v1/sessions/register directly, which bypassed Wave 2's kernel-owned +// minting authority; that path is now gone. +func (m *MCPServer) toolMod3RegisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3RegisterSessionInput) (*mcp.CallToolResult, any, error) { + if in.ParticipantID == "" { + return textResult("participant_id is required") + } + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") + } + resp, ferr := m.channelSessionBackend.RegisterChannelSession(ctx, channelSessionRegisterRequest{ + SessionID: in.SessionID, + ParticipantID: in.ParticipantID, + ParticipantType: in.ParticipantType, + PreferredVoice: in.PreferredVoice, + PreferredOutputDevice: in.PreferredOutputDevice, + Priority: in.Priority, + Kinds: in.Kinds, + Metadata: in.Metadata, + }) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) + } + return marshalResult(resp) +} + +func (m *MCPServer) toolMod3DeregisterSession(ctx context.Context, req *mcp.CallToolRequest, in mod3DeregisterSessionInput) (*mcp.CallToolResult, any, error) { + if in.SessionID == "" { + return textResult("session_id is required") + } + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") + } + mod3Resp, status, ferr := m.channelSessionBackend.DeregisterChannelSession(ctx, in.SessionID) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) + } + // Parse mod3's JSON body; surface mod3's non-2xx bodies intact as + // tool errors. The HTTP handler passes these through verbatim; the + // MCP tool wraps them so the caller sees the mod3 body text. + var parsed any + if len(mod3Resp) > 0 { + if jsonErr := json.Unmarshal(mod3Resp, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(mod3Resp)} + } + } + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %v", status, parsed)) + } + return marshalResult(parsed) +} + +func (m *MCPServer) toolMod3ListSessions(ctx context.Context, req *mcp.CallToolRequest, in mod3ListSessionsInput) (*mcp.CallToolResult, any, error) { + if m.channelSessionBackend == nil { + return mod3ErrorResult("channel-session backend not configured") + } + resp, _, ferr := m.channelSessionBackend.ListChannelSessions(ctx) + if ferr != nil { + return mod3ErrorResult(channelSessionForwardErrorText(ferr)) + } + return marshalResult(resp) +} + +// channelSessionForwardErrorText renders a *channelSessionForwardError into +// the "mod3 unreachable" / "mod3 returned N: body" shape the legacy MCP +// tool paths used, keeping error surfaces stable for callers that previously +// matched on those strings. +func channelSessionForwardErrorText(ferr *channelSessionForwardError) string { + switch ferr.Kind { + case "mod3_unreachable": + return ferr.Message + case "mod3_rejected": + var parsed any + if len(ferr.Mod3Body) > 0 { + if jsonErr := json.Unmarshal(ferr.Mod3Body, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(ferr.Mod3Body)} + } + } + return fmt.Sprintf("mod3 returned %d: %v", ferr.HTTPStatus, parsed) + default: + return ferr.Message + } +} + +// ─── HTTP forwarder primitives ─────────────────────────────────────────────── + +// proxyMod3Bytes issues an HTTP request to mod3 and returns the raw body, +// response headers, HTTP status, and a transport error. Caller owns the body +// bytes; they may be audio/wav (mod3_speak) or JSON (everything else). +func (m *MCPServer) proxyMod3Bytes(ctx context.Context, method, path string, body io.Reader, contentType string) ([]byte, http.Header, int, error) { + if m.cfg == nil { + return nil, nil, 0, errors.New("Mod3URL not configured (cfg nil)") + } + base := strings.TrimRight(m.cfg.Mod3URL, "/") + if base == "" { + return nil, nil, 0, errors.New("Mod3URL not configured") + } + + reqCtx, cancel := context.WithTimeout(ctx, defaultMod3ProxyTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, method, base+path, body) + if err != nil { + return nil, nil, 0, fmt.Errorf("build request: %w", err) + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + // Accept both audio and JSON so a single client path covers synthesize + // (audio/wav) and the rest (application/json). + req.Header.Set("Accept", "audio/wav, application/json") + + client := m.getModalityProxy().client + if client == nil { + client = defaultMod3ProxyClient + } + resp, err := client.Do(req) + if err != nil { + return nil, nil, 0, err + } + defer resp.Body.Close() + + // 16 MB cap — more than enough for multi-minute Kokoro wav at 24kHz + // (~2 MB per 30s), safety net against an upstream that never closes. + raw, err := io.ReadAll(io.LimitReader(resp.Body, 16<<20)) + if err != nil { + return nil, resp.Header, resp.StatusCode, fmt.Errorf("read response body: %w", err) + } + return raw, resp.Header, resp.StatusCode, nil +} + +// proxyMod3JSONAsMCP is a convenience wrapper for tools whose response is +// JSON (everything except mod3_speak). Reads the body, parses it as JSON if +// possible, and returns an mcp.CallToolResult; on non-2xx status returns a +// mod3-error marshalled result so the caller sees the mod3 body intact. +func (m *MCPServer) proxyMod3JSONAsMCP(ctx context.Context, method, path string, body io.Reader) (*mcp.CallToolResult, any, error) { + contentType := "" + if body != nil { + contentType = "application/json" + } + raw, _, status, err := m.proxyMod3Bytes(ctx, method, path, body, contentType) + if err != nil { + return mod3ErrorResult(fmt.Sprintf("mod3 unreachable: %v", err)) + } + // Try to parse as JSON; if parse fails, surface the body as text so the + // caller at least sees what mod3 said. + var parsed any + if len(raw) > 0 { + if jsonErr := json.Unmarshal(raw, &parsed); jsonErr != nil { + parsed = map[string]any{"raw": string(raw)} + } + } + if status < 200 || status >= 300 { + return mod3ErrorResult(fmt.Sprintf("mod3 returned %d: %v", status, parsed)) + } + return marshalResult(parsed) +} + +// mod3ErrorResult returns an IsError=true CallToolResult so the observer +// wrapper records the tool invocation as a failure in the ledger. +func mod3ErrorResult(msg string) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: msg}}, + IsError: true, + }, nil, nil +} + +// extractMod3Metrics pulls the X-Mod3-* headers into a metrics map. Numeric +// fields are parsed when possible; unknown headers pass through as strings +// so future mod3 headers surface without code changes. +func extractMod3Metrics(h http.Header) map[string]any { + out := map[string]any{} + for key, values := range h { + lk := strings.ToLower(key) + if !strings.HasPrefix(lk, "x-mod3-") || len(values) == 0 { + continue + } + short := strings.TrimPrefix(lk, "x-mod3-") + v := values[0] + // Try numeric parse for the common metric headers. + if f, err := strconv.ParseFloat(v, 64); err == nil { + // Preserve integers as int for nicer JSON output. + if i, ierr := strconv.ParseInt(v, 10, 64); ierr == nil { + out[short] = i + } else { + out[short] = f + } + continue + } + out[short] = v + } + return out +} + +// ─── playback helper ───────────────────────────────────────────────────────── + +// playAudio writes the wav bytes to a tempfile and spawns the platform's +// default player. When blocking==false the function returns immediately +// after the process starts; a goroutine waits for exit so the tempfile can +// be cleaned up. When blocking==true the function waits for exit and +// surfaces any non-zero return as an error. +// +// In tests, set modalityProxy.player to "/usr/bin/true" (or similar) to +// avoid actually playing audio. +func (p *modalityProxy) playAudio(wav []byte, blocking bool) error { + if p.disablePlayback { + return nil + } + f, err := os.CreateTemp("", "mod3-speak-*.wav") + if err != nil { + return fmt.Errorf("tempfile: %w", err) + } + path := f.Name() + if _, err := f.Write(wav); err != nil { + f.Close() + os.Remove(path) + return fmt.Errorf("write tempfile: %w", err) + } + if err := f.Close(); err != nil { + os.Remove(path) + return fmt.Errorf("close tempfile: %w", err) + } + + player := p.player + if player == "" { + player = defaultPlayerCommand() + } + if player == "" { + os.Remove(path) + return fmt.Errorf("no audio player available for GOOS=%s", runtime.GOOS) + } + + args := append([]string{}, p.playerArgs...) + args = append(args, path) + cmd := exec.Command(player, args...) + + if err := cmd.Start(); err != nil { + os.Remove(path) + return fmt.Errorf("start %s: %w", player, err) + } + + if blocking { + err := cmd.Wait() + _ = os.Remove(path) + if err != nil { + return fmt.Errorf("player %s exited: %w", player, err) + } + return nil + } + // Fire-and-forget: reap the child so the tempfile gets cleaned and the + // process isn't a zombie. Log errors; don't propagate (the MCP call + // already returned successfully). + go func() { + if werr := cmd.Wait(); werr != nil { + slog.Debug("mod3 proxy: player exited non-zero", + "player", player, "path", path, "err", werr) + } + _ = os.Remove(path) + }() + return nil +} + +// defaultPlayerCommand returns the preferred platform player, or "" when +// none is available in PATH. Resolved lazily per call so tests that change +// PATH take effect. +func defaultPlayerCommand() string { + candidates := map[string][]string{ + "darwin": {"afplay"}, + "linux": {"aplay", "paplay"}, + "freebsd": {"aplay"}, + } + for _, name := range candidates[runtime.GOOS] { + if _, err := exec.LookPath(name); err == nil { + return name + } + } + // Final fallback: if neither platform default is present, see if the + // caller has exposed one via PATH under its canonical name. + for _, name := range []string{"afplay", "aplay", "paplay", "ffplay"} { + if _, err := exec.LookPath(name); err == nil { + return name + } + } + return "" +} diff --git a/internal/engine/mcp_modality_proxy_test.go b/internal/engine/mcp_modality_proxy_test.go new file mode 100644 index 0000000..d12b702 --- /dev/null +++ b/internal/engine/mcp_modality_proxy_test.go @@ -0,0 +1,1108 @@ +// mcp_modality_proxy_test.go — coverage for the mod3 MCP proxy tools. +// +// Strategy: stand up a fake mod3 via httptest.NewServer, point an MCPServer's +// proxy at it, exercise each tool handler directly (handler function + +// typed input, bypassing the MCP SDK's JSON marshal layer). Playback is +// stubbed via disablePlayback or the injectable player field so tests don't +// spawn afplay. +package engine + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// ─── fake mod3 with synthesize + session routes ────────────────────────────── + +type fakeMod3Proxy struct { + t *testing.T + srv *httptest.Server + mu sync.Mutex + captured []capturedProxyRequest + + // Overrides for per-endpoint behavior. + synthesizeHandler http.HandlerFunc + stopHandler http.HandlerFunc + voicesHandler http.HandlerFunc + healthHandler http.HandlerFunc + regHandler http.HandlerFunc + deregHandler http.HandlerFunc + listSessionHandler http.HandlerFunc +} + +type capturedProxyRequest struct { + Method string + Path string + Query string + Body []byte +} + +// synthWav is a tiny valid-ish WAV header + ~1KB of silence. Not a real +// audio file — the proxy doesn't parse it, it just forwards/plays bytes. +var synthWav = func() []byte { + b := make([]byte, 1024) + copy(b, []byte("RIFF\x00\x00\x00\x00WAVEfmt ")) + return b +}() + +func newFakeMod3Proxy(t *testing.T) *fakeMod3Proxy { + t.Helper() + fm := &fakeMod3Proxy{t: t} + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/synthesize", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.synthesizeHandler != nil { + fm.synthesizeHandler(w, r) + return + } + // Emit the full mod3 header surface so the proxy can extract it. + w.Header().Set("X-Mod3-Job-Id", "job-test-0001") + w.Header().Set("X-Mod3-Voice", "bm_lewis") + w.Header().Set("X-Mod3-Duration-Sec", "1.23") + w.Header().Set("X-Mod3-Sample-Rate", "24000") + w.Header().Set("X-Mod3-Rtf", "9.29") + w.Header().Set("X-Mod3-Chunks", "1") + w.Header().Set("Content-Type", "audio/wav") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(synthWav) + }) + + mux.HandleFunc("POST /v1/stop", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.stopHandler != nil { + fm.stopHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "stopped", + "dropped_jobs": 0, + "interrupted": true, + "session_id": r.URL.Query().Get("session_id"), + "job_id_target": r.URL.Query().Get("job_id"), + }) + }) + + mux.HandleFunc("GET /v1/voices", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.voicesHandler != nil { + fm.voicesHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "voices": []map[string]any{ + {"id": "bm_lewis", "language": "en-GB"}, + {"id": "af_bella", "language": "en-US"}, + }, + }) + }) + + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.healthHandler != nil { + fm.healthHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "model_loaded": true, + "engine": "kokoro", + }) + }) + + mux.HandleFunc("POST /v1/sessions/register", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.regHandler != nil { + fm.regHandler(w, r) + return + } + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": body["session_id"], + "participant_id": body["participant_id"], + "assigned_voice": "bm_lewis", + }) + }) + + mux.HandleFunc("POST /v1/sessions/{id}/deregister", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.deregHandler != nil { + fm.deregHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "session_id": r.PathValue("id"), + }) + }) + + mux.HandleFunc("GET /v1/sessions", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.listSessionHandler != nil { + fm.listSessionHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "sessions": []any{}, + "voice_pool": []string{"bm_lewis", "af_bella"}, + }) + }) + + fm.srv = httptest.NewServer(mux) + t.Cleanup(func() { fm.srv.Close() }) + return fm +} + +func (fm *fakeMod3Proxy) capture(r *http.Request) { + body, _ := io.ReadAll(r.Body) + fm.mu.Lock() + defer fm.mu.Unlock() + fm.captured = append(fm.captured, capturedProxyRequest{ + Method: r.Method, Path: r.URL.Path, Query: r.URL.RawQuery, Body: body, + }) +} + +func (fm *fakeMod3Proxy) last() capturedProxyRequest { + fm.mu.Lock() + defer fm.mu.Unlock() + if len(fm.captured) == 0 { + fm.t.Fatalf("no captured requests") + } + return fm.captured[len(fm.captured)-1] +} + +// newProxyMCP builds a minimal MCPServer whose proxy points at fm, with +// playback fully disabled so tests don't touch the audio stack. A live +// Server is wired in as the channel-session backend so the session-family +// MCP tools (register/deregister/list) flow through the kernel's shared +// minting logic — matching production (ADR-082 Wave 3.5). Synthesis / +// control tools continue calling mod3 directly via the proxy. +func newProxyMCP(t *testing.T, fm *fakeMod3Proxy) *MCPServer { + t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + srv := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } + m := &MCPServer{ + cfg: cfg, + mod3Proxy: &modalityProxy{disablePlayback: true}, + channelSessionBackend: srv, + } + return m +} + +// newProxyMCPWithServer is like newProxyMCP but returns the wired-in Server +// so tests that want to assert on the kernel-side registry state can do so. +func newProxyMCPWithServer(t *testing.T, fm *fakeMod3Proxy) (*MCPServer, *Server) { + t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + srv := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } + m := &MCPServer{ + cfg: cfg, + mod3Proxy: &modalityProxy{disablePlayback: true}, + channelSessionBackend: srv, + } + return m, srv +} + +// decodeToolText parses the JSON text content of a CallToolResult. +func decodeToolText(t *testing.T, res *mcp.CallToolResult) map[string]any { + t.Helper() + if res == nil || len(res.Content) == 0 { + t.Fatalf("empty result") + } + tc, ok := res.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("expected *mcp.TextContent, got %T", res.Content[0]) + } + var out map[string]any + if err := json.Unmarshal([]byte(tc.Text), &out); err != nil { + t.Fatalf("decode result text: %v (raw=%q)", err, tc.Text) + } + return out +} + +// ─── mod3_speak ────────────────────────────────────────────────────────────── + +func TestMod3Speak_SuccessPath(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "hello world", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError=true: %v", res.Content) + } + + out := decodeToolText(t, res) + if ok, _ := out["ok"].(bool); !ok { + t.Fatalf("expected ok=true, got %v", out["ok"]) + } + if bytes, _ := out["bytes"].(float64); bytes < 100 { + t.Fatalf("expected bytes > 100, got %v", out["bytes"]) + } + + metrics, _ := out["metrics"].(map[string]any) + if metrics == nil { + t.Fatal("expected metrics map") + } + if jobID, _ := metrics["job-id"].(string); jobID != "job-test-0001" { + t.Fatalf("expected job-id=job-test-0001, got %v", metrics["job-id"]) + } + if dur, _ := metrics["duration-sec"].(float64); dur != 1.23 { + t.Fatalf("expected duration-sec=1.23, got %v", metrics["duration-sec"]) + } + // Integer parse path — sample-rate comes as "24000". + if sr, ok := metrics["sample-rate"].(float64); !ok || sr != 24000 { + t.Fatalf("expected sample-rate=24000, got %v (%T)", metrics["sample-rate"], metrics["sample-rate"]) + } + if got, _ := out["playback_status"].(string); got != "disabled" { + t.Fatalf("expected playback_status=disabled, got %v", out["playback_status"]) + } +} + +func TestMod3Speak_ForwardsSessionID(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "hello", + SessionID: "cs-abc123", + Voice: "af_bella", + Speed: 1.1, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cap := fm.last() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + if forwarded["session_id"] != "cs-abc123" { + t.Fatalf("expected session_id=cs-abc123, got %v", forwarded["session_id"]) + } + if forwarded["voice"] != "af_bella" { + t.Fatalf("expected voice=af_bella, got %v", forwarded["voice"]) + } + if speed, _ := forwarded["speed"].(float64); speed != 1.1 { + t.Fatalf("expected speed=1.1, got %v", forwarded["speed"]) + } +} + +func TestMod3Speak_OmitsSessionIDWhenAbsent(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "plain"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + cap := fm.last() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + if _, present := forwarded["session_id"]; present { + t.Fatalf("expected no session_id key, got %v", forwarded["session_id"]) + } +} + +func TestMod3Speak_EmptyTextRejects(t *testing.T) { + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: " "}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // textResult — IsError is false because it's a validation message, but + // the content must mention the required field. + if len(res.Content) == 0 { + t.Fatal("expected content") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "required") { + t.Fatalf("expected 'required' in message, got %q", tc.Text) + } +} + +func TestMod3Speak_Mod3DownReturnsCleanError(t *testing.T) { + // Port that refuses connections. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := l.Addr().String() + _ = l.Close() // release the port so dials get ECONNREFUSED + + m := &MCPServer{ + cfg: &Config{Mod3URL: "http://" + addr}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "hi"}) + if err != nil { + t.Fatalf("handler should not return Go error, got %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true when mod3 is unreachable") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(strings.ToLower(tc.Text), "mod3 unreachable") { + t.Fatalf("expected 'mod3 unreachable' message, got %q", tc.Text) + } +} + +func TestMod3Speak_PreservesMod3ErrorBody(t *testing.T) { + fm := newFakeMod3Proxy(t) + fm.synthesizeHandler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = w.Write([]byte(`{"detail":"text must not be empty"}`)) + } + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{Text: "bad"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true on mod3 422") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "422") { + t.Fatalf("expected '422' in error text, got %q", tc.Text) + } + if !strings.Contains(tc.Text, "text must not be empty") { + t.Fatalf("expected mod3 body preserved, got %q", tc.Text) + } +} + +func TestMod3Speak_SkipPlaybackReturnsBase64(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "skip", + SkipPlayback: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "skipped" { + t.Fatalf("expected playback_status=skipped, got %v", out["playback_status"]) + } + if b64, _ := out["audio_base64"].(string); b64 == "" { + t.Fatal("expected audio_base64 populated when skip_playback=true") + } +} + +// ─── mod3_stop / voices / status / sessions ────────────────────────────────── + +func TestMod3Stop_ForwardsSessionAndJob(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Stop(context.Background(), nil, mod3StopInput{ + SessionID: "cs-abc", + JobID: "job-xyz", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + + cap := fm.last() + if !strings.Contains(cap.Query, "session_id=cs-abc") { + t.Fatalf("expected session_id in query, got %q", cap.Query) + } + if !strings.Contains(cap.Query, "job_id=job-xyz") { + t.Fatalf("expected job_id in query, got %q", cap.Query) + } + + out := decodeToolText(t, res) + if out["status"] != "stopped" { + t.Fatalf("expected status=stopped, got %v", out["status"]) + } +} + +func TestMod3Voices_ReturnsRawList(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Voices(context.Background(), nil, mod3VoicesInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + out := decodeToolText(t, res) + voices, _ := out["voices"].([]any) + if len(voices) != 2 { + t.Fatalf("expected 2 voices, got %d", len(voices)) + } +} + +func TestMod3Voices_ThreadsSessionID(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + _, _, _ = m.toolMod3Voices(context.Background(), nil, mod3VoicesInput{SessionID: "cs-qq"}) + cap := fm.last() + if !strings.Contains(cap.Query, "session_id=cs-qq") { + t.Fatalf("expected session_id in query, got %q", cap.Query) + } +} + +func TestMod3Status_HitsHealth(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3Status(context.Background(), nil, mod3StatusInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatal("expected success") + } + out := decodeToolText(t, res) + if out["status"] != "ok" { + t.Fatalf("expected status=ok, got %v", out["status"]) + } + + cap := fm.last() + if cap.Path != "/health" { + t.Fatalf("expected /health, got %q", cap.Path) + } +} + +func TestMod3Status_Mod3DownClean(t *testing.T) { + l, _ := net.Listen("tcp", "127.0.0.1:0") + addr := l.Addr().String() + _ = l.Close() + + m := &MCPServer{cfg: &Config{Mod3URL: "http://" + addr}} + res, _, err := m.toolMod3Status(context.Background(), nil, mod3StatusInput{}) + if err != nil { + t.Fatalf("handler should not Go-error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true") + } +} + +// TestMod3RegisterSession_RoutesThroughKernel verifies that the MCP tool +// goes through the kernel's shared RegisterChannelSession backend — not +// directly to mod3 — and that the response is the merged {kernel, mod3} +// block produced by that shared code path (ADR-082 Wave 3.5). +func TestMod3RegisterSession_RoutesThroughKernel(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, srv := newProxyMCPWithServer(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + SessionID: "cs-regtest", + ParticipantID: "cog", + ParticipantType: "agent", + PreferredVoice: "bm_lewis", + PreferredOutputDevice: "system-default", + Priority: 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError: %v", res.Content) + } + + // The forward to mod3 must have happened via the kernel's shared + // method — verify the request body mod3 saw carries the caller- + // supplied session_id and participant_id unchanged. + cap := fm.last() + if cap.Path != "/v1/sessions/register" { + t.Fatalf("expected mod3 /v1/sessions/register, got %q", cap.Path) + } + var body map[string]any + if err := json.Unmarshal(cap.Body, &body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["session_id"] != "cs-regtest" { + t.Fatalf("bad session_id forwarded: %v", body["session_id"]) + } + if body["participant_id"] != "cog" { + t.Fatalf("bad participant_id forwarded: %v", body["participant_id"]) + } + if body["preferred_voice"] != "bm_lewis" { + t.Fatalf("bad preferred_voice forwarded: %v", body["preferred_voice"]) + } + + // The merged {kernel, mod3} shape must land — verify the kernel's + // identity record is present in the response. + out := decodeToolText(t, res) + kernel, ok := out["kernel"].(map[string]any) + if !ok { + t.Fatalf("expected kernel block in response, got %v", out) + } + if kernel["session_id"] != "cs-regtest" { + t.Fatalf("expected kernel.session_id=cs-regtest, got %v", kernel["session_id"]) + } + if kernel["id_source"] != "caller" { + t.Fatalf("expected id_source=caller, got %v", kernel["id_source"]) + } + + // Kernel registry must hold the committed record — proves we went + // through the shared backend and not straight to mod3. + if _, held := srv.channelSessionRegistry.Get("cs-regtest"); !held { + t.Fatal("expected kernel registry to hold record after register") + } +} + +// TestMod3RegisterSession_KernelMintsWhenAbsent exercises the minting +// path — caller omits session_id, kernel mints one, mod3 receives it. +func TestMod3RegisterSession_KernelMintsWhenAbsent(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, srv := newProxyMCPWithServer(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + ParticipantID: "cog", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError: %v", res.Content) + } + + out := decodeToolText(t, res) + kernel, _ := out["kernel"].(map[string]any) + sid, _ := kernel["session_id"].(string) + if !strings.HasPrefix(sid, "cs-") { + t.Fatalf("expected minted cs-* session_id, got %q", sid) + } + if kernel["id_source"] != "minted" { + t.Fatalf("expected id_source=minted, got %v", kernel["id_source"]) + } + + // Mod3 must have seen the kernel-minted ID verbatim. + cap := fm.last() + var body map[string]any + _ = json.Unmarshal(cap.Body, &body) + if body["session_id"] != sid { + t.Fatalf("mod3 got session_id=%v, expected %q", body["session_id"], sid) + } + + if _, held := srv.channelSessionRegistry.Get(sid); !held { + t.Fatalf("expected kernel registry to hold minted record %q", sid) + } +} + +// TestMod3RegisterSession_ForwardsKindsAndMetadata verifies the Wave 3.5 +// schema alignment with the channel-provider RFC — `kinds` and `metadata` +// flow through the kernel's register endpoint and land in mod3's request +// body unchanged. +func TestMod3RegisterSession_ForwardsKindsAndMetadata(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, _ := newProxyMCPWithServer(t, fm) + + _, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + SessionID: "cs-kinds", + ParticipantID: "mod3-provider", + ParticipantType: "provider", + Kinds: []string{"audio"}, + Metadata: map[string]any{ + "provider_id": "mod3-local", + "build": "0.5.0", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cap := fm.last() + var body map[string]any + if err := json.Unmarshal(cap.Body, &body); err != nil { + t.Fatalf("decode: %v", err) + } + kinds, ok := body["kinds"].([]any) + if !ok || len(kinds) != 1 || kinds[0] != "audio" { + t.Fatalf("expected kinds=[\"audio\"] forwarded, got %v", body["kinds"]) + } + md, ok := body["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected metadata object forwarded, got %v (%T)", body["metadata"], body["metadata"]) + } + if md["provider_id"] != "mod3-local" { + t.Fatalf("expected metadata.provider_id=mod3-local, got %v", md["provider_id"]) + } + if body["participant_type"] != "provider" { + t.Fatalf("expected participant_type=provider, got %v", body["participant_type"]) + } +} + +func TestMod3RegisterSession_RejectsWithoutParticipant(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "participant_id") { + t.Fatalf("expected validation message, got %q", tc.Text) + } +} + +func TestMod3RegisterSession_NoBackendReturnsCleanError(t *testing.T) { + // An MCPServer with no channel-session backend wired in must surface + // a clean "not configured" error rather than a nil deref — important + // because NewMCPServer (used by tests that only care about memory + // tools) doesn't wire the backend. + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3RegisterSession(context.Background(), nil, mod3RegisterSessionInput{ + ParticipantID: "cog", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true when backend is nil") + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "backend not configured") { + t.Fatalf("expected backend-not-configured message, got %q", tc.Text) + } +} + +// TestMod3DeregisterSession_RoutesThroughKernel verifies the deregister +// tool forwards via the kernel's shared path and the kernel drops its +// identity record on success. +func TestMod3DeregisterSession_RoutesThroughKernel(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, srv := newProxyMCPWithServer(t, fm) + + // Seed the kernel registry so we can see it get dropped. + srv.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-drop", ParticipantID: "cog", + }) + + res, _, err := m.toolMod3DeregisterSession(context.Background(), nil, mod3DeregisterSessionInput{ + SessionID: "cs-drop", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError: %v", res.Content) + } + cap := fm.last() + if cap.Path != "/v1/sessions/cs-drop/deregister" { + t.Fatalf("expected /v1/sessions/cs-drop/deregister at mod3, got %q", cap.Path) + } + if _, held := srv.channelSessionRegistry.Get("cs-drop"); held { + t.Fatal("expected kernel registry to drop record after successful deregister") + } +} + +func TestMod3DeregisterSession_RequiresID(t *testing.T) { + m := &MCPServer{cfg: &Config{Mod3URL: "http://unused"}} + res, _, err := m.toolMod3DeregisterSession(context.Background(), nil, mod3DeregisterSessionInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tc := res.Content[0].(*mcp.TextContent) + if !strings.Contains(tc.Text, "session_id") { + t.Fatalf("expected session_id message, got %q", tc.Text) + } +} + +// TestMod3ListSessions_RoutesThroughKernel verifies list merges the +// kernel snapshot with mod3's per-channel state (the new Wave 3.5 +// merged shape, not mod3's raw payload). +func TestMod3ListSessions_RoutesThroughKernel(t *testing.T) { + fm := newFakeMod3Proxy(t) + m, srv := newProxyMCPWithServer(t, fm) + + srv.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-seed", ParticipantID: "cog", + }) + + res, _, err := m.toolMod3ListSessions(context.Background(), nil, mod3ListSessionsInput{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got IsError: %v", res.Content) + } + out := decodeToolText(t, res) + kernel, ok := out["kernel"].([]any) + if !ok { + t.Fatalf("expected kernel array in merged response, got %v", out) + } + if len(kernel) != 1 { + t.Fatalf("expected 1 kernel record, got %d", len(kernel)) + } + // Mod3 block must be present (from the fake list handler). + if _, present := out["mod3"]; !present { + t.Fatal("expected mod3 block in merged response") + } +} + +// ─── metric extraction unit test ───────────────────────────────────────────── + +func TestExtractMod3Metrics_TypesByValue(t *testing.T) { + h := http.Header{} + h.Set("X-Mod3-Job-Id", "abc123") + h.Set("X-Mod3-Duration-Sec", "2.50") + h.Set("X-Mod3-Sample-Rate", "24000") + h.Set("X-Mod3-Rtf", "8.1") + h.Set("Content-Type", "audio/wav") // should be skipped + + out := extractMod3Metrics(h) + if len(out) != 4 { + t.Fatalf("expected 4 metrics, got %d: %v", len(out), out) + } + if got, _ := out["job-id"].(string); got != "abc123" { + t.Fatalf("job-id: got %v", out["job-id"]) + } + if got, ok := out["duration-sec"].(float64); !ok || got != 2.5 { + t.Fatalf("duration-sec: got %v (%T)", out["duration-sec"], out["duration-sec"]) + } + if got, ok := out["sample-rate"].(int64); !ok || got != 24000 { + t.Fatalf("sample-rate should parse to int64, got %v (%T)", + out["sample-rate"], out["sample-rate"]) + } + if _, present := out["content-type"]; present { + t.Fatal("content-type should be skipped") + } +} + +// ─── playback injection test ───────────────────────────────────────────────── + +// TestPlayAudio_StubPlayer — validate the playback plumbing by injecting a +// small shell-script player that records the file it received. Proves the +// audio bytes reach the player (the bug the installed binary has today is +// that they don't). Only runs on OSes where a shell is available. +func TestPlayAudio_StubPlayer(t *testing.T) { + dir := t.TempDir() + recPath := filepath.Join(dir, "received.log") + stubPath := filepath.Join(dir, "stub-player.sh") + + // Player writes its last-arg path and the first 4 bytes of the wav to + // received.log; proves (a) the player got a path, (b) the file exists + // at that path with our bytes. + stubBody := `#!/bin/sh +path="$1" +hdr=$(dd if="$path" bs=4 count=1 2>/dev/null) +printf 'path=%s hdr=%s\n' "$path" "$hdr" > "` + recPath + `" +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + + p := &modalityProxy{player: stubPath} + if err := p.playAudio(synthWav, true); err != nil { + t.Fatalf("playAudio: %v", err) + } + + got, err := os.ReadFile(recPath) + if err != nil { + t.Fatalf("read record: %v", err) + } + line := string(got) + if !strings.Contains(line, "hdr=RIFF") { + t.Fatalf("player did not see RIFF header; got %q", line) + } + if !strings.Contains(line, "path=") { + t.Fatalf("player did not get a path arg; got %q", line) + } +} + +// TestPlayAudio_NonBlockingSpawn — fire-and-forget. Use a sleep-style stub +// and assert playAudio returns before the stub would finish. Prevents the +// regression where speech synthesis blocks the MCP response on playback. +func TestPlayAudio_NonBlockingSpawn(t *testing.T) { + dir := t.TempDir() + stubPath := filepath.Join(dir, "sleeper.sh") + stubBody := `#!/bin/sh +sleep 5 +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + p := &modalityProxy{player: stubPath} + + done := make(chan struct{}) + go func() { + if err := p.playAudio(synthWav, false); err != nil { + t.Errorf("playAudio: %v", err) + } + close(done) + }() + + select { + case <-done: + // Expected: playAudio returns immediately in non-blocking mode. + case <-time.After(2 * time.Second): + t.Fatal("playAudio(blocking=false) did not return within 2s") + } +} + +// ─── Wave 4.3 — subscriber-check / afplay skip ─────────────────────────────── + +// TestMod3Speak_NoSessionAlwaysSpawnsPlayer — session_id="" bypasses the +// subscriber check entirely so CLI invocations of mod3_speak still play +// audio through afplay as they always did. +func TestMod3Speak_NoSessionAlwaysSpawnsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{player: stubPath} + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "no session", + Blocking: true, // wait for stub to finish so the test can assert invocation count + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once, got %d", got) + } +} + +// TestMod3Speak_SessionWithSubscriberSkipsPlayer — when the injected +// subscriber-check returns true, the kernel skips afplay entirely and +// returns playback_status=routed_ws. The stub player must NOT be invoked. +func TestMod3Speak_SessionWithSubscriberSkipsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + if sessionID != "cs-with-sub" { + t.Errorf("unexpected session_id=%q", sessionID) + } + return true, nil + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "skip me", + SessionID: "cs-with-sub", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "routed_ws" { + t.Fatalf("expected playback_status=routed_ws, got %v", out["playback_status"]) + } + // Give any stray goroutine a moment to trip the stub — proving non-invocation. + time.Sleep(100 * time.Millisecond) + if got := count(); got != 0 { + t.Fatalf("expected stub player NOT invoked, got %d", got) + } +} + +// TestMod3Speak_SessionWithoutSubscriberSpawnsPlayer — subscriber-check +// returns false: kernel falls back to the normal afplay path and the stub +// player IS invoked. +func TestMod3Speak_SessionWithoutSubscriberSpawnsPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + return false, nil + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "no subscriber", + SessionID: "cs-no-sub", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once, got %d", got) + } +} + +// TestMod3Speak_SubscriberCheckErrorFallsBackToPlayer — transient +// check error (mod3 flaky, timeout, etc.) must not orphan the audio. +// The kernel logs the error, records subscriber_check_error in the result, +// and still spawns the player so the user hears the reply. +func TestMod3Speak_SubscriberCheckErrorFallsBackToPlayer(t *testing.T) { + fm := newFakeMod3Proxy(t) + m := newProxyMCP(t, fm) + + stubPath, count := writeStubPlayer(t) + m.mod3Proxy = &modalityProxy{ + player: stubPath, + subscriberCheck: func(ctx context.Context, sessionID string) (bool, error) { + return false, fmt.Errorf("mod3 probe timed out") + }, + } + + res, _, err := m.toolMod3Speak(context.Background(), nil, mod3SpeakInput{ + Text: "check failed", + SessionID: "cs-flaky", + Blocking: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if res.IsError { + t.Fatalf("expected success despite check error, got %v", res.Content) + } + out := decodeToolText(t, res) + if got, _ := out["playback_status"].(string); got != "played" { + t.Fatalf("expected playback_status=played, got %v", out["playback_status"]) + } + if checkErr, _ := out["subscriber_check_error"].(string); !strings.Contains(checkErr, "probe timed out") { + t.Fatalf("expected subscriber_check_error to surface, got %v", out["subscriber_check_error"]) + } + if got := count(); got != 1 { + t.Fatalf("expected stub player invoked once on check error, got %d", got) + } +} + +// TestCheckSessionSubscriber_DefaultImplementationHitsMod3 — wire up a +// stand-alone fake HTTP server that answers the +// /v1/sessions/{id}/subscribers probe and verify the default implementation +// (no injected subscriberCheck) parses its response correctly. +func TestCheckSessionSubscriber_DefaultImplementationHitsMod3(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/sessions/cs-yes/subscribers", func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-yes", "subscribed": true, "count": 1, + }) + }) + mux.HandleFunc("GET /v1/sessions/cs-no/subscribers", func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-no", "subscribed": false, "count": 0, + }) + }) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + m := &MCPServer{ + cfg: &Config{Mod3URL: srv.URL}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + + yes, err := m.checkSessionSubscriber(context.Background(), "cs-yes") + if err != nil { + t.Fatalf("cs-yes: %v", err) + } + if !yes { + t.Fatal("cs-yes: expected subscribed=true") + } + + no, err := m.checkSessionSubscriber(context.Background(), "cs-no") + if err != nil { + t.Fatalf("cs-no: %v", err) + } + if no { + t.Fatal("cs-no: expected subscribed=false") + } +} + +// TestCheckSessionSubscriber_Mod3UnreachableReturnsError — transport +// error (connection refused, timeout) must surface as a non-nil error so +// toolMod3Speak records subscriber_check_error and falls back to afplay. +func TestCheckSessionSubscriber_Mod3UnreachableReturnsError(t *testing.T) { + // Bind a port, close it so dials get ECONNREFUSED. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := l.Addr().String() + _ = l.Close() + + m := &MCPServer{ + cfg: &Config{Mod3URL: "http://" + addr}, + mod3Proxy: &modalityProxy{disablePlayback: true}, + } + subscribed, err := m.checkSessionSubscriber(context.Background(), "cs-anything") + if err == nil { + t.Fatal("expected transport error") + } + if subscribed { + t.Fatal("expected subscribed=false on error") + } +} + +// writeStubPlayer builds a temp shell script that records each invocation +// by appending to a log file. Returns the path of the executable AND a +// getter that returns the current invocation count. The stub exits +// immediately so blocking=true still works in tests. +func writeStubPlayer(t *testing.T) (stubPath string, count func() int) { + t.Helper() + dir := t.TempDir() + logPath := filepath.Join(dir, "invocations.log") + stubPath = filepath.Join(dir, "stub-player.sh") + stubBody := `#!/bin/sh +echo "invoked" >> "` + logPath + `" +` + if err := os.WriteFile(stubPath, []byte(stubBody), 0o755); err != nil { + t.Fatalf("write stub: %v", err) + } + count = func() int { + data, err := os.ReadFile(logPath) + if err != nil { + return 0 + } + return strings.Count(string(data), "invoked\n") + } + return stubPath, count +} diff --git a/internal/engine/mcp_server.go b/internal/engine/mcp_server.go index ffffcc2..d5c33a9 100644 --- a/internal/engine/mcp_server.go +++ b/internal/engine/mcp_server.go @@ -47,6 +47,30 @@ type MCPServer struct { busSessions *BusSessionManager sessionRegistry *SessionRegistry handoffRegistry *HandoffRegistry + + // mod3Proxy backs the mod3_* MCP tools (Wave 3 of the mod3-kernel + // integration). Lazily initialised — tests can pre-seed it with a + // custom HTTP client + stub player to avoid real network + audio. + mod3Proxy *modalityProxy + + // channelSessionBackend is the kernel Server whose + // RegisterChannelSession / DeregisterChannelSession / + // ListChannelSessions methods own session-ID authority (ADR-082 + // Wave 2). Wave 3.5 routes the three mod3 session-family MCP tools + // through these shared methods so minting happens in exactly one + // place. Nil when the MCP server is built outside a live kernel — + // the tools return a clean "not configured" error in that case. + channelSessionBackend channelSessionBackend +} + +// channelSessionBackend is the narrow surface the mod3 session-family MCP +// tools need to forward through the kernel's shared minting/forwarding +// logic. Interface rather than concrete *Server so tests can inject a fake +// without building a whole Server. +type channelSessionBackend interface { + RegisterChannelSession(ctx context.Context, req channelSessionRegisterRequest) (*channelSessionResponse, *channelSessionForwardError) + DeregisterChannelSession(ctx context.Context, sessionID string) (json.RawMessage, int, *channelSessionForwardError) + ListChannelSessions(ctx context.Context) (*channelSessionListResponse, int, *channelSessionForwardError) } // NewMCPServer creates and configures the MCP server with all stage-1 tools. @@ -93,6 +117,15 @@ func (m *MCPServer) SetAgentController(ctrl AgentController) { m.agentController = ctrl } +// SetChannelSessionBackend wires the kernel-owned channel-session minting +// logic into the MCP server so the mod3_register_session / _deregister / +// _list tools call through the same shared methods the HTTP surface uses +// (ADR-082 Wave 3.5). Safe to pass nil; the tools surface a clean "not +// configured" error in that case. +func (m *MCPServer) SetChannelSessionBackend(b channelSessionBackend) { + m.channelSessionBackend = b +} + // Handler returns the http.Handler for mounting at /mcp. func (m *MCPServer) Handler() http.Handler { return m.handler @@ -241,6 +274,12 @@ func (m *MCPServer) registerTools() { // both surfaces coexist by design — same kernel truth, two MCP // doorways (amendment #5 of the Agent P hybrid plan). m.registerSessionTools() + + // Wave 3: mod3 proxy tools (mcp_modality_proxy.go). The kernel becomes + // the MCP front door for mod3 — HTTP-forwards synthesis/stop/voices/ + // status and plays the returned audio/wav locally. Supersedes the + // installed binary's OpenClaw gateway which silently drops audio bytes. + m.registerMod3Tools() } // registerResources registers MCP Resources — read-only addressable data. diff --git a/internal/engine/serve.go b/internal/engine/serve.go index c2c29d8..f5c5422 100644 --- a/internal/engine/serve.go +++ b/internal/engine/serve.go @@ -19,6 +19,15 @@ // GET /v1/constellation/fovea — current fovea state // GET /v1/constellation/adjacent?uri=… — adjacent nodes by attentional proximity // +// Channel-session forwarder (ADR-082 Wave 2, see serve_sessions_channel.go): +// +// POST /v1/channel-sessions/register — kernel mints session_id +// and forwards to mod3; +// returns merged response +// POST /v1/channel-sessions/{id}/deregister — proxy to mod3, drop record +// GET /v1/channel-sessions — kernel view + mod3 list +// GET /v1/channel-sessions/{id} — single-session detail +// // The chat endpoint routes through the inference Router when one is set, // otherwise returns 501. package engine @@ -52,8 +61,8 @@ type Server struct { process *Process router Router // nil until SetRouter is called srv *http.Server - debug debugStore // captures last request pipeline state - attentionLog *attentionLog // per-server log (avoids global write race) + debug debugStore // captures last request pipeline state + attentionLog *attentionLog // per-server log (avoids global write race) agentController AgentController // nil until SetAgentController is called mcpServer *MCPServer // so SetAgentController can propagate to tools @@ -71,6 +80,18 @@ type Server struct { // these are derived views rebuilt from bus replay at startup. sessionRegistry *SessionRegistry handoffRegistry *HandoffRegistry + + // ADR-082 Wave 2: kernel-owned identity registry for channel-participant + // sessions. The kernel mints session_id, mod3 stores per-channel state + // keyed on the kernel-issued ID. Distinct from sessionRegistry above, + // which enforces strict 3-component hyphen IDs for the agent/handoff + // protocol. See serve_sessions_channel.go for the full rationale. + channelSessionRegistry *ChannelSessionRegistry + + // mod3Client is the HTTP client used to forward channel-session calls + // to mod3. Nil in production (falls back to the package-level + // mod3HTTPClient); tests set this to an httptest-backed client. + mod3Client *http.Client } // NewServer constructs a Server bound to the configured port. @@ -94,6 +115,9 @@ func NewServer(cfg *Config, nucleus *Nucleus, process *Process) *Server { s.sessionRegistry = NewSessionRegistry() s.handoffRegistry = NewHandoffRegistry() + // ADR-082 Wave 2 kernel-owned channel-session identity. + s.channelSessionRegistry = NewChannelSessionRegistry() + mux := http.NewServeMux() mux.HandleFunc("GET /", handleDashboard) mux.HandleFunc("GET /canvas", handleCanvas) @@ -133,6 +157,13 @@ func NewServer(cfg *Config, nucleus *Nucleus, process *Process) *Server { // cleanly with the pre-existing GET /v1/sessions[/{id}] surface. s.registerSessionMgmtRoutes(mux) + // ADR-082 Wave 2: kernel-side channel-session forwarder. The four + // /v1/channel-sessions/* routes mint session_ids, record identity + // locally, and forward to mod3 at cfg.Mod3URL. Namespaced under + // /v1/channel-sessions/* to coexist with the agent-session surface + // above (incompatible session_id formats — see serve_sessions_channel.go). + s.registerChannelSessionRoutes(mux) + // Replay bus_sessions + bus_handoffs into the in-memory registries so // the kernel starts with an accurate derived view. Bus is authoritative // either way; this just warms the read path. diff --git a/internal/engine/serve_mcp.go b/internal/engine/serve_mcp.go index 397398b..148a980 100644 --- a/internal/engine/serve_mcp.go +++ b/internal/engine/serve_mcp.go @@ -17,6 +17,11 @@ func (s *Server) registerMCPRoutes(mux *http.ServeMux) { // these are nil — NewMCPServer (used by tests that only care about // memory tools) doesn't call this, which is fine. mcpSrv.SetSessionsBackend(s.busSessions, s.sessionRegistry, s.handoffRegistry) + // ADR-082 Wave 3.5: route the mod3 session-family MCP tools through + // the kernel's shared channel-session methods so session-ID minting + // happens in exactly one place (this Server). Handlers dispatching + // to mod3 directly was the Wave 3 divergence this removes. + mcpSrv.SetChannelSessionBackend(s) s.mcpServer = mcpSrv h := mcpSrv.Handler() mux.Handle("GET /mcp", h) diff --git a/internal/engine/serve_sessions_channel.go b/internal/engine/serve_sessions_channel.go new file mode 100644 index 0000000..fadaf96 --- /dev/null +++ b/internal/engine/serve_sessions_channel.go @@ -0,0 +1,588 @@ +// serve_sessions_channel.go — kernel-side HTTP forwarder for channel-session +// registration (ADR-082, Wave 2 of the mod3-kernel integration). +// +// Design locks (Wave 2 handoff): +// +// 1. Session authority = kernel-owned. The kernel mints `session_id` if the +// caller doesn't supply one. Mod3's SessionRegistry stores per-channel +// state (voice, queue, device) keyed on the kernel-issued ID. If mod3 +// crashes, the kernel knows which sessions existed; mod3 rebuilds on +// re-register. +// 2. MCP transport = HTTP proxy. Kernel reaches mod3 over plain HTTP at +// Config.Mod3URL (default http://localhost:7860), not stdio MCP. +// +// Path namespace choice: these routes live under `/v1/channel-sessions/*`, +// NOT `/v1/sessions/*`. The kernel's existing `/v1/sessions/*` family +// (serve_sessions_mgmt.go) serves agent-session state (3-component hyphen- +// validated session IDs, workspace/role required, tied to the handoff +// protocol). Channel-participant registration has an incompatible shape +// (short UUID session IDs, participant_id/participant_type/voice/device +// fields). Rather than weaken ValidateSessionID (which would cascade into +// handoff claim semantics) we namespace the new concern. The channel-provider +// RFC's guidance to "use the same cogos_session_register primitive with a +// participant_type discriminator" remains aspirational at the MCP tool layer +// (Wave 3); at the HTTP layer the two surfaces coexist cleanly. +// +// Routes owned by this file: +// +// POST /v1/channel-sessions/register — mint+forward to mod3 +// POST /v1/channel-sessions/{id}/deregister — forward deregister +// GET /v1/channel-sessions — list (kernel view + mod3 list) +// GET /v1/channel-sessions/{id} — single-session detail +package engine + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +// ─── in-memory registry for kernel-owned channel-session identity ──────────── + +// ChannelSessionRecord is the kernel's identity-authority record for a +// channel-participant session. Distinct from SessionState in sessions.go +// (which tracks agent sessions with strict 3-component hyphen IDs); the +// channel-session concern uses short UUIDs and caller-supplied participant +// metadata that the kernel passes through to mod3. +type ChannelSessionRecord struct { + // SessionID is the kernel-authoritative ID. Either caller-supplied or + // minted by the kernel (uuid short form). + SessionID string `json:"session_id"` + + // ParticipantID / ParticipantType / PreferredVoice / PreferredOutputDevice + // / Priority mirror the mod3 SessionRegisterRequest shape. The kernel + // stores them so a post-crash re-register can replay identity cleanly. + ParticipantID string `json:"participant_id,omitempty"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` + + // Kinds mirrors the channel-provider RFC's `kinds` metadata field + // (e.g. ["audio"] for a mod3-provider registration). Mod3 ignores it + // today; kept on the kernel record for downstream consumers that want + // to filter by capability. + Kinds []string `json:"kinds,omitempty"` + + // Metadata is an opaque pass-through map the RFC describes as + // "provider_id/kinds in the metadata" — preserved verbatim on the + // kernel record and forwarded to mod3 (which ignores unknown fields). + Metadata map[string]any `json:"metadata,omitempty"` + + RegisteredAt time.Time `json:"registered_at"` + LastSeen time.Time `json:"last_seen"` + + // Source records whether the session_id came from the caller or was + // minted. Useful for audit; no functional impact. + IDSource string `json:"id_source,omitempty"` // "caller" | "minted" +} + +// ChannelSessionRegistry is the in-memory map keyed by session_id. It holds +// kernel-owned identity only; per-channel state (assigned_voice, queue, +// device) lives in mod3 and is returned as part of the merged forward +// response. If mod3 crashes the kernel knows which sessions existed and +// callers can re-register to rebuild. +type ChannelSessionRegistry struct { + mu sync.RWMutex + rows map[string]*ChannelSessionRecord +} + +// NewChannelSessionRegistry returns an empty registry. +func NewChannelSessionRegistry() *ChannelSessionRegistry { + return &ChannelSessionRegistry{rows: make(map[string]*ChannelSessionRecord)} +} + +// Put stores (or overwrites) a record by session_id. +func (r *ChannelSessionRegistry) Put(rec ChannelSessionRecord) { + r.mu.Lock() + defer r.mu.Unlock() + cp := rec + r.rows[rec.SessionID] = &cp +} + +// Get returns a copy of the record for id, or (nil, false). +func (r *ChannelSessionRegistry) Get(id string) (*ChannelSessionRecord, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + row, ok := r.rows[id] + if !ok { + return nil, false + } + cp := *row + return &cp, true +} + +// Delete removes the record for id. No-op if absent; returns whether the row +// was present before removal (handy for logging / telemetry). +func (r *ChannelSessionRegistry) Delete(id string) bool { + r.mu.Lock() + defer r.mu.Unlock() + _, ok := r.rows[id] + delete(r.rows, id) + return ok +} + +// Snapshot returns a copy of every record. +func (r *ChannelSessionRegistry) Snapshot() []*ChannelSessionRecord { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]*ChannelSessionRecord, 0, len(r.rows)) + for _, row := range r.rows { + cp := *row + out = append(out, &cp) + } + return out +} + +// Len returns the number of tracked channel sessions. +func (r *ChannelSessionRegistry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.rows) +} + +// ─── route registration ────────────────────────────────────────────────────── + +// defaultMod3ForwardTimeout is the per-request timeout for channel-session +// HTTP forwards. 8s is deliberately shorter than WriteTimeout (300s) so a +// stalled mod3 doesn't block a kernel request the whole way through. +const defaultMod3ForwardTimeout = 8 * time.Second + +// mod3HTTPClient is the shared net/http client for forwards. Tests override +// Server.mod3Client for isolation; when nil, handlers fall back to this. +var mod3HTTPClient = &http.Client{Timeout: defaultMod3ForwardTimeout} + +// registerChannelSessionRoutes attaches the 4 channel-session routes onto mux. +// Called from NewServer. +func (s *Server) registerChannelSessionRoutes(mux *http.ServeMux) { + mux.HandleFunc("POST /v1/channel-sessions/register", s.handleChannelSessionRegister) + mux.HandleFunc("POST /v1/channel-sessions/{id}/deregister", s.handleChannelSessionDeregister) + mux.HandleFunc("GET /v1/channel-sessions", s.handleChannelSessionList) + mux.HandleFunc("GET /v1/channel-sessions/{id}", s.handleChannelSessionGet) +} + +// ─── wire types ────────────────────────────────────────────────────────────── + +// channelSessionRegisterRequest is the kernel-facing request body. Shape +// mirrors mod3's SessionRegisterRequest (see mod3/http_api.py line ~278) plus +// an optional session_id — when omitted, the kernel mints one. +// +// Wave 3.5 schema alignment with the channel-provider RFC's +// `cogos_session_register` primitive: `kinds` (array of adapter kinds the +// registrant participates in, e.g. ["audio"]) and `metadata` (opaque +// pass-through blob — RFC calls for provider_id/kinds to live in metadata) +// are both optional and flow through to mod3 unchanged (mod3 ignores +// unknown fields). +type channelSessionRegisterRequest struct { + SessionID string `json:"session_id,omitempty"` + ParticipantID string `json:"participant_id"` + ParticipantType string `json:"participant_type,omitempty"` + PreferredVoice string `json:"preferred_voice,omitempty"` + PreferredOutputDevice string `json:"preferred_output_device,omitempty"` + Priority int `json:"priority,omitempty"` + Kinds []string `json:"kinds,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// channelSessionResponse is the merged shape returned from the kernel. The +// `kernel` block is the identity record (authoritative owner); `mod3` is the +// live channel state (voice pool, queue, device). Callers get everything in +// one round trip. Unknown fields from mod3 are preserved under `mod3` as a +// raw JSON object. +type channelSessionResponse struct { + Kernel *ChannelSessionRecord `json:"kernel"` + Mod3 json.RawMessage `json:"mod3,omitempty"` +} + +// channelSessionListResponse is the shape of GET /v1/channel-sessions. The +// kernel list is always present; the mod3 block is an opaque pass-through so +// clients don't need to know mod3's schema. +type channelSessionListResponse struct { + Kernel []*ChannelSessionRecord `json:"kernel"` + Mod3 json.RawMessage `json:"mod3,omitempty"` +} + +// ─── shared register/deregister/list logic (Wave 3.5) ──────────────────────── +// +// These methods are the single place session-ID minting, kernel registry +// commits, and mod3 forwarding happen. Both the HTTP handlers below and the +// mod3_register_session / mod3_deregister_session / mod3_list_sessions MCP +// tools (see mcp_modality_proxy.go) call through here so session-ID +// authority stays centralized — nobody reaches mod3's /v1/sessions/* surface +// directly except this one codepath. + +// channelSessionForwardError classifies a failure in RegisterChannelSession / +// DeregisterChannelSession / ListChannelSessions so the caller (HTTP handler +// or MCP tool) can surface the right status/message without re-parsing. +type channelSessionForwardError struct { + // Kind: "invalid_request" | "mod3_unreachable" | "mod3_rejected" + Kind string + // HTTPStatus is the status the caller should emit (400 for + // invalid_request, 502 for mod3_unreachable, mod3's own status for + // mod3_rejected). + HTTPStatus int + // Message is a human-readable description, safe to surface to callers. + Message string + // Mod3Body carries the raw JSON body mod3 returned when Kind == + // "mod3_rejected", so the HTTP handler can pass it through verbatim. + Mod3Body json.RawMessage +} + +func (e *channelSessionForwardError) Error() string { return e.Message } + +// RegisterChannelSession is the Wave 2+3.5 shared entry point for channel- +// session registration. It mints a session_id when absent, forwards to mod3, +// and commits the kernel-side identity record on success. +// +// Callers: +// - HTTP: POST /v1/channel-sessions/register (handleChannelSessionRegister) +// - MCP: mod3_register_session tool (toolMod3RegisterSession) +// +// Returning a (resp, nil) pair means the caller should surface the merged +// response with 200 OK. Returning a non-nil *channelSessionForwardError +// tells the caller which status + body shape to surface. +func (s *Server) RegisterChannelSession(ctx context.Context, req channelSessionRegisterRequest) (*channelSessionResponse, *channelSessionForwardError) { + if req.ParticipantID == "" { + return nil, &channelSessionForwardError{ + Kind: "invalid_request", + HTTPStatus: http.StatusBadRequest, + Message: "participant_id is required", + } + } + + // Mint when absent — kernel is the session-ID authority (ADR-082). + idSource := "caller" + if req.SessionID == "" { + req.SessionID = mintChannelSessionID() + idSource = "minted" + } + if req.ParticipantType == "" { + req.ParticipantType = "agent" // matches mod3's default + } + if req.PreferredOutputDevice == "" { + req.PreferredOutputDevice = "system-default" + } + + now := time.Now().UTC() + record := ChannelSessionRecord{ + SessionID: req.SessionID, + ParticipantID: req.ParticipantID, + ParticipantType: req.ParticipantType, + PreferredVoice: req.PreferredVoice, + PreferredOutputDevice: req.PreferredOutputDevice, + Priority: req.Priority, + Kinds: req.Kinds, + Metadata: req.Metadata, + RegisteredAt: now, + LastSeen: now, + IDSource: idSource, + } + + // Forward to mod3 with the kernel-issued session_id. Mod3's body is the + // same shape modulo the optional session_id field (mod3 requires it; we + // always supply one). `kinds` and `metadata` are RFC-level fields mod3 + // currently ignores; we still forward them so mod3 can start consuming + // them without a kernel change when it's ready. + forwardBody := map[string]any{ + "session_id": req.SessionID, + "participant_id": req.ParticipantID, + "participant_type": req.ParticipantType, + "preferred_voice": req.PreferredVoice, + "preferred_output_device": req.PreferredOutputDevice, + "priority": req.Priority, + } + if len(req.Kinds) > 0 { + forwardBody["kinds"] = req.Kinds + } + if len(req.Metadata) > 0 { + forwardBody["metadata"] = req.Metadata + } + body, _ := json.Marshal(forwardBody) + + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodPost, + "/v1/sessions/register", bytes.NewReader(body)) + if err != nil { + slog.Warn("channel-sessions: forward to mod3 failed", + "session_id", req.SessionID, "err", err) + return nil, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } + } + + if status < 200 || status >= 300 { + slog.Warn("channel-sessions: mod3 returned non-2xx", + "session_id", req.SessionID, "status", status) + return nil, &channelSessionForwardError{ + Kind: "mod3_rejected", + HTTPStatus: status, + Message: fmt.Sprintf("mod3 returned %d", status), + Mod3Body: mod3Resp, + } + } + + // Mod3 accepted — commit the kernel-side identity record. + s.channelSessionRegistry.Put(record) + slog.Info("channel-sessions: registered", + "session_id", req.SessionID, "participant_id", req.ParticipantID, + "id_source", idSource) + + return &channelSessionResponse{Kernel: &record, Mod3: mod3Resp}, nil +} + +// DeregisterChannelSession is the shared entry point for deregistration. +// Forwards to mod3 and drops the kernel registry row on any non-5xx mod3 +// response (including 404 — "mod3 forgot" is equivalent to "kernel should +// forget too"). On transport failure the kernel keeps the record so the +// caller can retry. +// +// Returns the raw mod3 body + status on success. Callers should surface +// both verbatim (writeJSONPassThrough on the HTTP side; the MCP tool +// wraps it as a JSON result). +func (s *Server) DeregisterChannelSession(ctx context.Context, sessionID string) (json.RawMessage, int, *channelSessionForwardError) { + if sessionID == "" { + return nil, 0, &channelSessionForwardError{ + Kind: "invalid_request", + HTTPStatus: http.StatusBadRequest, + Message: "session_id is required", + } + } + + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodPost, + "/v1/sessions/"+sessionID+"/deregister", nil) + if err != nil { + slog.Warn("channel-sessions: deregister forward failed", + "session_id", sessionID, "err", err) + return nil, 0, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } + } + + // Kernel drops its identity record whenever mod3 successfully + // acknowledges, including 404 ("never registered" is equivalent to + // "not tracked"; clean slate in kernel matches clean slate in mod3). + if status >= 200 && status < 500 { + s.channelSessionRegistry.Delete(sessionID) + } + + return mod3Resp, status, nil +} + +// ListChannelSessions is the shared entry point for the merged list query. +// Returns the kernel snapshot plus mod3's raw `GET /v1/sessions` body. +// +// On mod3 transport failure: error (502). On mod3 non-2xx: returns the +// raw mod3 body with its status attached so the caller can surface intact. +func (s *Server) ListChannelSessions(ctx context.Context) (*channelSessionListResponse, int, *channelSessionForwardError) { + mod3Resp, status, err := s.forwardMod3(ctx, http.MethodGet, + "/v1/sessions", nil) + if err != nil { + slog.Warn("channel-sessions: list forward failed", "err", err) + return nil, 0, &channelSessionForwardError{ + Kind: "mod3_unreachable", + HTTPStatus: http.StatusBadGateway, + Message: fmt.Sprintf("mod3 unreachable: %v", err), + } + } + if status < 200 || status >= 300 { + return nil, status, &channelSessionForwardError{ + Kind: "mod3_rejected", + HTTPStatus: status, + Message: fmt.Sprintf("mod3 returned %d", status), + Mod3Body: mod3Resp, + } + } + return &channelSessionListResponse{ + Kernel: s.channelSessionRegistry.Snapshot(), + Mod3: mod3Resp, + }, http.StatusOK, nil +} + +// ─── POST /v1/channel-sessions/register ────────────────────────────────────── + +func (s *Server) handleChannelSessionRegister(w http.ResponseWriter, r *http.Request) { + var req channelSessionRegisterRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid_request", "body must be JSON") + return + } + + resp, ferr := s.RegisterChannelSession(r.Context(), req) + if ferr != nil { + switch ferr.Kind { + case "invalid_request": + writeJSONError(w, ferr.HTTPStatus, "invalid_request", ferr.Message) + case "mod3_unreachable": + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + case "mod3_rejected": + writeJSONPassThrough(w, ferr.HTTPStatus, ferr.Mod3Body) + default: + writeJSONError(w, http.StatusInternalServerError, "internal", ferr.Message) + } + return + } + writeJSONResp(w, http.StatusOK, resp) +} + +// ─── POST /v1/channel-sessions/{id}/deregister ─────────────────────────────── + +func (s *Server) handleChannelSessionDeregister(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "session_id required in path") + return + } + + mod3Resp, status, ferr := s.DeregisterChannelSession(r.Context(), id) + if ferr != nil { + if ferr.Kind == "mod3_unreachable" { + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + return + } + writeJSONError(w, ferr.HTTPStatus, ferr.Kind, ferr.Message) + return + } + writeJSONPassThrough(w, status, mod3Resp) +} + +// ─── GET /v1/channel-sessions ──────────────────────────────────────────────── + +func (s *Server) handleChannelSessionList(w http.ResponseWriter, r *http.Request) { + resp, status, ferr := s.ListChannelSessions(r.Context()) + if ferr != nil { + switch ferr.Kind { + case "mod3_unreachable": + writeJSONError(w, ferr.HTTPStatus, "mod3_unreachable", ferr.Message) + case "mod3_rejected": + writeJSONPassThrough(w, ferr.HTTPStatus, ferr.Mod3Body) + default: + writeJSONError(w, http.StatusInternalServerError, "internal", ferr.Message) + } + return + } + writeJSONResp(w, status, resp) +} + +// ─── GET /v1/channel-sessions/{id} ─────────────────────────────────────────── + +func (s *Server) handleChannelSessionGet(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if id == "" { + writeJSONError(w, http.StatusBadRequest, "invalid_request", + "session_id required in path") + return + } + + mod3Resp, status, err := s.forwardMod3(r.Context(), http.MethodGet, + "/v1/sessions/"+id, nil) + if err != nil { + slog.Warn("channel-sessions: get forward failed", + "session_id", id, "err", err) + writeJSONError(w, http.StatusBadGateway, "mod3_unreachable", + fmt.Sprintf("mod3 unreachable: %v", err)) + return + } + if status < 200 || status >= 300 { + writeJSONPassThrough(w, status, mod3Resp) + return + } + + kernelRec, _ := s.channelSessionRegistry.Get(id) + resp := channelSessionResponse{ + Kernel: kernelRec, + Mod3: mod3Resp, + } + writeJSONResp(w, http.StatusOK, resp) +} + +// ─── forwarder + helpers ───────────────────────────────────────────────────── + +// forwardMod3 is the single HTTP egress point to mod3. Returns the raw +// response body, HTTP status, and a transport error (non-nil only when the +// request never got a status back — connection refused, DNS, TLS, timeout). +// Mod3 error bodies (4xx/5xx) come back with err == nil and caller decides +// how to surface them. +func (s *Server) forwardMod3(ctx context.Context, method, path string, body io.Reader) (json.RawMessage, int, error) { + base := strings.TrimRight(s.cfg.Mod3URL, "/") + if base == "" { + return nil, 0, errors.New("Mod3URL not configured") + } + url := base + path + + // Scope a per-request timeout on top of the shared client's. The ambient + // request context may have a longer deadline; we want the forward to + // fail fast if mod3 stalls. + reqCtx, cancel := context.WithTimeout(ctx, defaultMod3ForwardTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, method, url, body) + if err != nil { + return nil, 0, fmt.Errorf("build request: %w", err) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + client := s.mod3Client + if client == nil { + client = mod3HTTPClient + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1 MB cap + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err) + } + if len(raw) == 0 { + // Mod3's deregister can return a JSON object; other endpoints + // always return JSON; if the body is empty we substitute null so + // json.RawMessage doesn't marshal an empty string (invalid JSON). + raw = []byte("null") + } + if !json.Valid(raw) { + // Surface as a synthetic JSON string; don't propagate raw bytes. + wrapped, _ := json.Marshal(map[string]string{"raw": string(raw)}) + return json.RawMessage(wrapped), resp.StatusCode, nil + } + return json.RawMessage(raw), resp.StatusCode, nil +} + +// mintChannelSessionID returns a 12-char lowercase-hex short UUID. Short +// enough to be human-readable in logs, unique enough to avoid collisions +// across a single mod3 instance's lifetime. +func mintChannelSessionID() string { + u := uuid.New() + hex := strings.ReplaceAll(u.String(), "-", "") + return "cs-" + hex[:12] +} + +// writeJSONPassThrough writes the given status + raw JSON body verbatim. Used +// when mod3's response (especially errors) should flow to the caller intact. +func writeJSONPassThrough(w http.ResponseWriter, status int, body json.RawMessage) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if len(body) == 0 { + _, _ = w.Write([]byte("null")) + return + } + _, _ = w.Write(body) +} diff --git a/internal/engine/serve_sessions_channel_test.go b/internal/engine/serve_sessions_channel_test.go new file mode 100644 index 0000000..1d9f41e --- /dev/null +++ b/internal/engine/serve_sessions_channel_test.go @@ -0,0 +1,627 @@ +// serve_sessions_channel_test.go — end-to-end tests for the kernel-side +// channel-session forwarder. Each test stands up a fake mod3 via +// httptest.NewServer and exercises the kernel handler against it. +package engine + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" +) + +// ─── fakeMod3 — a tiny stub that canonically mirrors mod3's responses ──────── + +type fakeMod3 struct { + t *testing.T + srv *httptest.Server + mu sync.Mutex + captured []capturedRequest + + // Overrides — tests set these to control per-endpoint behavior. + registerHandler http.HandlerFunc + deregisterHandler http.HandlerFunc + listHandler http.HandlerFunc + getHandler http.HandlerFunc +} + +type capturedRequest struct { + Method string + Path string + Body []byte +} + +func newFakeMod3(t *testing.T) *fakeMod3 { + t.Helper() + fm := &fakeMod3{t: t} + mux := http.NewServeMux() + + mux.HandleFunc("POST /v1/sessions/register", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.registerHandler != nil { + fm.registerHandler(w, r) + return + } + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + sid, _ := body["session_id"].(string) + resp := map[string]any{ + "session_id": sid, + "participant_id": body["participant_id"], + "assigned_voice": "bm_lewis", + "voice_conflict": false, + "output_device": map[string]any{"name": "system-default", "live": true}, + "queue_depth": 0, + "created": true, + } + writeFakeJSON(w, http.StatusOK, resp) + }) + + mux.HandleFunc("POST /v1/sessions/{id}/deregister", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.deregisterHandler != nil { + fm.deregisterHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "session_id": r.PathValue("id"), + "released_voice": "bm_lewis", + "dropped_jobs": 0, + }) + }) + + mux.HandleFunc("GET /v1/sessions", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.listHandler != nil { + fm.listHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "sessions": []any{}, + "serializer": map[string]any{"policy": "round-robin"}, + "voice_pool": []any{"bm_lewis", "af_bella"}, + "voice_holders": map[string]any{}, + }) + }) + + mux.HandleFunc("GET /v1/sessions/{id}", func(w http.ResponseWriter, r *http.Request) { + fm.capture(r) + if fm.getHandler != nil { + fm.getHandler(w, r) + return + } + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": r.PathValue("id"), + "assigned_voice": "bm_lewis", + "output_device": map[string]any{"name": "system-default"}, + }) + }) + + fm.srv = httptest.NewServer(mux) + t.Cleanup(func() { fm.srv.Close() }) + return fm +} + +func (fm *fakeMod3) capture(r *http.Request) { + body, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewReader(body)) + fm.mu.Lock() + defer fm.mu.Unlock() + fm.captured = append(fm.captured, capturedRequest{ + Method: r.Method, Path: r.URL.Path, Body: body, + }) +} + +func (fm *fakeMod3) lastCaptured() capturedRequest { + fm.mu.Lock() + defer fm.mu.Unlock() + if len(fm.captured) == 0 { + fm.t.Fatalf("expected at least one captured request") + } + return fm.captured[len(fm.captured)-1] +} + +func writeFakeJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +// ─── newChannelServer — a Server wired just enough to exercise the routes ──── + +// newChannelServer builds a Server with only the fields needed by the +// channel-session handlers: cfg.Mod3URL pointing at fake mod3, a fresh +// ChannelSessionRegistry, and mux routes registered. The httptest.Client +// used by the kernel is overridden so every call targets fake mod3 regardless +// of the Mod3URL's scheme (handy for tests that want to inject network +// failures without depending on DNS). +func newChannelServer(t *testing.T, fm *fakeMod3) (*Server, *httptest.Server) { + t.Helper() + cfg := &Config{Mod3URL: fm.srv.URL} + s := &Server{ + cfg: cfg, + channelSessionRegistry: NewChannelSessionRegistry(), + } + mux := http.NewServeMux() + s.registerChannelSessionRoutes(mux) + front := httptest.NewServer(mux) + t.Cleanup(func() { front.Close() }) + return s, front +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +func TestChannelSessionRegister_MintsIDWhenOmitted(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "participant_id": "cog", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode response: %v", err) + } + if decoded.Kernel == nil { + t.Fatalf("expected kernel block, got nil") + } + if decoded.Kernel.SessionID == "" { + t.Fatal("expected minted session_id, got empty string") + } + if !strings.HasPrefix(decoded.Kernel.SessionID, "cs-") { + t.Fatalf("expected minted ID to carry cs- prefix, got %q", decoded.Kernel.SessionID) + } + if decoded.Kernel.IDSource != "minted" { + t.Fatalf("expected id_source=minted, got %q", decoded.Kernel.IDSource) + } + + // Mod3 should have seen the kernel-minted ID. + cap := fm.lastCaptured() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("unmarshal forwarded body: %v", err) + } + if forwarded["session_id"] != decoded.Kernel.SessionID { + t.Fatalf("expected mod3 to receive minted session_id %q, got %v", + decoded.Kernel.SessionID, forwarded["session_id"]) + } + + // Kernel registry should hold the record. + if _, ok := s.channelSessionRegistry.Get(decoded.Kernel.SessionID); !ok { + t.Fatal("expected kernel registry to retain record after success") + } +} + +func TestChannelSessionRegister_UsesCallerSuppliedID(t *testing.T) { + fm := newFakeMod3(t) + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "vox-42", + "participant_id": "sandy", + "participant_type": "agent", + "preferred_voice": "af_bella", + "preferred_output_device": "AirPods", + "priority": 3, + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if decoded.Kernel.SessionID != "vox-42" { + t.Fatalf("expected caller session_id preserved, got %q", decoded.Kernel.SessionID) + } + if decoded.Kernel.IDSource != "caller" { + t.Fatalf("expected id_source=caller, got %q", decoded.Kernel.IDSource) + } + + // Verify all caller-supplied fields made it through the forward. + cap := fm.lastCaptured() + var forwarded map[string]any + _ = json.Unmarshal(cap.Body, &forwarded) + if forwarded["session_id"] != "vox-42" || + forwarded["participant_id"] != "sandy" || + forwarded["participant_type"] != "agent" || + forwarded["preferred_voice"] != "af_bella" || + forwarded["preferred_output_device"] != "AirPods" { + t.Fatalf("forwarded fields mismatch: %v", forwarded) + } + // Priority is a number in JSON; compare via float64. + if p, _ := forwarded["priority"].(float64); int(p) != 3 { + t.Fatalf("expected priority=3 forwarded, got %v", forwarded["priority"]) + } +} + +func TestChannelSessionRegister_MergesMod3Response(t *testing.T) { + fm := newFakeMod3(t) + fm.registerHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusOK, map[string]any{ + "session_id": "cs-fixed", + "assigned_voice": "bm_oxford", + "voice_conflict": true, + "queue_depth": 5, + "created": false, + }) + } + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "cs-fixed", + "participant_id": "cog", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + + // The merged response should carry mod3's assigned_voice and + // voice_conflict fields intact. + var decoded map[string]json.RawMessage + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + var mod3 map[string]any + if err := json.Unmarshal(decoded["mod3"], &mod3); err != nil { + t.Fatalf("decode mod3 block: %v", err) + } + if mod3["assigned_voice"] != "bm_oxford" { + t.Fatalf("expected assigned_voice=bm_oxford in merge, got %v", mod3["assigned_voice"]) + } + if mod3["voice_conflict"] != true { + t.Fatalf("expected voice_conflict=true in merge, got %v", mod3["voice_conflict"]) + } +} + +func TestChannelSessionRegister_ReturnsBadGatewayWhenMod3Down(t *testing.T) { + // Build a fakeMod3 and immediately close it so the URL refuses connections. + fm := newFakeMod3(t) + fm.srv.Close() + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{"participant_id": "cog"}) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadGateway { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 502, got %d; body: %s", resp.StatusCode, raw) + } + // Kernel must NOT hold a record for a failed forward. + if s.channelSessionRegistry.Len() != 0 { + t.Fatalf("expected empty kernel registry after 502, got %d rows", + s.channelSessionRegistry.Len()) + } +} + +func TestChannelSessionRegister_PropagatesMod3Error(t *testing.T) { + fm := newFakeMod3(t) + fm.registerHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusBadRequest, map[string]string{ + "error": "participant_type must be agent|user", + }) + } + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "participant_id": "cog", + "participant_type": "alien", + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 400 passthrough, got %d; body: %s", resp.StatusCode, raw) + } + var decoded map[string]string + _ = json.NewDecoder(resp.Body).Decode(&decoded) + if !strings.Contains(decoded["error"], "participant_type") { + t.Fatalf("expected mod3 error body preserved, got %v", decoded) + } + if s.channelSessionRegistry.Len() != 0 { + t.Fatal("expected kernel registry empty when mod3 rejected registration") + } +} + +// TestChannelSessionRegister_ForwardsKindsAndMetadata verifies the Wave 3.5 +// schema alignment: the optional `kinds` array and `metadata` object from +// the channel-provider RFC's cogos_session_register primitive flow through +// Wave 2's register endpoint and land in mod3's request body unchanged. +func TestChannelSessionRegister_ForwardsKindsAndMetadata(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{ + "session_id": "cs-kinds-http", + "participant_id": "mod3-provider", + "participant_type": "provider", + "kinds": []string{"audio"}, + "metadata": map[string]any{ + "provider_id": "mod3-local", + "build": "0.5.0", + }, + }) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + + cap := fm.lastCaptured() + var forwarded map[string]any + if err := json.Unmarshal(cap.Body, &forwarded); err != nil { + t.Fatalf("decode forwarded body: %v", err) + } + kinds, ok := forwarded["kinds"].([]any) + if !ok || len(kinds) != 1 || kinds[0] != "audio" { + t.Fatalf("expected mod3 to receive kinds=[\"audio\"], got %v", forwarded["kinds"]) + } + md, ok := forwarded["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected mod3 to receive metadata object, got %v (%T)", + forwarded["metadata"], forwarded["metadata"]) + } + if md["provider_id"] != "mod3-local" { + t.Fatalf("expected metadata.provider_id=mod3-local forwarded, got %v", md["provider_id"]) + } + + // Kernel record must retain the RFC fields so downstream consumers + // can filter by capability even when mod3 ignores them. + rec, ok := s.channelSessionRegistry.Get("cs-kinds-http") + if !ok { + t.Fatal("expected kernel registry to hold record") + } + if len(rec.Kinds) != 1 || rec.Kinds[0] != "audio" { + t.Fatalf("expected record.Kinds=[audio], got %v", rec.Kinds) + } + if rec.Metadata["provider_id"] != "mod3-local" { + t.Fatalf("expected record.Metadata.provider_id=mod3-local, got %v", rec.Metadata) + } +} + +func TestChannelSessionRegister_RequiresParticipantID(t *testing.T) { + fm := newFakeMod3(t) + _, front := newChannelServer(t, fm) + + body, _ := json.Marshal(map[string]any{}) + resp, err := http.Post(front.URL+"/v1/channel-sessions/register", + "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("POST register: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + // No forward should have happened. + fm.mu.Lock() + if len(fm.captured) != 0 { + t.Fatalf("expected no forward to mod3 when participant_id missing, got %d", len(fm.captured)) + } + fm.mu.Unlock() +} + +func TestChannelSessionDeregister_Forwards(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + // Pre-populate the kernel registry so we can verify deletion. + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-to-drop", ParticipantID: "cog", + }) + + req, _ := http.NewRequest(http.MethodPost, + front.URL+"/v1/channel-sessions/cs-to-drop/deregister", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("deregister: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d; body: %s", resp.StatusCode, raw) + } + if _, ok := s.channelSessionRegistry.Get("cs-to-drop"); ok { + t.Fatal("expected kernel registry to drop record after successful deregister") + } + cap := fm.lastCaptured() + if cap.Path != "/v1/sessions/cs-to-drop/deregister" || cap.Method != http.MethodPost { + t.Fatalf("unexpected forward: %+v", cap) + } +} + +func TestChannelSessionDeregister_Returns502WhenMod3Down(t *testing.T) { + fm := newFakeMod3(t) + fm.srv.Close() + s, front := newChannelServer(t, fm) + + s.channelSessionRegistry.Put(ChannelSessionRecord{SessionID: "keeper"}) + + req, _ := http.NewRequest(http.MethodPost, + front.URL+"/v1/channel-sessions/keeper/deregister", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("deregister: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("expected 502, got %d", resp.StatusCode) + } + // On transport failure the kernel keeps the record — nothing authoritative + // changed. The caller can retry. + if _, ok := s.channelSessionRegistry.Get("keeper"); !ok { + t.Fatal("expected kernel to preserve record on transport failure") + } +} + +func TestChannelSessionList_MergesSnapshots(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + + // Seed kernel registry so we can see the kernel block in the merged response. + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-seed", ParticipantID: "cog", + }) + + resp, err := http.Get(front.URL + "/v1/channel-sessions") + if err != nil { + t.Fatalf("list: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var decoded channelSessionListResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if len(decoded.Kernel) != 1 || decoded.Kernel[0].SessionID != "cs-seed" { + t.Fatalf("expected kernel snapshot of 1, got %+v", decoded.Kernel) + } + if len(decoded.Mod3) == 0 { + t.Fatal("expected mod3 block populated from fake mod3") + } +} + +func TestChannelSessionGet_MergesKernelAndMod3(t *testing.T) { + fm := newFakeMod3(t) + s, front := newChannelServer(t, fm) + s.channelSessionRegistry.Put(ChannelSessionRecord{ + SessionID: "cs-detail", ParticipantID: "cog", PreferredVoice: "bm_lewis", + }) + + resp, err := http.Get(front.URL + "/v1/channel-sessions/cs-detail") + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var decoded channelSessionResponse + if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { + t.Fatalf("decode: %v", err) + } + if decoded.Kernel == nil || decoded.Kernel.SessionID != "cs-detail" { + t.Fatalf("expected kernel record returned, got %+v", decoded.Kernel) + } + if len(decoded.Mod3) == 0 { + t.Fatal("expected mod3 block populated from fake mod3") + } +} + +func TestChannelSessionGet_Returns404WhenMod3NotFound(t *testing.T) { + fm := newFakeMod3(t) + fm.getHandler = func(w http.ResponseWriter, r *http.Request) { + writeFakeJSON(w, http.StatusNotFound, map[string]string{"error": "not found"}) + } + _, front := newChannelServer(t, fm) + + resp, err := http.Get(front.URL + "/v1/channel-sessions/does-not-exist") + if err != nil { + t.Fatalf("get: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404 passthrough, got %d", resp.StatusCode) + } +} + +// ─── forwardMod3 direct tests ──────────────────────────────────────────────── + +func TestForwardMod3_ErrorsWhenURLUnset(t *testing.T) { + s := &Server{cfg: &Config{Mod3URL: ""}} + _, _, err := s.forwardMod3(context.Background(), http.MethodGet, "/v1/sessions", nil) + if err == nil { + t.Fatal("expected error when Mod3URL is empty") + } +} + +func TestForwardMod3_TimeoutYieldsTransportError(t *testing.T) { + // A listener that accepts connections but never writes — guarantees the + // 8s per-request timeout fires. Use a short deadline so the test runs + // quickly. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { l.Close() }) + + s := &Server{ + cfg: &Config{Mod3URL: "http://" + l.Addr().String()}, + mod3Client: &http.Client{Timeout: 50 * 1e6}, // 50ms + } + _, _, err = s.forwardMod3(context.Background(), http.MethodGet, "/v1/sessions", nil) + if err == nil { + t.Fatal("expected transport error on stalled server") + } + // Deadline exceeded / i/o timeout / context canceled are all acceptable shapes. + msg := err.Error() + if !strings.Contains(strings.ToLower(msg), "timeout") && + !strings.Contains(strings.ToLower(msg), "deadline") && + !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected timeout-ish error, got %v", err) + } +} + +// ─── minor sanity tests for minting + helpers ──────────────────────────────── + +func TestMintChannelSessionID_ShapeAndUniqueness(t *testing.T) { + seen := map[string]bool{} + for i := 0; i < 100; i++ { + id := mintChannelSessionID() + if !strings.HasPrefix(id, "cs-") { + t.Fatalf("expected cs- prefix, got %q", id) + } + if len(id) != 3+12 { + t.Fatalf("expected 15-char ID (cs- + 12 hex), got len=%d (%q)", len(id), id) + } + if seen[id] { + t.Fatalf("mint collision after %d iterations: %q", i, id) + } + seen[id] = true + } +}