From 9ef30f368f7e584cf2c33e028b5d2a9aa0141567 Mon Sep 17 00:00:00 2001 From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com> Date: Wed, 28 Jan 2026 08:57:55 +0000 Subject: [PATCH 1/2] Refactor Home Assistant plugin and MCP client for improved configuration and error handling - Updated HomeAssistantPlugin to standardize string usage for configuration parameters. - Enhanced MCPClient to improve error handling and logging during memory operations. - Refactored OpenMemoryMCPService to streamline memory entry conversion and improve metadata handling. - Improved transcription job handling in transcription_jobs.py for better error reporting and session management. - Updated mock-services.yml to change model_url for testing compatibility with Docker environments. --- .../plugins/homeassistant/plugin.py | 249 ++++++++--------- .../services/memory/providers/mcp_client.py | 258 +++++++++--------- .../memory/providers/openmemory_mcp.py | 219 ++++++++------- .../workers/transcription_jobs.py | 186 +++++++------ tests/configs/mock-services.yml | 2 +- 5 files changed, 478 insertions(+), 436 deletions(-) diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py index 13683194..d456e89e 100644 --- a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py +++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py @@ -28,7 +28,7 @@ class HomeAssistantPlugin(BasePlugin): -> Returns: PluginResult with "I've turned off the hall light" """ - SUPPORTED_ACCESS_LEVELS: List[str] = ['transcript'] + SUPPORTED_ACCESS_LEVELS: List[str] = ["transcript"] name = "Home Assistant" description = "Wake word device control with Home Assistant integration" @@ -56,10 +56,10 @@ def __init__(self, config: Dict[str, Any]): self.cache_initialized = False # Configuration - self.ha_url = config.get('ha_url', 'http://localhost:8123') - self.ha_token = config.get('ha_token', '') - self.wake_word = config.get('wake_word', 'vivi') - self.timeout = config.get('timeout', 30) + self.ha_url = config.get("ha_url", "http://localhost:8123") + self.ha_token = config.get("ha_token", "") + self.wake_word = config.get("wake_word", "vivi") + self.timeout = config.get("timeout", 30) async def initialize(self): """ @@ -81,9 +81,7 @@ async def initialize(self): # Create MCP client (used for REST API calls, not MCP protocol) self.mcp_client = HAMCPClient( - base_url=self.ha_url, - token=self.ha_token, - timeout=self.timeout + base_url=self.ha_url, token=self.ha_token, timeout=self.timeout ) # Test basic API connectivity with Template API @@ -140,21 +138,17 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: should_continue=False ) """ - command = context.data.get('command', '') + command = context.data.get("command", "") if not command: - return PluginResult( - success=False, - message="No command provided", - should_continue=True - ) + return PluginResult(success=False, message="No command provided", should_continue=True) if not self.mcp_client: logger.error("MCP client not initialized") return PluginResult( success=False, message="Sorry, Home Assistant is not connected", - should_continue=True + should_continue=True, ) try: @@ -166,7 +160,7 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: return PluginResult( success=False, message="Sorry, I couldn't understand that command", - should_continue=True + should_continue=True, ) # Step 2: Resolve entities from parsed command @@ -174,47 +168,38 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: entity_ids = await self._resolve_entities(parsed) except ValueError as e: logger.warning(f"Entity resolution failed: {e}") - return PluginResult( - success=False, - message=str(e), - should_continue=True - ) + return PluginResult(success=False, message=str(e), should_continue=True) # Step 3: Determine service and domain # Extract domain from first entity (all should have same domain for area-based) - domain = entity_ids[0].split('.')[0] if entity_ids else 'light' + domain = entity_ids[0].split(".")[0] if entity_ids else "light" # Map action to service name service_map = { - 'turn_on': 'turn_on', - 'turn_off': 'turn_off', - 'toggle': 'toggle', - 'set_brightness': 'turn_on', # brightness uses turn_on with params - 'set_color': 'turn_on' # color uses turn_on with params + "turn_on": "turn_on", + "turn_off": "turn_off", + "toggle": "toggle", + "set_brightness": "turn_on", # brightness uses turn_on with params + "set_color": "turn_on", # color uses turn_on with params } - service = service_map.get(parsed.action, 'turn_on') + service = service_map.get(parsed.action, "turn_on") # Step 4: Call Home Assistant service - logger.info( - f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}" - ) + logger.info(f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}") result = await self.mcp_client.call_service( - domain=domain, - service=service, - entity_ids=entity_ids, - **parsed.parameters + domain=domain, service=service, entity_ids=entity_ids, **parsed.parameters ) # Step 5: Format user-friendly response entity_type_name = parsed.entity_type or domain - if parsed.target_type == 'area': + if parsed.target_type == "area": message = ( f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} " f"in {parsed.target}" ) - elif parsed.target_type == 'all_in_area': + elif parsed.target_type == "all_in_area": message = ( f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} " f"entities in {parsed.target}" @@ -227,14 +212,14 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: return PluginResult( success=True, data={ - 'action': parsed.action, - 'entity_ids': entity_ids, - 'target_type': parsed.target_type, - 'target': parsed.target, - 'ha_result': result + "action": parsed.action, + "entity_ids": entity_ids, + "target_type": parsed.target_type, + "target": parsed.target, + "ha_result": result, }, message=message, - should_continue=False # Stop normal processing - HA command handled + should_continue=False, # Stop normal processing - HA command handled ) except MCPError as e: @@ -242,14 +227,14 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]: return PluginResult( success=False, message=f"Sorry, Home Assistant couldn't execute that: {e}", - should_continue=True + should_continue=True, ) except Exception as e: logger.error(f"Command execution failed: {e}", exc_info=True) return PluginResult( success=False, message="Sorry, something went wrong while executing that command", - should_continue=True + should_continue=True, ) async def cleanup(self): @@ -298,23 +283,23 @@ async def _refresh_cache(self): # Create cache from datetime import datetime + self.entity_cache = EntityCache( areas=areas, area_entities=area_entities, entity_details=entity_details, - last_refresh=datetime.now() + last_refresh=datetime.now(), ) logger.info( - f"Entity cache refreshed: {len(areas)} areas, " - f"{len(entity_details)} entities" + f"Entity cache refreshed: {len(areas)} areas, " f"{len(entity_details)} entities" ) except Exception as e: logger.error(f"Failed to refresh entity cache: {e}", exc_info=True) raise - async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand']: + async def _parse_command_with_llm(self, command: str) -> Optional["ParsedCommand"]: """ Parse command using LLM with structured system prompt. @@ -336,6 +321,7 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand """ try: from advanced_omi_backend.llm_client import get_llm_client + from .command_parser import COMMAND_PARSER_SYSTEM_PROMPT, ParsedCommand llm_client = get_llm_client() @@ -347,36 +333,36 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand model=llm_client.model, messages=[ {"role": "system", "content": COMMAND_PARSER_SYSTEM_PROMPT}, - {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'} + {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'}, ], temperature=0.1, - max_tokens=150 + max_tokens=150, ) result_text = response.choices[0].message.content.strip() logger.debug(f"LLM response: {result_text}") # Remove markdown code blocks if present - if result_text.startswith('```'): - lines = result_text.split('\n') - result_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else result_text + if result_text.startswith("```"): + lines = result_text.split("\n") + result_text = "\n".join(lines[1:-1]) if len(lines) > 2 else result_text result_text = result_text.strip() # Parse JSON response result_json = json.loads(result_text) # Validate required fields - required_fields = ['action', 'target_type', 'target'] + required_fields = ["action", "target_type", "target"] if not all(field in result_json for field in required_fields): logger.warning(f"LLM response missing required fields: {result_json}") return None parsed = ParsedCommand( - action=result_json['action'], - target_type=result_json['target_type'], - target=result_json['target'], - entity_type=result_json.get('entity_type'), - parameters=result_json.get('parameters', {}) + action=result_json["action"], + target_type=result_json["target_type"], + target=result_json["target"], + entity_type=result_json.get("entity_type"), + parameters=result_json.get("parameters", {}), ) logger.info( @@ -394,7 +380,7 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand logger.error(f"LLM command parsing failed: {e}", exc_info=True) return None - async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: + async def _resolve_entities(self, parsed: "ParsedCommand") -> List[str]: """ Resolve ParsedCommand to actual Home Assistant entity IDs. @@ -424,11 +410,10 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: if not self.entity_cache: raise ValueError("Entity cache not initialized") - if parsed.target_type == 'area': + if parsed.target_type == "area": # Get entities in area, filtered by type entities = self.entity_cache.get_entities_in_area( - area=parsed.target, - entity_type=parsed.entity_type + area=parsed.target, entity_type=parsed.entity_type ) if not entities: @@ -444,12 +429,9 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: ) return entities - elif parsed.target_type == 'all_in_area': + elif parsed.target_type == "all_in_area": # Get ALL entities in area (no filter) - entities = self.entity_cache.get_entities_in_area( - area=parsed.target, - entity_type=None - ) + entities = self.entity_cache.get_entities_in_area(area=parsed.target, entity_type=None) if not entities: raise ValueError( @@ -460,7 +442,7 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]: logger.info(f"Resolved 'all in {parsed.target}' to {len(entities)} entities") return entities - elif parsed.target_type == 'entity': + elif parsed.target_type == "entity": # Fuzzy match entity by name entity_id = self.entity_cache.find_entity_by_name(parsed.target) @@ -501,37 +483,35 @@ async def _parse_command_fallback(self, command: str) -> Optional[Dict[str, Any] # Determine action tool = None - if any(word in command_lower for word in ['turn off', 'off', 'disable']): - tool = 'turn_off' - action_desc = 'turned off' - elif any(word in command_lower for word in ['turn on', 'on', 'enable']): - tool = 'turn_on' - action_desc = 'turned on' - elif 'toggle' in command_lower: - tool = 'toggle' - action_desc = 'toggled' + if any(word in command_lower for word in ["turn off", "off", "disable"]): + tool = "turn_off" + action_desc = "turned off" + elif any(word in command_lower for word in ["turn on", "on", "enable"]): + tool = "turn_on" + action_desc = "turned on" + elif "toggle" in command_lower: + tool = "toggle" + action_desc = "toggled" else: logger.warning(f"Unknown action in command: {command}") return None # Extract entity name from command entity_query = command_lower - for action_word in ['turn off', 'turn on', 'toggle', 'off', 'on', 'the']: - entity_query = entity_query.replace(action_word, '').strip() + for action_word in ["turn off", "turn on", "toggle", "off", "on", "the"]: + entity_query = entity_query.replace(action_word, "").strip() logger.info(f"Searching for entity: '{entity_query}'") # Return placeholder (this will work if entity ID matches pattern) return { "tool": tool, - "arguments": { - "entity_id": f"light.{entity_query.replace(' ', '_')}" - }, + "arguments": {"entity_id": f"light.{entity_query.replace(' ', '_')}"}, "friendly_name": entity_query.title(), - "action_desc": action_desc + "action_desc": action_desc, } - async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']: + async def _parse_command_hybrid(self, command: str) -> Optional["ParsedCommand"]: """ Hybrid command parser: Try LLM first, fallback to keywords. @@ -550,15 +530,13 @@ async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand'] ParsedCommand(action="turn_off", target_type="area", target="study", ...) """ import asyncio + from .command_parser import ParsedCommand # Try LLM parsing with timeout try: logger.debug("Attempting LLM-based command parsing...") - parsed = await asyncio.wait_for( - self._parse_command_with_llm(command), - timeout=5.0 - ) + parsed = await asyncio.wait_for(self._parse_command_with_llm(command), timeout=5.0) if parsed: logger.info("LLM parsing succeeded") @@ -581,16 +559,16 @@ async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand'] # Convert fallback format to ParsedCommand # Extract entity_id from arguments - entity_id = fallback_result['arguments'].get('entity_id', '') - entity_name = entity_id.split('.', 1)[1] if '.' in entity_id else entity_id + entity_id = fallback_result["arguments"].get("entity_id", "") + entity_name = entity_id.split(".", 1)[1] if "." in entity_id else entity_id # Simple heuristic: assume it's targeting a single entity parsed = ParsedCommand( - action=fallback_result['tool'], - target_type='entity', - target=entity_name.replace('_', ' '), + action=fallback_result["tool"], + target_type="entity", + target=entity_name.replace("_", " "), entity_type=None, - parameters={} + parameters={}, ) logger.info("Fallback parsing succeeded") @@ -629,64 +607,63 @@ async def test_connection(config: Dict[str, Any]) -> Dict[str, Any]: try: # Validate required config fields - required_fields = ['ha_url', 'ha_token'] + required_fields = ["ha_url", "ha_token"] missing_fields = [field for field in required_fields if not config.get(field)] if missing_fields: return { "success": False, "message": f"Missing required fields: {', '.join(missing_fields)}", - "status": "error" + "status": "error", } - ha_url = config.get('ha_url') - ha_token = config.get('ha_token') - timeout = config.get('timeout', 30) + ha_url = config.get("ha_url") + ha_token = config.get("ha_token") + timeout = config.get("timeout", 30) # Create temporary MCP client - mcp_client = HAMCPClient( - base_url=ha_url, - token=ha_token, - timeout=timeout - ) - - # Test API connectivity with Template API - logger.info(f"Testing Home Assistant API connection to {ha_url}...") - start_time = time.time() + mcp_client = HAMCPClient(base_url=ha_url, token=ha_token, timeout=timeout) - test_result = await mcp_client._render_template("{{ 1 + 1 }}") - connection_time_ms = int((time.time() - start_time) * 1000) - - if str(test_result).strip() != "2": - return { - "success": False, - "message": f"Unexpected template result: {test_result}", - "status": "error" - } - - # Try to fetch entities count for additional info try: - entities = await mcp_client.get_all_entities() - entity_count = len(entities) - except Exception: - entity_count = None + # Test API connectivity with Template API + logger.info(f"Testing Home Assistant API connection to {ha_url}...") + start_time = time.time() + + test_result = await mcp_client._render_template("{{ 1 + 1 }}") + connection_time_ms = int((time.time() - start_time) * 1000) + + if str(test_result).strip() != "2": + return { + "success": False, + "message": f"Unexpected template result: {test_result}", + "status": "error", + } + + # Try to fetch entities count for additional info + try: + entities = await mcp_client.discover_entities() + entity_count = len(entities) + except Exception: + entity_count = None - return { - "success": True, - "message": f"Successfully connected to Home Assistant at {ha_url}", - "status": "success", - "details": { - "ha_url": ha_url, - "connection_time_ms": connection_time_ms, - "entity_count": entity_count, - "api_test": "Template rendering successful" + return { + "success": True, + "message": f"Successfully connected to Home Assistant at {ha_url}", + "status": "success", + "details": { + "ha_url": ha_url, + "connection_time_ms": connection_time_ms, + "entity_count": entity_count, + "api_test": "Template rendering successful", + }, } - } + finally: + await mcp_client.close() except Exception as e: logger.error(f"Home Assistant connection test failed: {e}", exc_info=True) return { "success": False, "message": f"Connection test failed: {str(e)}", - "status": "error" + "status": "error", } diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py index 8c5b5389..1a4e545f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py @@ -6,7 +6,8 @@ import logging import uuid -from typing import List, Dict, Any, Optional +from typing import Any, Dict, List, Optional + import httpx memory_logger = logging.getLogger("memory_service") @@ -14,12 +15,12 @@ class MCPClient: """Client for communicating with OpenMemory servers. - + Uses the official OpenMemory REST API: - POST /api/v1/memories - Create new memory - GET /api/v1/memories - List memories - DELETE /api/v1/memories - Delete memories - + Attributes: server_url: Base URL of the OpenMemory server (default: http://localhost:8765) client_name: Client identifier for memory tagging @@ -27,8 +28,15 @@ class MCPClient: timeout: Request timeout in seconds client: HTTP client instance """ - - def __init__(self, server_url: str, client_name: str = "chronicle", user_id: str = "default", user_email: str = "", timeout: int = 30): + + def __init__( + self, + server_url: str, + client_name: str = "chronicle", + user_id: str = "default", + user_email: str = "", + timeout: int = 30, + ): """Initialize client for OpenMemory. Args: @@ -38,43 +46,44 @@ def __init__(self, server_url: str, client_name: str = "chronicle", user_id: str user_email: User email address for user metadata timeout: HTTP request timeout in seconds """ - self.server_url = server_url.rstrip('/') + self.server_url = server_url.rstrip("/") self.client_name = client_name self.user_id = user_id self.user_email = user_email self.timeout = timeout - + # Use custom CA certificate if available import os - ca_bundle = os.getenv('REQUESTS_CA_BUNDLE') + + ca_bundle = os.getenv("REQUESTS_CA_BUNDLE") verify = ca_bundle if ca_bundle and os.path.exists(ca_bundle) else True - + self.client = httpx.AsyncClient(timeout=timeout, verify=verify) - + async def close(self): """Close the HTTP client.""" await self.client.aclose() - + async def __aenter__(self): return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() - + async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List[str]: """Add memories to the OpenMemory server. - + Uses the REST API to create memories. OpenMemory will handle: - Memory extraction from text - Deduplication - Vector embedding and storage - + Args: text: Memory text to store - + Returns: List of created memory IDs - + Raises: MCPError: If the server request fails """ @@ -113,7 +122,7 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List default_metadata = { "source": "chronicle", "client": self.client_name, - "user_email": self.user_email + "user_email": self.user_email, } if metadata: default_metadata.update(metadata) @@ -125,58 +134,63 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List "text": text, "app": self.client_name, "metadata": default_metadata, - "infer": True + "infer": True, } - memory_logger.info(f"POSTing memory to {self.server_url}/api/v1/memories/ with payload={payload}") - - response = await self.client.post( - f"{self.server_url}/api/v1/memories/", - json=payload + memory_logger.info( + f"POSTing memory to {self.server_url}/api/v1/memories/ " + f"(user_id={self.user_id}, text_len={len(text)}, metadata_keys={list(default_metadata.keys())})" ) + memory_logger.debug(f"Full payload: {payload}") + + response = await self.client.post(f"{self.server_url}/api/v1/memories/", json=payload) response_body = response.text[:500] if response.status_code != 200 else "..." - memory_logger.info(f"OpenMemory response: status={response.status_code}, body={response_body}, headers={dict(response.headers)}") + memory_logger.info( + f"OpenMemory response: status={response.status_code}, body={response_body}, headers={dict(response.headers)}" + ) response.raise_for_status() - + result = response.json() - + # Handle None result - OpenMemory returns None when no memory is created # (due to deduplication, insufficient content, etc.) if result is None: - memory_logger.info("OpenMemory returned None - no memory created (likely deduplication)") + memory_logger.info( + "OpenMemory returned None - no memory created (likely deduplication)" + ) return [] - + # Handle error response if isinstance(result, dict) and "error" in result: memory_logger.error(f"OpenMemory error: {result['error']}") return [] - + # Extract memory ID from response if isinstance(result, dict): memory_id = result.get("id") or str(uuid.uuid4()) return [memory_id] elif isinstance(result, list): return [str(item.get("id", uuid.uuid4())) for item in result] - + # Default success response return [str(uuid.uuid4())] - + except httpx.HTTPError as e: memory_logger.error(f"HTTP error adding memories: {e}") raise MCPError(f"HTTP error: {e}") except Exception as e: memory_logger.error(f"Error adding memories: {e}") raise MCPError(f"Failed to add memories: {e}") - + async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: """Search for memories using semantic similarity. - + Args: query: Search query text limit: Maximum number of results to return - + Returns: List of memory dictionaries with content and metadata """ @@ -185,38 +199,34 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any apps_response = await self.client.get(f"{self.server_url}/api/v1/apps/") apps_response.raise_for_status() apps_data = apps_response.json() - + if not apps_data.get("apps") or len(apps_data["apps"]) == 0: memory_logger.warning("No apps found in OpenMemory MCP for search") return [] - + # Find the app matching our client name, or use first app as fallback app_id = None for app in apps_data["apps"]: if app["name"] == self.client_name: app_id = app["id"] break - + if not app_id: - memory_logger.warning(f"App '{self.client_name}' not found, using first available app") + memory_logger.warning( + f"App '{self.client_name}' not found, using first available app" + ) app_id = apps_data["apps"][0]["id"] - + # Use app-specific memories endpoint with search - params = { - "user_id": self.user_id, - "search_query": query, - "page": 1, - "size": limit - } - + params = {"user_id": self.user_id, "search_query": query, "page": 1, "size": limit} + response = await self.client.get( - f"{self.server_url}/api/v1/apps/{app_id}/memories", - params=params + f"{self.server_url}/api/v1/apps/{app_id}/memories", params=params ) response.raise_for_status() - + result = response.json() - + # Extract memories from app-specific response format if isinstance(result, dict) and "memories" in result: memories = result["memories"] @@ -224,30 +234,32 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any memories = result else: memories = [] - + # Format memories for Chronicle formatted_memories = [] for memory in memories: - formatted_memories.append({ - "id": memory.get("id", str(uuid.uuid4())), - "content": memory.get("content", "") or memory.get("text", ""), - "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), - "created_at": memory.get("created_at"), - "score": memory.get("score", 0.0) # No score from list API - }) - + formatted_memories.append( + { + "id": memory.get("id", str(uuid.uuid4())), + "content": memory.get("content", "") or memory.get("text", ""), + "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), + "created_at": memory.get("created_at"), + "score": memory.get("score", 0.0), # No score from list API + } + ) + return formatted_memories[:limit] - + except Exception as e: memory_logger.error(f"Error searching memories: {e}") return [] - + async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]: """List all memories for the current user. - + Args: limit: Maximum number of memories to return - + Returns: List of memory dictionaries """ @@ -256,37 +268,34 @@ async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]: apps_response = await self.client.get(f"{self.server_url}/api/v1/apps/") apps_response.raise_for_status() apps_data = apps_response.json() - + if not apps_data.get("apps") or len(apps_data["apps"]) == 0: memory_logger.warning("No apps found in OpenMemory MCP") return [] - + # Find the app matching our client name, or use first app as fallback app_id = None for app in apps_data["apps"]: if app["name"] == self.client_name: app_id = app["id"] break - + if not app_id: - memory_logger.warning(f"App '{self.client_name}' not found, using first available app") + memory_logger.warning( + f"App '{self.client_name}' not found, using first available app" + ) app_id = apps_data["apps"][0]["id"] - + # Use app-specific memories endpoint - params = { - "user_id": self.user_id, - "page": 1, - "size": limit - } - + params = {"user_id": self.user_id, "page": 1, "size": limit} + response = await self.client.get( - f"{self.server_url}/api/v1/apps/{app_id}/memories", - params=params + f"{self.server_url}/api/v1/apps/{app_id}/memories", params=params ) response.raise_for_status() - + result = response.json() - + # Extract memories from app-specific response format if isinstance(result, dict) and "memories" in result: memories = result["memories"] @@ -294,29 +303,31 @@ async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]: memories = result else: memories = [] - + # Format memories formatted_memories = [] for memory in memories: - formatted_memories.append({ - "id": memory.get("id", str(uuid.uuid4())), - "content": memory.get("content", "") or memory.get("text", ""), - "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), - "created_at": memory.get("created_at") - }) - + formatted_memories.append( + { + "id": memory.get("id", str(uuid.uuid4())), + "content": memory.get("content", "") or memory.get("text", ""), + "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), + "created_at": memory.get("created_at"), + } + ) + return formatted_memories - + except Exception as e: memory_logger.error(f"Error listing memories: {e}") return [] - + async def delete_all_memories(self) -> int: """Delete all memories for the current user. - + Note: OpenMemory may not support bulk delete via REST API. This is typically done through MCP tools for safety. - + Returns: Number of memories that were deleted """ @@ -325,31 +336,29 @@ async def delete_all_memories(self) -> int: memories = await self.list_memories(limit=1000) if not memories: return 0 - + memory_ids = [m["id"] for m in memories] - + # Delete memories using the batch delete endpoint response = await self.client.request( "DELETE", f"{self.server_url}/api/v1/memories/", - json={ - "memory_ids": memory_ids, - "user_id": self.user_id - } + json={"memory_ids": memory_ids, "user_id": self.user_id}, ) response.raise_for_status() - + result = response.json() - + # Extract count from response if isinstance(result, dict): if "message" in result: # Parse message like "Successfully deleted 5 memories" import re - match = re.search(r'(\d+)', result["message"]) + + match = re.search(r"(\d+)", result["message"]) return int(match.group(1)) if match else len(memory_ids) return result.get("deleted_count", len(memory_ids)) - + return len(memory_ids) except Exception as e: @@ -368,8 +377,7 @@ async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]: try: # Use the memories endpoint with specific ID response = await self.client.get( - f"{self.server_url}/api/v1/memories/{memory_id}", - params={"user_id": self.user_id} + f"{self.server_url}/api/v1/memories/{memory_id}", params={"user_id": self.user_id} ) if response.status_code == 404: @@ -403,7 +411,7 @@ async def update_memory( self, memory_id: str, content: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ) -> bool: """Update a specific memory's content and/or metadata. @@ -431,8 +439,7 @@ async def update_memory( # Use PUT to update memory response = await self.client.put( - f"{self.server_url}/api/v1/memories/{memory_id}", - json=update_data + f"{self.server_url}/api/v1/memories/{memory_id}", json=update_data ) response.raise_for_status() @@ -446,12 +453,14 @@ async def update_memory( memory_logger.error(f"Error updating memory: {e}") return False - async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None) -> bool: + async def delete_memory( + self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + ) -> bool: """Delete a specific memory by ID. - + Args: memory_id: ID of the memory to delete - + Returns: True if deletion succeeded, False otherwise """ @@ -459,21 +468,18 @@ async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, use response = await self.client.request( "DELETE", f"{self.server_url}/api/v1/memories/", - json={ - "memory_ids": [memory_id], - "user_id": self.user_id - } + json={"memory_ids": [memory_id], "user_id": self.user_id}, ) response.raise_for_status() return True - + except Exception as e: memory_logger.warning(f"Error deleting memory {memory_id}: {e}") return False - + async def test_connection(self) -> bool: """Test connection to the OpenMemory server. - + Returns: True if server is reachable and responsive, False otherwise """ @@ -484,16 +490,23 @@ async def test_connection(self) -> bool: try: response = await self.client.get( f"{self.server_url}{endpoint}", - params={"user_id": self.user_id, "page": 1, "size": 1} - if endpoint == "/api/v1/memories" else {} + params=( + {"user_id": self.user_id, "page": 1, "size": 1} + if endpoint == "/api/v1/memories" + else {} + ), ) - if response.status_code in [200, 404, 422]: # 404/422 means endpoint exists but params wrong + if response.status_code in [ + 200, + 404, + 422, + ]: # 404/422 means endpoint exists but params wrong return True except: continue - + return False - + except Exception as e: memory_logger.error(f"OpenMemory server connection test failed: {e}") return False @@ -501,4 +514,5 @@ async def test_connection(self) -> bool: class MCPError(Exception): """Exception raised for MCP server communication errors.""" - pass \ No newline at end of file + + pass diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py index 922f2555..d5061a2c 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py @@ -10,9 +10,9 @@ import os import time import uuid -from typing import Optional, List, Tuple, Any, Dict +from typing import Any, Dict, List, Optional, Tuple -from ..base import MemoryServiceBase, MemoryEntry +from ..base import MemoryEntry, MemoryServiceBase from .mcp_client import MCPClient, MCPError memory_logger = logging.getLogger("memory_service") @@ -24,14 +24,14 @@ class OpenMemoryMCPService(MemoryServiceBase): This class implements the MemoryServiceBase interface by delegating memory operations to an OpenMemory MCP server. It handles the translation between Chronicle's memory service API and the standardized MCP operations. - + Key features: - Maintains compatibility with existing MemoryServiceBase interface - Leverages OpenMemory MCP's deduplication and processing - - Supports transcript-based memory extraction + - Supports transcript-based memory extraction - Provides user isolation and metadata management - Handles memory search and CRUD operations - + Attributes: server_url: URL of the OpenMemory MCP server timeout: Request timeout in seconds @@ -39,7 +39,7 @@ class OpenMemoryMCPService(MemoryServiceBase): mcp_client: Client for communicating with MCP server _initialized: Whether the service has been initialized """ - + def __init__( self, server_url: Optional[str] = None, @@ -67,42 +67,42 @@ def __init__( self.user_id = user_id or os.getenv("OPENMEMORY_USER_ID", "default") self.timeout = int(timeout or os.getenv("OPENMEMORY_TIMEOUT", "30")) self.mcp_client: Optional[MCPClient] = None - + async def initialize(self) -> None: """Initialize the OpenMemory MCP service. - + Sets up the MCP client connection and tests connectivity to ensure the service is ready for memory operations. - + Raises: RuntimeError: If initialization or connection test fails """ if self._initialized: return - + try: self.mcp_client = MCPClient( server_url=self.server_url, client_name=self.client_name, user_id=self.user_id, - timeout=self.timeout + timeout=self.timeout, ) - + # Test connection to OpenMemory MCP server is_connected = await self.mcp_client.test_connection() if not is_connected: raise RuntimeError(f"Cannot connect to OpenMemory MCP server at {self.server_url}") - + self._initialized = True memory_logger.info( f"✅ OpenMemory MCP service initialized successfully at {self.server_url} " f"(client: {self.client_name}, user: {self.user_id})" ) - + except Exception as e: memory_logger.error(f"OpenMemory MCP service initialization failed: {e}") raise RuntimeError(f"Initialization failed: {e}") - + async def add_memory( self, transcript: str, @@ -111,14 +111,14 @@ async def add_memory( user_id: str, user_email: str, allow_update: bool = False, - db_helper: Any = None + db_helper: Any = None, ) -> Tuple[bool, List[str]]: """Add memories extracted from a transcript. - + Processes a transcript to extract meaningful memories and stores them in the OpenMemory MCP server. Can either extract memories locally first or send the raw transcript to MCP for processing. - + Args: transcript: Raw transcript text to extract memories from client_id: Client identifier for tracking @@ -127,74 +127,78 @@ async def add_memory( user_email: User email address allow_update: Whether to allow updating existing memories (Note: MCP may handle this internally) db_helper: Optional database helper for relationship tracking - + Returns: Tuple of (success: bool, created_memory_ids: List[str]) - + Raises: MCPError: If MCP server communication fails """ await self._ensure_initialized() - + try: # Skip empty transcripts if not transcript or len(transcript.strip()) < 10: memory_logger.info(f"Skipping empty transcript for {source_id}") return True, [] - + # Use configured OpenMemory user (from config) for all Chronicle users # Chronicle user_id and email are stored in metadata for filtering enriched_transcript = f"[Source: {source_id}, Client: {client_id}] {transcript}" - memory_logger.info(f"Delegating memory processing to OpenMemory for user {user_id} (email: {user_email}), source {source_id}") + memory_logger.info( + f"Delegating memory processing to OpenMemory for user {user_id}, source {source_id}" + ) # Pass Chronicle user details in metadata for filtering/search metadata = { "chronicle_user_id": user_id, "chronicle_user_email": user_email, "source_id": source_id, - "client_id": client_id + "client_id": client_id, } - memory_ids = await self.mcp_client.add_memories(text=enriched_transcript, metadata=metadata) - + memory_ids = await self.mcp_client.add_memories( + text=enriched_transcript, metadata=metadata + ) + # Update database relationships if helper provided if memory_ids and db_helper: await self._update_database_relationships(db_helper, source_id, memory_ids) - + if memory_ids: - memory_logger.info(f"✅ OpenMemory MCP processed memory for {source_id}: {len(memory_ids)} memories") + memory_logger.info( + f"✅ OpenMemory MCP processed memory for {source_id}: {len(memory_ids)} memories" + ) return True, memory_ids - + # NOOP due to deduplication is SUCCESS, not failure - memory_logger.info(f"✅ OpenMemory MCP processed {source_id}: no new memories needed (likely deduplication)") + memory_logger.info( + f"✅ OpenMemory MCP processed {source_id}: no new memories needed (likely deduplication)" + ) return True, [] - + except MCPError as e: memory_logger.error(f"❌ OpenMemory MCP error for {source_id}: {e}") raise e except Exception as e: memory_logger.error(f"❌ OpenMemory MCP service failed for {source_id}: {e}") raise e - + async def search_memories( - self, - query: str, - user_id: str, - limit: int = 10, - score_threshold: float = 0.0 + self, query: str, user_id: str, limit: int = 10, score_threshold: float = 0.0 ) -> List[MemoryEntry]: """Search memories using semantic similarity. - + Uses the OpenMemory MCP server to perform semantic search across stored memories for the specified user. - + Args: query: Search query text user_id: User identifier to filter memories limit: Maximum number of results to return score_threshold: Minimum similarity score (ignored - OpenMemory MCP server controls filtering) - + Returns: List of matching MemoryEntry objects ordered by relevance """ @@ -206,8 +210,7 @@ async def search_memories( try: # Get more results since we'll filter by user results = await self.mcp_client.search_memory( - query=query, - limit=limit * 3 # Get extra to account for filtering + query=query, limit=limit * 3 # Get extra to account for filtering ) # Convert MCP results to MemoryEntry objects and filter by user @@ -222,7 +225,9 @@ async def search_memories( if len(memory_entries) >= limit: break # Got enough results - memory_logger.info(f"🔍 Found {len(memory_entries)} memories for query '{query}' (user: {user_id})") + memory_logger.info( + f"🔍 Found {len(memory_entries)} memories for query '{query}' (user: {user_id})" + ) return memory_entries except MCPError as e: @@ -231,21 +236,17 @@ async def search_memories( except Exception as e: memory_logger.error(f"Search memories failed: {e}") return [] - - async def get_all_memories( - self, - user_id: str, - limit: int = 100 - ) -> List[MemoryEntry]: + + async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: """Get all memories for a specific user. - + Retrieves all stored memories for the given user without similarity filtering. - + Args: user_id: User identifier limit: Maximum number of memories to return - + Returns: List of MemoryEntry objects for the user """ @@ -280,7 +281,9 @@ async def get_all_memories( memory_logger.error(f"Get all memories failed: {e}") return [] - async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Optional[MemoryEntry]: + async def get_memory( + self, memory_id: str, user_id: Optional[str] = None + ) -> Optional[MemoryEntry]: """Get a specific memory by ID. Args: @@ -323,7 +326,7 @@ async def update_memory( content: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None, - user_email: Optional[str] = None + user_email: Optional[str] = None, ) -> bool: """Update a specific memory's content and/or metadata. @@ -346,9 +349,7 @@ async def update_memory( try: success = await self.mcp_client.update_memory( - memory_id=memory_id, - content=content, - metadata=metadata + memory_id=memory_id, content=content, metadata=metadata ) if success: @@ -362,18 +363,20 @@ async def update_memory( # Restore original user_id self.mcp_client.user_id = original_user_id - async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None) -> bool: + async def delete_memory( + self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + ) -> bool: """Delete a specific memory by ID. - + Args: memory_id: Unique identifier of the memory to delete - + Returns: True if successfully deleted, False otherwise """ if not self._initialized: await self.initialize() - + try: success = await self.mcp_client.delete_memory(memory_id) if success: @@ -382,38 +385,38 @@ async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, use except Exception as e: memory_logger.error(f"Delete memory failed: {e}") return False - + async def delete_all_user_memories(self, user_id: str) -> int: """Delete all memories for a specific user. - + Args: user_id: User identifier - + Returns: Number of memories that were deleted """ if not self._initialized: await self.initialize() - + # Update MCP client user context for this operation original_user_id = self.mcp_client.user_id self.mcp_client.user_id = self.user_id # Use configured user ID - + try: count = await self.mcp_client.delete_all_memories() memory_logger.info(f"🗑️ Deleted {count} memories for user {user_id} via OpenMemory MCP") return count - + except Exception as e: memory_logger.error(f"Delete user memories failed: {e}") return 0 finally: # Restore original user_id self.mcp_client.user_id = original_user_id - + async def test_connection(self) -> bool: """Test if the memory service and its dependencies are working. - + Returns: True if all connections are healthy, False otherwise """ @@ -424,7 +427,7 @@ async def test_connection(self) -> bool: except Exception as e: memory_logger.error(f"Connection test failed: {e}") return False - + def shutdown(self) -> None: """Shutdown the memory service and clean up resources.""" if self.mcp_client: @@ -433,56 +436,73 @@ def shutdown(self) -> None: self._initialized = False self.mcp_client = None memory_logger.info("OpenMemory MCP service shut down") - + # Private helper methods - + def _ensure_client(self) -> MCPClient: """Ensure MCP client is available and return it.""" if self.mcp_client is None: raise RuntimeError("OpenMemory MCP client not initialized") return self.mcp_client - - def _mcp_result_to_memory_entry(self, mcp_result: Dict[str, Any], user_id: str) -> Optional[MemoryEntry]: + + def _mcp_result_to_memory_entry( + self, mcp_result: Dict[str, Any], user_id: str + ) -> Optional[MemoryEntry]: """Convert OpenMemory MCP server result to MemoryEntry object. - + Args: mcp_result: Result dictionary from OpenMemory MCP server user_id: User identifier to include in metadata - + Returns: MemoryEntry object or None if conversion fails """ try: # OpenMemory MCP results may have different formats, adapt as needed - memory_id = mcp_result.get('id', str(uuid.uuid4())) - content = mcp_result.get('content', '') or mcp_result.get('memory', '') or mcp_result.get('text', '') or mcp_result.get('data', '') - + memory_id = mcp_result.get("id", str(uuid.uuid4())) + content = ( + mcp_result.get("content", "") + or mcp_result.get("memory", "") + or mcp_result.get("text", "") + or mcp_result.get("data", "") + ) + if not content: memory_logger.warning(f"Empty content in MCP result: {mcp_result}") return None - + # Build metadata with OpenMemory context - metadata = mcp_result.get('metadata', {}) + metadata = mcp_result.get("metadata", {}) if not metadata: metadata = {} - + # Ensure we have user context - metadata.update({ - 'user_id': user_id, - 'source': 'openmemory_mcp', - 'client_name': self.client_name, - 'mcp_server': self.server_url - }) - + metadata.update( + { + "user_id": user_id, + "source": "openmemory_mcp", + "client_name": self.client_name, + "mcp_server": self.server_url, + } + ) + # Extract similarity score if available (for search results) - score = mcp_result.get('score') or mcp_result.get('similarity') or mcp_result.get('relevance') + score = ( + mcp_result.get("score") + or mcp_result.get("similarity") + or mcp_result.get("relevance") + ) # Extract timestamps - created_at = mcp_result.get('created_at') or mcp_result.get('timestamp') or mcp_result.get('date') + created_at = ( + mcp_result.get("created_at") + or mcp_result.get("timestamp") + or mcp_result.get("date") + ) if created_at is None: created_at = str(int(time.time())) - updated_at = mcp_result.get('updated_at') or mcp_result.get('modified_at') + updated_at = mcp_result.get("updated_at") or mcp_result.get("modified_at") if updated_at is None: updated_at = str(created_at) # Default to created_at if not provided @@ -493,21 +513,18 @@ def _mcp_result_to_memory_entry(self, mcp_result: Dict[str, Any], user_id: str) embedding=None, # OpenMemory MCP server handles embeddings internally score=score, created_at=str(created_at), - updated_at=str(updated_at) + updated_at=str(updated_at), ) - + except Exception as e: memory_logger.error(f"Failed to convert MCP result to MemoryEntry: {e}") return None - + async def _update_database_relationships( - self, - db_helper: Any, - source_id: str, - created_ids: List[str] + self, db_helper: Any, source_id: str, created_ids: List[str] ) -> None: """Update database relationships for created memories. - + Args: db_helper: Database helper instance source_id: Source session identifier @@ -517,4 +534,4 @@ async def _update_database_relationships( try: await db_helper.add_memory_reference(source_id, memory_id, "created") except Exception as db_error: - memory_logger.error(f"Database relationship update failed: {db_error}") \ No newline at end of file + memory_logger.error(f"Database relationship update failed: {db_error}") diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index 0ce9d77e..e54f3393 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -5,35 +5,44 @@ """ import asyncio -import os import logging +import os import time import uuid from datetime import datetime from pathlib import Path -from typing import Dict, Any -from rq import get_current_job -from rq.job import Job -from rq.exceptions import NoSuchJobError +from typing import Any, Dict -from advanced_omi_backend.models.job import JobPriority, BaseRQJob, async_job -from advanced_omi_backend.models.conversation import Conversation -from advanced_omi_backend.models.audio_chunk import AudioChunkDocument from beanie.operators import In +from rq import get_current_job +from rq.exceptions import NoSuchJobError +from rq.job import Job +from advanced_omi_backend.config import get_backend_config from advanced_omi_backend.controllers.queue_controller import ( - transcription_queue, - redis_conn, JOB_RESULT_TTL, REDIS_URL, + redis_conn, start_post_conversation_jobs, + transcription_queue, ) -from advanced_omi_backend.utils.conversation_utils import analyze_speech, mark_conversation_deleted -from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_wav_from_conversation, convert_audio_to_chunks -from advanced_omi_backend.services.plugin_service import get_plugin_router -from advanced_omi_backend.services.transcription import get_transcription_provider, is_transcription_available +from advanced_omi_backend.models.audio_chunk import AudioChunkDocument +from advanced_omi_backend.models.conversation import Conversation +from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator -from advanced_omi_backend.config import get_backend_config +from advanced_omi_backend.services.plugin_service import get_plugin_router +from advanced_omi_backend.services.transcription import ( + get_transcription_provider, + is_transcription_available, +) +from advanced_omi_backend.utils.audio_chunk_utils import ( + convert_audio_to_chunks, + reconstruct_wav_from_conversation, +) +from advanced_omi_backend.utils.conversation_utils import ( + analyze_speech, + mark_conversation_deleted, +) logger = logging.getLogger(__name__) @@ -63,7 +72,9 @@ async def apply_speaker_recognition( Updated list of segments with identified speakers """ try: - from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient + from advanced_omi_backend.speaker_recognition_client import ( + SpeakerRecognitionClient, + ) speaker_client = SpeakerRecognitionClient() if not speaker_client.enabled: @@ -177,7 +188,7 @@ async def transcribe_full_audio_job( # Extract user_id and client_id for plugin context user_id = str(conversation.user_id) if conversation.user_id else None - client_id = conversation.client_id if hasattr(conversation, 'client_id') else None + client_id = conversation.client_id if hasattr(conversation, "client_id") else None # Get the transcription provider provider = get_transcription_provider(mode="batch") @@ -195,15 +206,14 @@ async def transcribe_full_audio_job( wav_data = await reconstruct_wav_from_conversation(conversation_id) logger.info( - f"📦 Reconstructed audio from MongoDB chunks: " - f"{len(wav_data) / 1024 / 1024:.2f} MB" + f"📦 Reconstructed audio from MongoDB chunks: " f"{len(wav_data) / 1024 / 1024:.2f} MB" ) # Transcribe the audio directly from memory (no disk I/O needed) transcription_result = await provider.transcribe( audio_data=wav_data, # Pass bytes directly, already in memory sample_rate=16000, - diarize=True + diarize=True, ) except ValueError as e: @@ -224,7 +234,9 @@ async def transcribe_full_audio_job( # Trigger transcript-level plugins BEFORE speech validation # This ensures wake-word commands execute even if conversation gets deleted - logger.info(f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}") + logger.info( + f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}" + ) if transcript_text: try: from advanced_omi_backend.services.plugin_service import init_plugin_router @@ -236,7 +248,9 @@ async def transcribe_full_audio_job( if not plugin_router: logger.info("🔧 Initializing plugin router in worker process...") plugin_router = init_plugin_router() - logger.info(f"🔧 After init, plugin_router: {plugin_router is not None}, plugins count: {len(plugin_router.plugins) if plugin_router else 0}") + logger.info( + f"🔧 After init, plugin_router: {plugin_router is not None}, plugins count: {len(plugin_router.plugins) if plugin_router else 0}" + ) # Initialize async plugins if plugin_router: @@ -245,18 +259,24 @@ async def transcribe_full_audio_job( await plugin.initialize() logger.info(f"✅ Plugin '{plugin_id}' initialized in worker") except Exception as e: - logger.exception(f"Failed to initialize plugin '{plugin_id}' in worker: {e}") + logger.exception( + f"Failed to initialize plugin '{plugin_id}' in worker: {e}" + ) - logger.info(f"🔍 DEBUG: Plugin router final check: {plugin_router is not None}, has {len(plugin_router.plugins) if plugin_router else 0} plugins") + logger.info( + f"🔍 DEBUG: Plugin router final check: {plugin_router is not None}, has {len(plugin_router.plugins) if plugin_router else 0} plugins" + ) if plugin_router: - logger.info(f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}") + logger.info( + f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}" + ) plugin_data = { - 'transcript': transcript_text, - 'segment_id': f"{conversation_id}_batch", - 'conversation_id': conversation_id, - 'segments': segments, - 'word_count': len(words), + "transcript": transcript_text, + "segment_id": f"{conversation_id}_batch", + "conversation_id": conversation_id, + "segments": segments, + "word_count": len(words), } logger.info( @@ -265,10 +285,10 @@ async def transcribe_full_audio_job( ) plugin_results = await plugin_router.dispatch_event( - event='transcript.batch', + event="transcript.batch", user_id=user_id, data=plugin_data, - metadata={'client_id': client_id} + metadata={"client_id": client_id}, ) logger.info( @@ -276,7 +296,9 @@ async def transcribe_full_audio_job( ) if plugin_results: - logger.info(f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode") + logger.info( + f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode" + ) for result in plugin_results: if result.message: logger.info(f" Plugin: {result.message}") @@ -332,7 +354,9 @@ async def transcribe_full_audio_job( logger.info(f"✅ Cancelled dependent job: {job_id}") except Exception as e: if isinstance(e, NoSuchJobError): - logger.debug(f"Job {job_id} hash not found (likely already completed or expired)") + logger.debug( + f"Job {job_id} hash not found (likely already completed or expired)" + ) else: logger.debug(f"Job {job_id} not found or already completed: {e}") @@ -363,8 +387,8 @@ async def transcribe_full_audio_job( processing_time = time.time() - start_time # Check if we should use provider segments as fallback - transcription_config = get_backend_config('transcription') - use_provider_segments = transcription_config.get('use_provider_segments', False) + transcription_config = get_backend_config("transcription") + use_provider_segments = transcription_config.get("use_provider_segments", False) # If flag enabled and provider returned segments, use them # Otherwise, speaker service will create segments via diarization @@ -375,7 +399,7 @@ async def transcribe_full_audio_job( speaker=str(seg.get("speaker", "0")), # Convert to string for Pydantic validation start=seg.get("start", 0.0), end=seg.get("end", 0.0), - text=seg.get("text", "") + text=seg.get("text", ""), ) for seg in segments ] @@ -399,7 +423,7 @@ async def transcribe_full_audio_job( word=w.get("word", ""), start=w.get("start", 0.0), end=w.get("end", 0.0), - confidence=w.get("confidence") + confidence=w.get("confidence"), ) for w in words ] @@ -409,7 +433,9 @@ async def transcribe_full_audio_job( "trigger": trigger, "audio_file_size": len(wav_data), "word_count": len(words), - "segments_created_by": "provider" if (use_provider_segments and segments) else "speaker_service", + "segments_created_by": ( + "provider" if (use_provider_segments and segments) else "speaker_service" + ), } conversation.add_transcript_version( @@ -528,9 +554,7 @@ async def transcribe_full_audio_job( async def create_audio_only_conversation( - session_id: str, - user_id: str, - client_id: str + session_id: str, user_id: str, client_id: str ) -> "Conversation": """ Create or reuse conversation for batch transcription fallback. @@ -544,7 +568,7 @@ async def create_audio_only_conversation( placeholder_conversation = await Conversation.find_one( Conversation.client_id == session_id, Conversation.always_persist == True, - In(Conversation.processing_status, ["pending_transcription", "transcription_failed"]) + In(Conversation.processing_status, ["pending_transcription", "transcription_failed"]), ) if placeholder_conversation: @@ -582,20 +606,13 @@ async def create_audio_only_conversation( ) await conversation.insert() - logger.info( - f"✅ Created batch transcription conversation {session_id[:12]} for fallback" - ) + logger.info(f"✅ Created batch transcription conversation {session_id[:12]} for fallback") return conversation @async_job(redis=True, beanie=True) async def transcription_fallback_check_job( - session_id: str, - user_id: str, - client_id: str, - timeout_seconds: int = 1800, - *, - redis_client=None + session_id: str, user_id: str, client_id: str, timeout_seconds: int = 1800, *, redis_client=None ) -> Dict[str, Any]: """ Check if streaming transcription succeeded, fallback to batch if needed. @@ -617,9 +634,7 @@ async def transcription_fallback_check_job( logger.info(f"🔍 Checking transcription status for session {session_id[:12]}") # Find conversation by session_id (client_id for streaming sessions) - conversation = await Conversation.find_one( - Conversation.client_id == session_id - ) + conversation = await Conversation.find_one(Conversation.client_id == session_id) # Check if transcript exists (streaming succeeded) if conversation and conversation.active_transcript and conversation.transcript: @@ -630,7 +645,7 @@ async def transcription_fallback_check_job( return { "status": "pass_through", "transcript_source": "streaming", - "conversation_id": conversation.conversation_id + "conversation_id": conversation.conversation_id, } # No transcript → Trigger batch fallback @@ -674,7 +689,7 @@ async def transcription_fallback_check_job( "status": "skipped", "reason": "no_audio", "message": "No audio was captured for this session", - "session_id": session_id + "session_id": session_id, } logger.info( @@ -718,7 +733,7 @@ async def transcription_fallback_check_job( "status": "skipped", "reason": "no_matching_audio", "message": "No audio matched this session in Redis stream", - "session_id": session_id + "session_id": session_id, } # Combine audio chunks in order @@ -769,7 +784,7 @@ async def transcription_fallback_check_job( job_timeout=1800, job_id=f"transcribe_{conversation.conversation_id[:12]}", description=f"Batch transcription fallback for {session_id[:8]}", - meta={"session_id": session_id, "client_id": client_id} + meta={"session_id": session_id, "client_id": client_id}, ) logger.info(f"🔄 Enqueued batch transcription fallback job {batch_job.id}") @@ -779,10 +794,11 @@ async def transcription_fallback_check_job( waited = 0 while waited < max_wait: batch_job.refresh() + # Check is_failed BEFORE is_finished - a failed job is also "finished" + if batch_job.is_failed: + raise RuntimeError(f"Batch transcription failed: {batch_job.exc_info}") if batch_job.is_finished: - if batch_job.is_failed: - raise Exception(f"Batch transcription failed: {batch_job.exc_info}") - logger.info(f"✅ Batch transcription completed successfully") + logger.info("✅ Batch transcription completed successfully") break await asyncio.sleep(2) waited += 2 @@ -797,7 +813,7 @@ async def transcription_fallback_check_job( transcript_version_id=version_id, depends_on_job=None, # Batch already completed (we waited for it) client_id=client_id, - end_reason="batch_fallback" + end_reason="batch_fallback", ) logger.info( @@ -810,7 +826,7 @@ async def transcription_fallback_check_job( "transcript_source": "batch", "conversation_id": conversation.conversation_id, "batch_job_id": batch_job.id, - "post_job_ids": post_jobs + "post_job_ids": post_jobs, } @@ -876,7 +892,9 @@ async def stream_speech_detection_job( # Track when session closes for graceful shutdown session_closed_at = None - final_check_grace_period = 15 # Wait up to 15 seconds for final transcription after session closes + final_check_grace_period = ( + 15 # Wait up to 15 seconds for final transcription after session closes + ) last_speech_analysis = None # Track last analysis for detailed logging # Main loop: Listen for speech @@ -894,7 +912,9 @@ async def stream_speech_detection_job( if session_closed and session_closed_at is None: # Session just closed - start grace period for final transcription session_closed_at = time.time() - logger.info(f"🛑 Session closed, waiting up to {final_check_grace_period}s for final transcription results...") + logger.info( + f"🛑 Session closed, waiting up to {final_check_grace_period}s for final transcription results..." + ) # Exit if grace period expired without speech if session_closed_at and (time.time() - session_closed_at) > final_check_grace_period: @@ -922,7 +942,9 @@ async def stream_speech_detection_job( grace_elapsed = time.time() - session_closed_at if grace_elapsed > 5 and not combined.get("chunk_count", 0): # 5+ seconds with no transcription activity at all - likely API key issue - logger.error(f"❌ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue") + logger.error( + f"❌ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue" + ) logger.error(f"❌ Session failed - check transcription service configuration") break @@ -947,7 +969,9 @@ async def stream_speech_detection_job( ) if not speech_analysis.get("has_speech", False): - logger.info(f"⏳ Waiting for more speech - {speech_analysis.get('reason', 'unknown reason')}") + logger.info( + f"⏳ Waiting for more speech - {speech_analysis.get('reason', 'unknown reason')}" + ) await asyncio.sleep(2) continue @@ -1123,10 +1147,14 @@ async def stream_speech_detection_job( } # Session ended without speech - reason = last_speech_analysis.get('reason', 'No transcription received') if last_speech_analysis else 'No transcription received' + reason = ( + last_speech_analysis.get("reason", "No transcription received") + if last_speech_analysis + else "No transcription received" + ) # Distinguish between transcription failures (error) vs legitimate no speech (info) - if reason == 'No transcription received': + if reason == "No transcription received": logger.error( f"❌ Session failed - transcription service did not respond\n" f" Reason: {reason}\n" @@ -1149,11 +1177,13 @@ async def stream_speech_detection_job( conversation = await Conversation.find_one( Conversation.client_id == session_id, Conversation.always_persist == True, - Conversation.processing_status == "pending_transcription" + Conversation.processing_status == "pending_transcription", ) if conversation: - logger.info(f"🔴 Found always_persist placeholder conversation {conversation.conversation_id} for failed session {session_id[:12]}") + logger.info( + f"🔴 Found always_persist placeholder conversation {conversation.conversation_id} for failed session {session_id[:12]}" + ) # Update conversation with failure status conversation.processing_status = "transcription_failed" @@ -1162,9 +1192,13 @@ async def stream_speech_detection_job( await conversation.save() - logger.warning(f"🔴 Marked conversation {conversation.conversation_id} as transcription_failed") + logger.warning( + f"🔴 Marked conversation {conversation.conversation_id} as transcription_failed" + ) else: - logger.info(f"ℹ️ No always_persist placeholder conversation found for session {session_id[:12]}") + logger.info( + f"ℹ️ No always_persist placeholder conversation found for session {session_id[:12]}" + ) # Enqueue fallback check job for failed streaming sessions # This will attempt batch transcription as a fallback @@ -1174,10 +1208,10 @@ async def stream_speech_detection_job( user_id, client_id, timeout_seconds=1800, # 30 minutes for batch transcription - job_timeout=2400, # 40 minutes job timeout + job_timeout=2400, # 40 minutes job timeout job_id=f"fallback_check_{session_id[:12]}", description=f"Transcription fallback check for {session_id[:8]} (no speech)", - meta={"session_id": session_id, "client_id": client_id, "no_speech": True} + meta={"session_id": session_id, "client_id": client_id, "no_speech": True}, ) logger.info( diff --git a/tests/configs/mock-services.yml b/tests/configs/mock-services.yml index dd6a097c..6735cd61 100644 --- a/tests/configs/mock-services.yml +++ b/tests/configs/mock-services.yml @@ -51,7 +51,7 @@ models: description: Mock STT for testing (batch) model_provider: mock model_type: stt - model_url: http://localhost:9999 + model_url: http://host.docker.internal:9999 name: mock-stt operations: stt_transcribe: From ea0c591229259fa18c862b9a114111dd37649817 Mon Sep 17 00:00:00 2001 From: Ankush Malaker <43288948+AnkushMalaker@users.noreply.github.com> Date: Wed, 28 Jan 2026 19:14:29 +0000 Subject: [PATCH 2/2] Add Recording Context and UI Enhancements - Introduced a new RecordingContext to manage audio recording state and functionality, including start/stop actions and duration tracking. - Updated various components to utilize the new RecordingContext, replacing previous audio recording hooks for improved consistency. - Added a GlobalRecordingIndicator component to display recording status across the application. - Enhanced the Layout component to include the GlobalRecordingIndicator for better user feedback during audio recording sessions. - Refactored audio-related components to accept the new RecordingContext type, ensuring type safety and clarity in props. - Implemented configuration options for managing provider segments in transcription, allowing for more flexible audio processing based on user settings. - Added raw segments JSON display in the Conversations page for better debugging and data visibility. --- .../services/transcription/__init__.py | 20 +- backends/advanced/webui/src/App.tsx | 7 +- .../src/components/audio/SimpleDebugPanel.tsx | 4 +- .../components/audio/SimplifiedControls.tsx | 4 +- .../src/components/audio/StatusDisplay.tsx | 4 +- .../layout/GlobalRecordingIndicator.tsx | 66 +++++ .../webui/src/components/layout/Layout.tsx | 6 +- .../RecordingContext.tsx} | 228 +++++++++++------- .../webui/src/pages/Conversations.tsx | 12 + .../advanced/webui/src/pages/LiveRecord.tsx | 4 +- 10 files changed, 254 insertions(+), 101 deletions(-) create mode 100644 backends/advanced/webui/src/components/layout/GlobalRecordingIndicator.tsx rename backends/advanced/webui/src/{hooks/useSimpleAudioRecording.ts => contexts/RecordingContext.tsx} (84%) diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py index 99b79a6f..d87fd2e3 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py @@ -15,8 +15,14 @@ import httpx import websockets +from advanced_omi_backend.config_loader import get_backend_config from advanced_omi_backend.model_registry import get_models_registry -from .base import BaseTranscriptionProvider, BatchTranscriptionProvider, StreamingTranscriptionProvider + +from .base import ( + BaseTranscriptionProvider, + BatchTranscriptionProvider, + StreamingTranscriptionProvider, +) logger = logging.getLogger(__name__) @@ -148,9 +154,15 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = words = _dotted_get(data, extract.get("words")) or [] segments = _dotted_get(data, extract.get("segments")) or [] - # Ignore segments from all providers - speaker service creates them via diarization - segments = [] - logger.debug(f"Transcription: Extracted {len(words)} words, ignoring provider segments (speaker service will create them)") + # Check config to decide whether to keep or discard provider segments + transcription_config = get_backend_config("transcription") + use_provider_segments = transcription_config.get("use_provider_segments", False) + + if not use_provider_segments: + segments = [] + logger.debug(f"Transcription: Extracted {len(words)} words, ignoring provider segments (use_provider_segments=false)") + else: + logger.debug(f"Transcription: Extracted {len(words)} words, keeping {len(segments)} provider segments (use_provider_segments=true)") return {"text": text, "words": words, "segments": segments} diff --git a/backends/advanced/webui/src/App.tsx b/backends/advanced/webui/src/App.tsx index d2c31a05..33400f92 100644 --- a/backends/advanced/webui/src/App.tsx +++ b/backends/advanced/webui/src/App.tsx @@ -1,6 +1,7 @@ import { BrowserRouter as Router, Routes, Route } from 'react-router-dom' import { AuthProvider } from './contexts/AuthContext' import { ThemeProvider } from './contexts/ThemeContext' +import { RecordingProvider } from './contexts/RecordingContext' import Layout from './components/layout/Layout' import LoginPage from './pages/LoginPage' import Chat from './pages/Chat' @@ -28,7 +29,8 @@ function App() { - + + } /> - + + diff --git a/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx b/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx index af5d4a3c..db17d626 100644 --- a/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx +++ b/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx @@ -1,7 +1,7 @@ -import { SimpleAudioRecordingReturn } from '../../hooks/useSimpleAudioRecording' +import { RecordingContextType } from '../../contexts/RecordingContext' interface SimpleDebugPanelProps { - recording: SimpleAudioRecordingReturn + recording: RecordingContextType } export default function SimpleDebugPanel({ recording }: SimpleDebugPanelProps) { diff --git a/backends/advanced/webui/src/components/audio/SimplifiedControls.tsx b/backends/advanced/webui/src/components/audio/SimplifiedControls.tsx index f81142c5..a3299deb 100644 --- a/backends/advanced/webui/src/components/audio/SimplifiedControls.tsx +++ b/backends/advanced/webui/src/components/audio/SimplifiedControls.tsx @@ -1,8 +1,8 @@ import { Mic, MicOff, Loader2 } from 'lucide-react' -import { SimpleAudioRecordingReturn } from '../../hooks/useSimpleAudioRecording' +import { RecordingContextType } from '../../contexts/RecordingContext' interface SimplifiedControlsProps { - recording: SimpleAudioRecordingReturn + recording: RecordingContextType } const getStepText = (step: string): string => { diff --git a/backends/advanced/webui/src/components/audio/StatusDisplay.tsx b/backends/advanced/webui/src/components/audio/StatusDisplay.tsx index 1e28ee52..d151ef4d 100644 --- a/backends/advanced/webui/src/components/audio/StatusDisplay.tsx +++ b/backends/advanced/webui/src/components/audio/StatusDisplay.tsx @@ -1,9 +1,9 @@ import React from 'react' import { Check, Loader2, AlertCircle, Mic, Wifi, Play, Radio } from 'lucide-react' -import { SimpleAudioRecordingReturn, RecordingStep } from '../../hooks/useSimpleAudioRecording' +import { RecordingContextType, RecordingStep } from '../../contexts/RecordingContext' interface StatusDisplayProps { - recording: SimpleAudioRecordingReturn + recording: RecordingContextType } interface StepInfo { diff --git a/backends/advanced/webui/src/components/layout/GlobalRecordingIndicator.tsx b/backends/advanced/webui/src/components/layout/GlobalRecordingIndicator.tsx new file mode 100644 index 00000000..2bbb9147 --- /dev/null +++ b/backends/advanced/webui/src/components/layout/GlobalRecordingIndicator.tsx @@ -0,0 +1,66 @@ +import { useNavigate, useLocation } from 'react-router-dom' +import { Radio, Square, Zap, Archive } from 'lucide-react' +import { useRecording } from '../../contexts/RecordingContext' + +export default function GlobalRecordingIndicator() { + const navigate = useNavigate() + const location = useLocation() + const { isRecording, recordingDuration, mode, stopRecording, formatDuration } = useRecording() + + // Don't show if not recording + if (!isRecording) return null + + // Don't show on the Live Record page (it has its own UI) + if (location.pathname === '/live-record') return null + + return ( +
+ {/* Pulsing red dot */} +
+ + +
+ + {/* Recording info */} +
+ + {formatDuration(recordingDuration)} + + + {mode === 'streaming' ? ( + <> + + Streaming + + ) : ( + <> + + Batch + + )} + +
+ + {/* Actions */} +
+ {/* Navigate to Live Record */} + + + {/* Stop button */} + +
+
+ ) +} diff --git a/backends/advanced/webui/src/components/layout/Layout.tsx b/backends/advanced/webui/src/components/layout/Layout.tsx index 5a7e10be..814634d9 100644 --- a/backends/advanced/webui/src/components/layout/Layout.tsx +++ b/backends/advanced/webui/src/components/layout/Layout.tsx @@ -2,6 +2,7 @@ import { Link, useLocation, Outlet } from 'react-router-dom' import { Music, MessageSquare, MessageCircle, Brain, Users, Upload, Settings, LogOut, Sun, Moon, Shield, Radio, Layers, Calendar, Puzzle, Zap } from 'lucide-react' import { useAuth } from '../../contexts/AuthContext' import { useTheme } from '../../contexts/ThemeContext' +import GlobalRecordingIndicator from './GlobalRecordingIndicator' export default function Layout() { const location = useLocation() @@ -37,6 +38,9 @@ export default function Layout() {
+ {/* Global Recording Indicator */} + + - + {/* User info */}
diff --git a/backends/advanced/webui/src/hooks/useSimpleAudioRecording.ts b/backends/advanced/webui/src/contexts/RecordingContext.tsx similarity index 84% rename from backends/advanced/webui/src/hooks/useSimpleAudioRecording.ts rename to backends/advanced/webui/src/contexts/RecordingContext.tsx index d34c8ea6..ba39badc 100644 --- a/backends/advanced/webui/src/hooks/useSimpleAudioRecording.ts +++ b/backends/advanced/webui/src/contexts/RecordingContext.tsx @@ -1,6 +1,7 @@ -import { useState, useRef, useCallback, useEffect } from 'react' +import { createContext, useContext, useState, useRef, useCallback, useEffect, ReactNode } from 'react' import { BACKEND_URL } from '../services/api' import { getStorageKey } from '../utils/storage' +import { useAuth } from './AuthContext' export type RecordingStep = 'idle' | 'mic' | 'websocket' | 'audio-start' | 'streaming' | 'stopping' | 'error' export type RecordingMode = 'batch' | 'streaming' @@ -14,7 +15,7 @@ export interface DebugStats { connectionAttempts: number } -export interface SimpleAudioRecordingReturn { +export interface RecordingContextType { // Current state currentStep: RecordingStep isRecording: boolean @@ -36,14 +37,19 @@ export interface SimpleAudioRecordingReturn { canAccessMicrophone: boolean } -export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { +const RecordingContext = createContext(undefined) + +export function RecordingProvider({ children }: { children: ReactNode }) { + const { user } = useAuth() + // Basic state const [currentStep, setCurrentStep] = useState('idle') const [isRecording, setIsRecording] = useState(false) const [recordingDuration, setRecordingDuration] = useState(0) const [error, setError] = useState(null) const [mode, setMode] = useState('streaming') - + const [analyserState, setAnalyserState] = useState(null) + // Debug stats const [debugStats, setDebugStats] = useState({ chunksSent: 0, @@ -53,7 +59,7 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { sessionStartTime: null, connectionAttempts: 0 }) - + // Refs for direct access const wsRef = useRef(null) const mediaStreamRef = useRef(null) @@ -64,9 +70,7 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { const keepAliveIntervalRef = useRef>() const chunkCountRef = useRef(0) const audioProcessingStartedRef = useRef(false) - - // Note: user was unused and removed - + // Check if we're on localhost or using HTTPS const isLocalhost = window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1' const isHttps = window.location.protocol === 'https:' @@ -78,64 +82,65 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { const isDevelopmentHost = devAllowedHosts.includes(window.location.hostname) const canAccessMicrophone = isLocalhost || isHttps || isDevelopmentHost - + // Format duration helper const formatDuration = useCallback((seconds: number) => { const mins = Math.floor(seconds / 60) const secs = seconds % 60 return `${mins}:${secs.toString().padStart(2, '0')}` }, []) - + // Cleanup function const cleanup = useCallback(() => { console.log('🧹 Cleaning up audio recording resources') - + // Stop audio processing audioProcessingStartedRef.current = false - + // Clean up media stream if (mediaStreamRef.current) { mediaStreamRef.current.getTracks().forEach(track => track.stop()) mediaStreamRef.current = null } - + // Clean up audio context if (audioContextRef.current?.state !== 'closed') { audioContextRef.current?.close() } audioContextRef.current = null analyserRef.current = null + setAnalyserState(null) processorRef.current = null - + // Clean up WebSocket if (wsRef.current) { wsRef.current.close() wsRef.current = null } - + // Clear intervals if (durationIntervalRef.current) { clearInterval(durationIntervalRef.current) durationIntervalRef.current = undefined } - + if (keepAliveIntervalRef.current) { clearInterval(keepAliveIntervalRef.current) keepAliveIntervalRef.current = undefined } - + // Reset counters chunkCountRef.current = 0 }, []) - + // Step 1: Get microphone access const getMicrophoneAccess = useCallback(async (): Promise => { console.log('🎤 Step 1: Requesting microphone access') - + if (!canAccessMicrophone) { throw new Error('Microphone access requires HTTPS or localhost') } - + const stream = await navigator.mediaDevices.getUserMedia({ audio: { sampleRate: 16000, @@ -145,12 +150,26 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { autoGainControl: true } }) - + mediaStreamRef.current = stream + + // Track when mic permission is revoked + stream.getTracks().forEach(track => { + track.onended = () => { + console.log('🎤 Microphone track ended (permission revoked or device disconnected)') + if (isRecording) { + setError('Microphone disconnected or permission revoked') + setCurrentStep('error') + cleanup() + setIsRecording(false) + } + } + }) + console.log('✅ Microphone access granted') return stream - }, [canAccessMicrophone]) - + }, [canAccessMicrophone, isRecording, cleanup]) + // Step 2: Connect WebSocket const connectWebSocket = useCallback(async (): Promise => { console.log('🔗 Step 2: Connecting to WebSocket') @@ -159,7 +178,7 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { if (!token) { throw new Error('No authentication token found') } - + // Build WebSocket URL using BACKEND_URL from API service (handles base path correctly) const { protocol } = window.location const wsProtocol = protocol === 'https:' ? 'wss:' : 'ws:' @@ -176,23 +195,23 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { // BACKEND_URL is empty (same origin) wsUrl = `${wsProtocol}//${window.location.host}/ws?codec=pcm&token=${token}&device_name=webui-recorder` } - + return new Promise((resolve, reject) => { const ws = new WebSocket(wsUrl) // Don't set binaryType yet - only when needed for audio chunks - + ws.onopen = () => { console.log('🔌 WebSocket connected') - + // Add stabilization delay before resolving setTimeout(() => { wsRef.current = ws - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, connectionAttempts: prev.connectionAttempts + 1, sessionStartTime: new Date() })) - + // Start keepalive ping every 30 seconds keepAliveIntervalRef.current = setInterval(() => { if (ws.readyState === WebSocket.OPEN) { @@ -204,27 +223,35 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { } } }, 30000) - + console.log('✅ WebSocket stabilized and ready') resolve(ws) }, 100) // 100ms stabilization delay } - + ws.onclose = (event) => { console.log('🔌 WebSocket disconnected:', event.code, event.reason) wsRef.current = null - + if (keepAliveIntervalRef.current) { clearInterval(keepAliveIntervalRef.current) keepAliveIntervalRef.current = undefined } + + // If recording was active, set error state + if (isRecording) { + setError('WebSocket connection lost') + setCurrentStep('error') + cleanup() + setIsRecording(false) + } } - + ws.onerror = (error) => { console.error('🔌 WebSocket error:', error) reject(new Error('Failed to connect to backend')) } - + ws.onmessage = (event) => { console.log('📨 Received message from server:', event.data) setDebugStats(prev => ({ ...prev, messagesReceived: prev.messagesReceived + 1 })) @@ -262,8 +289,8 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { } } }) - }, []) - + }, [isRecording, cleanup]) + // Step 3: Send audio-start message const sendAudioStartMessage = useCallback(async (ws: WebSocket): Promise => { console.log('📤 Step 3: Sending audio-start message') @@ -286,7 +313,7 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { ws.send(JSON.stringify(startMessage) + '\n') console.log('✅ Audio-start message sent with mode:', mode) }, [mode]) - + // Step 4: Start audio streaming const startAudioStreaming = useCallback(async (stream: MediaStream, ws: WebSocket): Promise => { console.log('🎵 Step 4: Starting audio streaming') @@ -312,10 +339,11 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { audioContextRef.current = audioContext analyserRef.current = analyser + setAnalyserState(analyser) // Wait brief moment for backend to process audio-start await new Promise(resolve => setTimeout(resolve, 100)) - + // Set up audio processing const processor = audioContext.createScriptProcessor(4096, 1, 1) source.connect(processor) @@ -356,15 +384,13 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { return } - // inputData already declared above for audio level calculation - // Convert float32 to int16 PCM const pcmBuffer = new Int16Array(inputData.length) for (let i = 0; i < inputData.length; i++) { const sample = Math.max(-1, Math.min(1, inputData[i])) pcmBuffer[i] = sample < 0 ? sample * 0x8000 : sample * 0x7FFF } - + try { const chunkHeader = { type: 'audio-chunk', @@ -375,13 +401,13 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { }, payload_length: pcmBuffer.byteLength } - + // Set binary type for WebSocket before sending binary data if (ws.binaryType !== 'arraybuffer') { ws.binaryType = 'arraybuffer' console.log('🔧 Set WebSocket binaryType to arraybuffer for audio chunks') } - + ws.send(JSON.stringify(chunkHeader) + '\n') ws.send(new Uint8Array(pcmBuffer.buffer, pcmBuffer.byteOffset, pcmBuffer.byteLength)) @@ -395,75 +421,75 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { } } catch (error) { console.error('Failed to send audio chunk:', error) - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, lastError: error instanceof Error ? error.message : 'Chunk send failed', lastErrorTime: new Date() })) } } - + processorRef.current = processor audioProcessingStartedRef.current = true - + console.log('✅ Audio streaming started') }, []) - + // Main start recording function - sequential flow const startRecording = useCallback(async () => { try { setError(null) setCurrentStep('mic') - + // Step 1: Get microphone access const stream = await getMicrophoneAccess() - + setCurrentStep('websocket') // Step 2: Connect WebSocket (includes stabilization delay) const ws = await connectWebSocket() - + setCurrentStep('audio-start') // Step 3: Send audio-start message await sendAudioStartMessage(ws) - + setCurrentStep('streaming') // Step 4: Start audio streaming (includes processing delay) await startAudioStreaming(stream, ws) - + // All steps complete - mark as recording setIsRecording(true) setRecordingDuration(0) - + // Start duration timer durationIntervalRef.current = setInterval(() => { setRecordingDuration(prev => prev + 1) }, 1000) - + console.log('🎉 Recording started successfully!') - + } catch (error) { console.error('❌ Recording failed:', error) setCurrentStep('error') setError(error instanceof Error ? error.message : 'Recording failed') - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, lastError: error instanceof Error ? error.message : 'Recording failed', lastErrorTime: new Date() })) cleanup() } }, [getMicrophoneAccess, connectWebSocket, sendAudioStartMessage, startAudioStreaming, cleanup]) - + // Stop recording function const stopRecording = useCallback(() => { if (!isRecording) return - + console.log('🛑 Stopping recording') setCurrentStep('stopping') - + // Stop audio processing audioProcessingStartedRef.current = false - + // Send audio-stop message if WebSocket is still open if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { try { @@ -478,37 +504,67 @@ export const useSimpleAudioRecording = (): SimpleAudioRecordingReturn => { console.error('Failed to send audio-stop:', error) } } - + // Cleanup resources cleanup() - + // Reset state setIsRecording(false) setRecordingDuration(0) setCurrentStep('idle') - + console.log('✅ Recording stopped') }, [isRecording, cleanup]) - - // Cleanup on unmount + + // Stop recording when user logs out useEffect(() => { - return () => { - cleanup() + if (!user && isRecording) { + console.log('🔐 User logged out, stopping recording') + stopRecording() + } + }, [user, isRecording, stopRecording]) + + // Warn user before closing tab during recording + useEffect(() => { + const handleBeforeUnload = (event: BeforeUnloadEvent) => { + if (isRecording) { + event.preventDefault() + event.returnValue = 'Recording in progress. Are you sure you want to leave?' + return event.returnValue + } } - }, [cleanup]) - - return { - currentStep, - isRecording, - recordingDuration, - error, - mode, - startRecording, - stopRecording, - setMode, - analyser: analyserRef.current, - debugStats, - formatDuration, - canAccessMicrophone + + window.addEventListener('beforeunload', handleBeforeUnload) + return () => window.removeEventListener('beforeunload', handleBeforeUnload) + }, [isRecording]) + + // NOTE: No cleanup on unmount - recording persists across navigation + // This is intentional for the global recording feature + + return ( + + {children} + + ) +} + +export function useRecording() { + const context = useContext(RecordingContext) + if (context === undefined) { + throw new Error('useRecording must be used within a RecordingProvider') } -} \ No newline at end of file + return context +} diff --git a/backends/advanced/webui/src/pages/Conversations.tsx b/backends/advanced/webui/src/pages/Conversations.tsx index ef57e738..d8861859 100644 --- a/backends/advanced/webui/src/pages/Conversations.tsx +++ b/backends/advanced/webui/src/pages/Conversations.tsx @@ -1225,6 +1225,18 @@ export default function Conversations() {
Memory Count: {conversation.memory_count || 0}
Client ID: {conversation.client_id}
+ + {/* Raw Segments JSON */} + {conversation.segments && conversation.segments.length > 0 && ( +
+ + Raw Segments ({conversation.segments.length}) + +
+                        {JSON.stringify(conversation.segments, null, 2)}
+                      
+
+ )}
)}
diff --git a/backends/advanced/webui/src/pages/LiveRecord.tsx b/backends/advanced/webui/src/pages/LiveRecord.tsx index 4b763746..b4a61fb4 100644 --- a/backends/advanced/webui/src/pages/LiveRecord.tsx +++ b/backends/advanced/webui/src/pages/LiveRecord.tsx @@ -1,12 +1,12 @@ import { Radio, Zap, Archive } from 'lucide-react' -import { useSimpleAudioRecording } from '../hooks/useSimpleAudioRecording' +import { useRecording } from '../contexts/RecordingContext' import SimplifiedControls from '../components/audio/SimplifiedControls' import StatusDisplay from '../components/audio/StatusDisplay' import AudioVisualizer from '../components/audio/AudioVisualizer' import SimpleDebugPanel from '../components/audio/SimpleDebugPanel' export default function LiveRecord() { - const recording = useSimpleAudioRecording() + const recording = useRecording() return (