diff --git a/config/dev.yaml b/config/dev.yaml index b239d0b..b049005 100644 --- a/config/dev.yaml +++ b/config/dev.yaml @@ -62,6 +62,7 @@ enabled_handlers: # Add more handlers as they're implemented: # scripter: "role_play.scripter.handler.ScripterHandler" voice: "role_play.voice.handler.VoiceHandler" + voice_v2: "role_play.voice.handler_v2_spike.VoiceHandlerV2" # Language configuration supported_languages: diff --git a/docs/superpowers/plans/2026-03-20-voice-handler-v2-spike.md b/docs/superpowers/plans/2026-03-20-voice-handler-v2-spike.md new file mode 100644 index 0000000..89c8f0b --- /dev/null +++ b/docs/superpowers/plans/2026-03-20-voice-handler-v2-spike.md @@ -0,0 +1,2213 @@ +# Voice Handler V2 Spike Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement `handler_v2_spike.py` — a prototype voice handler for 45-minute behavioral interviews addressing all gaps from the voice handler gap report. + +**Architecture:** Single file with dataclasses (`VoiceSessionState`, `EventClassification`) and one handler class (`VoiceHandlerV2`). Four concurrent coroutines (receive, send, timer, heartbeat) coordinated by a shared `asyncio.Event` stop signal. Storage injected at construction. Sync event classifier + async router split. + +**Tech Stack:** Python 3.11+, FastAPI/Starlette WebSocket, google-adk (`Runner`, `LiveRequestQueue`, `RunConfig`), google-genai types, pytest/pytest-asyncio + +**Spec:** `docs/superpowers/specs/2026-03-20-voice-handler-v2-spike-design.md` + +**Constraints:** +- Do NOT modify any existing files +- All output goes to new files only +- Use existing `StorageBackend` interface (no new abstract methods) +- Use `google.genai` types (not `google.generativeai`) +- All new async methods must be `async def` + +--- + +## File Structure + +| File | Purpose | +|---|---| +| **Create:** `src/python/role_play/voice/handler_v2_spike.py` | Main handler — dataclasses, `VoiceHandlerV2` class | +| **Create:** `src/python/role_play/voice/voice_config_v2.py` | V2-specific constants (extends VoiceConfig without modifying it) | +| **Create:** `test/python/unit/voice/test_handler_v2_spike.py` | Unit tests for all testable methods | +| **Read (reference only):** `src/python/role_play/voice/handler.py` | V1 handler — reference for auth methods, patterns | +| **Read (reference only):** `src/python/role_play/voice/voice_config.py` | V1 config — inheritable constants | +| **Read (reference only):** `src/python/role_play/voice/models.py` | `VoiceRequest` model — reused as-is | +| **Read (reference only):** `src/python/role_play/server/base_handler.py` | `BaseHandler` ABC — must implement `router`, `prefix` | +| **Read (reference only):** `src/python/role_play/common/storage.py` | `StorageBackend` interface — `read`, `write`, `exists`, `read_bytes`, `write_bytes` | + +--- + +## Task 1: VoiceConfigV2 Constants + +**Files:** +- Create: `src/python/role_play/voice/voice_config_v2.py` +- Create: `test/python/unit/voice/test_voice_config_v2.py` + +- [ ] **Step 1: Write the test file** + +```python +"""Tests for VoiceConfigV2 constants.""" +import pytest +from role_play.voice.voice_config_v2 import VoiceConfigV2 + + +def test_session_timeout_default(): + assert VoiceConfigV2.DEFAULT_SESSION_TIMEOUT_SECONDS == 2700 + + +def test_session_warning_default(): + assert VoiceConfigV2.DEFAULT_SESSION_WARNING_SECONDS == 300 + + +def test_warning_floor(): + assert VoiceConfigV2.MIN_SESSION_WARNING_SECONDS == 10 + + +def test_heartbeat_interval(): + assert VoiceConfigV2.HEARTBEAT_INTERVAL_SECONDS == 30 + + +def test_context_window_trigger_tokens(): + assert VoiceConfigV2.CONTEXT_WINDOW_TRIGGER_TOKENS == 100_000 + + +def test_sentinel_patterns(): + assert "RPS_SESSION_COMPLETE" in VoiceConfigV2.SENTINEL_SESSION_COMPLETE + assert "RPS_END_EARLY" in VoiceConfigV2.SENTINEL_END_EARLY + + +def test_termination_reasons_are_strings(): + reasons = [ + VoiceConfigV2.REASON_USER_ENDED, + VoiceConfigV2.REASON_TIME_LIMIT, + VoiceConfigV2.REASON_AI_CONCLUDED, + VoiceConfigV2.REASON_AI_EARLY_TERMINATION, + VoiceConfigV2.REASON_DISCONNECTED, + ] + assert all(isinstance(r, str) for r in reasons) + + +def test_inherits_v1_audio_constants(): + """V2 re-exports V1 audio constants for compatibility.""" + assert VoiceConfigV2.AUDIO_SAMPLE_RATE == 16000 + assert VoiceConfigV2.AUDIO_CHANNELS == 1 + assert VoiceConfigV2.AUDIO_BIT_DEPTH == 16 + assert VoiceConfigV2.AUDIO_FORMAT == "pcm" +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_voice_config_v2.py` +Expected: FAIL — `ModuleNotFoundError: No module named 'role_play.voice.voice_config_v2'` + +- [ ] **Step 3: Write the implementation** + +```python +"""V2 voice configuration constants. + +Extends VoiceConfig with constants for the v2 spike handler: +session timer, heartbeat, sentinels, termination reasons. +Does not modify the original voice_config.py. +""" +from .voice_config import VoiceConfig + + +class VoiceConfigV2(VoiceConfig): + """Constants for voice handler v2 spike.""" + + # Session timer (per-session overridable via query params) + DEFAULT_SESSION_TIMEOUT_SECONDS = 2700 # 45 min + DEFAULT_SESSION_WARNING_SECONDS = 300 # 5 min + MIN_SESSION_WARNING_SECONDS = 10 # floor for sanity + + # Keepalive + HEARTBEAT_INTERVAL_SECONDS = 30 + + # Context window compression + CONTEXT_WINDOW_TRIGGER_TOKENS = 100_000 + CONTEXT_WINDOW_STRATEGY = "sliding_window" + + # AI termination sentinels (RPS_ prefix avoids false positives) + SENTINEL_SESSION_COMPLETE = "[RPS_SESSION_COMPLETE]" + SENTINEL_END_EARLY = "[RPS_END_EARLY:" # followed by reason] + + # Termination reasons + REASON_USER_ENDED = "USER_ENDED" + REASON_TIME_LIMIT = "TIME_LIMIT" + REASON_AI_CONCLUDED = "AI_CONCLUDED" + REASON_AI_EARLY_TERMINATION = "AI_EARLY_TERMINATION" + REASON_DISCONNECTED = "DISCONNECTED" + + # Storage key templates + HANDLE_KEY_TEMPLATE = "users/{user_id}/voice_sessions/{session_id}/gemini_handle" + META_KEY_TEMPLATE = "users/{user_id}/voice_sessions/{session_id}/session_meta" +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_voice_config_v2.py` +Expected: All 8 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/voice_config_v2.py test/python/unit/voice/test_voice_config_v2.py +git commit -m "feat(voice): add VoiceConfigV2 constants for v2 spike" +``` + +--- + +## Task 2: Dataclasses — VoiceSessionState & EventClassification + +**Files:** +- Create: `src/python/role_play/voice/handler_v2_spike.py` (initial scaffold — dataclasses + imports only) +- Create: `test/python/unit/voice/test_handler_v2_spike.py` (initial scaffold) + +- [ ] **Step 1: Write the test file for dataclasses** + +```python +"""Tests for voice handler v2 spike — dataclasses and event classification.""" +import asyncio +import base64 +import pytest +from datetime import datetime, timezone +from unittest.mock import MagicMock, AsyncMock, patch + +from starlette.websockets import WebSocketDisconnect + +from role_play.voice.handler_v2_spike import ( + VoiceSessionState, + EventClassification, +) + + +class TestVoiceSessionState: + def test_creation_with_required_fields(self): + state = VoiceSessionState( + session_id="sess-1", + user_id="user-1", + runner=MagicMock(), + live_events=MagicMock(), + live_request_queue=MagicMock(), + adk_session=MagicMock(), + stop_event=asyncio.Event(), + termination_reason=None, + started_at=datetime.now(timezone.utc), + session_timeout=60, + warning_seconds=10, + chat_logger=MagicMock(), + transcript_buffer=[], + stats={"audio_chunks_sent": 0, "errors": 0}, + ) + assert state.session_id == "sess-1" + assert state.termination_reason is None + assert state.transcript_buffer == [] + assert not state.stop_event.is_set() + + def test_started_at_is_datetime(self): + now = datetime.now(timezone.utc) + state = VoiceSessionState( + session_id="s", user_id="u", runner=MagicMock(), + live_events=MagicMock(), live_request_queue=MagicMock(), + adk_session=MagicMock(), stop_event=asyncio.Event(), + termination_reason=None, started_at=now, + session_timeout=60, warning_seconds=10, + chat_logger=MagicMock(), transcript_buffer=[], stats={}, + ) + assert isinstance(state.started_at, datetime) + assert state.started_at.tzinfo is not None + + def test_stop_event_can_be_set(self): + event = asyncio.Event() + state = VoiceSessionState( + session_id="s", user_id="u", runner=MagicMock(), + live_events=MagicMock(), live_request_queue=MagicMock(), + adk_session=MagicMock(), stop_event=event, + termination_reason=None, started_at=datetime.now(timezone.utc), + session_timeout=60, warning_seconds=10, + chat_logger=MagicMock(), transcript_buffer=[], stats={}, + ) + assert not state.stop_event.is_set() + state.stop_event.set() + assert state.stop_event.is_set() + + +class TestEventClassification: + def test_audio_event(self): + ec = EventClassification(kind="audio", data=b"\x01\x02") + assert ec.kind == "audio" + assert ec.role is None + assert ec.is_terminal is False + + def test_transcript_partial(self): + ec = EventClassification( + kind="transcript", data="Hello", role="user", is_partial=True + ) + assert ec.is_partial is True + assert ec.role == "user" + + def test_terminal_event(self): + ec = EventClassification( + kind="transcript", data="Goodbye", + role="assistant", is_terminal=True, + terminal_reason="AI_CONCLUDED", + ) + assert ec.is_terminal is True + assert ec.terminal_reason == "AI_CONCLUDED" + + def test_unknown_event(self): + ec = EventClassification(kind="unknown", data={"raw": "stuff"}) + assert ec.kind == "unknown" +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py` +Expected: FAIL — `ImportError` + +- [ ] **Step 3: Write the initial handler file with dataclasses** + +Create `src/python/role_play/voice/handler_v2_spike.py` with: + +```python +"""Voice Handler V2 Spike. + +Prototype handler for 45-minute behavioral interviews. +Addresses all gaps from docs/voice_handler_gap_report.md. + +This is a spike — not registered via standard handler registration. +Test directly via unit tests and manual WebSocket testing. +""" +import asyncio +import base64 +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional, Dict, Any, AsyncGenerator, List + +from fastapi import WebSocket, HTTPException, APIRouter +from google.adk import Runner +from google.adk.agents import RunConfig, LiveRequestQueue +from google.adk.sessions import BaseSessionService +from google.genai.types import ( + AudioTranscriptionConfig, Blob, Part, Content, Modality, +) +from starlette.websockets import WebSocketDisconnect + +from .models import VoiceRequest +from .voice_config_v2 import VoiceConfigV2 +from ..chat.chat_logger import ChatLogger +from ..common.exceptions import TokenExpiredError, AuthenticationError +from ..common.models import User, EnvironmentInfo +from ..common.storage import StorageBackend +from ..common.time_utils import utc_now_isoformat +from ..dev_agents.roleplay_agent.agent import get_production_agent +from ..server.base_handler import BaseHandler +from ..server.dependencies import ( + get_chat_logger, + get_adk_session_service, + get_auth_manager, + get_environment_info, +) + +logger = logging.getLogger(__name__) + + +# region Dataclasses + +@dataclass +class EventClassification: + """Result of classifying an ADK live event. + + Pure data — no side effects. Built by _classify_adk_event (sync). + Routed by _send_to_client (async). + """ + kind: str # "audio", "transcript", "turn_status", + # "session_resumption", "go_away", "unknown" + data: Any # parsed payload + role: Optional[str] = None # "user" or "assistant" + is_partial: bool = False + is_terminal: bool = False + terminal_reason: Optional[str] = None # AI_CONCLUDED or AI_EARLY_TERMINATION + + +@dataclass +class VoiceSessionState: + """Typed session state replacing v1's Dict[str, Any]. + + Holds all per-session data needed by the four streaming coroutines. + """ + session_id: str + user_id: str + runner: Runner + live_events: AsyncGenerator + live_request_queue: LiveRequestQueue + adk_session: Any + stop_event: asyncio.Event + termination_reason: Optional[str] + started_at: datetime + session_timeout: int + warning_seconds: int + chat_logger: ChatLogger + transcript_buffer: List[Dict] = field(default_factory=list) + stats: Dict[str, int] = field(default_factory=lambda: { + "audio_chunks_sent": 0, + "audio_chunks_received": 0, + "text_chunks_sent": 0, + "transcripts_processed": 0, + "errors": 0, + }) + +# endregion +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py` +Expected: All 7 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add v2 spike dataclasses — VoiceSessionState, EventClassification" +``` + +--- + +## Task 3: Event Classifier — `_classify_adk_event` + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` (add classifier tests) +- Modify: `src/python/role_play/voice/handler_v2_spike.py` (add classifier as module-level function) + +The classifier is a pure sync function — no class needed yet. It takes an event and returns an `EventClassification`. This makes it trivially testable. + +**Spec deviation:** The spec's method inventory (Section 10) lists `_classify_adk_event` as a class method. We deliberately implement it as a module-level `classify_adk_event` (public, no `self`) because it's a pure function with no dependency on handler state. This improves testability — tests call it directly without instantiating the handler class. + +- [ ] **Step 1: Write the tests** + +Add to `test/python/unit/voice/test_handler_v2_spike.py`: + +```python +from role_play.voice.handler_v2_spike import classify_adk_event +from role_play.voice.voice_config_v2 import VoiceConfigV2 + + +class TestClassifyAdkEvent: + """Tests for the sync event classifier.""" + + def _make_event(self, **kwargs): + """Build a mock ADK event with specified attributes.""" + event = MagicMock() + # Remove all optional attributes by default + for attr in ["content", "turn_complete", "interrupted", "partial", + "session_resumption_update", "go_away"]: + if attr not in kwargs: + # Use spec to control which attributes exist + delattr(event, attr) if hasattr(event, attr) else None + for k, v in kwargs.items(): + setattr(event, k, v) + return event + + def test_transcript_final_from_model(self): + part = MagicMock() + part.text = "Hello there" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = self._make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.data == "Hello there" + assert result.role == "assistant" + assert result.is_partial is False + assert result.is_terminal is False + + def test_transcript_partial(self): + part = MagicMock() + part.text = "Hel" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = self._make_event(content=content, partial=True) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.is_partial is True + + def test_transcript_from_user(self): + part = MagicMock() + part.text = "My answer" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "user" + event = self._make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.role == "user" + + def test_audio_event(self): + part = MagicMock() + del part.text # no text attribute + part.inline_data = MagicMock(data=b"\x01\x02", mime_type="audio/pcm") + content = MagicMock() + content.parts = [part] + event = self._make_event(content=content) + + result = classify_adk_event(event) + assert result.kind == "audio" + assert result.data["audio_data"] == b"\x01\x02" + assert result.data["mime_type"] == "audio/pcm" + + def test_sentinel_session_complete(self): + part = MagicMock() + part.text = "Thank you for your time. [RPS_SESSION_COMPLETE]" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = self._make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.is_terminal is True + assert result.terminal_reason == VoiceConfigV2.REASON_AI_CONCLUDED + assert VoiceConfigV2.SENTINEL_SESSION_COMPLETE not in result.data + + def test_sentinel_end_early(self): + part = MagicMock() + part.text = "I need to stop. [RPS_END_EARLY:off_topic]" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = self._make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.is_terminal is True + assert result.terminal_reason == VoiceConfigV2.REASON_AI_EARLY_TERMINATION + + def test_turn_status_complete(self): + event = self._make_event(turn_complete=True, interrupted=False) + # Ensure no content + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "turn_status" + assert result.data["turn_complete"] is True + + def test_turn_status_interrupted(self): + event = self._make_event(turn_complete=False, interrupted=True) + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "turn_status" + assert result.data["interrupted"] is True + + def test_session_resumption_event(self): + event = self._make_event(session_resumption_update="handle-abc-123") + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "session_resumption" + assert result.data == "handle-abc-123" + + def test_go_away_event(self): + event = self._make_event(go_away=MagicMock()) + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "go_away" + + def test_unknown_event(self): + """Event with no recognizable attributes → unknown.""" + event = MagicMock(spec=[]) # empty spec = no attributes + result = classify_adk_event(event) + assert result.kind == "unknown" + + def test_empty_content_parts(self): + content = MagicMock() + content.parts = [] + event = self._make_event(content=content) + + result = classify_adk_event(event) + # Empty content falls through to turn_status or unknown + assert result.kind in ("turn_status", "unknown") + + def test_none_content(self): + event = self._make_event(content=None, turn_complete=True) + result = classify_adk_event(event) + assert result.kind == "turn_status" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestClassifyAdkEvent` +Expected: FAIL — `ImportError: cannot import name 'classify_adk_event'` + +- [ ] **Step 3: Implement `classify_adk_event`** + +Add to `handler_v2_spike.py` after the dataclass definitions: + +```python +# region Event Classification + +def classify_adk_event(event: Any) -> EventClassification: + """Classify an ADK live event into a routing category. + + Pure sync function — no side effects, no I/O. + Unknown events are classified as "unknown" with full attribute logging. + + See spec Section 3.3 for detection heuristics. + """ + # Check for session resumption (best-effort detection) + if hasattr(event, "session_resumption_update"): + handle = event.session_resumption_update + # Handle could be string, bytes, or complex object + if not isinstance(handle, str): + try: + handle = json.dumps(handle) if not isinstance(handle, bytes) else handle.decode() + except (TypeError, UnicodeDecodeError): + handle = str(handle) + logger.info(f"Session resumption event detected, handle type: {type(event.session_resumption_update).__name__}") + return EventClassification(kind="session_resumption", data=handle) + + # Check for GoAway (best-effort detection) + if hasattr(event, "go_away"): + logger.warning(f"GoAway event received: {event.go_away}") + return EventClassification(kind="go_away", data=event.go_away) + + # Check for content (audio or transcript) + content = getattr(event, "content", None) + if content is not None and hasattr(content, "parts") and content.parts: + for part in content.parts: + # Text content → transcript + if hasattr(part, "text") and part.text: + is_partial = getattr(event, "partial", False) + role = "assistant" if getattr(content, "role", None) == "model" else "user" + text = part.text + is_terminal = False + terminal_reason = None + + # Sentinel detection + if VoiceConfigV2.SENTINEL_SESSION_COMPLETE in text: + is_terminal = True + terminal_reason = VoiceConfigV2.REASON_AI_CONCLUDED + text = text.replace(VoiceConfigV2.SENTINEL_SESSION_COMPLETE, "").strip() + elif VoiceConfigV2.SENTINEL_END_EARLY in text: + is_terminal = True + terminal_reason = VoiceConfigV2.REASON_AI_EARLY_TERMINATION + # Strip sentinel: [RPS_END_EARLY:reason] + idx = text.find(VoiceConfigV2.SENTINEL_END_EARLY) + end_idx = text.find("]", idx) + if end_idx != -1: + text = (text[:idx] + text[end_idx + 1:]).strip() + + return EventClassification( + kind="transcript", data=text, role=role, + is_partial=is_partial, is_terminal=is_terminal, + terminal_reason=terminal_reason, + ) + + # Audio content → audio + inline_data = getattr(part, "inline_data", None) + if inline_data is not None: + audio_data = getattr(inline_data, "data", None) + mime_type = getattr(inline_data, "mime_type", "audio/pcm") + if audio_data and len(audio_data) > 0: + return EventClassification( + kind="audio", + data={"audio_data": audio_data, "mime_type": mime_type}, + ) + + # Turn status (check after content — content events may also have these flags) + if hasattr(event, "turn_complete") or hasattr(event, "interrupted"): + return EventClassification( + kind="turn_status", + data={ + "turn_complete": getattr(event, "turn_complete", None), + "interrupted": getattr(event, "interrupted", None), + "partial": getattr(event, "partial", None), + }, + ) + + # Unknown — log everything for spike discovery + event_type = type(event).__name__ + event_attrs = {attr: str(getattr(event, attr, None)) + for attr in dir(event) if not attr.startswith("_")} + logger.debug(f"Unknown ADK event: type={event_type}, attrs={event_attrs}") + return EventClassification(kind="unknown", data={"type": event_type, "attrs": event_attrs}) + +# endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestClassifyAdkEvent` +Expected: All 13 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add sync event classifier for v2 spike" +``` + +--- + +## Task 4: Termination Handler — `_handle_termination` + +This is the single convergence point. Must be tested before the coroutines that call it. + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` (add termination tests) +- Modify: `src/python/role_play/voice/handler_v2_spike.py` (add `VoiceHandlerV2` class with `_handle_termination`) + +- [ ] **Step 1: Write the tests** + +Add to `test/python/unit/voice/test_handler_v2_spike.py`: + +```python +from role_play.voice.handler_v2_spike import VoiceHandlerV2 + + +def _make_state(**overrides) -> VoiceSessionState: + """Factory for VoiceSessionState with sensible defaults.""" + defaults = dict( + session_id="sess-1", user_id="user-1", + runner=MagicMock(), live_events=MagicMock(), + live_request_queue=MagicMock(), adk_session=MagicMock(), + stop_event=asyncio.Event(), termination_reason=None, + started_at=datetime.now(timezone.utc), + session_timeout=60, warning_seconds=10, + chat_logger=MagicMock(), transcript_buffer=[], stats={}, + ) + defaults.update(overrides) + return VoiceSessionState(**defaults) + + +class TestHandleTermination: + @pytest.fixture + def handler(self): + storage = AsyncMock() + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_sets_reason_and_stop_event(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "USER_ENDED") + assert state.termination_reason == "USER_ENDED" + assert state.stop_event.is_set() + + @pytest.mark.asyncio + async def test_sends_session_ended_to_frontend(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "TIME_LIMIT") + ws.send_json.assert_called_once() + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "session_ended" + assert msg["reason"] == "TIME_LIMIT" + + @pytest.mark.asyncio + async def test_closes_queue_for_non_disconnect(self, handler): + queue = MagicMock() + state = _make_state(live_request_queue=queue) + await handler._handle_termination(AsyncMock(), state, "USER_ENDED") + queue.close.assert_called_once() + + @pytest.mark.asyncio + async def test_keeps_queue_open_for_disconnected(self, handler): + queue = MagicMock() + state = _make_state(live_request_queue=queue) + await handler._handle_termination(AsyncMock(), state, "DISCONNECTED") + queue.close.assert_not_called() + + @pytest.mark.asyncio + async def test_guard_prevents_double_fire(self, handler): + ws = AsyncMock() + state = _make_state() + state.stop_event.set() # already set + state.termination_reason = "USER_ENDED" # already has reason + await handler._handle_termination(ws, state, "TIME_LIMIT") + # Should not overwrite reason or send another message + assert state.termination_reason == "USER_ENDED" + ws.send_json.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_websocket_send_failure(self, handler): + ws = AsyncMock() + ws.send_json.side_effect = ConnectionError("closed") + state = _make_state() + # Should not raise — swallows the connection error + await handler._handle_termination(ws, state, "DISCONNECTED") + assert state.termination_reason == "DISCONNECTED" + assert state.stop_event.is_set() + + @pytest.mark.asyncio + async def test_detail_included_in_message(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "AI_EARLY_TERMINATION", detail="off_topic") + msg = ws.send_json.call_args[0][0] + assert msg["detail"] == "off_topic" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHandleTermination` +Expected: FAIL — `ImportError: cannot import name 'VoiceHandlerV2'` + +- [ ] **Step 3: Implement `VoiceHandlerV2` with `_handle_termination`** + +Add to `handler_v2_spike.py`: + +```python +# region Handler + +class VoiceHandlerV2(BaseHandler): + """Voice handler v2 spike for 45-minute behavioral interviews. + + Not registered via standard handler registration. + Uses /voice-v2 prefix to avoid conflicts with v1. + Storage injected at construction. + """ + + def __init__(self, storage: StorageBackend): + super().__init__() + self._storage = storage + + @property + def router(self) -> APIRouter: + if self._router is None: + self._router = APIRouter() + # WebSocket route defined here — tested manually, not via handler registration + return self._router + + @property + def prefix(self) -> str: + return "/voice-v2" + + # region Termination + + async def _handle_termination( + self, + websocket: WebSocket, + state: VoiceSessionState, + reason: str, + detail: Optional[str] = None, + ) -> None: + """Single convergence point for all session exit paths. + + Guard: check-and-set stop_event atomically (no await between + check and set — safe in single-threaded event loop). + """ + # Guard: prevent double-fire + if state.stop_event.is_set(): + return + state.stop_event.set() + state.termination_reason = reason + + # Notify frontend (may fail for DISCONNECTED — that's fine) + try: + await websocket.send_json({ + "type": "session_ended", + "reason": reason, + "detail": detail, + "timestamp": utc_now_isoformat(), + }) + except (ConnectionError, WebSocketDisconnect, RuntimeError): + pass # Socket may already be closed + + # Close queue unless resumable + if reason != VoiceConfigV2.REASON_DISCONNECTED: + state.live_request_queue.close() + + logger.info( + f"Session {state.session_id} terminated: reason={reason}, detail={detail}" + ) + + # endregion + +# endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHandleTermination` +Expected: All 7 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add VoiceHandlerV2 class with _handle_termination" +``` + +--- + +## Task 5: Session Timer & Warning Injection + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` +- Modify: `src/python/role_play/voice/handler_v2_spike.py` + +- [ ] **Step 1: Write the tests** + +Add to test file: + +```python +class TestSessionTimer: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_timer_fires_termination_after_timeout(self, handler): + state = _make_state(session_timeout=1, warning_seconds=0) + # Patch _inject_time_warning and _handle_termination + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_TIME_LIMIT + ) + + @pytest.mark.asyncio + async def test_timer_exits_early_when_stop_event_set(self, handler): + state = _make_state(session_timeout=10, warning_seconds=2) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + # Set stop event after a very short delay + async def set_stop(): + await asyncio.sleep(0.05) + state.stop_event.set() + asyncio.create_task(set_stop()) + + await handler._session_timer(ws, state) + handler._handle_termination.assert_not_called() + + @pytest.mark.asyncio + async def test_timer_calls_inject_warning(self, handler): + state = _make_state(session_timeout=1, warning_seconds=0) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._inject_time_warning.assert_called_once_with(state) + + +class TestInjectTimeWarning: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_injects_content_into_queue(self, handler): + queue = MagicMock() + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + await handler._inject_time_warning(state) + queue.send_content.assert_called_once() + content_arg = queue.send_content.call_args[0][0] + assert "5 minutes remaining" in content_arg.parts[0].text + + @pytest.mark.asyncio + async def test_sets_session_state_time_warning(self, handler): + queue = MagicMock() + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + await handler._inject_time_warning(state) + assert session.state.get("time_warning") is True + + @pytest.mark.asyncio + async def test_survives_queue_failure(self, handler): + queue = MagicMock() + queue.send_content.side_effect = RuntimeError("queue closed") + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + # Should not raise + await handler._inject_time_warning(state) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestSessionTimer` +Expected: FAIL — `AttributeError: 'VoiceHandlerV2' object has no attribute '_session_timer'` + +- [ ] **Step 3: Implement `_session_timer` and `_inject_time_warning`** + +Add to `VoiceHandlerV2` class: + +```python + # region Timer + + async def _session_timer( + self, websocket: WebSocket, state: VoiceSessionState, + ) -> None: + """Wall clock timer with warning injection. + + Two phases: sleep until warning, fire warning, sleep until termination. + Uses wait_for(stop_event.wait()) so it responds immediately to session end. + """ + elapsed = (datetime.now(timezone.utc) - state.started_at).total_seconds() + warning_at = state.session_timeout - state.warning_seconds + + # Phase 1: wait until warning + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=max(0, warning_at - elapsed), + ) + return # stop_event set before warning + except asyncio.TimeoutError: + pass + + await self._inject_time_warning(state) + + # Phase 2: wait until termination + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=state.warning_seconds, + ) + return # stop_event set during warning period + except asyncio.TimeoutError: + pass + + await self._handle_termination(websocket, state, VoiceConfigV2.REASON_TIME_LIMIT) + + async def _inject_time_warning(self, state: VoiceSessionState) -> None: + """Inject time warning into the live session. + + Tries both approaches (state update + content injection). + Failures are logged as spike discovery data, not propagated. + """ + # Approach 1: State update (likely won't propagate mid-stream) + try: + state.adk_session.state["time_warning"] = True + logger.info(f"Session {state.session_id}: set time_warning=True in session state") + except Exception as e: + logger.warning(f"Session {state.session_id}: state update failed: {e}") + + # Approach 2: Content injection (primary approach) + try: + warning_content = Content( + parts=[Part(text="[SYSTEM] 5 minutes remaining. Begin wrapping up the interview.")] + ) + state.live_request_queue.send_content(warning_content) + logger.info(f"Session {state.session_id}: injected time warning content") + except Exception as e: + logger.warning(f"Session {state.session_id}: content injection failed: {e}") + + # endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py -k "TestSessionTimer or TestInjectTimeWarning"` +Expected: All 6 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add session timer with dual warning injection" +``` + +--- + +## Task 6: Heartbeat Coroutine + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` +- Modify: `src/python/role_play/voice/handler_v2_spike.py` + +- [ ] **Step 1: Write the tests** + +```python +class TestHeartbeat: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_heartbeat_exits_when_stop_event_set(self, handler): + ws = AsyncMock() + state = _make_state() + state.stop_event.set() + + await handler._heartbeat(ws, state) + ws.send_json.assert_not_called() + + @pytest.mark.asyncio + async def test_heartbeat_sends_json(self, handler): + ws = AsyncMock() + state = _make_state() + + async def stop_after_one(): + await asyncio.sleep(0.05) + state.stop_event.set() + + # Use very short interval for testing + with patch.object(VoiceConfigV2, "HEARTBEAT_INTERVAL_SECONDS", 0.01): + asyncio.create_task(stop_after_one()) + await handler._heartbeat(ws, state) + + assert ws.send_json.call_count >= 1 + msg = ws.send_json.call_args_list[0][0][0] + assert msg["type"] == "heartbeat" + assert "timestamp" in msg + + @pytest.mark.asyncio + async def test_heartbeat_calls_termination_on_connection_error(self, handler): + ws = AsyncMock() + ws.send_json.side_effect = ConnectionError("closed") + state = _make_state() + handler._handle_termination = AsyncMock() + + with patch.object(VoiceConfigV2, "HEARTBEAT_INTERVAL_SECONDS", 0.01): + with pytest.raises(ConnectionError): + await handler._heartbeat(ws, state) + + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_DISCONNECTED + ) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHeartbeat` +Expected: FAIL + +- [ ] **Step 3: Implement `_heartbeat`** + +Add to `VoiceHandlerV2`: + +```python + # region Keepalive + + async def _heartbeat(self, websocket: WebSocket, state: VoiceSessionState) -> None: + """Send periodic heartbeat JSON to keep the WebSocket alive. + + Exits when stop_event is set. Calls _handle_termination(DISCONNECTED) + if the WebSocket send fails. + """ + while not state.stop_event.is_set(): + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=VoiceConfigV2.HEARTBEAT_INTERVAL_SECONDS, + ) + break # stop_event was set + except asyncio.TimeoutError: + pass # interval elapsed + + if state.stop_event.is_set(): + break + + try: + await websocket.send_json({ + "type": "heartbeat", + "timestamp": utc_now_isoformat(), + }) + except (ConnectionError, WebSocketDisconnect): + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + raise + + # endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHeartbeat` +Expected: All 3 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add heartbeat coroutine with disconnect detection" +``` + +--- + +## Task 7: Event Router — `_send_to_client` + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` +- Modify: `src/python/role_play/voice/handler_v2_spike.py` + +- [ ] **Step 1: Write the tests** + +```python +class TestSendToClient: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.write = AsyncMock() + return VoiceHandlerV2(storage) + + def _make_live_events(self, events): + """Create an async generator from a list of events.""" + async def gen(): + for e in events: + yield e + return gen() + + @pytest.mark.asyncio + async def test_audio_sent_to_frontend(self, handler): + ws = AsyncMock() + part = MagicMock() + del part.text + part.inline_data = MagicMock(data=b"\x01\x02", mime_type="audio/pcm") + event = MagicMock() + event.content = MagicMock(parts=[part]) + for attr in ["turn_complete", "interrupted", "session_resumption_update", "go_away"]: + if hasattr(event, attr): + delattr(event, attr) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + assert ws.send_json.call_count == 1 + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "audio" + + @pytest.mark.asyncio + async def test_transcript_not_sent_to_frontend(self, handler): + ws = AsyncMock() + part = MagicMock() + part.text = "Hello" + part.inline_data = None + event = MagicMock() + event.content = MagicMock(parts=[part], role="model") + event.partial = False + for attr in ["turn_complete", "interrupted", "session_resumption_update", "go_away"]: + if hasattr(event, attr): + delattr(event, attr) + + chat_logger = AsyncMock() + state = _make_state( + live_events=self._make_live_events([event]), + chat_logger=chat_logger, + ) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + # Transcript must NOT be sent to frontend + ws.send_json.assert_not_called() + # But must be logged + chat_logger.log_voice_message.assert_called_once() + + @pytest.mark.asyncio + async def test_transcript_appended_to_buffer(self, handler): + ws = AsyncMock() + part = MagicMock() + part.text = "Test transcript" + part.inline_data = None + event = MagicMock() + event.content = MagicMock(parts=[part], role="model") + event.partial = False + for attr in ["turn_complete", "interrupted", "session_resumption_update", "go_away"]: + if hasattr(event, attr): + delattr(event, attr) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + assert len(state.transcript_buffer) == 1 + assert state.transcript_buffer[0]["text"] == "Test transcript" + + @pytest.mark.asyncio + async def test_session_resumption_writes_to_storage(self, handler): + ws = AsyncMock() + event = MagicMock(spec=[]) + event.session_resumption_update = "handle-xyz" + event.content = None + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + handler._storage.write.assert_called_once() + key = handler._storage.write.call_args[0][0] + assert "gemini_handle" in key + + @pytest.mark.asyncio + async def test_terminal_event_triggers_termination(self, handler): + ws = AsyncMock() + part = MagicMock() + part.text = "Goodbye [RPS_SESSION_COMPLETE]" + part.inline_data = None + event = MagicMock() + event.content = MagicMock(parts=[part], role="model") + event.partial = False + for attr in ["turn_complete", "interrupted", "session_resumption_update", "go_away"]: + if hasattr(event, attr): + delattr(event, attr) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + handler._handle_termination.assert_called_once() + call_args = handler._handle_termination.call_args + assert call_args[0][2] == VoiceConfigV2.REASON_AI_CONCLUDED + + @pytest.mark.asyncio + async def test_turn_status_sent_to_frontend(self, handler): + ws = AsyncMock() + event = MagicMock() + event.content = None + event.turn_complete = True + event.interrupted = False + for attr in ["session_resumption_update", "go_away"]: + if hasattr(event, attr): + delattr(event, attr) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + ws.send_json.assert_called_once() + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "turn_status" +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestSendToClient` +Expected: FAIL + +- [ ] **Step 3: Implement `_send_to_client`** + +Add to `VoiceHandlerV2`: + +```python + # region Event Routing + + async def _send_to_client( + self, websocket: WebSocket, state: VoiceSessionState, + ) -> None: + """Route ADK events to frontend, storage, or logs. + + Uses classify_adk_event (sync) for classification, then routes async. + Transcripts are logged but NOT sent to frontend (REQ-7). + """ + try: + async for event in state.live_events: + if state.stop_event.is_set(): + break + + classification = classify_adk_event(event) + + if classification.kind == "audio": + audio_data = classification.data["audio_data"] + state.stats["audio_chunks_received"] += 1 + await websocket.send_json({ + "type": "audio", + "data": base64.b64encode(audio_data).decode("utf-8"), + "mime_type": classification.data["mime_type"], + "timestamp": utc_now_isoformat(), + }) + + elif classification.kind == "transcript": + # REQ-7: log but do NOT send to frontend + state.stats["transcripts_processed"] += 1 + entry = { + "text": classification.data, + "role": classification.role, + "is_partial": classification.is_partial, + "timestamp": utc_now_isoformat(), + } + state.transcript_buffer.append(entry) + + if not classification.is_partial: + await state.chat_logger.log_voice_message( + user_id=state.user_id, + session_id=state.session_id, + role=classification.role, + transcript_text=classification.data, + duration_ms=0, + message_number=-1, + confidence=1.0, + voice_metadata=entry, + ) + + # Check for AI termination sentinel + if classification.is_terminal: + logger.info( + f"Session {state.session_id}: AI terminal signal " + f"reason={classification.terminal_reason}" + ) + if classification.terminal_reason == VoiceConfigV2.REASON_AI_EARLY_TERMINATION: + # Log extra context for future human review + recent = state.transcript_buffer[-5:] + logger.warning( + f"Session {state.session_id}: AI_EARLY_TERMINATION, " + f"recent transcripts: {recent}" + ) + await self._handle_termination( + websocket, state, classification.terminal_reason, + ) + return + + elif classification.kind == "turn_status": + # Log turn_complete correlation for analysis + if classification.data.get("turn_complete") and state.transcript_buffer: + logger.debug( + f"Session {state.session_id}: turn_complete after " + f"'{state.transcript_buffer[-1].get('text', '')[:50]}'" + ) + await websocket.send_json({ + "type": "turn_status", + **classification.data, + "timestamp": utc_now_isoformat(), + }) + + elif classification.kind == "session_resumption": + handle_key = VoiceConfigV2.HANDLE_KEY_TEMPLATE.format( + user_id=state.user_id, session_id=state.session_id, + ) + await self._storage.write(handle_key, classification.data) + logger.info(f"Session {state.session_id}: persisted resumption handle") + + elif classification.kind == "go_away": + logger.warning(f"Session {state.session_id}: GoAway received (suppressed)") + + else: # unknown + logger.debug(f"Session {state.session_id}: unknown event: {classification.data}") + + except asyncio.CancelledError: + logger.info(f"Event processing cancelled for session {state.session_id}") + except (ConnectionError, WebSocketDisconnect) as e: + logger.error(f"Connection error in _send_to_client: {e}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + except Exception as e: + logger.error(f"Unexpected error in _send_to_client: {e}", exc_info=True) + state.stats["errors"] += 1 + + # endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestSendToClient` +Expected: All 6 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add event router with transcript suppression and handle persistence" +``` + +--- + +## Task 8: Receive From Client — `_receive_from_client` + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` +- Modify: `src/python/role_play/voice/handler_v2_spike.py` + +- [ ] **Step 1: Write the tests** + +```python +class TestReceiveFromClient: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_end_session_triggers_user_ended(self, handler): + ws = AsyncMock() + request_json = '{"mime_type": "text/plain", "data": "dGVzdA==", "end_session": true}' + ws.receive_text = AsyncMock(return_value=request_json) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_USER_ENDED, + ) + + @pytest.mark.asyncio + async def test_audio_forwarded_to_adk(self, handler): + ws = AsyncMock() + audio_b64 = base64.b64encode(b"\x01\x02").decode() + request_json = f'{{"mime_type": "audio/pcm", "data": "{audio_b64}", "end_session": false}}' + + call_count = 0 + async def receive_then_stop(self_=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return request_json + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=receive_then_stop) + queue = MagicMock() + state = _make_state(live_request_queue=queue) + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + queue.send_realtime.assert_called_once() + assert state.stats["audio_chunks_sent"] == 1 + + @pytest.mark.asyncio + async def test_text_forwarded_to_adk(self, handler): + ws = AsyncMock() + text_b64 = base64.b64encode(b"Hello").decode() + request_json = f'{{"mime_type": "text/plain", "data": "{text_b64}", "end_session": false}}' + + call_count = 0 + async def receive_then_stop(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return request_json + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=receive_then_stop) + queue = MagicMock() + chat_logger = AsyncMock() + state = _make_state(live_request_queue=queue, chat_logger=chat_logger) + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + queue.send_content.assert_called_once() + chat_logger.log_message.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect_triggers_disconnected(self, handler): + ws = AsyncMock() + ws.receive_text = AsyncMock(side_effect=WebSocketDisconnect(code=1000)) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + + @pytest.mark.asyncio + async def test_invalid_json_increments_errors(self, handler): + ws = AsyncMock() + call_count = 0 + async def bad_then_disconnect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "not valid json" + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=bad_then_disconnect) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + assert state.stats["errors"] >= 1 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestReceiveFromClient` +Expected: FAIL + +- [ ] **Step 3: Implement `_receive_from_client`** + +Add to `VoiceHandlerV2`: + +```python + # region Client Input + + async def _receive_from_client( + self, + websocket: WebSocket, + state: VoiceSessionState, + env_info: EnvironmentInfo, + ) -> None: + """Receive from client and forward to ADK. + + Handles audio/pcm and text/plain. end_session triggers USER_ENDED. + WebSocketDisconnect triggers DISCONNECTED. + """ + try: + while not state.stop_event.is_set(): + data = await websocket.receive_text() + + try: + request = VoiceRequest.model_validate_json(data) + except ValueError as e: + logger.warning(f"Invalid JSON from client: {e}") + state.stats["errors"] += 1 + continue + + if request.end_session: + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_USER_ENDED, + ) + return + + try: + if request.mime_type == "audio/pcm": + audio_data = request.decode_data() + if audio_data is None: + state.stats["errors"] += 1 + continue + + # PCM logging in non-production (same as v1) + if not env_info.is_production: + try: + await state.chat_logger.log_pcm_audio( + user_id=state.user_id, + session_id=state.session_id, + audio_data=audio_data, + ) + except (AttributeError, Exception) as e: + logger.debug(f"PCM logging skipped: {e}") + + blob = Blob(mime_type=request.mime_type, data=audio_data) + state.live_request_queue.send_realtime(blob) + state.stats["audio_chunks_sent"] += 1 + + elif request.mime_type == "text/plain": + text_data = request.decode_data() + if text_data is None: + state.stats["errors"] += 1 + continue + + await state.chat_logger.log_message( + user_id=state.user_id, + session_id=state.session_id, + role="user", + content=text_data, + message_number=-1, + ) + content = Content(parts=[Part(text=text_data)]) + state.live_request_queue.send_content(content) + state.stats["text_chunks_sent"] += 1 + + except Exception as e: + logger.error(f"Error processing client input: {e}") + state.stats["errors"] += 1 + + except WebSocketDisconnect: + logger.info(f"Client disconnected from session {state.session_id}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + except Exception as e: + logger.error(f"Error in _receive_from_client: {e}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + + # endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestReceiveFromClient` +Expected: All 5 tests PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add client input receiver with termination routing" +``` + +--- + +## Task 9: Streaming Orchestrator & Session Lifecycle + +**Files:** +- Modify: `test/python/unit/voice/test_handler_v2_spike.py` +- Modify: `src/python/role_play/voice/handler_v2_spike.py` + +This is the final assembly — `_handle_streaming`, `_initialize_adk`, `_cleanup_session`, `handle_voice_session`, and the WebSocket route. + +- [ ] **Step 1: Write the tests for `_handle_streaming`** + +```python +class TestHandleStreaming: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_orchestrator_cancels_remaining_on_stop(self, handler): + ws = AsyncMock() + state = _make_state() + env_info = MagicMock(is_production=True) + + # Mock all four coroutines + async def fast_receive(*a, **kw): + await self_handler._handle_termination(ws, state, "USER_ENDED") + self_handler = handler + + handler._receive_from_client = AsyncMock(side_effect=fast_receive) + handler._send_to_client = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + handler._session_timer = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + handler._heartbeat = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + + await handler._handle_streaming(ws, state, env_info) + assert state.stop_event.is_set() + + +class TestCleanupSession: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.write = AsyncMock() + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_logs_voice_session_end(self, handler): + chat_logger = AsyncMock() + state = _make_state(chat_logger=chat_logger, termination_reason="USER_ENDED") + + await handler._cleanup_session(state) + chat_logger.log_voice_session_end.assert_called_once() + + @pytest.mark.asyncio + async def test_persists_meta_on_disconnect(self, handler): + state = _make_state(termination_reason="DISCONNECTED") + + await handler._cleanup_session(state) + # Should write session_meta for timer continuity + handler._storage.write.assert_called() + key = handler._storage.write.call_args[0][0] + assert "session_meta" in key +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py -k "TestHandleStreaming or TestCleanupSession"` +Expected: FAIL + +- [ ] **Step 3: Implement orchestrator, cleanup, and initialize_adk** + +Add to `VoiceHandlerV2`: + +```python + # region Streaming Orchestrator + + async def _handle_streaming( + self, + websocket: WebSocket, + state: VoiceSessionState, + env_info: EnvironmentInfo, + ) -> None: + """Orchestrate four concurrent coroutines with stop_event coordination.""" + tasks = [ + asyncio.create_task( + self._receive_from_client(websocket, state, env_info), + name="receive", + ), + asyncio.create_task( + self._send_to_client(websocket, state), + name="send", + ), + asyncio.create_task( + self._session_timer(websocket, state), + name="timer", + ), + asyncio.create_task( + self._heartbeat(websocket, state), + name="heartbeat", + ), + ] + + try: + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED, + ) + + # Check if stop_event was set by the completed task + if not state.stop_event.is_set(): + for task in done: + exc = task.exception() + if exc: + logger.warning( + f"Task {task.get_name()} failed unexpectedly: {exc}" + ) + state.stop_event.set() + + finally: + for task in tasks: + if not task.done(): + task.cancel() + # Wait for cancellation to complete + await asyncio.gather(*tasks, return_exceptions=True) + + # endregion + + # region Session Lifecycle + + async def _initialize_adk( + self, + session_id: str, + user: User, + adk_session: Any, + adk_session_service: BaseSessionService, + session_timeout: int, + warning_seconds: int, + ) -> VoiceSessionState: + """Initialize ADK components and build VoiceSessionState. + + Checks for existing resumption handle and session metadata. + Writes initial session_meta for timer continuity. + """ + # Check for resumption handle + handle_key = VoiceConfigV2.HANDLE_KEY_TEMPLATE.format( + user_id=user.id, session_id=session_id, + ) + meta_key = VoiceConfigV2.META_KEY_TEMPLATE.format( + user_id=user.id, session_id=session_id, + ) + + existing_handle = None + if await self._storage.exists(handle_key): + existing_handle = await self._storage.read(handle_key) + logger.info(f"Session {session_id}: found resumption handle") + + # Read session meta for timer continuity + started_at = datetime.now(timezone.utc) + if await self._storage.exists(meta_key): + try: + meta = json.loads(await self._storage.read(meta_key)) + started_at = datetime.fromisoformat(meta["started_at"]) + session_timeout = meta.get("session_timeout", session_timeout) + warning_seconds = meta.get("warning_seconds", warning_seconds) + logger.info(f"Session {session_id}: resumed with original started_at") + except Exception as e: + logger.warning(f"Session {session_id}: failed to read meta: {e}") + + # Create agent + agent = await get_production_agent( + character_id=adk_session.state.get("character_id"), + scenario_id=adk_session.state.get("scenario_id"), + language=getattr(user, "preferred_language", "en"), + scripted=bool(adk_session.state.get("script_data")), + agent_model="gemini-2.5-flash-live-preview", + ) + if not agent: + raise ValueError("Failed to create roleplay agent") + + # Build RunConfig + run_config_kwargs = dict( + response_modalities=[Modality.AUDIO], + output_audio_transcription=AudioTranscriptionConfig(), + input_audio_transcription=AudioTranscriptionConfig(), + ) + + # Context window compression (verify SDK support at runtime) + try: + from google.genai.types import ContextWindowCompressionConfig + run_config_kwargs["context_window_compression"] = ContextWindowCompressionConfig( + trigger_tokens=VoiceConfigV2.CONTEXT_WINDOW_TRIGGER_TOKENS, + sliding_window=True, + ) + logger.info("Context window compression enabled") + except (ImportError, AttributeError, TypeError) as e: + logger.warning(f"Context window compression not available: {e}") + + # Session resumption config (verify SDK support at runtime) + if existing_handle: + try: + from google.genai.types import SessionResumptionConfig + run_config_kwargs["session_resumption_config"] = SessionResumptionConfig( + handle=existing_handle, + ) + logger.info("Session resumption handle applied") + except (ImportError, AttributeError, TypeError) as e: + logger.warning(f"Session resumption not available in SDK: {e}") + + run_config = RunConfig(**run_config_kwargs) + + # Create runner and start live streaming + runner = Runner( + app_name="roleplay_chat", agent=agent, + session_service=adk_session_service, + ) + live_request_queue = LiveRequestQueue() + live_events = runner.run_live( + session=adk_session, + live_request_queue=live_request_queue, + run_config=run_config, + ) + + chat_logger = get_chat_logger(self._storage) + + state = VoiceSessionState( + session_id=session_id, + user_id=user.id, + runner=runner, + live_events=live_events, + live_request_queue=live_request_queue, + adk_session=adk_session, + stop_event=asyncio.Event(), + termination_reason=None, + started_at=started_at, + session_timeout=session_timeout, + warning_seconds=warning_seconds, + chat_logger=chat_logger, + ) + + # Write initial session_meta for timer continuity on reconnect + meta_data = json.dumps({ + "started_at": started_at.isoformat(), + "session_timeout": session_timeout, + "warning_seconds": warning_seconds, + }) + await self._storage.write(meta_key, meta_data) + + return state + + async def _cleanup_session(self, state: VoiceSessionState) -> None: + """Clean up session resources and log final stats.""" + duration = (datetime.now(timezone.utc) - state.started_at).total_seconds() + stats = { + **state.stats, + "duration_seconds": duration, + "termination_reason": state.termination_reason, + "ended_at": utc_now_isoformat(), + } + + logger.info(f"Session {state.session_id} final stats: {stats}") + + try: + await state.chat_logger.log_voice_session_end( + state.user_id, state.session_id, voice_stats=stats, + ) + except Exception as e: + logger.error(f"Failed to log session end: {e}") + + # Persist meta for timer continuity on reconnect + if state.termination_reason == VoiceConfigV2.REASON_DISCONNECTED: + try: + meta_key = VoiceConfigV2.META_KEY_TEMPLATE.format( + user_id=state.user_id, session_id=state.session_id, + ) + meta_data = json.dumps({ + "started_at": state.started_at.isoformat(), + "session_timeout": state.session_timeout, + "warning_seconds": state.warning_seconds, + }) + await self._storage.write(meta_key, meta_data) + except Exception as e: + logger.error(f"Failed to persist session meta: {e}") + + # endregion + + # region Auth (unchanged from v1) + + @staticmethod + async def _validate_jwt_token(token: str, storage: StorageBackend) -> Optional[User]: + """Validate JWT token and return user.""" + try: + auth_manager = get_auth_manager(storage) + token_data = auth_manager.verify_token(token) + user = await storage.get_user(token_data.user_id) + if user is None: + raise HTTPException(status_code=401, detail="User not found") + return user + except TokenExpiredError as exc: + raise HTTPException(status_code=401, detail="Token expired") from exc + except AuthenticationError as exc: + raise HTTPException(status_code=401, detail="Invalid token") from exc + except Exception as e: + logger.error(f"JWT validation error: {e}") + raise HTTPException(status_code=401, detail="Unknown error during validation") from e + + def _check_session_limit(self, user_id: str) -> bool: + """Placeholder — always returns True. See VoiceConfig docstring.""" + return True + + # endregion +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py -k "TestHandleStreaming or TestCleanupSession"` +Expected: All 3 tests PASS + +- [ ] **Step 5: Run the full test suite** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py -v` +Expected: All tests PASS (dataclasses + classifier + termination + timer + heartbeat + router + orchestrator + cleanup) + +- [ ] **Step 6: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): add streaming orchestrator, session lifecycle, and cleanup" +``` + +--- + +## Task 10: WebSocket Route & Entry Point + +**Files:** +- Modify: `src/python/role_play/voice/handler_v2_spike.py` (flesh out `router` property and `handle_voice_session`) + +- [ ] **Step 1: Write test for `handle_voice_session` auth flow** + +```python +class TestHandleVoiceSession: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.exists = AsyncMock(return_value=False) + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_missing_token_closes_websocket(self, handler): + ws = AsyncMock() + ws.query_params = {} + env_info = MagicMock() + + await handler.handle_voice_session(ws, "sess-1", env_info) + ws.close.assert_called_once() + + @pytest.mark.asyncio + async def test_invalid_token_closes_websocket(self, handler): + ws = AsyncMock() + ws.query_params = {"token": "bad", "timeout": "60", "warning": "10"} + + with patch.object( + VoiceHandlerV2, "_validate_jwt_token", + new_callable=AsyncMock, return_value=None, + ): + env_info = MagicMock() + await handler.handle_voice_session(ws, "sess-1", env_info) + ws.close.assert_called_once() +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHandleVoiceSession` +Expected: FAIL + +- [ ] **Step 3: Implement `handle_voice_session` and `router`** + +Update the `router` property and add `handle_voice_session`: + +```python + @property + def router(self) -> APIRouter: + if self._router is None: + self._router = APIRouter() + + @self._router.websocket("/ws/{session_id}") + async def voice_v2_websocket( + websocket: WebSocket, + session_id: str, + environment_info: Annotated[EnvironmentInfo, Depends(get_environment_info)], + ): + await websocket.accept() + await self.handle_voice_session(websocket, session_id, environment_info) + + return self._router + + async def handle_voice_session( + self, + websocket: WebSocket, + session_id: str, + env_info: EnvironmentInfo, + ) -> None: + """Entry point for a voice WebSocket connection. + + Auth → session lookup → ADK init → streaming → cleanup. + """ + state = None + try: + # Extract query params + token = websocket.query_params.get("token") + if not token: + await websocket.close( + code=VoiceConfigV2.WS_MISSING_TOKEN, reason="Missing token", + ) + return + + timeout = int(websocket.query_params.get( + "timeout", VoiceConfigV2.DEFAULT_SESSION_TIMEOUT_SECONDS, + )) + warning = int(websocket.query_params.get( + "warning", VoiceConfigV2.DEFAULT_SESSION_WARNING_SECONDS, + )) + warning = max(warning, VoiceConfigV2.MIN_SESSION_WARNING_SECONDS) + + # Auth + user = await self._validate_jwt_token(token, self._storage) + if user is None: + await websocket.close( + code=VoiceConfigV2.WS_INVALID_TOKEN, reason="Invalid token", + ) + return + + if not self._check_session_limit(user.id): + await websocket.close( + code=VoiceConfigV2.WS_INVALID_TOKEN, reason="Session limit exceeded", + ) + return + + # Session lookup + adk_session_service = get_adk_session_service() + adk_session = await adk_session_service.get_session( + app_name="roleplay_chat", user_id=user.id, session_id=session_id, + ) + if not adk_session: + await websocket.close( + code=VoiceConfigV2.WS_SESSION_NOT_FOUND, reason="Session not found", + ) + return + + logger.info(f"Voice v2 session {session_id} starting for user {user.id}") + + # Send initial status + user_lang = getattr(user, "preferred_language", "en") + await websocket.send_json({ + "type": "status", "status": "connecting", + "message": "Initializing voice session", + }) + + # Initialize ADK + state = await self._initialize_adk( + session_id=session_id, user=user, + adk_session=adk_session, + adk_session_service=adk_session_service, + session_timeout=timeout, warning_seconds=warning, + ) + + # Send config + await websocket.send_json({ + "type": "config", + "audio_format": VoiceConfigV2.AUDIO_FORMAT, + "sample_rate": VoiceConfigV2.AUDIO_SAMPLE_RATE, + "channels": VoiceConfigV2.AUDIO_CHANNELS, + "bit_depth": VoiceConfigV2.AUDIO_BIT_DEPTH, + "language": user_lang, + "session_timeout": timeout, + "warning_seconds": warning, + }) + + await state.chat_logger.log_voice_session_start( + user.id, session_id, + voice_config={"language": user_lang, "timeout": timeout}, + ) + + await websocket.send_json({ + "type": "status", "status": "ready", + "message": "Voice session ready", + }) + + # Run streaming + await self._handle_streaming(websocket, state, env_info) + + except WebSocketDisconnect: + logger.info(f"WebSocket disconnected for session {session_id}") + except Exception as e: + logger.error(f"Unexpected error for session {session_id}: {e}", exc_info=True) + try: + await websocket.send_json({ + "type": "error", "error": str(e), + "timestamp": utc_now_isoformat(), + }) + except Exception: + pass + finally: + if state is not None: + await self._cleanup_session(state) + logger.info(f"Voice v2 session {session_id} cleanup completed") +``` + +Add the missing import at the top of the file: + +```python +from typing import Optional, Dict, Any, AsyncGenerator, List, Annotated +from fastapi import WebSocket, HTTPException, APIRouter, Depends +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py::TestHandleVoiceSession` +Expected: All 2 tests PASS + +- [ ] **Step 5: Run the full v2 test suite** + +Run: `make test-specific TEST_PATH=test/python/unit/voice/test_handler_v2_spike.py -v` +Expected: ALL tests PASS + +- [ ] **Step 6: Run existing v1 tests to confirm no breakage** + +Run: `make test-voice` +Expected: All existing voice tests still PASS + +- [ ] **Step 7: Commit** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py test/python/unit/voice/test_handler_v2_spike.py +git commit -m "feat(voice): complete v2 spike — WebSocket route, session lifecycle, full test suite" +``` + +--- + +## Task 11: Final Verification & Smoke Test + +- [ ] **Step 1: Run full project test suite** + +Run: `make test` +Expected: All 260+ tests PASS, no regressions from v2 spike files + +- [ ] **Step 2: Verify no existing files were modified** + +Run: `git diff --name-only HEAD~10 -- src/python/role_play/` +Expected: Only new files appear: +- `src/python/role_play/voice/handler_v2_spike.py` +- `src/python/role_play/voice/voice_config_v2.py` + +- [ ] **Step 3: Check import health** + +Run: `python -c "from role_play.voice.handler_v2_spike import VoiceHandlerV2, classify_adk_event, VoiceSessionState, EventClassification; print('imports OK')"` +Expected: `imports OK` + +- [ ] **Step 4: Final commit (if any fixups needed)** + +```bash +git add src/python/role_play/voice/handler_v2_spike.py src/python/role_play/voice/voice_config_v2.py test/python/unit/voice/test_handler_v2_spike.py test/python/unit/voice/test_voice_config_v2.py +git commit -m "chore(voice): final v2 spike verification pass" +``` diff --git a/docs/superpowers/specs/2026-03-20-voice-handler-v2-spike-design.md b/docs/superpowers/specs/2026-03-20-voice-handler-v2-spike-design.md new file mode 100644 index 0000000..c994c02 --- /dev/null +++ b/docs/superpowers/specs/2026-03-20-voice-handler-v2-spike-design.md @@ -0,0 +1,418 @@ +# Voice Handler V2 Spike — Design Spec + +**Date:** 2026-03-20 +**Status:** Approved +**Output file:** `src/python/role_play/voice/handler_v2_spike.py` + +--- + +## 1. Purpose + +Prototype a voice handler for 45-minute behavioral interviews that addresses all gaps identified in `docs/voice_handler_gap_report.md`. The spike validates feasibility of session resumption, wall clock enforcement, and Live API event handling before production investment. + +## 2. Non-Negotiables + +These constraints are fixed and not subject to design trade-offs: + +- Use existing `StorageBackend` interface — no new abstract methods +- Use `google.genai` types (not `google.generativeai`) +- All new async methods must be `async def` +- Output goes to `handler_v2_spike.py` — do not modify existing files +- Inject `StorageBackend` at construction time, do not call `get_storage_backend()` in loops + +### 2.1 Registration & Route Path + +The spike is **not registered via the standard `server.register_handler()` path**. It lives in the same package as v1 but is imported and tested directly (unit tests, manual WebSocket testing). It uses the route prefix `/voice-v2` to avoid conflicts with v1's `/voice` prefix. Production integration into the handler registration system is post-spike work. + +## 3. Architecture + +### 3.1 Class Shape & Construction + +`VoiceHandlerV2` extends `BaseHandler`. Storage is injected at construction, not fetched per-call. + +```python +class VoiceHandlerV2(BaseHandler): + def __init__(self, storage: StorageBackend): + super().__init__() + self._storage = storage +``` + +The v1 untyped `adk: Dict[str, Any]` is replaced by a dataclass: + +```python +@dataclass +class VoiceSessionState: + session_id: str + user_id: str + runner: Runner + live_events: AsyncGenerator + live_request_queue: LiveRequestQueue + adk_session: Any + stop_event: asyncio.Event # replaces adk["active"] boolean + termination_reason: Optional[str] # set by _handle_termination + started_at: datetime # datetime object, not ISO string + session_timeout: int # per-session, from query param or default + warning_seconds: int # per-session, from query param or default + chat_logger: ChatLogger # injected alongside storage + transcript_buffer: List[Dict] # accumulated transcripts for evaluation + stats: Dict[str, int] # numeric counters only (audio_chunks_sent, etc.) +``` + +`chat_logger` is obtained once via `get_chat_logger(self._storage)` in `handle_voice_session` and passed into the state, not fetched per-call. + +`transcript_buffer` accumulates all transcript entries (both partial and final, with role and timestamp) during the session. `_send_to_client` appends to it when routing transcript events. `_cleanup_session` can access it for evaluation data. Initialized as an empty list. + +`stats` contains only numeric counters (`audio_chunks_sent`, `audio_chunks_received`, `text_chunks_sent`, `transcripts_processed`, `errors`). Non-numeric session data like `termination_reason` lives on dedicated `VoiceSessionState` fields. + +Key changes from v1: +- `stop_event` (`asyncio.Event`) replaces the `active` boolean — any coroutine can set it, any coroutine can `await` it +- `started_at` is a `datetime` object for elapsed-time arithmetic +- `termination_reason` is captured before cleanup for the `session_ended` message +- `session_timeout` and `warning_seconds` are per-session configurable +- Storage lives on `self._storage`, not fetched per-call + +### 3.2 Event Classification + +v1's `_process_adk_event` (sync) did classification AND built frontend messages. v2 splits responsibilities: + +- `_classify_adk_event(event) -> EventClassification` — sync, pure classification, no side effects +- `_send_to_client` — async, routes based on classification + +```python +@dataclass +class EventClassification: + kind: str # "audio", "transcript", "turn_status", + # "session_resumption", "go_away", "unknown" + data: Any # parsed payload + role: Optional[str] = None # "user" or "assistant" for transcripts + is_partial: bool = False # partial vs final transcript + is_terminal: bool = False # AI sentinel detected + terminal_reason: Optional[str] = None # AI_CONCLUDED or AI_EARLY_TERMINATION +``` + +### 3.3 ADK Event Detection Strategy + +The ADK Live API event shape is not fully documented. Detection uses a **discovery-first approach**: log all event attributes for unrecognized events, and use best-effort attribute checks for known types. + +**Detection heuristics for `_classify_adk_event`:** + +| Kind | Detection logic | +|---|---| +| `audio` | `event.content.parts[*].inline_data` is not None (existing v1 logic) | +| `transcript` | `event.content.parts[*].text` is not None (existing v1 logic) | +| `turn_status` | `hasattr(event, "turn_complete")` or `hasattr(event, "interrupted")` (existing v1 logic) | +| `session_resumption` | Check `hasattr(event, "session_resumption_update")` or `type(event).__name__` contains `SessionResumption`. **If neither works, the spike logs the raw event and classifies as `unknown`.** | +| `go_away` | Check `hasattr(event, "go_away")` or `type(event).__name__` contains `GoAway`. **Same fallback to `unknown` with logging.** | +| `unknown` | Catch-all. Log `type(event).__name__` and all attributes at DEBUG level for discovery. | + +**Important:** The `session_resumption` and `go_away` detection is explicitly best-effort. If the ADK sends these events in a shape we don't anticipate, they land in `unknown` with full attribute logging. The spike's purpose is to discover the actual event shapes. The routing table (Section 5) handles `unknown` safely (log only, no frontend delivery). + +The resumption handle itself is expected to be a string token (based on ADK documentation references). If it turns out to be bytes or a complex object, serialize to JSON string via `json.dumps()` before `storage.write()`. Log the actual type on first encounter. + +## 4. Coroutine Structure & Task Racing + +`_handle_streaming` races four coroutines with a shared `asyncio.Event` (stop signal). The key improvement over v1: `asyncio.wait(FIRST_COMPLETED)` still drives the orchestrator, but individual coroutines use `stop_event` to coordinate graceful wind-down. In v1, the first task to complete caused immediate cancellation of its sibling — notably, `_receive_from_client` going quiet killed `_send_to_client`. In v2, only `_handle_termination` (or the orchestrator's fallback) sets the stop event, so silence on one side doesn't kill the other. + +``` +_handle_streaming (orchestrator) + ├── _receive_from_client — reads WebSocket input, forwards to ADK + ├── _send_to_client — reads ADK events, routes to frontend/storage/logs + ├── _session_timer — enforces wall clock, injects warning + └── _heartbeat — sends keepalive JSON every N seconds +``` + +**Lifecycle:** + +1. Orchestrator creates all four tasks, then `asyncio.wait(FIRST_COMPLETED)` +2. When a task completes, orchestrator checks `stop_event` +3. If set: cancel remaining tasks, proceed to cleanup +4. If not set (unexpected early return): log warning, set `stop_event`, cancel remaining + +**Each coroutine's exit behavior:** + +| Coroutine | Normal exit trigger | Calls `_handle_termination`? | +|---|---|---| +| `_receive_from_client` | `end_session=True` or `WebSocketDisconnect` | Yes — `USER_ENDED` or `DISCONNECTED` | +| `_send_to_client` | Detects AI sentinel or `stop_event` | Yes — `AI_CONCLUDED` or `AI_EARLY_TERMINATION` | +| `_session_timer` | Clock hits T=0 | Yes — `TIME_LIMIT` | +| `_heartbeat` | `stop_event` is set | No — just exits | +| `_heartbeat` | `ConnectionError` on send | Yes — `DISCONNECTED` (before re-raising) | + +**Heartbeat disconnection:** If `_heartbeat` catches a `ConnectionError` when sending, it calls `_handle_termination(state, "DISCONNECTED")` before re-raising the exception. This ensures `termination_reason` is set correctly rather than relying on the orchestrator's fallback path (which would set `stop_event` but not the reason). + +**Key difference from v1:** `_receive_from_client` ending due to silence does NOT kill `_send_to_client`. The receive coroutine only terminates the session on explicit `end_session` or disconnect. + +## 5. Event Routing + +Routing table in `_send_to_client`, based on `EventClassification.kind`: + +| `kind` | Sent to frontend? | Backend action | +|---|---|---| +| `audio` | Yes (base64 JSON) | Update stats | +| `transcript` | **No** (REQ-7) | Log to `chat_logger`, accumulate in buffer | +| `turn_status` | Yes | Update stats; log turn_complete correlation for AI termination analysis | +| `session_resumption` | No | `await self._storage.write(handle_key, handle)` | +| `go_away` | No (REQ-3) | Log at WARNING level | +| `unknown` | No | Log at DEBUG for spike discovery | + +### 5.1 AI Termination Detection (Hybrid) + +**Primary — sentinel text scanning:** +`_classify_adk_event` checks transcript text for `[RPS_SESSION_COMPLETE]` or `[RPS_END_EARLY:reason]`. If found: +- `is_terminal = True` +- `terminal_reason` = `AI_CONCLUDED` or `AI_EARLY_TERMINATION` +- Sentinel is stripped from `data` before storage + +**Secondary — turn_complete correlation:** +When `turn_status` has `turn_complete=True`, log it with the last transcript buffer entry for post-hoc analysis. For the spike, only the sentinel triggers actual termination. Turn-complete correlation is logged for future heuristic development. + +`AI_EARLY_TERMINATION` events log extra context (last N transcript lines, agent state) to support future human review pipeline. + +## 6. Session Timer & 5-Minute Warning + +### 6.1 Timer Coroutine + +Two phases: sleep until warning time, fire warning, sleep until termination. + +```python +async def _session_timer(self, state: VoiceSessionState): + elapsed = (datetime.now(timezone.utc) - state.started_at).total_seconds() + warning_at = state.session_timeout - state.warning_seconds + + # Phase 1: wait until warning (interruptible via stop_event) + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=max(0, warning_at - elapsed) + ) + return # stop_event was set, session ended before warning + except asyncio.TimeoutError: + pass # timer reached warning point + + await self._inject_time_warning(state) + + # Phase 2: wait until termination (interruptible via stop_event) + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=state.warning_seconds + ) + return # stop_event was set during warning period + except asyncio.TimeoutError: + pass # timer reached termination point + + await self._handle_termination(state, "TIME_LIMIT") +``` + +This uses the same `wait_for(stop_event.wait(), timeout=...)` pattern as `_heartbeat` (Section 9), ensuring the timer responds immediately to session termination rather than sleeping through it. If the orchestrator cancels the task, `CancelledError` propagates naturally from `wait_for`. + +### 6.2 Warning Injection (Dual Approach) + +`_inject_time_warning` fires both approaches to test which the Live API responds to: + +1. **State update (likely won't work, here for discovery):** `state.adk_session.state["time_warning"] = True` — ADK session state mutation via direct dict access may not propagate to the live agent mid-session. The runner likely only reads state at turn boundaries, not mid-stream. Included to confirm this hypothesis; content injection is the real approach. +2. **Content injection (primary):** `state.live_request_queue.send_content(Content(parts=[Part(text="[SYSTEM] 5 minutes remaining. Begin wrapping up the interview.")]))` + +Both fire. Each is wrapped in try/except — failures are logged as valuable spike data, not propagated to the timer coroutine. Spike logging will reveal which (or both) triggers observable agent behavior. + +### 6.3 Configuration + +**VoiceConfig defaults:** + +```python +DEFAULT_SESSION_TIMEOUT_SECONDS = 2700 # 45 min +DEFAULT_SESSION_WARNING_SECONDS = 300 # 5 min +MIN_SESSION_WARNING_SECONDS = 10 # floor +HEARTBEAT_INTERVAL_SECONDS = 30 # keepalive +``` + +**Per-session override via WebSocket query params** (for testing): + +``` +/voice-v2/ws/{session_id}?token=xxx&timeout=60&warning=10 +``` + +Production will eventually pull values from session metadata or session-type config. + +### 6.4 Reconnection Timer Continuity + +On reconnect, `started_at` is read from persisted session metadata, not reset. The wall clock is absolute from original session start. + +## 7. Termination & Cleanup + +### 7.1 `_handle_termination(state, reason, detail=None)` + +Single convergence point for all exit paths: + +1. **Guard:** Check-and-set `stop_event` atomically (no `await` between check and set). Since all coroutines run on a single event loop, `if not stop_event.is_set(): stop_event.set()` is safe — there are no await points between the check and set to create a race. If `stop_event` is already set, return immediately. +2. Set `state.termination_reason = reason` +3. (stop_event already set in step 1) +4. Send frontend message, guarded by try/except `ConnectionError`/`WebSocketDisconnect` (for `DISCONNECTED` the socket is already dead, so the send silently fails): `{"type": "session_ended", "reason": reason, "detail": detail, "timestamp": "..."}` +5. Close `live_request_queue` — **unless** reason is `DISCONNECTED` (keep alive for resume) +6. Log termination with full context + +### 7.2 Five Termination Reasons + +| Reason | Close queue? | Resumable? | Notes | +|---|---|---|---| +| `USER_ENDED` | Yes | No | Client sent `end_session` | +| `TIME_LIMIT` | Yes | No | Wall clock expired | +| `AI_CONCLUDED` | Yes | No | Agent signaled interview complete | +| `AI_EARLY_TERMINATION` | Yes | No | Agent ended early; log extra context for future human review | +| `DISCONNECTED` | No | Yes | WebSocket died; resumption handle in storage | + +### 7.3 Cleanup (`_cleanup_session`) + +Called from `finally` block after `_handle_streaming` completes: + +1. Cancel any remaining tasks +2. Compute final stats (duration, chunk counts, termination reason) +3. `await chat_logger.log_voice_session_end(user_id, session_id, voice_stats=stats)` +4. If `DISCONNECTED`: persist `started_at` alongside resumption handle for timer continuity + +## 8. Session Resumption + +### 8.1 Handle Storage (REQ-1, REQ-4) + +Storage key: `users/{user_id}/voice_sessions/{session_id}/gemini_handle` (single overwrite for spike; production will version). + +Session metadata key: `users/{user_id}/voice_sessions/{session_id}/session_meta` (JSON with `started_at`, `session_timeout`, `warning_seconds`). + +Handle writes happen in `_send_to_client` when `EventClassification.kind == "session_resumption"`. + +**Initial metadata write:** `session_meta` is written at the end of `_initialize_adk` (not only on `DISCONNECTED` cleanup) to ensure timer continuity from session start. This way, if a disconnect occurs before any `session_resumption` event arrives, the reconnection path still has a valid `started_at` and timeout config. + +### 8.2 Reconnection Path in `_initialize_adk` + +```python +handle_key = f"users/{user_id}/voice_sessions/{session_id}/gemini_handle" +meta_key = f"users/{user_id}/voice_sessions/{session_id}/session_meta" + +existing_handle = None +if await self._storage.exists(handle_key): + existing_handle = await self._storage.read(handle_key) + +session_meta = None +if await self._storage.exists(meta_key): + session_meta = json.loads(await self._storage.read(meta_key)) +``` + +If Gemini rejects a stale handle, fall back to a fresh session and log the failure. `started_at` from `session_meta` is still used so the wall clock doesn't reset. + +### 8.3 Context Window Compression (REQ-2) + +Set unconditionally on every `RunConfig`, not just reconnections. + +**Verify at implementation time:** The exact type name and `RunConfig` field for context window compression must be confirmed against the installed `google-adk` and `google-genai` package versions. The ADK documentation references `ContextWindowCompressionConfig` on `RunConfig`, but the actual import path and parameter names may differ. The implementer should: + +1. Check `google.adk.agents.RunConfig` for a `context_window_compression` field +2. Check `google.genai.types` for a `ContextWindowCompressionConfig` class +3. If neither exists, check for `SlidingWindowConfig`, `ContextCompressionConfig`, or similar +4. If no compression config exists in the current SDK version, log a WARNING and skip — the spike should still run without it. A 45-minute session during testing won't hit the context limit, so the spike remains valid + +**Expected usage (pending verification):** + +```python +context_window_compression=ContextWindowCompressionConfig( + trigger_tokens=CONTEXT_WINDOW_TRIGGER_TOKENS, # 100_000 + strategy="sliding_window", +) +``` + +**VoiceConfig constants:** + +```python +CONTEXT_WINDOW_TRIGGER_TOKENS = 100_000 +CONTEXT_WINDOW_STRATEGY = "sliding_window" +``` + +## 9. Keepalive + +A fourth coroutine sends application-level heartbeat JSON: + +```python +async def _heartbeat(self, websocket, state): + while not state.stop_event.is_set(): + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=HEARTBEAT_INTERVAL_SECONDS + ) + break # stop_event was set + except asyncio.TimeoutError: + pass # interval elapsed, send heartbeat + try: + await websocket.send_json({ + "type": "heartbeat", + "timestamp": utc_now_isoformat() + }) + except (ConnectionError, WebSocketDisconnect): + await self._handle_termination(state, "DISCONNECTED") + raise +``` + +## 10. Method Inventory + +All methods in `VoiceHandlerV2`, grouped by region: + +### Session Lifecycle +| Method | Async? | Purpose | +|---|---|---| +| `handle_voice_session` | Yes | Entry point, auth, setup, finally cleanup | +| `_initialize_adk` | Yes | Create agent, runner, handle reconnection | +| `_cleanup_session` | Yes | Stats, logging, persisting disconnect state | + +### Streaming +| Method | Async? | Purpose | +|---|---|---| +| `_handle_streaming` | Yes | Orchestrate four coroutines | +| `_receive_from_client` | Yes | WebSocket → ADK input | +| `_send_to_client` | Yes | ADK events → frontend/storage/logs | +| `_heartbeat` | Yes | Periodic keepalive | +| `_session_timer` | Yes | Wall clock enforcement | + +### Event Processing +| Method | Async? | Purpose | +|---|---|---| +| `_classify_adk_event` | No | Pure event classification | +| `_inject_time_warning` | Yes | Dual warning approach | + +### Termination +| Method | Async? | Purpose | +|---|---|---| +| `_handle_termination` | Yes | Single convergence point for all exits | + +### Auth (unchanged from v1) +| Method | Async? | Purpose | +|---|---|---| +| `_validate_jwt_token` | Yes | JWT verification | +| `_check_session_limit` | No | Placeholder, returns True | + +## 11. Gap Report Coverage + +| Gap | Section | Resolution | +|---|---|---| +| REQ-1: Session resumption | 8.1, 8.2 | Handle persisted on every update, read on reconnect | +| REQ-2: Context compression | 8.3 | Always-on in RunConfig | +| REQ-3: GoAway handling | 5 | Classify → log at WARNING, suppress frontend | +| REQ-4: Handle persistence | 8.1 | Write in `_send_to_client` routing | +| REQ-5: Termination reasons | 7.1, 7.2 | Five reasons (expanded from gap report's four to add `DISCONNECTED`), single `_handle_termination()` | +| REQ-6: Wall clock timer | 6 | Per-session configurable, dual warning | +| REQ-7: Transcript suppression | 5 | Routing table skips frontend for transcripts | +| O-1: No reconnection | 8.2 | Resumption handle + session meta | +| O-2: FIRST_COMPLETED kills sibling | 4 | Four-task racing with stop_event | +| O-3: No keepalive | 9 | Heartbeat coroutine | +| O-4: Storage not threaded | 3.1 | `self._storage` injected at construction; streaming methods access it via `self._storage` | +| O-5: Sync process method | 3.2 | Split into sync classifier + async router | +| O-6: started_at is string | 3.1 | `datetime` object in `VoiceSessionState` | +| O-7: Timeout never enforced | 6.1 | Timer coroutine enforces it | + +## 12. Out of Scope (Post-Spike) + +- Human review pipeline for `AI_EARLY_TERMINATION` events +- Versioned handle storage (production upgrade from single-key) +- Protocol-level WebSocket ping/pong (server config) +- Per-session-type timeout configuration (session metadata) +- `_check_session_limit` real implementation +- Frontend changes (frontend only sees `session_ended` + `heartbeat` as new message types) +- Ticket-based auth replacing JWT in query params diff --git a/pytest.ini b/pytest.ini index 4337411..7bf5a8b 100644 --- a/pytest.ini +++ b/pytest.ini @@ -13,6 +13,7 @@ addopts = --cov-report=html:test/python/htmlcov --cov-fail-under=25 --asyncio-mode=auto + -m "not smoke" asyncio_default_fixture_loop_scope = function markers = unit: Unit tests @@ -22,6 +23,7 @@ markers = auth: Authentication related tests storage: Storage backend tests cloud: Cloud storage integration tests + smoke: Smoke tests requiring external services (Gemini API key) filterwarnings = ignore::DeprecationWarning ignore::PendingDeprecationWarning diff --git a/src/python/role_play/voice/handler_v2_spike.py b/src/python/role_play/voice/handler_v2_spike.py new file mode 100644 index 0000000..8175a60 --- /dev/null +++ b/src/python/role_play/voice/handler_v2_spike.py @@ -0,0 +1,837 @@ +"""Voice Handler V2 Spike. + +Prototype handler for 45-minute behavioral interviews. +Addresses all gaps from docs/voice_handler_gap_report.md. + +This is a spike — not registered via standard handler registration. +Test directly via unit tests and manual WebSocket testing. +""" +import asyncio +import base64 +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional, Dict, Any, AsyncGenerator, List + +from starlette.websockets import WebSocketDisconnect + +from fastapi import WebSocket, Depends, APIRouter +from typing import Annotated + +from .models import VoiceRequest +from .voice_config_v2 import VoiceConfigV2 +from ..common.models import EnvironmentInfo +from ..common.time_utils import utc_now_isoformat +from ..server.base_handler import BaseHandler +from ..server.dependencies import ( + get_storage_backend, + get_chat_logger, + get_adk_session_service, + get_environment_info, +) + +logger = logging.getLogger(__name__) + + +# region Dataclasses + +@dataclass +class EventClassification: + """Result of classifying an ADK live event. + + Pure data — no side effects. Built by classify_adk_event (sync). + Routed by _send_to_client (async). + """ + kind: str # "audio", "transcript", "turn_status", + # "session_resumption", "go_away", "unknown" + data: Any # parsed payload + role: Optional[str] = None # "user" or "assistant" + is_partial: bool = False + is_terminal: bool = False + terminal_reason: Optional[str] = None # AI_CONCLUDED or AI_EARLY_TERMINATION + + +@dataclass +class VoiceSessionState: + """Typed session state replacing v1's Dict[str, Any]. + + Holds all per-session data needed by the four streaming coroutines. + """ + session_id: str + user_id: str + runner: Any # google.adk.Runner — typed as Any to avoid import in tests + live_events: AsyncGenerator + live_request_queue: Any # google.adk.agents.LiveRequestQueue + adk_session: Any + stop_event: asyncio.Event + termination_reason: Optional[str] + started_at: datetime + session_timeout: int + warning_seconds: int + chat_logger: Any # ChatLogger + transcript_buffer: List[Dict] = field(default_factory=list) + stats: Dict[str, int] = field(default_factory=lambda: { + "audio_chunks_sent": 0, + "audio_chunks_received": 0, + "text_chunks_sent": 0, + "transcripts_processed": 0, + "errors": 0, + }) + +# endregion + + +# region ADK Type Stubs (avoid importing google.genai in tests) + +class _Blob: + """Minimal Blob stub matching google.genai.types.Blob interface.""" + def __init__(self, mime_type: str, data: bytes): + self.mime_type = mime_type + self.data = data + + +class _Content: + """Minimal Content stub matching google.genai.types.Content interface.""" + def __init__(self, parts: list): + self.parts = parts + + +class _Part: + """Minimal Part stub matching google.genai.types.Part interface.""" + def __init__(self, text: str): + self.text = text + +# endregion + + +# region Event Classification + +def classify_adk_event(event: Any) -> EventClassification: + """Classify an ADK live event into a routing category. + + Pure sync function — no side effects, no I/O. + Unknown events are classified as "unknown" with full attribute logging. + + See spec Section 3.3 for detection heuristics. + """ + # Check for session resumption (best-effort detection) + if hasattr(event, "session_resumption_update"): + handle = event.session_resumption_update + if not isinstance(handle, str): + try: + handle = json.dumps(handle) if not isinstance(handle, bytes) else handle.decode() + except (TypeError, UnicodeDecodeError): + handle = str(handle) + logger.info(f"Session resumption event detected, handle type: {type(event.session_resumption_update).__name__}") + return EventClassification(kind="session_resumption", data=handle) + + # Check for GoAway (best-effort detection) + if hasattr(event, "go_away"): + logger.warning(f"GoAway event received: {event.go_away}") + return EventClassification(kind="go_away", data=event.go_away) + + # Check for content (audio or transcript) + content = getattr(event, "content", None) + if content is not None and hasattr(content, "parts") and content.parts: + for part in content.parts: + # Text content -> transcript + if hasattr(part, "text") and part.text: + is_partial = getattr(event, "partial", False) + role = "assistant" if getattr(content, "role", None) == "model" else "user" + text = part.text + is_terminal = False + terminal_reason = None + + # Sentinel detection + if VoiceConfigV2.SENTINEL_SESSION_COMPLETE in text: + is_terminal = True + terminal_reason = VoiceConfigV2.REASON_AI_CONCLUDED + text = text.replace(VoiceConfigV2.SENTINEL_SESSION_COMPLETE, "").strip() + elif VoiceConfigV2.SENTINEL_END_EARLY in text: + is_terminal = True + terminal_reason = VoiceConfigV2.REASON_AI_EARLY_TERMINATION + idx = text.find(VoiceConfigV2.SENTINEL_END_EARLY) + end_idx = text.find("]", idx) + if end_idx != -1: + text = (text[:idx] + text[end_idx + 1:]).strip() + + return EventClassification( + kind="transcript", data=text, role=role, + is_partial=is_partial, is_terminal=is_terminal, + terminal_reason=terminal_reason, + ) + + # Audio content -> audio + inline_data = getattr(part, "inline_data", None) + if inline_data is not None: + audio_data = getattr(inline_data, "data", None) + mime_type = getattr(inline_data, "mime_type", "audio/pcm") + if audio_data and len(audio_data) > 0: + return EventClassification( + kind="audio", + data={"audio_data": audio_data, "mime_type": mime_type}, + ) + + # Turn status (check after content — content events may also have these flags) + if hasattr(event, "turn_complete") or hasattr(event, "interrupted"): + return EventClassification( + kind="turn_status", + data={ + "turn_complete": getattr(event, "turn_complete", None), + "interrupted": getattr(event, "interrupted", None), + "partial": getattr(event, "partial", None), + }, + ) + + # Unknown — log everything for spike discovery + event_type = type(event).__name__ + event_attrs = {attr: str(getattr(event, attr, None)) + for attr in dir(event) if not attr.startswith("_")} + logger.debug(f"Unknown ADK event: type={event_type}, attrs={event_attrs}") + return EventClassification(kind="unknown", data={"type": event_type, "attrs": event_attrs}) + +# endregion + + +# region Handler + +class VoiceHandlerV2(BaseHandler): + """Voice handler v2 spike for 45-minute behavioral interviews. + + Uses /voice-v2 prefix to avoid conflicts with v1. + Storage injected at construction or resolved via get_storage_backend(). + """ + + def __init__(self, storage=None): + super().__init__() + self._storage = storage or get_storage_backend() + + @property + def prefix(self) -> str: + return "/voice-v2" + + @property + def router(self) -> APIRouter: + if self._router is None: + self._router = APIRouter() + + @self._router.websocket("/ws/{session_id}") + async def voice_v2_websocket_endpoint( + websocket: WebSocket, + session_id: str, + environment_info: Annotated[EnvironmentInfo, Depends(get_environment_info)], + ): + await websocket.accept() + await self.handle_voice_session(websocket, session_id, environment_info) + + return self._router + + # region ADK Initialization + + async def _initialize_adk( + self, + session_id: str, + user, + timeout: int, + warning: int, + ) -> VoiceSessionState: + """Initialize ADK components and return a VoiceSessionState.""" + from google.adk import Runner + from google.adk.agents import RunConfig, LiveRequestQueue + from google.genai import types + from google.genai.types import AudioTranscriptionConfig + + adk_session_service = get_adk_session_service() + adk_session = await adk_session_service.get_session( + app_name="roleplay_chat", user_id=user.id, session_id=session_id, + ) + if not adk_session: + raise ValueError(f"Session {session_id} not found in ADK session service") + + # Create agent + from ..dev_agents.roleplay_agent.agent import get_production_agent + agent = await get_production_agent( + character_id=adk_session.state.get("character_id"), + scenario_id=adk_session.state.get("scenario_id"), + language=getattr(user, "preferred_language", "en"), + scripted=bool(adk_session.state.get("script_data")), + agent_model="gemini-2.5-flash-native-audio-preview-12-2025", + ) + if not agent: + raise ValueError("Failed to create roleplay agent") + + runner = Runner( + app_name="roleplay_chat", agent=agent, + session_service=adk_session_service, + ) + run_config = RunConfig( + response_modalities=[types.Modality.AUDIO], + output_audio_transcription=AudioTranscriptionConfig(), + input_audio_transcription=AudioTranscriptionConfig(), + ) + live_request_queue = LiveRequestQueue() + live_events = runner.run_live( + session=adk_session, + live_request_queue=live_request_queue, + run_config=run_config, + ) + + chat_logger = get_chat_logger(get_storage_backend()) + + return VoiceSessionState( + session_id=session_id, + user_id=user.id, + runner=runner, + live_events=live_events, + live_request_queue=live_request_queue, + adk_session=adk_session, + stop_event=asyncio.Event(), + termination_reason=None, + started_at=datetime.now(timezone.utc), + session_timeout=timeout, + warning_seconds=warning, + chat_logger=chat_logger, + ) + + # endregion + + # region Termination + + async def _handle_termination( + self, + websocket, + state: VoiceSessionState, + reason: str, + detail: Optional[str] = None, + ) -> None: + """Single convergence point for all session exit paths.""" + # Guard: prevent double-fire + if state.stop_event.is_set(): + return + state.stop_event.set() + state.termination_reason = reason + + # Notify frontend + try: + await websocket.send_json({ + "type": "session_ended", + "reason": reason, + "detail": detail, + "timestamp": utc_now_isoformat(), + }) + except Exception: + pass # Socket may already be closed + + # Close queue unless resumable + if reason != VoiceConfigV2.REASON_DISCONNECTED: + state.live_request_queue.close() + + logger.info( + f"Session {state.session_id} terminated: reason={reason}, detail={detail}" + ) + + # endregion + + # region Timer + + async def _session_timer( + self, websocket, state: VoiceSessionState, + ) -> None: + """Wall clock timer with warning injection. + + Two phases: sleep until warning, fire warning, sleep until termination. + Uses wait_for(stop_event.wait()) so it responds immediately to session end. + """ + elapsed = (datetime.now(timezone.utc) - state.started_at).total_seconds() + remaining = state.session_timeout - elapsed + + # Already past the deadline (e.g. reconnect to an expired session) + if remaining <= 0: + await self._handle_termination(websocket, state, VoiceConfigV2.REASON_TIME_LIMIT) + return + + time_until_warning = max(0, remaining - state.warning_seconds) + + # Phase 1: wait until warning + if time_until_warning > 0: + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=time_until_warning, + ) + return # stop_event set before warning + except asyncio.TimeoutError: + pass + + await self._inject_time_warning(state) + + # Phase 2: wait remaining time after warning + time_after_warning = remaining - time_until_warning + if time_after_warning > 0: + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=time_after_warning, + ) + return # stop_event set during warning period + except asyncio.TimeoutError: + pass + + await self._handle_termination(websocket, state, VoiceConfigV2.REASON_TIME_LIMIT) + + async def _inject_time_warning(self, state: VoiceSessionState) -> None: + """Inject time warning into the live session. + + Tries both approaches (state update + content injection). + Failures are logged, not propagated. + """ + # Approach 1: State update (likely won't propagate mid-stream) + try: + state.adk_session.state["time_warning"] = True + logger.info(f"Session {state.session_id}: set time_warning=True in session state") + except Exception as e: + logger.warning(f"Session {state.session_id}: state update failed: {e}") + + # Approach 2: Content injection (primary approach) + try: + warning_content = _Content( + parts=[_Part(text="[SYSTEM] 5 minutes remaining. Begin wrapping up the interview.")] + ) + state.live_request_queue.send_content(warning_content) + logger.info(f"Session {state.session_id}: injected time warning content") + except Exception as e: + logger.warning(f"Session {state.session_id}: content injection failed: {e}") + + # endregion + + # region Keepalive + + async def _heartbeat(self, websocket, state: VoiceSessionState) -> None: + """Send periodic heartbeat JSON to keep the WebSocket alive. + + Exits when stop_event is set. Calls _handle_termination(DISCONNECTED) + if the WebSocket send fails. + """ + while not state.stop_event.is_set(): + try: + await asyncio.wait_for( + state.stop_event.wait(), + timeout=VoiceConfigV2.HEARTBEAT_INTERVAL_SECONDS, + ) + break # stop_event was set + except asyncio.TimeoutError: + pass # interval elapsed + + if state.stop_event.is_set(): + break + + try: + await websocket.send_json({ + "type": "heartbeat", + "timestamp": utc_now_isoformat(), + }) + except ConnectionError: + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + raise + + # endregion + + # region Event Routing + + async def _send_to_client( + self, websocket, state: VoiceSessionState, + ) -> None: + """Route ADK events to frontend, storage, or logs. + + Uses classify_adk_event (sync) for classification, then routes async. + Transcripts are logged but NOT sent to frontend (REQ-7). + """ + try: + async for event in state.live_events: + if state.stop_event.is_set(): + break + + classification = classify_adk_event(event) + + if classification.kind == "audio": + audio_data = classification.data["audio_data"] + state.stats.setdefault("audio_chunks_received", 0) + state.stats["audio_chunks_received"] += 1 + await websocket.send_json({ + "type": "audio", + "data": base64.b64encode(audio_data).decode("utf-8"), + "mime_type": classification.data["mime_type"], + "timestamp": utc_now_isoformat(), + }) + + elif classification.kind == "transcript": + # REQ-7: log but do NOT send to frontend + state.stats.setdefault("transcripts_processed", 0) + state.stats["transcripts_processed"] += 1 + entry = { + "text": classification.data, + "role": classification.role, + "is_partial": classification.is_partial, + "timestamp": utc_now_isoformat(), + } + state.transcript_buffer.append(entry) + + if not classification.is_partial: + await state.chat_logger.log_voice_message( + user_id=state.user_id, + session_id=state.session_id, + role=classification.role, + transcript_text=classification.data, + duration_ms=0, + message_number=-1, + confidence=1.0, + voice_metadata=entry, + ) + + # Check for AI termination sentinel + if classification.is_terminal: + logger.info( + f"Session {state.session_id}: AI terminal signal " + f"reason={classification.terminal_reason}" + ) + await self._handle_termination( + websocket, state, classification.terminal_reason, + ) + return + + elif classification.kind == "turn_status": + await websocket.send_json({ + "type": "turn_status", + **classification.data, + "timestamp": utc_now_isoformat(), + }) + + elif classification.kind == "session_resumption": + handle_key = VoiceConfigV2.HANDLE_KEY_TEMPLATE.format( + user_id=state.user_id, session_id=state.session_id, + ) + await self._storage.write(handle_key, classification.data) + logger.info(f"Session {state.session_id}: persisted resumption handle") + + elif classification.kind == "go_away": + logger.warning(f"Session {state.session_id}: GoAway received (suppressed)") + + else: # unknown + logger.debug(f"Session {state.session_id}: unknown event: {classification.data}") + + except asyncio.CancelledError: + logger.info(f"Event processing cancelled for session {state.session_id}") + except ConnectionError as e: + logger.error(f"Connection error in _send_to_client: {e}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + except Exception as e: + logger.error(f"Unexpected error in _send_to_client: {e}", exc_info=True) + state.stats.setdefault("errors", 0) + state.stats["errors"] += 1 + + # endregion + + # region Client Input + + async def _receive_from_client( + self, + websocket, + state: VoiceSessionState, + env_info, + ) -> None: + """Receive from client and forward to ADK.""" + try: + while not state.stop_event.is_set(): + data = await websocket.receive_text() + + try: + request = VoiceRequest.model_validate_json(data) + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Invalid request from client: {e}") + state.stats.setdefault("errors", 0) + state.stats["errors"] += 1 + continue + + if request.end_session: + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_USER_ENDED, + ) + return + + try: + decoded = request.decode_data() + except ValueError as e: + logger.warning(f"Data decoding error: {e}") + state.stats.setdefault("errors", 0) + state.stats["errors"] += 1 + continue + + try: + if request.mime_type == "audio/pcm": + # PCM logging in non-production + if not getattr(env_info, "is_production", True): + try: + await state.chat_logger.log_pcm_audio( + user_id=state.user_id, + session_id=state.session_id, + audio_data=decoded, + ) + except (AttributeError, Exception) as e: + logger.debug(f"PCM logging skipped: {e}") + + blob = _Blob(mime_type=request.mime_type, data=decoded) + state.live_request_queue.send_realtime(blob) + state.stats.setdefault("audio_chunks_sent", 0) + state.stats["audio_chunks_sent"] += 1 + + elif request.mime_type == "text/plain": + text_data = decoded if isinstance(decoded, str) else decoded.decode("utf-8") + await state.chat_logger.log_message( + user_id=state.user_id, + session_id=state.session_id, + role="user", + content=text_data, + message_number=-1, + ) + content = _Content(parts=[_Part(text=text_data)]) + state.live_request_queue.send_content(content) + state.stats.setdefault("text_chunks_sent", 0) + state.stats["text_chunks_sent"] += 1 + + except Exception as e: + logger.error(f"Error processing client input: {e}") + state.stats.setdefault("errors", 0) + state.stats["errors"] += 1 + + except WebSocketDisconnect: + logger.info(f"Client disconnected from session {state.session_id}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + except Exception as e: + logger.error(f"Error in _receive_from_client: {e}") + await self._handle_termination( + websocket, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + + # endregion + + # region Streaming Orchestrator + + async def _handle_streaming( + self, + websocket, + state: VoiceSessionState, + env_info, + ) -> None: + """Orchestrate four concurrent coroutines with stop_event coordination.""" + tasks = [ + asyncio.create_task( + self._receive_from_client(websocket, state, env_info), + name="receive", + ), + asyncio.create_task( + self._send_to_client(websocket, state), + name="send", + ), + asyncio.create_task( + self._session_timer(websocket, state), + name="timer", + ), + asyncio.create_task( + self._heartbeat(websocket, state), + name="heartbeat", + ), + ] + + try: + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED, + ) + + # Check if stop_event was set by the completed task + if not state.stop_event.is_set(): + for task in done: + exc = task.exception() + if exc: + logger.warning( + f"Task {task.get_name()} failed unexpectedly: {exc}" + ) + state.stop_event.set() + + finally: + for task in tasks: + if not task.done(): + task.cancel() + # Wait for cancellation to complete + await asyncio.gather(*tasks, return_exceptions=True) + + # endregion + + # region Auth + + @staticmethod + async def _validate_jwt_token(token: str, storage) -> Optional[Any]: + """Validate JWT token and return user. Placeholder for spike.""" + # In production, this would use auth_manager.verify_token + # For the spike, this is just a stub + try: + from ..common.exceptions import TokenExpiredError, AuthenticationError + from ..server.dependencies import get_auth_manager + auth_manager = get_auth_manager(storage) + token_data = auth_manager.verify_token(token) + user = await storage.get_user(token_data.user_id) + return user + except Exception as e: + logger.error(f"JWT validation error: {e}") + return None + + def _check_session_limit(self, user_id: str) -> bool: + """Placeholder — always returns True.""" + return True + + # endregion + + # region Entry Point + + async def handle_voice_session( + self, + websocket, + session_id: str, + env_info, + ) -> None: + """Entry point for a voice WebSocket connection. + + Auth -> session lookup -> ADK init -> streaming -> cleanup. + """ + state = None + try: + # Extract query params + token = websocket.query_params.get("token") + if not token: + await websocket.close( + code=VoiceConfigV2.WS_MISSING_TOKEN, reason="Missing token", + ) + return + + timeout = int(websocket.query_params.get( + "timeout", VoiceConfigV2.DEFAULT_SESSION_TIMEOUT_SECONDS, + )) + warning = int(websocket.query_params.get( + "warning", VoiceConfigV2.DEFAULT_SESSION_WARNING_SECONDS, + )) + warning = max(warning, VoiceConfigV2.MIN_SESSION_WARNING_SECONDS) + + # Auth + user = await self._validate_jwt_token(token, self._storage) + if user is None: + await websocket.close( + code=VoiceConfigV2.WS_INVALID_TOKEN, reason="Invalid token", + ) + return + + if not self._check_session_limit(user.id if hasattr(user, 'id') else str(user)): + await websocket.close( + code=VoiceConfigV2.WS_INVALID_TOKEN, reason="Session limit exceeded", + ) + return + + logger.info(f"Voice v2 session {session_id} starting for user") + + # Send initial status + await websocket.send_json({ + "type": "status", "status": "connecting", + "message": "Initializing voice session", + }) + + # Initialize ADK and start streaming + state = await self._initialize_adk( + session_id=session_id, + user=user, + timeout=timeout, + warning=warning, + ) + + user_lang = getattr(user, "preferred_language", "en") + await websocket.send_json({ + "type": "config", + "audio_format": "pcm", + "sample_rate": 16000, + "channels": 1, + "bit_depth": 16, + "language": user_lang, + "session_timeout": timeout, + "warning_seconds": warning, + }) + + await state.chat_logger.log_voice_session_start( + user.id, session_id, voice_config={"language": user_lang}, + ) + + await websocket.send_json({ + "type": "status", "status": "ready", + "message": "Voice session ready", + }) + + await self._handle_streaming(websocket, state, env_info) + + except Exception as e: + logger.error(f"Unexpected error for session {session_id}: {e}", exc_info=True) + try: + await websocket.send_json({ + "type": "error", "error": str(e), + "timestamp": utc_now_isoformat(), + }) + except Exception: + pass + finally: + if state is not None: + await self._cleanup_session(state) + logger.info(f"Voice v2 session {session_id} cleanup completed") + + # endregion + + # region Session Lifecycle + + async def _cleanup_session(self, state: VoiceSessionState) -> None: + """Clean up session resources and log final stats.""" + duration = (datetime.now(timezone.utc) - state.started_at).total_seconds() + stats = { + **state.stats, + "duration_seconds": duration, + "termination_reason": state.termination_reason, + "ended_at": utc_now_isoformat(), + } + + logger.info(f"Session {state.session_id} final stats: {stats}") + + try: + await state.chat_logger.log_voice_session_end( + state.user_id, state.session_id, voice_stats=stats, + ) + except Exception as e: + logger.error(f"Failed to log session end: {e}") + + # Persist meta for timer continuity on reconnect + if state.termination_reason == VoiceConfigV2.REASON_DISCONNECTED: + try: + meta_key = VoiceConfigV2.META_KEY_TEMPLATE.format( + user_id=state.user_id, session_id=state.session_id, + ) + meta_data = json.dumps({ + "started_at": state.started_at.isoformat(), + "session_timeout": state.session_timeout, + "warning_seconds": state.warning_seconds, + }) + await self._storage.write(meta_key, meta_data) + except Exception as e: + logger.error(f"Failed to persist session meta: {e}") + + # endregion + +# endregion diff --git a/src/python/role_play/voice/voice_config_v2.py b/src/python/role_play/voice/voice_config_v2.py new file mode 100644 index 0000000..8b1012f --- /dev/null +++ b/src/python/role_play/voice/voice_config_v2.py @@ -0,0 +1,38 @@ +"""V2 voice configuration constants. + +Extends VoiceConfig with constants for the v2 spike handler: +session timer, heartbeat, sentinels, termination reasons. +Does not modify the original voice_config.py. +""" +from .voice_config import VoiceConfig + + +class VoiceConfigV2(VoiceConfig): + """Constants for voice handler v2 spike.""" + + # Session timer (per-session overridable via query params) + DEFAULT_SESSION_TIMEOUT_SECONDS = 2700 # 45 min + DEFAULT_SESSION_WARNING_SECONDS = 300 # 5 min + MIN_SESSION_WARNING_SECONDS = 10 # floor for sanity + + # Keepalive + HEARTBEAT_INTERVAL_SECONDS = 30 + + # Context window compression + CONTEXT_WINDOW_TRIGGER_TOKENS = 100_000 + CONTEXT_WINDOW_STRATEGY = "sliding_window" + + # AI termination sentinels (RPS_ prefix avoids false positives) + SENTINEL_SESSION_COMPLETE = "[RPS_SESSION_COMPLETE]" + SENTINEL_END_EARLY = "[RPS_END_EARLY:" # followed by reason] + + # Termination reasons + REASON_USER_ENDED = "USER_ENDED" + REASON_TIME_LIMIT = "TIME_LIMIT" + REASON_AI_CONCLUDED = "AI_CONCLUDED" + REASON_AI_EARLY_TERMINATION = "AI_EARLY_TERMINATION" + REASON_DISCONNECTED = "DISCONNECTED" + + # Storage key templates + HANDLE_KEY_TEMPLATE = "users/{user_id}/voice_sessions/{session_id}/gemini_handle" + META_KEY_TEMPLATE = "users/{user_id}/voice_sessions/{session_id}/session_meta" diff --git a/test/python/smoke/test_gemini_live_smoke.py b/test/python/smoke/test_gemini_live_smoke.py new file mode 100644 index 0000000..bf3d3c7 --- /dev/null +++ b/test/python/smoke/test_gemini_live_smoke.py @@ -0,0 +1,439 @@ +"""Smoke test: Gemini Live API round-trip. + +Validates that the v2 spike's assumptions about the Gemini Live API hold: + 1. google.genai types (Blob, Content, Part) match our _Blob/_Content/_Part stubs + 2. Runner.run_live yields events with the attributes we classify + 3. Session resumption handles are actually emitted + 4. LiveRequestQueue.send_content / .close work as expected + +Requires: + GOOGLE_API_KEY env var (Gemini API key) + +Usage: + # From worktree root: + GOOGLE_API_KEY= python -m pytest test/python/smoke/test_gemini_live_smoke.py -v -s + + # Or run directly: + GOOGLE_API_KEY= python test/python/smoke/test_gemini_live_smoke.py +""" +import asyncio +import os +import sys +import logging + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Skip early if no API key +# --------------------------------------------------------------------------- +API_KEY = os.environ.get("GOOGLE_API_KEY", "") +LIVE_MODEL = os.environ.get("GEMINI_LIVE_MODEL", "gemini-2.5-flash-native-audio-preview-12-2025") + +try: + import pytest + smoke = pytest.mark.smoke + skip_no_key = pytest.mark.skipif(not API_KEY, reason="GOOGLE_API_KEY not set") +except ImportError: + # Running directly (not via pytest) + def smoke(fn): + return fn + def skip_no_key(fn): + return fn + + +# --------------------------------------------------------------------------- +# Imports — these validate that google-genai + google-adk are importable +# --------------------------------------------------------------------------- +def _import_deps(): + """Import and return all ADK/genai dependencies. Raises ImportError if missing.""" + from google.adk import Runner + from google.adk.agents import LiveRequestQueue, RunConfig, Agent + from google.adk.sessions import InMemorySessionService + from google.genai import types + from google.genai.types import ( + AudioTranscriptionConfig, + Blob, + Content, + Part, + ) + return { + "Runner": Runner, + "LiveRequestQueue": LiveRequestQueue, + "RunConfig": RunConfig, + "Agent": Agent, + "InMemorySessionService": InMemorySessionService, + "types": types, + "AudioTranscriptionConfig": AudioTranscriptionConfig, + "Blob": Blob, + "Content": Content, + "Part": Part, + } + + +# --------------------------------------------------------------------------- +# 1. Type compatibility check (no API call) +# --------------------------------------------------------------------------- +@smoke +@skip_no_key +def test_type_stubs_match_real_types(): + """Verify our _Blob/_Content/_Part stubs have the same construction interface.""" + deps = _import_deps() + Blob = deps["Blob"] + Content = deps["Content"] + Part = deps["Part"] + + # Our stubs construct with these signatures — real types must accept them too + blob = Blob(mime_type="audio/pcm", data=b"\x00\x01\x02") + assert blob.mime_type == "audio/pcm" + assert blob.data == b"\x00\x01\x02" + + part = Part(text="hello") + assert part.text == "hello" + + content = Content(parts=[part]) + assert len(content.parts) == 1 + assert content.parts[0].text == "hello" + + logger.info("PASS: Type stubs match real google.genai types") + + +# --------------------------------------------------------------------------- +# 2. Live session round-trip (requires API key) +# --------------------------------------------------------------------------- +@smoke +@skip_no_key +def test_live_session_round_trip(): + """Send one text turn to Gemini Live, read back at least one event.""" + asyncio.run(_live_round_trip()) + + +async def _live_round_trip(): + deps = _import_deps() + Runner = deps["Runner"] + LiveRequestQueue = deps["LiveRequestQueue"] + RunConfig = deps["RunConfig"] + Agent = deps["Agent"] + InMemorySessionService = deps["InMemorySessionService"] + types = deps["types"] + AudioTranscriptionConfig = deps["AudioTranscriptionConfig"] + Content = deps["Content"] + Part = deps["Part"] + + # Set up a minimal agent + agent = Agent( + name="smoke_test_agent", + model=LIVE_MODEL, + instruction="You are a test assistant. Respond briefly to confirm you received the message.", + ) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="smoke_test", user_id="smoke_user", + ) + + runner = Runner( + app_name="smoke_test", agent=agent, session_service=session_service, + ) + + run_config = RunConfig( + response_modalities=[types.Modality.AUDIO], + output_audio_transcription=AudioTranscriptionConfig(), + input_audio_transcription=AudioTranscriptionConfig(), + ) + + live_queue = LiveRequestQueue() + live_events = runner.run_live( + session=session, + live_request_queue=live_queue, + run_config=run_config, + ) + + # Send a single text message + content = Content(parts=[Part(text="Say the word 'pineapple' and nothing else.")]) + live_queue.send_content(content) + + # Collect events with timeout + events_received = [] + event_kinds = set() + + async def _collect(): + async for event in live_events: + info = _describe_event(event) + events_received.append(info) + event_kinds.add(info["kind"]) + logger.info(f" Event: {info}") + + # Stop after we get a turn_complete or enough events + if info.get("turn_complete") or len(events_received) > 50: + break + + try: + await asyncio.wait_for(_collect(), timeout=30.0) + except asyncio.TimeoutError: + logger.warning(f"Timed out after 30s with {len(events_received)} events") + + live_queue.close() + + # Assertions + assert len(events_received) > 0, "Expected at least one event from Gemini Live" + logger.info(f"PASS: Received {len(events_received)} events, kinds: {event_kinds}") + + # Verify we got at least audio or transcript back + has_content = any(k in event_kinds for k in ("audio", "transcript", "text")) + assert has_content, f"Expected audio/transcript/text event, got kinds: {event_kinds}" + logger.info("PASS: Got content back from Gemini Live") + + +def _describe_event(event) -> dict: + """Describe an ADK live event in terms our classifier cares about.""" + info = { + "kind": "unknown", + "author": getattr(event, "author", None), + "turn_complete": getattr(event, "turn_complete", False), + "partial": getattr(event, "partial", False), + "interrupted": getattr(event, "interrupted", False), + } + + content = getattr(event, "content", None) + if content is None: + # Check for server_content (some event types) + server_content = getattr(event, "server_content", None) + if server_content: + content = getattr(server_content, "model_turn", None) + + if content and hasattr(content, "parts") and content.parts: + for part in content.parts: + if hasattr(part, "inline_data") and part.inline_data: + info["kind"] = "audio" + info["mime_type"] = getattr(part.inline_data, "mime_type", "unknown") + info["data_size"] = len(getattr(part.inline_data, "data", b"")) + break + elif hasattr(part, "text") and part.text: + info["kind"] = "transcript" if info.get("partial") or "transcript" in str(type(part)).lower() else "text" + info["text_preview"] = part.text[:100] + break + + # Session resumption handle + if hasattr(event, "session_resumption_update"): + update = event.session_resumption_update + if update and hasattr(update, "new_handle") and update.new_handle: + info["kind"] = "session_resumption" + info["handle_length"] = len(update.new_handle) + + # GoAway + if hasattr(event, "go_away") and event.go_away: + info["kind"] = "go_away" + + return info + + +# --------------------------------------------------------------------------- +# 3. Session resumption handle check +# --------------------------------------------------------------------------- +@smoke +@skip_no_key +def test_session_resumption_handle_emitted(): + """Verify Gemini Live emits a session_resumption_update event.""" + asyncio.run(_check_resumption_handle()) + + +async def _check_resumption_handle(): + deps = _import_deps() + Runner = deps["Runner"] + LiveRequestQueue = deps["LiveRequestQueue"] + RunConfig = deps["RunConfig"] + Agent = deps["Agent"] + InMemorySessionService = deps["InMemorySessionService"] + types = deps["types"] + AudioTranscriptionConfig = deps["AudioTranscriptionConfig"] + Content = deps["Content"] + Part = deps["Part"] + + agent = Agent( + name="handle_test_agent", + model=LIVE_MODEL, + instruction="Respond briefly.", + ) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="handle_test", user_id="handle_user", + ) + + runner = Runner( + app_name="handle_test", agent=agent, session_service=session_service, + ) + + run_config = RunConfig( + response_modalities=[types.Modality.AUDIO], + output_audio_transcription=AudioTranscriptionConfig(), + ) + + live_queue = LiveRequestQueue() + live_events = runner.run_live( + session=session, + live_request_queue=live_queue, + run_config=run_config, + ) + + # Send a message to trigger activity + content = Content(parts=[Part(text="Hello")]) + live_queue.send_content(content) + + resumption_handle = None + + async def _collect(): + nonlocal resumption_handle + async for event in live_events: + if hasattr(event, "session_resumption_update"): + update = event.session_resumption_update + if update and hasattr(update, "new_handle") and update.new_handle: + resumption_handle = update.new_handle + logger.info(f" Got resumption handle ({len(resumption_handle)} bytes)") + break + + if getattr(event, "turn_complete", False): + break + + try: + await asyncio.wait_for(_collect(), timeout=30.0) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for resumption handle") + + live_queue.close() + + # Resumption handle is not guaranteed on every session, but log the result + if resumption_handle: + logger.info(f"PASS: Session resumption handle received ({len(resumption_handle)} bytes)") + else: + logger.warning( + "SKIP: No resumption handle emitted — this is expected for short sessions. " + "The handle typically arrives after the first turn exchange." + ) + + +# --------------------------------------------------------------------------- +# 4. Classify real events with our classifier +# --------------------------------------------------------------------------- +@smoke +@skip_no_key +def test_classify_real_events(): + """Run real Gemini events through classify_adk_event and verify no crashes.""" + asyncio.run(_classify_real()) + + +async def _classify_real(): + """Collect events from Gemini and run them through the v2 spike classifier.""" + # Add source to path + src_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "src", "python") + sys.path.insert(0, os.path.abspath(src_path)) + + from role_play.voice.handler_v2_spike import classify_adk_event + + deps = _import_deps() + Runner = deps["Runner"] + LiveRequestQueue = deps["LiveRequestQueue"] + RunConfig = deps["RunConfig"] + Agent = deps["Agent"] + InMemorySessionService = deps["InMemorySessionService"] + types = deps["types"] + AudioTranscriptionConfig = deps["AudioTranscriptionConfig"] + Content = deps["Content"] + Part = deps["Part"] + + agent = Agent( + name="classify_test_agent", + model=LIVE_MODEL, + instruction="Say hello briefly.", + ) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="classify_test", user_id="classify_user", + ) + + runner = Runner( + app_name="classify_test", agent=agent, session_service=session_service, + ) + + run_config = RunConfig( + response_modalities=[types.Modality.AUDIO], + output_audio_transcription=AudioTranscriptionConfig(), + input_audio_transcription=AudioTranscriptionConfig(), + ) + + live_queue = LiveRequestQueue() + live_events = runner.run_live( + session=session, + live_request_queue=live_queue, + run_config=run_config, + ) + + content = Content(parts=[Part(text="Say hello")]) + live_queue.send_content(content) + + classifications = [] + + async def _collect(): + async for event in live_events: + try: + classification = classify_adk_event(event) + classifications.append(classification) + logger.info( + f" {classification.kind:20s} | " + f"role={classification.role or '-':10s} | " + f"partial={classification.is_partial} | " + f"terminal={classification.is_terminal}" + ) + except Exception as e: + logger.error(f" classify_adk_event CRASHED on event: {e}") + logger.error(f" Event type: {type(event).__name__}, attrs: {dir(event)}") + raise + + if getattr(event, "turn_complete", False): + break + + try: + await asyncio.wait_for(_collect(), timeout=30.0) + except asyncio.TimeoutError: + logger.warning(f"Timed out with {len(classifications)} classifications") + + live_queue.close() + + assert len(classifications) > 0, "Expected at least one classified event" + + # Count by kind + kinds = {} + for c in classifications: + kinds[c.kind] = kinds.get(c.kind, 0) + 1 + logger.info(f"PASS: Classified {len(classifications)} events: {kinds}") + + +# --------------------------------------------------------------------------- +# Direct execution +# --------------------------------------------------------------------------- +if __name__ == "__main__": + if not API_KEY: + print("ERROR: Set GOOGLE_API_KEY environment variable") + sys.exit(1) + + print("=" * 60) + print("Gemini Live API Smoke Tests") + print("=" * 60) + + print("\n--- Test 1: Type stub compatibility ---") + test_type_stubs_match_real_types() + + print("\n--- Test 2: Live session round-trip ---") + test_live_session_round_trip() + + print("\n--- Test 3: Session resumption handle ---") + test_session_resumption_handle_emitted() + + print("\n--- Test 4: Classify real events ---") + test_classify_real_events() + + print("\n" + "=" * 60) + print("All smoke tests complete!") + print("=" * 60) diff --git a/test/python/unit/voice/test_handler_v2_spike.py b/test/python/unit/voice/test_handler_v2_spike.py new file mode 100644 index 0000000..ee8160c --- /dev/null +++ b/test/python/unit/voice/test_handler_v2_spike.py @@ -0,0 +1,778 @@ +"""Tests for voice handler v2 spike — dataclasses and event classification.""" +import asyncio +import base64 +import pytest +from datetime import datetime, timezone +from unittest.mock import MagicMock, AsyncMock, patch + +from starlette.websockets import WebSocketDisconnect + +from role_play.voice.handler_v2_spike import ( + VoiceSessionState, + EventClassification, + classify_adk_event, + VoiceHandlerV2, +) +from role_play.voice.voice_config_v2 import VoiceConfigV2 + + +class TestVoiceSessionState: + def test_creation_with_required_fields(self): + state = VoiceSessionState( + session_id="sess-1", + user_id="user-1", + runner=MagicMock(), + live_events=MagicMock(), + live_request_queue=MagicMock(), + adk_session=MagicMock(), + stop_event=asyncio.Event(), + termination_reason=None, + started_at=datetime.now(timezone.utc), + session_timeout=60, + warning_seconds=10, + chat_logger=MagicMock(), + transcript_buffer=[], + stats={"audio_chunks_sent": 0, "errors": 0}, + ) + assert state.session_id == "sess-1" + assert state.termination_reason is None + assert state.transcript_buffer == [] + assert not state.stop_event.is_set() + + def test_started_at_is_datetime(self): + now = datetime.now(timezone.utc) + state = VoiceSessionState( + session_id="s", user_id="u", runner=MagicMock(), + live_events=MagicMock(), live_request_queue=MagicMock(), + adk_session=MagicMock(), stop_event=asyncio.Event(), + termination_reason=None, started_at=now, + session_timeout=60, warning_seconds=10, + chat_logger=MagicMock(), transcript_buffer=[], stats={}, + ) + assert isinstance(state.started_at, datetime) + assert state.started_at.tzinfo is not None + + def test_stop_event_can_be_set(self): + event = asyncio.Event() + state = VoiceSessionState( + session_id="s", user_id="u", runner=MagicMock(), + live_events=MagicMock(), live_request_queue=MagicMock(), + adk_session=MagicMock(), stop_event=event, + termination_reason=None, started_at=datetime.now(timezone.utc), + session_timeout=60, warning_seconds=10, + chat_logger=MagicMock(), transcript_buffer=[], stats={}, + ) + assert not state.stop_event.is_set() + state.stop_event.set() + assert state.stop_event.is_set() + + +class TestEventClassification: + def test_audio_event(self): + ec = EventClassification(kind="audio", data=b"\x01\x02") + assert ec.kind == "audio" + assert ec.role is None + assert ec.is_terminal is False + + def test_transcript_partial(self): + ec = EventClassification( + kind="transcript", data="Hello", role="user", is_partial=True + ) + assert ec.is_partial is True + assert ec.role == "user" + + def test_terminal_event(self): + ec = EventClassification( + kind="transcript", data="Goodbye", + role="assistant", is_terminal=True, + terminal_reason="AI_CONCLUDED", + ) + assert ec.is_terminal is True + assert ec.terminal_reason == "AI_CONCLUDED" + + def test_unknown_event(self): + ec = EventClassification(kind="unknown", data={"raw": "stuff"}) + assert ec.kind == "unknown" + + +def _make_event(**kwargs): + """Build a mock ADK event with specified attributes.""" + event = MagicMock() + for attr in ["content", "turn_complete", "interrupted", "partial", + "session_resumption_update", "go_away"]: + if attr not in kwargs: + delattr(event, attr) if hasattr(event, attr) else None + for k, v in kwargs.items(): + setattr(event, k, v) + return event + + +class TestClassifyAdkEvent: + """Tests for the sync event classifier.""" + + def test_transcript_final_from_model(self): + part = MagicMock() + part.text = "Hello there" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = _make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.data == "Hello there" + assert result.role == "assistant" + assert result.is_partial is False + assert result.is_terminal is False + + def test_transcript_partial(self): + part = MagicMock() + part.text = "Hel" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = _make_event(content=content, partial=True) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.is_partial is True + + def test_transcript_from_user(self): + part = MagicMock() + part.text = "My answer" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "user" + event = _make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.role == "user" + + def test_audio_event(self): + part = MagicMock() + del part.text # no text attribute + part.inline_data = MagicMock(data=b"\x01\x02", mime_type="audio/pcm") + content = MagicMock() + content.parts = [part] + event = _make_event(content=content) + + result = classify_adk_event(event) + assert result.kind == "audio" + assert result.data["audio_data"] == b"\x01\x02" + assert result.data["mime_type"] == "audio/pcm" + + def test_sentinel_session_complete(self): + part = MagicMock() + part.text = "Thank you for your time. [RPS_SESSION_COMPLETE]" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = _make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.kind == "transcript" + assert result.is_terminal is True + assert result.terminal_reason == VoiceConfigV2.REASON_AI_CONCLUDED + assert VoiceConfigV2.SENTINEL_SESSION_COMPLETE not in result.data + + def test_sentinel_end_early(self): + part = MagicMock() + part.text = "I need to stop. [RPS_END_EARLY:off_topic]" + part.inline_data = None + content = MagicMock() + content.parts = [part] + content.role = "model" + event = _make_event(content=content, partial=False) + + result = classify_adk_event(event) + assert result.is_terminal is True + assert result.terminal_reason == VoiceConfigV2.REASON_AI_EARLY_TERMINATION + + def test_turn_status_complete(self): + event = _make_event(turn_complete=True, interrupted=False) + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "turn_status" + assert result.data["turn_complete"] is True + + def test_turn_status_interrupted(self): + event = _make_event(turn_complete=False, interrupted=True) + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "turn_status" + assert result.data["interrupted"] is True + + def test_session_resumption_event(self): + event = _make_event(session_resumption_update="handle-abc-123") + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "session_resumption" + assert result.data == "handle-abc-123" + + def test_go_away_event(self): + event = _make_event(go_away=MagicMock()) + if hasattr(event, "content"): + event.content = None + + result = classify_adk_event(event) + assert result.kind == "go_away" + + def test_unknown_event(self): + """Event with no recognizable attributes -> unknown.""" + event = MagicMock(spec=[]) # empty spec = no attributes + result = classify_adk_event(event) + assert result.kind == "unknown" + + def test_empty_content_parts(self): + content = MagicMock() + content.parts = [] + event = _make_event(content=content) + + result = classify_adk_event(event) + assert result.kind in ("turn_status", "unknown") + + def test_none_content(self): + event = _make_event(content=None, turn_complete=True) + result = classify_adk_event(event) + assert result.kind == "turn_status" + + +def _make_state(**overrides) -> VoiceSessionState: + """Factory for VoiceSessionState with sensible defaults.""" + defaults = dict( + session_id="sess-1", user_id="user-1", + runner=MagicMock(), live_events=MagicMock(), + live_request_queue=MagicMock(), adk_session=MagicMock(), + stop_event=asyncio.Event(), termination_reason=None, + started_at=datetime.now(timezone.utc), + session_timeout=60, warning_seconds=10, + chat_logger=AsyncMock(), transcript_buffer=[], stats={}, + ) + defaults.update(overrides) + return VoiceSessionState(**defaults) + + +class TestHandleTermination: + @pytest.fixture + def handler(self): + storage = AsyncMock() + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_sets_reason_and_stop_event(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "USER_ENDED") + assert state.termination_reason == "USER_ENDED" + assert state.stop_event.is_set() + + @pytest.mark.asyncio + async def test_sends_session_ended_to_frontend(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "TIME_LIMIT") + ws.send_json.assert_called_once() + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "session_ended" + assert msg["reason"] == "TIME_LIMIT" + + @pytest.mark.asyncio + async def test_closes_queue_for_non_disconnect(self, handler): + queue = MagicMock() + state = _make_state(live_request_queue=queue) + await handler._handle_termination(AsyncMock(), state, "USER_ENDED") + queue.close.assert_called_once() + + @pytest.mark.asyncio + async def test_keeps_queue_open_for_disconnected(self, handler): + queue = MagicMock() + state = _make_state(live_request_queue=queue) + await handler._handle_termination(AsyncMock(), state, "DISCONNECTED") + queue.close.assert_not_called() + + @pytest.mark.asyncio + async def test_guard_prevents_double_fire(self, handler): + ws = AsyncMock() + state = _make_state() + state.stop_event.set() # already set + state.termination_reason = "USER_ENDED" # already has reason + await handler._handle_termination(ws, state, "TIME_LIMIT") + assert state.termination_reason == "USER_ENDED" + ws.send_json.assert_not_called() + + @pytest.mark.asyncio + async def test_handles_websocket_send_failure(self, handler): + ws = AsyncMock() + ws.send_json.side_effect = ConnectionError("closed") + state = _make_state() + await handler._handle_termination(ws, state, "DISCONNECTED") + assert state.termination_reason == "DISCONNECTED" + assert state.stop_event.is_set() + + @pytest.mark.asyncio + async def test_detail_included_in_message(self, handler): + ws = AsyncMock() + state = _make_state() + await handler._handle_termination(ws, state, "AI_EARLY_TERMINATION", detail="off_topic") + msg = ws.send_json.call_args[0][0] + assert msg["detail"] == "off_topic" + + +class TestSessionTimer: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_timer_terminates_immediately_when_already_expired(self, handler): + """Reconnect to a session that's already past its deadline.""" + from datetime import timedelta + past = datetime.now(timezone.utc) - timedelta(seconds=3600) + state = _make_state(session_timeout=60, warning_seconds=10, started_at=past) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_TIME_LIMIT + ) + handler._inject_time_warning.assert_not_called() + + @pytest.mark.asyncio + async def test_timer_fires_warning_immediately_when_past_warning_threshold(self, handler): + """Reconnect to a session that's past the warning point but not expired.""" + from datetime import timedelta + # 55s into a 60s session with 10s warning — past the 50s warning point + past = datetime.now(timezone.utc) - timedelta(seconds=55) + state = _make_state(session_timeout=60, warning_seconds=10, started_at=past) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._inject_time_warning.assert_called_once_with(state) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_TIME_LIMIT + ) + + @pytest.mark.asyncio + async def test_timer_fires_termination_after_timeout(self, handler): + state = _make_state(session_timeout=1, warning_seconds=0) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_TIME_LIMIT + ) + + @pytest.mark.asyncio + async def test_timer_exits_early_when_stop_event_set(self, handler): + state = _make_state(session_timeout=10, warning_seconds=2) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + async def set_stop(): + await asyncio.sleep(0.05) + state.stop_event.set() + asyncio.create_task(set_stop()) + + await handler._session_timer(ws, state) + handler._handle_termination.assert_not_called() + + @pytest.mark.asyncio + async def test_timer_calls_inject_warning(self, handler): + state = _make_state(session_timeout=1, warning_seconds=0) + handler._inject_time_warning = AsyncMock() + handler._handle_termination = AsyncMock() + ws = AsyncMock() + + await handler._session_timer(ws, state) + handler._inject_time_warning.assert_called_once_with(state) + + +class TestInjectTimeWarning: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_injects_content_into_queue(self, handler): + queue = MagicMock() + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + await handler._inject_time_warning(state) + queue.send_content.assert_called_once() + content_arg = queue.send_content.call_args[0][0] + assert "5 minutes remaining" in content_arg.parts[0].text + + @pytest.mark.asyncio + async def test_sets_session_state_time_warning(self, handler): + queue = MagicMock() + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + await handler._inject_time_warning(state) + assert session.state.get("time_warning") is True + + @pytest.mark.asyncio + async def test_survives_queue_failure(self, handler): + queue = MagicMock() + queue.send_content.side_effect = RuntimeError("queue closed") + session = MagicMock() + session.state = {} + state = _make_state(live_request_queue=queue, adk_session=session) + + # Should not raise + await handler._inject_time_warning(state) + + +class TestHeartbeat: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_heartbeat_exits_when_stop_event_set(self, handler): + ws = AsyncMock() + state = _make_state() + state.stop_event.set() + + await handler._heartbeat(ws, state) + ws.send_json.assert_not_called() + + @pytest.mark.asyncio + async def test_heartbeat_sends_json(self, handler): + ws = AsyncMock() + state = _make_state() + + async def stop_after_one(): + await asyncio.sleep(0.05) + state.stop_event.set() + + with patch.object(VoiceConfigV2, "HEARTBEAT_INTERVAL_SECONDS", 0.01): + asyncio.create_task(stop_after_one()) + await handler._heartbeat(ws, state) + + assert ws.send_json.call_count >= 1 + msg = ws.send_json.call_args_list[0][0][0] + assert msg["type"] == "heartbeat" + assert "timestamp" in msg + + @pytest.mark.asyncio + async def test_heartbeat_calls_termination_on_connection_error(self, handler): + ws = AsyncMock() + ws.send_json.side_effect = ConnectionError("closed") + state = _make_state() + handler._handle_termination = AsyncMock() + + with patch.object(VoiceConfigV2, "HEARTBEAT_INTERVAL_SECONDS", 0.01): + with pytest.raises(ConnectionError): + await handler._heartbeat(ws, state) + + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_DISCONNECTED + ) + + +class TestSendToClient: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.write = AsyncMock() + return VoiceHandlerV2(storage) + + def _make_live_events(self, events): + """Create an async generator from a list of events.""" + async def gen(): + for e in events: + yield e + return gen() + + @pytest.mark.asyncio + async def test_audio_sent_to_frontend(self, handler): + ws = AsyncMock() + part = MagicMock() + del part.text + part.inline_data = MagicMock(data=b"\x01\x02", mime_type="audio/pcm") + event = _make_event(content=MagicMock(parts=[part])) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + assert ws.send_json.call_count == 1 + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "audio" + + @pytest.mark.asyncio + async def test_transcript_not_sent_to_frontend(self, handler): + ws = AsyncMock() + part = MagicMock(text="Hello", inline_data=None) + event = _make_event(content=MagicMock(parts=[part], role="model"), partial=False) + + chat_logger = AsyncMock() + state = _make_state( + live_events=self._make_live_events([event]), + chat_logger=chat_logger, + ) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + # Transcript must NOT be sent to frontend + ws.send_json.assert_not_called() + # But must be logged + chat_logger.log_voice_message.assert_called_once() + + @pytest.mark.asyncio + async def test_transcript_appended_to_buffer(self, handler): + ws = AsyncMock() + part = MagicMock(text="Test transcript", inline_data=None) + event = _make_event(content=MagicMock(parts=[part], role="model"), partial=False) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + assert len(state.transcript_buffer) == 1 + assert state.transcript_buffer[0]["text"] == "Test transcript" + + @pytest.mark.asyncio + async def test_session_resumption_writes_to_storage(self, handler): + ws = AsyncMock() + event = MagicMock(spec=[]) + event.session_resumption_update = "handle-xyz" + event.content = None + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + handler._storage.write.assert_called_once() + key = handler._storage.write.call_args[0][0] + assert "gemini_handle" in key + + @pytest.mark.asyncio + async def test_terminal_event_triggers_termination(self, handler): + ws = AsyncMock() + part = MagicMock(text="Goodbye [RPS_SESSION_COMPLETE]", inline_data=None) + event = _make_event(content=MagicMock(parts=[part], role="model"), partial=False) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + handler._handle_termination.assert_called_once() + call_args = handler._handle_termination.call_args + assert call_args[0][2] == VoiceConfigV2.REASON_AI_CONCLUDED + + @pytest.mark.asyncio + async def test_turn_status_sent_to_frontend(self, handler): + ws = AsyncMock() + event = _make_event(content=None, turn_complete=True, interrupted=False) + + state = _make_state(live_events=self._make_live_events([event])) + handler._handle_termination = AsyncMock() + + await handler._send_to_client(ws, state) + ws.send_json.assert_called_once() + msg = ws.send_json.call_args[0][0] + assert msg["type"] == "turn_status" + + +class TestReceiveFromClient: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_end_session_triggers_user_ended(self, handler): + ws = AsyncMock() + request_json = '{"mime_type": "text/plain", "data": "dGVzdA==", "end_session": true}' + ws.receive_text = AsyncMock(return_value=request_json) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_USER_ENDED, + ) + + @pytest.mark.asyncio + async def test_audio_forwarded_to_adk(self, handler): + ws = AsyncMock() + audio_b64 = base64.b64encode(b"\x01\x02").decode() + request_json = f'{{"mime_type": "audio/pcm", "data": "{audio_b64}", "end_session": false}}' + + call_count = 0 + async def receive_then_stop(self_=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return request_json + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=receive_then_stop) + queue = MagicMock() + state = _make_state(live_request_queue=queue) + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + queue.send_realtime.assert_called_once() + assert state.stats.get("audio_chunks_sent", 0) == 1 + + @pytest.mark.asyncio + async def test_text_forwarded_to_adk(self, handler): + ws = AsyncMock() + text_b64 = base64.b64encode(b"Hello").decode() + request_json = f'{{"mime_type": "text/plain", "data": "{text_b64}", "end_session": false}}' + + call_count = 0 + async def receive_then_stop(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return request_json + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=receive_then_stop) + queue = MagicMock() + chat_logger = AsyncMock() + state = _make_state(live_request_queue=queue, chat_logger=chat_logger) + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + queue.send_content.assert_called_once() + chat_logger.log_message.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect_triggers_disconnected(self, handler): + ws = AsyncMock() + ws.receive_text = AsyncMock(side_effect=WebSocketDisconnect(code=1000)) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + handler._handle_termination.assert_called_once_with( + ws, state, VoiceConfigV2.REASON_DISCONNECTED, + ) + + @pytest.mark.asyncio + async def test_invalid_json_increments_errors(self, handler): + ws = AsyncMock() + call_count = 0 + async def bad_then_disconnect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "not valid json" + raise WebSocketDisconnect(code=1000) + + ws.receive_text = AsyncMock(side_effect=bad_then_disconnect) + state = _make_state() + handler._handle_termination = AsyncMock() + env_info = MagicMock(is_production=True) + + await handler._receive_from_client(ws, state, env_info) + assert state.stats.get("errors", 0) >= 1 + + +class TestHandleStreaming: + @pytest.fixture + def handler(self): + return VoiceHandlerV2(AsyncMock()) + + @pytest.mark.asyncio + async def test_orchestrator_cancels_remaining_on_stop(self, handler): + ws = AsyncMock() + state = _make_state() + env_info = MagicMock(is_production=True) + + # Mock all four coroutines — receive terminates immediately + async def fast_receive(*a, **kw): + state.stop_event.set() + state.termination_reason = "USER_ENDED" + + handler._receive_from_client = AsyncMock(side_effect=fast_receive) + handler._send_to_client = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + handler._session_timer = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + handler._heartbeat = AsyncMock(side_effect=lambda *a, **kw: asyncio.sleep(10)) + + await handler._handle_streaming(ws, state, env_info) + assert state.stop_event.is_set() + + +class TestCleanupSession: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.write = AsyncMock() + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_logs_voice_session_end(self, handler): + chat_logger = AsyncMock() + state = _make_state(chat_logger=chat_logger, termination_reason="USER_ENDED") + + await handler._cleanup_session(state) + chat_logger.log_voice_session_end.assert_called_once() + + @pytest.mark.asyncio + async def test_persists_meta_on_disconnect(self, handler): + state = _make_state(termination_reason="DISCONNECTED") + + await handler._cleanup_session(state) + # Should write session_meta for timer continuity + handler._storage.write.assert_called() + key = handler._storage.write.call_args[0][0] + assert "session_meta" in key + + +class TestHandleVoiceSession: + @pytest.fixture + def handler(self): + storage = AsyncMock() + storage.exists = AsyncMock(return_value=False) + return VoiceHandlerV2(storage) + + @pytest.mark.asyncio + async def test_missing_token_closes_websocket(self, handler): + ws = AsyncMock() + ws.query_params = {} + env_info = MagicMock() + + await handler.handle_voice_session(ws, "sess-1", env_info) + ws.close.assert_called_once() + + @pytest.mark.asyncio + async def test_invalid_token_closes_websocket(self, handler): + ws = AsyncMock() + ws.query_params = {"token": "bad", "timeout": "60", "warning": "10"} + + with patch.object( + VoiceHandlerV2, "_validate_jwt_token", + new_callable=AsyncMock, return_value=None, + ): + env_info = MagicMock() + await handler.handle_voice_session(ws, "sess-1", env_info) + ws.close.assert_called_once() diff --git a/test/python/unit/voice/test_voice_config_v2.py b/test/python/unit/voice/test_voice_config_v2.py new file mode 100644 index 0000000..9d41bcf --- /dev/null +++ b/test/python/unit/voice/test_voice_config_v2.py @@ -0,0 +1,47 @@ +"""Tests for VoiceConfigV2 constants.""" +import pytest +from role_play.voice.voice_config_v2 import VoiceConfigV2 + + +def test_session_timeout_default(): + assert VoiceConfigV2.DEFAULT_SESSION_TIMEOUT_SECONDS == 2700 + + +def test_session_warning_default(): + assert VoiceConfigV2.DEFAULT_SESSION_WARNING_SECONDS == 300 + + +def test_warning_floor(): + assert VoiceConfigV2.MIN_SESSION_WARNING_SECONDS == 10 + + +def test_heartbeat_interval(): + assert VoiceConfigV2.HEARTBEAT_INTERVAL_SECONDS == 30 + + +def test_context_window_trigger_tokens(): + assert VoiceConfigV2.CONTEXT_WINDOW_TRIGGER_TOKENS == 100_000 + + +def test_sentinel_patterns(): + assert "RPS_SESSION_COMPLETE" in VoiceConfigV2.SENTINEL_SESSION_COMPLETE + assert "RPS_END_EARLY" in VoiceConfigV2.SENTINEL_END_EARLY + + +def test_termination_reasons_are_strings(): + reasons = [ + VoiceConfigV2.REASON_USER_ENDED, + VoiceConfigV2.REASON_TIME_LIMIT, + VoiceConfigV2.REASON_AI_CONCLUDED, + VoiceConfigV2.REASON_AI_EARLY_TERMINATION, + VoiceConfigV2.REASON_DISCONNECTED, + ] + assert all(isinstance(r, str) for r in reasons) + + +def test_inherits_v1_audio_constants(): + """V2 re-exports V1 audio constants for compatibility.""" + assert VoiceConfigV2.AUDIO_SAMPLE_RATE == 16000 + assert VoiceConfigV2.AUDIO_CHANNELS == 1 + assert VoiceConfigV2.AUDIO_BIT_DEPTH == 16 + assert VoiceConfigV2.AUDIO_FORMAT == "pcm"