Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
partial_transcript, transcript,
trace_event — kernel cycle-trace events (ADR-083), fanned out via
BrowserChannel.broadcast_trace_event().

The MOD3_USE_COGOS_AGENT kernel-bridged path emits response_text AND
response_complete via BrowserChannel.broadcast_response_{text,complete}
so the dashboard UI's turn-done signal fires on every turn, matching the
local-inference path's behavior.
"""

from __future__ import annotations
Expand Down Expand Up @@ -493,6 +498,40 @@ def broadcast_response_text(cls, text: str, session_id: str | None = None) -> No
except Exception as exc: # noqa: BLE001 — disconnected clients are expected
logger.debug("response_text send failed for %s: %s", ch.channel_id, exc)

@classmethod
def broadcast_response_complete(
cls,
metrics: dict | None = None,
session_id: str | None = None,
) -> None:
"""Push a `response_complete` frame to dashboard WebSocket clients.

Companion to :meth:`broadcast_response_text`: the MOD3_USE_COGOS_AGENT
response bridge emits exactly one complete-frame per kernel
`agent_response` event so the dashboard UI's per-turn `isResponding`
state gets cleared (otherwise the chat panel spinner hangs forever).

Routing and threading match `broadcast_response_text` 1:1 — pass the
same `session_id` so the completion frame lands on the same channel
that received the text frames for this turn. `metrics` follows the
local-path convention from `agent_loop._process` (`{"llm_ms": ...,
"provider": ...}`); the kernel path populates it with
`{"provider": "cogos-agent", ...}`.
"""
frame = {"type": "response_complete", "metrics": metrics or {}}
expected_channel = None
if session_id and session_id.startswith("mod3:"):
expected_channel = session_id[len("mod3:") :]
for ch in list(cls._active_channels):
if not ch._active:
continue
if expected_channel and ch.channel_id != expected_channel:
continue
try:
asyncio.run_coroutine_threadsafe(ch.ws.send_json(frame), ch._loop)
except Exception as exc: # noqa: BLE001 — disconnected clients are expected
logger.debug("response_complete send failed for %s: %s", ch.channel_id, exc)

# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
Expand Down
31 changes: 28 additions & 3 deletions cogos_agent_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,24 @@ def _extract_response_text(payload: dict) -> Optional[str]:
async def run_response_bridge(subscriber: KernelBusSubscriber) -> None:
"""Consume `subscriber` and broadcast agent replies to dashboard clients.

`BrowserChannel.broadcast_response_text()` is thread-safe via
`run_coroutine_threadsafe`, matching the existing trace-event pattern.
Malformed events (no recoverable text) are logged at debug and skipped.
Each kernel `agent_response` event on `bus_dashboard_response` is a
complete per-turn reply (see `apps/cogos/agent_tools_respond.go` — the
`respond` tool is documented as "call at most once per user turn" and
the auto-fallback publishes once if the model skipped the tool call).
We therefore emit two dashboard frames per kernel event:

* ``broadcast_response_text`` — the reply body (chat panel render)
* ``broadcast_response_complete`` — the turn-done signal so the UI's
per-turn spinner clears. Without this, the dashboard hangs
awaiting completion because the kernel path never reaches the
``send_response_complete`` call that the local-inference branch
emits at ``agent_loop._process`` ~L300.

``BrowserChannel.broadcast_response_{text,complete}()`` are thread-safe
via ``run_coroutine_threadsafe``, matching the existing trace-event
pattern. Malformed events (no recoverable text) are logged at debug
and skipped — we do NOT emit a completion frame for skipped events
(keeps the 1:1 pairing with what the UI actually rendered).
"""
first_event_logged = False
forwarded = 0
Expand All @@ -221,6 +236,16 @@ async def run_response_bridge(subscriber: KernelBusSubscriber) -> None:
session_id = _extract_session_id(env.payload)
try:
BrowserChannel.broadcast_response_text(text, session_id=session_id)
# Pair the text frame with a completion frame so the dashboard's
# per-turn "awaiting response" state clears. Kernel emits exactly
# one agent_response per user turn, so one complete per event is
# the correct cardinality.
metrics: dict = {"provider": "cogos-agent"}
if env.event_id:
metrics["event_id"] = env.event_id
if env.ts:
metrics["kernel_ts"] = env.ts
BrowserChannel.broadcast_response_complete(metrics, session_id=session_id)
forwarded += 1
logger.debug(
"cogos-agent: forwarded response event_id=%s session=%s (total=%d)",
Expand Down
2 changes: 1 addition & 1 deletion demo/e2e_dashboard_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def deferred_interrupt() -> None:
audio_pcm.append(wav_bytes)
except Exception as e:
print(f" (audio decode err: {e})", flush=True)
elif t == "response_done" or t == "turn_complete":
elif t == "response_complete":
got_done = True
done_ts = time.time()
elif t == "trace_event":
Expand Down
161 changes: 161 additions & 0 deletions http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
from modality import EncodedOutput, ModalityType
from modules.text import TextModule
from modules.voice import VoiceModule
from session_registry import (
get_default_registry,
resolve_output_device,
)
from vad import detect_speech_file, is_hallucination
from vad import is_model_loaded as vad_loaded

Expand Down Expand Up @@ -244,6 +248,11 @@ class SynthesizeRequest(BaseModel):
speed: float = Field(default=1.25)
emotion: float = Field(default=0.5)
format: str = Field(default="wav", pattern="^(wav|pcm)$")
# ADR-082 Phase 1: optional session routing. When present, the
# session's assigned_voice overrides ``voice`` (unless an explicit
# non-default was passed), and the session is advanced in the global
# serializer's round-robin.
session_id: str | None = Field(default=None)


class SpeechRequest(BaseModel):
Expand All @@ -254,6 +263,9 @@ class SpeechRequest(BaseModel):
voice: str = Field(default="af_heart")
response_format: str = Field(default="mp3")
speed: float = Field(default=1.0)
# ADR-082 Phase 1 extension — not part of the OpenAI schema but harmless
# to accept. When absent, behavior is identical to before Phase 1.
session_id: str | None = Field(default=None)


class ShutdownRequest(BaseModel):
Expand All @@ -263,6 +275,17 @@ class ShutdownRequest(BaseModel):
reason: str = Field(default="shutdown-requested")


class SessionRegisterRequest(BaseModel):
"""Register a session with the Mod3 communication bus (ADR-082)."""

session_id: str
participant_id: str
participant_type: str = Field(default="agent")
preferred_voice: str | None = Field(default=None)
preferred_output_device: str = Field(default="system-default")
priority: int = Field(default=0)


# ---------------------------------------------------------------------------
# Shutdown middleware — reject new requests once shutdown is initiated
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -290,6 +313,39 @@ def synthesize(req: SynthesizeRequest):
import numpy as np

t_request = time.perf_counter()

# ADR-082 Phase 1: session routing. If the request names a session, we
# honor the session's assigned voice (unless the caller explicitly
# picked a non-default voice) and account the job against the session's
# queue + serializer so multi-session callers can see round-robin.
session_id = req.session_id
session_payload: dict | None = None
if session_id:
registry = get_default_registry()
session = registry.get(session_id)
if session is None:
return JSONResponse(
status_code=404,
content={
"error": f"session '{session_id}' is not registered — POST /v1/sessions/register first",
},
)
if req.voice == "bm_lewis" and session.assigned_voice != "bm_lewis":
req.voice = session.assigned_voice
# Register the submission with the serializer for accounting only.
# The synthesize endpoint is non-blocking on the audio side (we
# return bytes synchronously), so we do not run the registry's
# dispatcher here — we just record the submission.
try:
registry.submit(session_id, {"type": "synthesize", "text": req.text[:200]})
except Exception as exc: # noqa: BLE001
logger.debug("session submit accounting failed: %s", exc)
session_payload = {
"session_id": session.session_id,
"assigned_voice": session.assigned_voice,
"preferred_output_device": session.preferred_output_device,
}

job_id = _record_job(
{
"type": "synthesize",
Expand All @@ -301,6 +357,7 @@ def synthesize(req: SynthesizeRequest):
"emotion": req.emotion,
"format": req.format,
"engine": None,
"session_id": session_id,
"timeline": [{"event": "request_received", "t": 0.0}],
}
)
Expand Down Expand Up @@ -389,6 +446,8 @@ def synthesize(req: SynthesizeRequest):
"X-Mod3-RTF": f"{duration / gen_time:.2f}" if gen_time > 0 else "0",
"X-Mod3-Chunks": str(len(chunk_metrics)),
}
if session_payload is not None:
headers["X-Mod3-Session-Id"] = session_payload["session_id"]

return Response(content=audio_bytes, media_type=media_type, headers=headers)

Expand All @@ -400,6 +459,30 @@ def audio_speech(req: SpeechRequest):

t_request = time.perf_counter()

# ADR-082 Phase 1: optional session routing. Same semantics as
# /v1/synthesize — the session's assigned voice overrides ``voice`` when
# the caller passed the default, and the submission is accounted against
# the session's queue.
session_id = req.session_id
if session_id:
registry = get_default_registry()
session = registry.get(session_id)
if session is None:
return JSONResponse(
status_code=404,
content={
"error": f"session '{session_id}' is not registered — POST /v1/sessions/register first",
},
)
# OpenAI default is af_heart; if the caller left it at the default,
# prefer the session's voice.
if req.voice == "af_heart" and session.assigned_voice != "af_heart":
req.voice = session.assigned_voice
try:
registry.submit(session_id, {"type": "audio_speech", "text": req.input[:200]})
except Exception as exc: # noqa: BLE001
logger.debug("session submit accounting failed: %s", exc)

voice = req.voice
try:
voice = _resolve_voice_via_bus(voice)
Expand All @@ -414,6 +497,7 @@ def audio_speech(req: SpeechRequest):
"text": req.input[:200],
"voice": voice,
"speed": req.speed,
"session_id": session_id,
"timeline": [{"event": "request_received", "t": 0.0}],
}
)
Expand Down Expand Up @@ -464,6 +548,8 @@ def audio_speech(req: SpeechRequest):
"X-Mod3-Gen-Time-Sec": f"{gen_time:.3f}",
"X-Mod3-Total-Time-Sec": f"{total_time:.3f}",
}
if session_id:
headers["X-Mod3-Session-Id"] = session_id

return Response(content=audio_bytes, media_type="audio/wav", headers=headers)

Expand Down Expand Up @@ -650,6 +736,81 @@ def stop_speech(job_id: str = ""):
return JSONResponse(status_code=503, content={"error": "Speech queue not available in HTTP-only mode"})


# ---------------------------------------------------------------------------
# Sessions — ADR-082 Phase 1
# ---------------------------------------------------------------------------


@app.post("/v1/sessions/register")
def session_register(req: SessionRegisterRequest):
"""Register a session with the Mod3 communication bus (ADR-082).

Body:
{
"session_id": "...",
"participant_id": "cog" | "sandy" | "alice" | ...,
"participant_type": "agent" | "user",
"preferred_voice": "bm_lewis" | ... | null,
"preferred_output_device": "system-default" | "<device-name>"
}

Returns the SessionChannel with a live-resolved output_device.
"""
registry = get_default_registry()
try:
result = registry.register(
session_id=req.session_id,
participant_id=req.participant_id,
participant_type=req.participant_type,
preferred_voice=req.preferred_voice,
preferred_output_device=req.preferred_output_device or "system-default",
priority=req.priority,
)
except Exception as exc: # noqa: BLE001 — surface the error verbatim
return JSONResponse(status_code=400, content={"error": str(exc)})

payload = result.session.to_dict(device_resolver=resolve_output_device)
payload["created"] = result.created
# Top-level live device snapshot so the caller does not have to
# round-trip; the nested one is available for debugging.
payload["output_device"] = registry.resolve_device(result.session.session_id).to_dict()
return payload


@app.post("/v1/sessions/{session_id}/deregister")
def session_deregister(session_id: str):
"""Drop a session — drains/cancels pending jobs, returns voice to pool."""
registry = get_default_registry()
result = registry.deregister(session_id)
if result.get("status") == "not_found":
return JSONResponse(status_code=404, content=result)
return result


@app.get("/v1/sessions")
def session_list():
"""List all registered sessions plus a serializer snapshot."""
registry = get_default_registry()
return {
"sessions": registry.list_serialized(),
"serializer": registry.serializer.snapshot(),
"voice_pool": registry.voice_pool(),
"voice_holders": registry.voice_holder_snapshot(),
}


@app.get("/v1/sessions/{session_id}")
def session_get(session_id: str):
"""Get a single session's current state (with live device resolution)."""
registry = get_default_registry()
session = registry.get(session_id)
if session is None:
return JSONResponse(status_code=404, content={"error": f"session '{session_id}' not found"})
payload = session.to_dict(device_resolver=resolve_output_device)
payload["output_device"] = registry.resolve_device(session_id).to_dict()
return payload


@app.get("/health")
def health():
"""Health check — standardized CogOS service format."""
Expand Down
Loading
Loading