diff --git a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py
index 13683194..d456e89e 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/homeassistant/plugin.py
@@ -28,7 +28,7 @@ class HomeAssistantPlugin(BasePlugin):
-> Returns: PluginResult with "I've turned off the hall light"
"""
- SUPPORTED_ACCESS_LEVELS: List[str] = ['transcript']
+ SUPPORTED_ACCESS_LEVELS: List[str] = ["transcript"]
name = "Home Assistant"
description = "Wake word device control with Home Assistant integration"
@@ -56,10 +56,10 @@ def __init__(self, config: Dict[str, Any]):
self.cache_initialized = False
# Configuration
- self.ha_url = config.get('ha_url', 'http://localhost:8123')
- self.ha_token = config.get('ha_token', '')
- self.wake_word = config.get('wake_word', 'vivi')
- self.timeout = config.get('timeout', 30)
+ self.ha_url = config.get("ha_url", "http://localhost:8123")
+ self.ha_token = config.get("ha_token", "")
+ self.wake_word = config.get("wake_word", "vivi")
+ self.timeout = config.get("timeout", 30)
async def initialize(self):
"""
@@ -81,9 +81,7 @@ async def initialize(self):
# Create MCP client (used for REST API calls, not MCP protocol)
self.mcp_client = HAMCPClient(
- base_url=self.ha_url,
- token=self.ha_token,
- timeout=self.timeout
+ base_url=self.ha_url, token=self.ha_token, timeout=self.timeout
)
# Test basic API connectivity with Template API
@@ -140,21 +138,17 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
should_continue=False
)
"""
- command = context.data.get('command', '')
+ command = context.data.get("command", "")
if not command:
- return PluginResult(
- success=False,
- message="No command provided",
- should_continue=True
- )
+ return PluginResult(success=False, message="No command provided", should_continue=True)
if not self.mcp_client:
logger.error("MCP client not initialized")
return PluginResult(
success=False,
message="Sorry, Home Assistant is not connected",
- should_continue=True
+ should_continue=True,
)
try:
@@ -166,7 +160,7 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
return PluginResult(
success=False,
message="Sorry, I couldn't understand that command",
- should_continue=True
+ should_continue=True,
)
# Step 2: Resolve entities from parsed command
@@ -174,47 +168,38 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
entity_ids = await self._resolve_entities(parsed)
except ValueError as e:
logger.warning(f"Entity resolution failed: {e}")
- return PluginResult(
- success=False,
- message=str(e),
- should_continue=True
- )
+ return PluginResult(success=False, message=str(e), should_continue=True)
# Step 3: Determine service and domain
# Extract domain from first entity (all should have same domain for area-based)
- domain = entity_ids[0].split('.')[0] if entity_ids else 'light'
+ domain = entity_ids[0].split(".")[0] if entity_ids else "light"
# Map action to service name
service_map = {
- 'turn_on': 'turn_on',
- 'turn_off': 'turn_off',
- 'toggle': 'toggle',
- 'set_brightness': 'turn_on', # brightness uses turn_on with params
- 'set_color': 'turn_on' # color uses turn_on with params
+ "turn_on": "turn_on",
+ "turn_off": "turn_off",
+ "toggle": "toggle",
+ "set_brightness": "turn_on", # brightness uses turn_on with params
+ "set_color": "turn_on", # color uses turn_on with params
}
- service = service_map.get(parsed.action, 'turn_on')
+ service = service_map.get(parsed.action, "turn_on")
# Step 4: Call Home Assistant service
- logger.info(
- f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}"
- )
+ logger.info(f"Calling {domain}.{service} for {len(entity_ids)} entities: {entity_ids}")
result = await self.mcp_client.call_service(
- domain=domain,
- service=service,
- entity_ids=entity_ids,
- **parsed.parameters
+ domain=domain, service=service, entity_ids=entity_ids, **parsed.parameters
)
# Step 5: Format user-friendly response
entity_type_name = parsed.entity_type or domain
- if parsed.target_type == 'area':
+ if parsed.target_type == "area":
message = (
f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} "
f"{entity_type_name}{'s' if len(entity_ids) != 1 else ''} "
f"in {parsed.target}"
)
- elif parsed.target_type == 'all_in_area':
+ elif parsed.target_type == "all_in_area":
message = (
f"I've {parsed.action.replace('_', ' ')} {len(entity_ids)} "
f"entities in {parsed.target}"
@@ -227,14 +212,14 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
return PluginResult(
success=True,
data={
- 'action': parsed.action,
- 'entity_ids': entity_ids,
- 'target_type': parsed.target_type,
- 'target': parsed.target,
- 'ha_result': result
+ "action": parsed.action,
+ "entity_ids": entity_ids,
+ "target_type": parsed.target_type,
+ "target": parsed.target,
+ "ha_result": result,
},
message=message,
- should_continue=False # Stop normal processing - HA command handled
+ should_continue=False, # Stop normal processing - HA command handled
)
except MCPError as e:
@@ -242,14 +227,14 @@ async def on_transcript(self, context: PluginContext) -> Optional[PluginResult]:
return PluginResult(
success=False,
message=f"Sorry, Home Assistant couldn't execute that: {e}",
- should_continue=True
+ should_continue=True,
)
except Exception as e:
logger.error(f"Command execution failed: {e}", exc_info=True)
return PluginResult(
success=False,
message="Sorry, something went wrong while executing that command",
- should_continue=True
+ should_continue=True,
)
async def cleanup(self):
@@ -298,23 +283,23 @@ async def _refresh_cache(self):
# Create cache
from datetime import datetime
+
self.entity_cache = EntityCache(
areas=areas,
area_entities=area_entities,
entity_details=entity_details,
- last_refresh=datetime.now()
+ last_refresh=datetime.now(),
)
logger.info(
- f"Entity cache refreshed: {len(areas)} areas, "
- f"{len(entity_details)} entities"
+ f"Entity cache refreshed: {len(areas)} areas, " f"{len(entity_details)} entities"
)
except Exception as e:
logger.error(f"Failed to refresh entity cache: {e}", exc_info=True)
raise
- async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand']:
+ async def _parse_command_with_llm(self, command: str) -> Optional["ParsedCommand"]:
"""
Parse command using LLM with structured system prompt.
@@ -336,6 +321,7 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand
"""
try:
from advanced_omi_backend.llm_client import get_llm_client
+
from .command_parser import COMMAND_PARSER_SYSTEM_PROMPT, ParsedCommand
llm_client = get_llm_client()
@@ -347,36 +333,36 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand
model=llm_client.model,
messages=[
{"role": "system", "content": COMMAND_PARSER_SYSTEM_PROMPT},
- {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'}
+ {"role": "user", "content": f'Command: "{command}"\n\nReturn JSON only.'},
],
temperature=0.1,
- max_tokens=150
+ max_tokens=150,
)
result_text = response.choices[0].message.content.strip()
logger.debug(f"LLM response: {result_text}")
# Remove markdown code blocks if present
- if result_text.startswith('```'):
- lines = result_text.split('\n')
- result_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else result_text
+ if result_text.startswith("```"):
+ lines = result_text.split("\n")
+ result_text = "\n".join(lines[1:-1]) if len(lines) > 2 else result_text
result_text = result_text.strip()
# Parse JSON response
result_json = json.loads(result_text)
# Validate required fields
- required_fields = ['action', 'target_type', 'target']
+ required_fields = ["action", "target_type", "target"]
if not all(field in result_json for field in required_fields):
logger.warning(f"LLM response missing required fields: {result_json}")
return None
parsed = ParsedCommand(
- action=result_json['action'],
- target_type=result_json['target_type'],
- target=result_json['target'],
- entity_type=result_json.get('entity_type'),
- parameters=result_json.get('parameters', {})
+ action=result_json["action"],
+ target_type=result_json["target_type"],
+ target=result_json["target"],
+ entity_type=result_json.get("entity_type"),
+ parameters=result_json.get("parameters", {}),
)
logger.info(
@@ -394,7 +380,7 @@ async def _parse_command_with_llm(self, command: str) -> Optional['ParsedCommand
logger.error(f"LLM command parsing failed: {e}", exc_info=True)
return None
- async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]:
+ async def _resolve_entities(self, parsed: "ParsedCommand") -> List[str]:
"""
Resolve ParsedCommand to actual Home Assistant entity IDs.
@@ -424,11 +410,10 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]:
if not self.entity_cache:
raise ValueError("Entity cache not initialized")
- if parsed.target_type == 'area':
+ if parsed.target_type == "area":
# Get entities in area, filtered by type
entities = self.entity_cache.get_entities_in_area(
- area=parsed.target,
- entity_type=parsed.entity_type
+ area=parsed.target, entity_type=parsed.entity_type
)
if not entities:
@@ -444,12 +429,9 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]:
)
return entities
- elif parsed.target_type == 'all_in_area':
+ elif parsed.target_type == "all_in_area":
# Get ALL entities in area (no filter)
- entities = self.entity_cache.get_entities_in_area(
- area=parsed.target,
- entity_type=None
- )
+ entities = self.entity_cache.get_entities_in_area(area=parsed.target, entity_type=None)
if not entities:
raise ValueError(
@@ -460,7 +442,7 @@ async def _resolve_entities(self, parsed: 'ParsedCommand') -> List[str]:
logger.info(f"Resolved 'all in {parsed.target}' to {len(entities)} entities")
return entities
- elif parsed.target_type == 'entity':
+ elif parsed.target_type == "entity":
# Fuzzy match entity by name
entity_id = self.entity_cache.find_entity_by_name(parsed.target)
@@ -501,37 +483,35 @@ async def _parse_command_fallback(self, command: str) -> Optional[Dict[str, Any]
# Determine action
tool = None
- if any(word in command_lower for word in ['turn off', 'off', 'disable']):
- tool = 'turn_off'
- action_desc = 'turned off'
- elif any(word in command_lower for word in ['turn on', 'on', 'enable']):
- tool = 'turn_on'
- action_desc = 'turned on'
- elif 'toggle' in command_lower:
- tool = 'toggle'
- action_desc = 'toggled'
+ if any(word in command_lower for word in ["turn off", "off", "disable"]):
+ tool = "turn_off"
+ action_desc = "turned off"
+ elif any(word in command_lower for word in ["turn on", "on", "enable"]):
+ tool = "turn_on"
+ action_desc = "turned on"
+ elif "toggle" in command_lower:
+ tool = "toggle"
+ action_desc = "toggled"
else:
logger.warning(f"Unknown action in command: {command}")
return None
# Extract entity name from command
entity_query = command_lower
- for action_word in ['turn off', 'turn on', 'toggle', 'off', 'on', 'the']:
- entity_query = entity_query.replace(action_word, '').strip()
+ for action_word in ["turn off", "turn on", "toggle", "off", "on", "the"]:
+ entity_query = entity_query.replace(action_word, "").strip()
logger.info(f"Searching for entity: '{entity_query}'")
# Return placeholder (this will work if entity ID matches pattern)
return {
"tool": tool,
- "arguments": {
- "entity_id": f"light.{entity_query.replace(' ', '_')}"
- },
+ "arguments": {"entity_id": f"light.{entity_query.replace(' ', '_')}"},
"friendly_name": entity_query.title(),
- "action_desc": action_desc
+ "action_desc": action_desc,
}
- async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']:
+ async def _parse_command_hybrid(self, command: str) -> Optional["ParsedCommand"]:
"""
Hybrid command parser: Try LLM first, fallback to keywords.
@@ -550,15 +530,13 @@ async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']
ParsedCommand(action="turn_off", target_type="area", target="study", ...)
"""
import asyncio
+
from .command_parser import ParsedCommand
# Try LLM parsing with timeout
try:
logger.debug("Attempting LLM-based command parsing...")
- parsed = await asyncio.wait_for(
- self._parse_command_with_llm(command),
- timeout=5.0
- )
+ parsed = await asyncio.wait_for(self._parse_command_with_llm(command), timeout=5.0)
if parsed:
logger.info("LLM parsing succeeded")
@@ -581,16 +559,16 @@ async def _parse_command_hybrid(self, command: str) -> Optional['ParsedCommand']
# Convert fallback format to ParsedCommand
# Extract entity_id from arguments
- entity_id = fallback_result['arguments'].get('entity_id', '')
- entity_name = entity_id.split('.', 1)[1] if '.' in entity_id else entity_id
+ entity_id = fallback_result["arguments"].get("entity_id", "")
+ entity_name = entity_id.split(".", 1)[1] if "." in entity_id else entity_id
# Simple heuristic: assume it's targeting a single entity
parsed = ParsedCommand(
- action=fallback_result['tool'],
- target_type='entity',
- target=entity_name.replace('_', ' '),
+ action=fallback_result["tool"],
+ target_type="entity",
+ target=entity_name.replace("_", " "),
entity_type=None,
- parameters={}
+ parameters={},
)
logger.info("Fallback parsing succeeded")
@@ -629,64 +607,63 @@ async def test_connection(config: Dict[str, Any]) -> Dict[str, Any]:
try:
# Validate required config fields
- required_fields = ['ha_url', 'ha_token']
+ required_fields = ["ha_url", "ha_token"]
missing_fields = [field for field in required_fields if not config.get(field)]
if missing_fields:
return {
"success": False,
"message": f"Missing required fields: {', '.join(missing_fields)}",
- "status": "error"
+ "status": "error",
}
- ha_url = config.get('ha_url')
- ha_token = config.get('ha_token')
- timeout = config.get('timeout', 30)
+ ha_url = config.get("ha_url")
+ ha_token = config.get("ha_token")
+ timeout = config.get("timeout", 30)
# Create temporary MCP client
- mcp_client = HAMCPClient(
- base_url=ha_url,
- token=ha_token,
- timeout=timeout
- )
-
- # Test API connectivity with Template API
- logger.info(f"Testing Home Assistant API connection to {ha_url}...")
- start_time = time.time()
+ mcp_client = HAMCPClient(base_url=ha_url, token=ha_token, timeout=timeout)
- test_result = await mcp_client._render_template("{{ 1 + 1 }}")
- connection_time_ms = int((time.time() - start_time) * 1000)
-
- if str(test_result).strip() != "2":
- return {
- "success": False,
- "message": f"Unexpected template result: {test_result}",
- "status": "error"
- }
-
- # Try to fetch entities count for additional info
try:
- entities = await mcp_client.get_all_entities()
- entity_count = len(entities)
- except Exception:
- entity_count = None
+ # Test API connectivity with Template API
+ logger.info(f"Testing Home Assistant API connection to {ha_url}...")
+ start_time = time.time()
+
+ test_result = await mcp_client._render_template("{{ 1 + 1 }}")
+ connection_time_ms = int((time.time() - start_time) * 1000)
+
+ if str(test_result).strip() != "2":
+ return {
+ "success": False,
+ "message": f"Unexpected template result: {test_result}",
+ "status": "error",
+ }
+
+ # Try to fetch entities count for additional info
+ try:
+ entities = await mcp_client.discover_entities()
+ entity_count = len(entities)
+ except Exception:
+ entity_count = None
- return {
- "success": True,
- "message": f"Successfully connected to Home Assistant at {ha_url}",
- "status": "success",
- "details": {
- "ha_url": ha_url,
- "connection_time_ms": connection_time_ms,
- "entity_count": entity_count,
- "api_test": "Template rendering successful"
+ return {
+ "success": True,
+ "message": f"Successfully connected to Home Assistant at {ha_url}",
+ "status": "success",
+ "details": {
+ "ha_url": ha_url,
+ "connection_time_ms": connection_time_ms,
+ "entity_count": entity_count,
+ "api_test": "Template rendering successful",
+ },
}
- }
+ finally:
+ await mcp_client.close()
except Exception as e:
logger.error(f"Home Assistant connection test failed: {e}", exc_info=True)
return {
"success": False,
"message": f"Connection test failed: {str(e)}",
- "status": "error"
+ "status": "error",
}
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py
index 8c5b5389..1a4e545f 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py
@@ -6,7 +6,8 @@
import logging
import uuid
-from typing import List, Dict, Any, Optional
+from typing import Any, Dict, List, Optional
+
import httpx
memory_logger = logging.getLogger("memory_service")
@@ -14,12 +15,12 @@
class MCPClient:
"""Client for communicating with OpenMemory servers.
-
+
Uses the official OpenMemory REST API:
- POST /api/v1/memories - Create new memory
- GET /api/v1/memories - List memories
- DELETE /api/v1/memories - Delete memories
-
+
Attributes:
server_url: Base URL of the OpenMemory server (default: http://localhost:8765)
client_name: Client identifier for memory tagging
@@ -27,8 +28,15 @@ class MCPClient:
timeout: Request timeout in seconds
client: HTTP client instance
"""
-
- def __init__(self, server_url: str, client_name: str = "chronicle", user_id: str = "default", user_email: str = "", timeout: int = 30):
+
+ def __init__(
+ self,
+ server_url: str,
+ client_name: str = "chronicle",
+ user_id: str = "default",
+ user_email: str = "",
+ timeout: int = 30,
+ ):
"""Initialize client for OpenMemory.
Args:
@@ -38,43 +46,44 @@ def __init__(self, server_url: str, client_name: str = "chronicle", user_id: str
user_email: User email address for user metadata
timeout: HTTP request timeout in seconds
"""
- self.server_url = server_url.rstrip('/')
+ self.server_url = server_url.rstrip("/")
self.client_name = client_name
self.user_id = user_id
self.user_email = user_email
self.timeout = timeout
-
+
# Use custom CA certificate if available
import os
- ca_bundle = os.getenv('REQUESTS_CA_BUNDLE')
+
+ ca_bundle = os.getenv("REQUESTS_CA_BUNDLE")
verify = ca_bundle if ca_bundle and os.path.exists(ca_bundle) else True
-
+
self.client = httpx.AsyncClient(timeout=timeout, verify=verify)
-
+
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()
-
+
async def __aenter__(self):
return self
-
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
-
+
async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List[str]:
"""Add memories to the OpenMemory server.
-
+
Uses the REST API to create memories. OpenMemory will handle:
- Memory extraction from text
- Deduplication
- Vector embedding and storage
-
+
Args:
text: Memory text to store
-
+
Returns:
List of created memory IDs
-
+
Raises:
MCPError: If the server request fails
"""
@@ -113,7 +122,7 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List
default_metadata = {
"source": "chronicle",
"client": self.client_name,
- "user_email": self.user_email
+ "user_email": self.user_email,
}
if metadata:
default_metadata.update(metadata)
@@ -125,58 +134,63 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List
"text": text,
"app": self.client_name,
"metadata": default_metadata,
- "infer": True
+ "infer": True,
}
- memory_logger.info(f"POSTing memory to {self.server_url}/api/v1/memories/ with payload={payload}")
-
- response = await self.client.post(
- f"{self.server_url}/api/v1/memories/",
- json=payload
+ memory_logger.info(
+ f"POSTing memory to {self.server_url}/api/v1/memories/ "
+ f"(user_id={self.user_id}, text_len={len(text)}, metadata_keys={list(default_metadata.keys())})"
)
+ memory_logger.debug(f"Full payload: {payload}")
+
+ response = await self.client.post(f"{self.server_url}/api/v1/memories/", json=payload)
response_body = response.text[:500] if response.status_code != 200 else "..."
- memory_logger.info(f"OpenMemory response: status={response.status_code}, body={response_body}, headers={dict(response.headers)}")
+ memory_logger.info(
+ f"OpenMemory response: status={response.status_code}, body={response_body}, headers={dict(response.headers)}"
+ )
response.raise_for_status()
-
+
result = response.json()
-
+
# Handle None result - OpenMemory returns None when no memory is created
# (due to deduplication, insufficient content, etc.)
if result is None:
- memory_logger.info("OpenMemory returned None - no memory created (likely deduplication)")
+ memory_logger.info(
+ "OpenMemory returned None - no memory created (likely deduplication)"
+ )
return []
-
+
# Handle error response
if isinstance(result, dict) and "error" in result:
memory_logger.error(f"OpenMemory error: {result['error']}")
return []
-
+
# Extract memory ID from response
if isinstance(result, dict):
memory_id = result.get("id") or str(uuid.uuid4())
return [memory_id]
elif isinstance(result, list):
return [str(item.get("id", uuid.uuid4())) for item in result]
-
+
# Default success response
return [str(uuid.uuid4())]
-
+
except httpx.HTTPError as e:
memory_logger.error(f"HTTP error adding memories: {e}")
raise MCPError(f"HTTP error: {e}")
except Exception as e:
memory_logger.error(f"Error adding memories: {e}")
raise MCPError(f"Failed to add memories: {e}")
-
+
async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
"""Search for memories using semantic similarity.
-
+
Args:
query: Search query text
limit: Maximum number of results to return
-
+
Returns:
List of memory dictionaries with content and metadata
"""
@@ -185,38 +199,34 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any
apps_response = await self.client.get(f"{self.server_url}/api/v1/apps/")
apps_response.raise_for_status()
apps_data = apps_response.json()
-
+
if not apps_data.get("apps") or len(apps_data["apps"]) == 0:
memory_logger.warning("No apps found in OpenMemory MCP for search")
return []
-
+
# Find the app matching our client name, or use first app as fallback
app_id = None
for app in apps_data["apps"]:
if app["name"] == self.client_name:
app_id = app["id"]
break
-
+
if not app_id:
- memory_logger.warning(f"App '{self.client_name}' not found, using first available app")
+ memory_logger.warning(
+ f"App '{self.client_name}' not found, using first available app"
+ )
app_id = apps_data["apps"][0]["id"]
-
+
# Use app-specific memories endpoint with search
- params = {
- "user_id": self.user_id,
- "search_query": query,
- "page": 1,
- "size": limit
- }
-
+ params = {"user_id": self.user_id, "search_query": query, "page": 1, "size": limit}
+
response = await self.client.get(
- f"{self.server_url}/api/v1/apps/{app_id}/memories",
- params=params
+ f"{self.server_url}/api/v1/apps/{app_id}/memories", params=params
)
response.raise_for_status()
-
+
result = response.json()
-
+
# Extract memories from app-specific response format
if isinstance(result, dict) and "memories" in result:
memories = result["memories"]
@@ -224,30 +234,32 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any
memories = result
else:
memories = []
-
+
# Format memories for Chronicle
formatted_memories = []
for memory in memories:
- formatted_memories.append({
- "id": memory.get("id", str(uuid.uuid4())),
- "content": memory.get("content", "") or memory.get("text", ""),
- "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}),
- "created_at": memory.get("created_at"),
- "score": memory.get("score", 0.0) # No score from list API
- })
-
+ formatted_memories.append(
+ {
+ "id": memory.get("id", str(uuid.uuid4())),
+ "content": memory.get("content", "") or memory.get("text", ""),
+ "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}),
+ "created_at": memory.get("created_at"),
+ "score": memory.get("score", 0.0), # No score from list API
+ }
+ )
+
return formatted_memories[:limit]
-
+
except Exception as e:
memory_logger.error(f"Error searching memories: {e}")
return []
-
+
async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]:
"""List all memories for the current user.
-
+
Args:
limit: Maximum number of memories to return
-
+
Returns:
List of memory dictionaries
"""
@@ -256,37 +268,34 @@ async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]:
apps_response = await self.client.get(f"{self.server_url}/api/v1/apps/")
apps_response.raise_for_status()
apps_data = apps_response.json()
-
+
if not apps_data.get("apps") or len(apps_data["apps"]) == 0:
memory_logger.warning("No apps found in OpenMemory MCP")
return []
-
+
# Find the app matching our client name, or use first app as fallback
app_id = None
for app in apps_data["apps"]:
if app["name"] == self.client_name:
app_id = app["id"]
break
-
+
if not app_id:
- memory_logger.warning(f"App '{self.client_name}' not found, using first available app")
+ memory_logger.warning(
+ f"App '{self.client_name}' not found, using first available app"
+ )
app_id = apps_data["apps"][0]["id"]
-
+
# Use app-specific memories endpoint
- params = {
- "user_id": self.user_id,
- "page": 1,
- "size": limit
- }
-
+ params = {"user_id": self.user_id, "page": 1, "size": limit}
+
response = await self.client.get(
- f"{self.server_url}/api/v1/apps/{app_id}/memories",
- params=params
+ f"{self.server_url}/api/v1/apps/{app_id}/memories", params=params
)
response.raise_for_status()
-
+
result = response.json()
-
+
# Extract memories from app-specific response format
if isinstance(result, dict) and "memories" in result:
memories = result["memories"]
@@ -294,29 +303,31 @@ async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]:
memories = result
else:
memories = []
-
+
# Format memories
formatted_memories = []
for memory in memories:
- formatted_memories.append({
- "id": memory.get("id", str(uuid.uuid4())),
- "content": memory.get("content", "") or memory.get("text", ""),
- "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}),
- "created_at": memory.get("created_at")
- })
-
+ formatted_memories.append(
+ {
+ "id": memory.get("id", str(uuid.uuid4())),
+ "content": memory.get("content", "") or memory.get("text", ""),
+ "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}),
+ "created_at": memory.get("created_at"),
+ }
+ )
+
return formatted_memories
-
+
except Exception as e:
memory_logger.error(f"Error listing memories: {e}")
return []
-
+
async def delete_all_memories(self) -> int:
"""Delete all memories for the current user.
-
+
Note: OpenMemory may not support bulk delete via REST API.
This is typically done through MCP tools for safety.
-
+
Returns:
Number of memories that were deleted
"""
@@ -325,31 +336,29 @@ async def delete_all_memories(self) -> int:
memories = await self.list_memories(limit=1000)
if not memories:
return 0
-
+
memory_ids = [m["id"] for m in memories]
-
+
# Delete memories using the batch delete endpoint
response = await self.client.request(
"DELETE",
f"{self.server_url}/api/v1/memories/",
- json={
- "memory_ids": memory_ids,
- "user_id": self.user_id
- }
+ json={"memory_ids": memory_ids, "user_id": self.user_id},
)
response.raise_for_status()
-
+
result = response.json()
-
+
# Extract count from response
if isinstance(result, dict):
if "message" in result:
# Parse message like "Successfully deleted 5 memories"
import re
- match = re.search(r'(\d+)', result["message"])
+
+ match = re.search(r"(\d+)", result["message"])
return int(match.group(1)) if match else len(memory_ids)
return result.get("deleted_count", len(memory_ids))
-
+
return len(memory_ids)
except Exception as e:
@@ -368,8 +377,7 @@ async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]:
try:
# Use the memories endpoint with specific ID
response = await self.client.get(
- f"{self.server_url}/api/v1/memories/{memory_id}",
- params={"user_id": self.user_id}
+ f"{self.server_url}/api/v1/memories/{memory_id}", params={"user_id": self.user_id}
)
if response.status_code == 404:
@@ -403,7 +411,7 @@ async def update_memory(
self,
memory_id: str,
content: Optional[str] = None,
- metadata: Optional[Dict[str, Any]] = None
+ metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""Update a specific memory's content and/or metadata.
@@ -431,8 +439,7 @@ async def update_memory(
# Use PUT to update memory
response = await self.client.put(
- f"{self.server_url}/api/v1/memories/{memory_id}",
- json=update_data
+ f"{self.server_url}/api/v1/memories/{memory_id}", json=update_data
)
response.raise_for_status()
@@ -446,12 +453,14 @@ async def update_memory(
memory_logger.error(f"Error updating memory: {e}")
return False
- async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None) -> bool:
+ async def delete_memory(
+ self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None
+ ) -> bool:
"""Delete a specific memory by ID.
-
+
Args:
memory_id: ID of the memory to delete
-
+
Returns:
True if deletion succeeded, False otherwise
"""
@@ -459,21 +468,18 @@ async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, use
response = await self.client.request(
"DELETE",
f"{self.server_url}/api/v1/memories/",
- json={
- "memory_ids": [memory_id],
- "user_id": self.user_id
- }
+ json={"memory_ids": [memory_id], "user_id": self.user_id},
)
response.raise_for_status()
return True
-
+
except Exception as e:
memory_logger.warning(f"Error deleting memory {memory_id}: {e}")
return False
-
+
async def test_connection(self) -> bool:
"""Test connection to the OpenMemory server.
-
+
Returns:
True if server is reachable and responsive, False otherwise
"""
@@ -484,16 +490,23 @@ async def test_connection(self) -> bool:
try:
response = await self.client.get(
f"{self.server_url}{endpoint}",
- params={"user_id": self.user_id, "page": 1, "size": 1}
- if endpoint == "/api/v1/memories" else {}
+ params=(
+ {"user_id": self.user_id, "page": 1, "size": 1}
+ if endpoint == "/api/v1/memories"
+ else {}
+ ),
)
- if response.status_code in [200, 404, 422]: # 404/422 means endpoint exists but params wrong
+ if response.status_code in [
+ 200,
+ 404,
+ 422,
+ ]: # 404/422 means endpoint exists but params wrong
return True
except:
continue
-
+
return False
-
+
except Exception as e:
memory_logger.error(f"OpenMemory server connection test failed: {e}")
return False
@@ -501,4 +514,5 @@ async def test_connection(self) -> bool:
class MCPError(Exception):
"""Exception raised for MCP server communication errors."""
- pass
\ No newline at end of file
+
+ pass
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py
index 922f2555..d5061a2c 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py
@@ -10,9 +10,9 @@
import os
import time
import uuid
-from typing import Optional, List, Tuple, Any, Dict
+from typing import Any, Dict, List, Optional, Tuple
-from ..base import MemoryServiceBase, MemoryEntry
+from ..base import MemoryEntry, MemoryServiceBase
from .mcp_client import MCPClient, MCPError
memory_logger = logging.getLogger("memory_service")
@@ -24,14 +24,14 @@ class OpenMemoryMCPService(MemoryServiceBase):
This class implements the MemoryServiceBase interface by delegating memory
operations to an OpenMemory MCP server. It handles the translation between
Chronicle's memory service API and the standardized MCP operations.
-
+
Key features:
- Maintains compatibility with existing MemoryServiceBase interface
- Leverages OpenMemory MCP's deduplication and processing
- - Supports transcript-based memory extraction
+ - Supports transcript-based memory extraction
- Provides user isolation and metadata management
- Handles memory search and CRUD operations
-
+
Attributes:
server_url: URL of the OpenMemory MCP server
timeout: Request timeout in seconds
@@ -39,7 +39,7 @@ class OpenMemoryMCPService(MemoryServiceBase):
mcp_client: Client for communicating with MCP server
_initialized: Whether the service has been initialized
"""
-
+
def __init__(
self,
server_url: Optional[str] = None,
@@ -67,42 +67,42 @@ def __init__(
self.user_id = user_id or os.getenv("OPENMEMORY_USER_ID", "default")
self.timeout = int(timeout or os.getenv("OPENMEMORY_TIMEOUT", "30"))
self.mcp_client: Optional[MCPClient] = None
-
+
async def initialize(self) -> None:
"""Initialize the OpenMemory MCP service.
-
+
Sets up the MCP client connection and tests connectivity to ensure
the service is ready for memory operations.
-
+
Raises:
RuntimeError: If initialization or connection test fails
"""
if self._initialized:
return
-
+
try:
self.mcp_client = MCPClient(
server_url=self.server_url,
client_name=self.client_name,
user_id=self.user_id,
- timeout=self.timeout
+ timeout=self.timeout,
)
-
+
# Test connection to OpenMemory MCP server
is_connected = await self.mcp_client.test_connection()
if not is_connected:
raise RuntimeError(f"Cannot connect to OpenMemory MCP server at {self.server_url}")
-
+
self._initialized = True
memory_logger.info(
f"✅ OpenMemory MCP service initialized successfully at {self.server_url} "
f"(client: {self.client_name}, user: {self.user_id})"
)
-
+
except Exception as e:
memory_logger.error(f"OpenMemory MCP service initialization failed: {e}")
raise RuntimeError(f"Initialization failed: {e}")
-
+
async def add_memory(
self,
transcript: str,
@@ -111,14 +111,14 @@ async def add_memory(
user_id: str,
user_email: str,
allow_update: bool = False,
- db_helper: Any = None
+ db_helper: Any = None,
) -> Tuple[bool, List[str]]:
"""Add memories extracted from a transcript.
-
+
Processes a transcript to extract meaningful memories and stores them
in the OpenMemory MCP server. Can either extract memories locally first
or send the raw transcript to MCP for processing.
-
+
Args:
transcript: Raw transcript text to extract memories from
client_id: Client identifier for tracking
@@ -127,74 +127,78 @@ async def add_memory(
user_email: User email address
allow_update: Whether to allow updating existing memories (Note: MCP may handle this internally)
db_helper: Optional database helper for relationship tracking
-
+
Returns:
Tuple of (success: bool, created_memory_ids: List[str])
-
+
Raises:
MCPError: If MCP server communication fails
"""
await self._ensure_initialized()
-
+
try:
# Skip empty transcripts
if not transcript or len(transcript.strip()) < 10:
memory_logger.info(f"Skipping empty transcript for {source_id}")
return True, []
-
+
# Use configured OpenMemory user (from config) for all Chronicle users
# Chronicle user_id and email are stored in metadata for filtering
enriched_transcript = f"[Source: {source_id}, Client: {client_id}] {transcript}"
- memory_logger.info(f"Delegating memory processing to OpenMemory for user {user_id} (email: {user_email}), source {source_id}")
+ memory_logger.info(
+ f"Delegating memory processing to OpenMemory for user {user_id}, source {source_id}"
+ )
# Pass Chronicle user details in metadata for filtering/search
metadata = {
"chronicle_user_id": user_id,
"chronicle_user_email": user_email,
"source_id": source_id,
- "client_id": client_id
+ "client_id": client_id,
}
- memory_ids = await self.mcp_client.add_memories(text=enriched_transcript, metadata=metadata)
-
+ memory_ids = await self.mcp_client.add_memories(
+ text=enriched_transcript, metadata=metadata
+ )
+
# Update database relationships if helper provided
if memory_ids and db_helper:
await self._update_database_relationships(db_helper, source_id, memory_ids)
-
+
if memory_ids:
- memory_logger.info(f"✅ OpenMemory MCP processed memory for {source_id}: {len(memory_ids)} memories")
+ memory_logger.info(
+ f"✅ OpenMemory MCP processed memory for {source_id}: {len(memory_ids)} memories"
+ )
return True, memory_ids
-
+
# NOOP due to deduplication is SUCCESS, not failure
- memory_logger.info(f"✅ OpenMemory MCP processed {source_id}: no new memories needed (likely deduplication)")
+ memory_logger.info(
+ f"✅ OpenMemory MCP processed {source_id}: no new memories needed (likely deduplication)"
+ )
return True, []
-
+
except MCPError as e:
memory_logger.error(f"❌ OpenMemory MCP error for {source_id}: {e}")
raise e
except Exception as e:
memory_logger.error(f"❌ OpenMemory MCP service failed for {source_id}: {e}")
raise e
-
+
async def search_memories(
- self,
- query: str,
- user_id: str,
- limit: int = 10,
- score_threshold: float = 0.0
+ self, query: str, user_id: str, limit: int = 10, score_threshold: float = 0.0
) -> List[MemoryEntry]:
"""Search memories using semantic similarity.
-
+
Uses the OpenMemory MCP server to perform semantic search across
stored memories for the specified user.
-
+
Args:
query: Search query text
user_id: User identifier to filter memories
limit: Maximum number of results to return
score_threshold: Minimum similarity score (ignored - OpenMemory MCP server controls filtering)
-
+
Returns:
List of matching MemoryEntry objects ordered by relevance
"""
@@ -206,8 +210,7 @@ async def search_memories(
try:
# Get more results since we'll filter by user
results = await self.mcp_client.search_memory(
- query=query,
- limit=limit * 3 # Get extra to account for filtering
+ query=query, limit=limit * 3 # Get extra to account for filtering
)
# Convert MCP results to MemoryEntry objects and filter by user
@@ -222,7 +225,9 @@ async def search_memories(
if len(memory_entries) >= limit:
break # Got enough results
- memory_logger.info(f"🔍 Found {len(memory_entries)} memories for query '{query}' (user: {user_id})")
+ memory_logger.info(
+ f"🔍 Found {len(memory_entries)} memories for query '{query}' (user: {user_id})"
+ )
return memory_entries
except MCPError as e:
@@ -231,21 +236,17 @@ async def search_memories(
except Exception as e:
memory_logger.error(f"Search memories failed: {e}")
return []
-
- async def get_all_memories(
- self,
- user_id: str,
- limit: int = 100
- ) -> List[MemoryEntry]:
+
+ async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]:
"""Get all memories for a specific user.
-
+
Retrieves all stored memories for the given user without
similarity filtering.
-
+
Args:
user_id: User identifier
limit: Maximum number of memories to return
-
+
Returns:
List of MemoryEntry objects for the user
"""
@@ -280,7 +281,9 @@ async def get_all_memories(
memory_logger.error(f"Get all memories failed: {e}")
return []
- async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Optional[MemoryEntry]:
+ async def get_memory(
+ self, memory_id: str, user_id: Optional[str] = None
+ ) -> Optional[MemoryEntry]:
"""Get a specific memory by ID.
Args:
@@ -323,7 +326,7 @@ async def update_memory(
content: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None,
- user_email: Optional[str] = None
+ user_email: Optional[str] = None,
) -> bool:
"""Update a specific memory's content and/or metadata.
@@ -346,9 +349,7 @@ async def update_memory(
try:
success = await self.mcp_client.update_memory(
- memory_id=memory_id,
- content=content,
- metadata=metadata
+ memory_id=memory_id, content=content, metadata=metadata
)
if success:
@@ -362,18 +363,20 @@ async def update_memory(
# Restore original user_id
self.mcp_client.user_id = original_user_id
- async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None) -> bool:
+ async def delete_memory(
+ self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None
+ ) -> bool:
"""Delete a specific memory by ID.
-
+
Args:
memory_id: Unique identifier of the memory to delete
-
+
Returns:
True if successfully deleted, False otherwise
"""
if not self._initialized:
await self.initialize()
-
+
try:
success = await self.mcp_client.delete_memory(memory_id)
if success:
@@ -382,38 +385,38 @@ async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, use
except Exception as e:
memory_logger.error(f"Delete memory failed: {e}")
return False
-
+
async def delete_all_user_memories(self, user_id: str) -> int:
"""Delete all memories for a specific user.
-
+
Args:
user_id: User identifier
-
+
Returns:
Number of memories that were deleted
"""
if not self._initialized:
await self.initialize()
-
+
# Update MCP client user context for this operation
original_user_id = self.mcp_client.user_id
self.mcp_client.user_id = self.user_id # Use configured user ID
-
+
try:
count = await self.mcp_client.delete_all_memories()
memory_logger.info(f"🗑️ Deleted {count} memories for user {user_id} via OpenMemory MCP")
return count
-
+
except Exception as e:
memory_logger.error(f"Delete user memories failed: {e}")
return 0
finally:
# Restore original user_id
self.mcp_client.user_id = original_user_id
-
+
async def test_connection(self) -> bool:
"""Test if the memory service and its dependencies are working.
-
+
Returns:
True if all connections are healthy, False otherwise
"""
@@ -424,7 +427,7 @@ async def test_connection(self) -> bool:
except Exception as e:
memory_logger.error(f"Connection test failed: {e}")
return False
-
+
def shutdown(self) -> None:
"""Shutdown the memory service and clean up resources."""
if self.mcp_client:
@@ -433,56 +436,73 @@ def shutdown(self) -> None:
self._initialized = False
self.mcp_client = None
memory_logger.info("OpenMemory MCP service shut down")
-
+
# Private helper methods
-
+
def _ensure_client(self) -> MCPClient:
"""Ensure MCP client is available and return it."""
if self.mcp_client is None:
raise RuntimeError("OpenMemory MCP client not initialized")
return self.mcp_client
-
- def _mcp_result_to_memory_entry(self, mcp_result: Dict[str, Any], user_id: str) -> Optional[MemoryEntry]:
+
+ def _mcp_result_to_memory_entry(
+ self, mcp_result: Dict[str, Any], user_id: str
+ ) -> Optional[MemoryEntry]:
"""Convert OpenMemory MCP server result to MemoryEntry object.
-
+
Args:
mcp_result: Result dictionary from OpenMemory MCP server
user_id: User identifier to include in metadata
-
+
Returns:
MemoryEntry object or None if conversion fails
"""
try:
# OpenMemory MCP results may have different formats, adapt as needed
- memory_id = mcp_result.get('id', str(uuid.uuid4()))
- content = mcp_result.get('content', '') or mcp_result.get('memory', '') or mcp_result.get('text', '') or mcp_result.get('data', '')
-
+ memory_id = mcp_result.get("id", str(uuid.uuid4()))
+ content = (
+ mcp_result.get("content", "")
+ or mcp_result.get("memory", "")
+ or mcp_result.get("text", "")
+ or mcp_result.get("data", "")
+ )
+
if not content:
memory_logger.warning(f"Empty content in MCP result: {mcp_result}")
return None
-
+
# Build metadata with OpenMemory context
- metadata = mcp_result.get('metadata', {})
+ metadata = mcp_result.get("metadata", {})
if not metadata:
metadata = {}
-
+
# Ensure we have user context
- metadata.update({
- 'user_id': user_id,
- 'source': 'openmemory_mcp',
- 'client_name': self.client_name,
- 'mcp_server': self.server_url
- })
-
+ metadata.update(
+ {
+ "user_id": user_id,
+ "source": "openmemory_mcp",
+ "client_name": self.client_name,
+ "mcp_server": self.server_url,
+ }
+ )
+
# Extract similarity score if available (for search results)
- score = mcp_result.get('score') or mcp_result.get('similarity') or mcp_result.get('relevance')
+ score = (
+ mcp_result.get("score")
+ or mcp_result.get("similarity")
+ or mcp_result.get("relevance")
+ )
# Extract timestamps
- created_at = mcp_result.get('created_at') or mcp_result.get('timestamp') or mcp_result.get('date')
+ created_at = (
+ mcp_result.get("created_at")
+ or mcp_result.get("timestamp")
+ or mcp_result.get("date")
+ )
if created_at is None:
created_at = str(int(time.time()))
- updated_at = mcp_result.get('updated_at') or mcp_result.get('modified_at')
+ updated_at = mcp_result.get("updated_at") or mcp_result.get("modified_at")
if updated_at is None:
updated_at = str(created_at) # Default to created_at if not provided
@@ -493,21 +513,18 @@ def _mcp_result_to_memory_entry(self, mcp_result: Dict[str, Any], user_id: str)
embedding=None, # OpenMemory MCP server handles embeddings internally
score=score,
created_at=str(created_at),
- updated_at=str(updated_at)
+ updated_at=str(updated_at),
)
-
+
except Exception as e:
memory_logger.error(f"Failed to convert MCP result to MemoryEntry: {e}")
return None
-
+
async def _update_database_relationships(
- self,
- db_helper: Any,
- source_id: str,
- created_ids: List[str]
+ self, db_helper: Any, source_id: str, created_ids: List[str]
) -> None:
"""Update database relationships for created memories.
-
+
Args:
db_helper: Database helper instance
source_id: Source session identifier
@@ -517,4 +534,4 @@ async def _update_database_relationships(
try:
await db_helper.add_memory_reference(source_id, memory_id, "created")
except Exception as db_error:
- memory_logger.error(f"Database relationship update failed: {db_error}")
\ No newline at end of file
+ memory_logger.error(f"Database relationship update failed: {db_error}")
diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
index 99b79a6f..d87fd2e3 100644
--- a/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
+++ b/backends/advanced/src/advanced_omi_backend/services/transcription/__init__.py
@@ -15,8 +15,14 @@
import httpx
import websockets
+from advanced_omi_backend.config_loader import get_backend_config
from advanced_omi_backend.model_registry import get_models_registry
-from .base import BaseTranscriptionProvider, BatchTranscriptionProvider, StreamingTranscriptionProvider
+
+from .base import (
+ BaseTranscriptionProvider,
+ BatchTranscriptionProvider,
+ StreamingTranscriptionProvider,
+)
logger = logging.getLogger(__name__)
@@ -148,9 +154,15 @@ async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool =
words = _dotted_get(data, extract.get("words")) or []
segments = _dotted_get(data, extract.get("segments")) or []
- # Ignore segments from all providers - speaker service creates them via diarization
- segments = []
- logger.debug(f"Transcription: Extracted {len(words)} words, ignoring provider segments (speaker service will create them)")
+ # Check config to decide whether to keep or discard provider segments
+ transcription_config = get_backend_config("transcription")
+ use_provider_segments = transcription_config.get("use_provider_segments", False)
+
+ if not use_provider_segments:
+ segments = []
+ logger.debug(f"Transcription: Extracted {len(words)} words, ignoring provider segments (use_provider_segments=false)")
+ else:
+ logger.debug(f"Transcription: Extracted {len(words)} words, keeping {len(segments)} provider segments (use_provider_segments=true)")
return {"text": text, "words": words, "segments": segments}
diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
index 0ce9d77e..e54f3393 100644
--- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
+++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py
@@ -5,35 +5,44 @@
"""
import asyncio
-import os
import logging
+import os
import time
import uuid
from datetime import datetime
from pathlib import Path
-from typing import Dict, Any
-from rq import get_current_job
-from rq.job import Job
-from rq.exceptions import NoSuchJobError
+from typing import Any, Dict
-from advanced_omi_backend.models.job import JobPriority, BaseRQJob, async_job
-from advanced_omi_backend.models.conversation import Conversation
-from advanced_omi_backend.models.audio_chunk import AudioChunkDocument
from beanie.operators import In
+from rq import get_current_job
+from rq.exceptions import NoSuchJobError
+from rq.job import Job
+from advanced_omi_backend.config import get_backend_config
from advanced_omi_backend.controllers.queue_controller import (
- transcription_queue,
- redis_conn,
JOB_RESULT_TTL,
REDIS_URL,
+ redis_conn,
start_post_conversation_jobs,
+ transcription_queue,
)
-from advanced_omi_backend.utils.conversation_utils import analyze_speech, mark_conversation_deleted
-from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_wav_from_conversation, convert_audio_to_chunks
-from advanced_omi_backend.services.plugin_service import get_plugin_router
-from advanced_omi_backend.services.transcription import get_transcription_provider, is_transcription_available
+from advanced_omi_backend.models.audio_chunk import AudioChunkDocument
+from advanced_omi_backend.models.conversation import Conversation
+from advanced_omi_backend.models.job import BaseRQJob, JobPriority, async_job
from advanced_omi_backend.services.audio_stream import TranscriptionResultsAggregator
-from advanced_omi_backend.config import get_backend_config
+from advanced_omi_backend.services.plugin_service import get_plugin_router
+from advanced_omi_backend.services.transcription import (
+ get_transcription_provider,
+ is_transcription_available,
+)
+from advanced_omi_backend.utils.audio_chunk_utils import (
+ convert_audio_to_chunks,
+ reconstruct_wav_from_conversation,
+)
+from advanced_omi_backend.utils.conversation_utils import (
+ analyze_speech,
+ mark_conversation_deleted,
+)
logger = logging.getLogger(__name__)
@@ -63,7 +72,9 @@ async def apply_speaker_recognition(
Updated list of segments with identified speakers
"""
try:
- from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient
+ from advanced_omi_backend.speaker_recognition_client import (
+ SpeakerRecognitionClient,
+ )
speaker_client = SpeakerRecognitionClient()
if not speaker_client.enabled:
@@ -177,7 +188,7 @@ async def transcribe_full_audio_job(
# Extract user_id and client_id for plugin context
user_id = str(conversation.user_id) if conversation.user_id else None
- client_id = conversation.client_id if hasattr(conversation, 'client_id') else None
+ client_id = conversation.client_id if hasattr(conversation, "client_id") else None
# Get the transcription provider
provider = get_transcription_provider(mode="batch")
@@ -195,15 +206,14 @@ async def transcribe_full_audio_job(
wav_data = await reconstruct_wav_from_conversation(conversation_id)
logger.info(
- f"📦 Reconstructed audio from MongoDB chunks: "
- f"{len(wav_data) / 1024 / 1024:.2f} MB"
+ f"📦 Reconstructed audio from MongoDB chunks: " f"{len(wav_data) / 1024 / 1024:.2f} MB"
)
# Transcribe the audio directly from memory (no disk I/O needed)
transcription_result = await provider.transcribe(
audio_data=wav_data, # Pass bytes directly, already in memory
sample_rate=16000,
- diarize=True
+ diarize=True,
)
except ValueError as e:
@@ -224,7 +234,9 @@ async def transcribe_full_audio_job(
# Trigger transcript-level plugins BEFORE speech validation
# This ensures wake-word commands execute even if conversation gets deleted
- logger.info(f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}")
+ logger.info(
+ f"🔍 DEBUG: About to trigger plugins - transcript_text exists: {bool(transcript_text)}"
+ )
if transcript_text:
try:
from advanced_omi_backend.services.plugin_service import init_plugin_router
@@ -236,7 +248,9 @@ async def transcribe_full_audio_job(
if not plugin_router:
logger.info("🔧 Initializing plugin router in worker process...")
plugin_router = init_plugin_router()
- logger.info(f"🔧 After init, plugin_router: {plugin_router is not None}, plugins count: {len(plugin_router.plugins) if plugin_router else 0}")
+ logger.info(
+ f"🔧 After init, plugin_router: {plugin_router is not None}, plugins count: {len(plugin_router.plugins) if plugin_router else 0}"
+ )
# Initialize async plugins
if plugin_router:
@@ -245,18 +259,24 @@ async def transcribe_full_audio_job(
await plugin.initialize()
logger.info(f"✅ Plugin '{plugin_id}' initialized in worker")
except Exception as e:
- logger.exception(f"Failed to initialize plugin '{plugin_id}' in worker: {e}")
+ logger.exception(
+ f"Failed to initialize plugin '{plugin_id}' in worker: {e}"
+ )
- logger.info(f"🔍 DEBUG: Plugin router final check: {plugin_router is not None}, has {len(plugin_router.plugins) if plugin_router else 0} plugins")
+ logger.info(
+ f"🔍 DEBUG: Plugin router final check: {plugin_router is not None}, has {len(plugin_router.plugins) if plugin_router else 0} plugins"
+ )
if plugin_router:
- logger.info(f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}")
+ logger.info(
+ f"🔍 DEBUG: Preparing to trigger transcript plugins for conversation {conversation_id}"
+ )
plugin_data = {
- 'transcript': transcript_text,
- 'segment_id': f"{conversation_id}_batch",
- 'conversation_id': conversation_id,
- 'segments': segments,
- 'word_count': len(words),
+ "transcript": transcript_text,
+ "segment_id": f"{conversation_id}_batch",
+ "conversation_id": conversation_id,
+ "segments": segments,
+ "word_count": len(words),
}
logger.info(
@@ -265,10 +285,10 @@ async def transcribe_full_audio_job(
)
plugin_results = await plugin_router.dispatch_event(
- event='transcript.batch',
+ event="transcript.batch",
user_id=user_id,
data=plugin_data,
- metadata={'client_id': client_id}
+ metadata={"client_id": client_id},
)
logger.info(
@@ -276,7 +296,9 @@ async def transcribe_full_audio_job(
)
if plugin_results:
- logger.info(f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode")
+ logger.info(
+ f"✅ Triggered {len(plugin_results)} transcript plugins in batch mode"
+ )
for result in plugin_results:
if result.message:
logger.info(f" Plugin: {result.message}")
@@ -332,7 +354,9 @@ async def transcribe_full_audio_job(
logger.info(f"✅ Cancelled dependent job: {job_id}")
except Exception as e:
if isinstance(e, NoSuchJobError):
- logger.debug(f"Job {job_id} hash not found (likely already completed or expired)")
+ logger.debug(
+ f"Job {job_id} hash not found (likely already completed or expired)"
+ )
else:
logger.debug(f"Job {job_id} not found or already completed: {e}")
@@ -363,8 +387,8 @@ async def transcribe_full_audio_job(
processing_time = time.time() - start_time
# Check if we should use provider segments as fallback
- transcription_config = get_backend_config('transcription')
- use_provider_segments = transcription_config.get('use_provider_segments', False)
+ transcription_config = get_backend_config("transcription")
+ use_provider_segments = transcription_config.get("use_provider_segments", False)
# If flag enabled and provider returned segments, use them
# Otherwise, speaker service will create segments via diarization
@@ -375,7 +399,7 @@ async def transcribe_full_audio_job(
speaker=str(seg.get("speaker", "0")), # Convert to string for Pydantic validation
start=seg.get("start", 0.0),
end=seg.get("end", 0.0),
- text=seg.get("text", "")
+ text=seg.get("text", ""),
)
for seg in segments
]
@@ -399,7 +423,7 @@ async def transcribe_full_audio_job(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
- confidence=w.get("confidence")
+ confidence=w.get("confidence"),
)
for w in words
]
@@ -409,7 +433,9 @@ async def transcribe_full_audio_job(
"trigger": trigger,
"audio_file_size": len(wav_data),
"word_count": len(words),
- "segments_created_by": "provider" if (use_provider_segments and segments) else "speaker_service",
+ "segments_created_by": (
+ "provider" if (use_provider_segments and segments) else "speaker_service"
+ ),
}
conversation.add_transcript_version(
@@ -528,9 +554,7 @@ async def transcribe_full_audio_job(
async def create_audio_only_conversation(
- session_id: str,
- user_id: str,
- client_id: str
+ session_id: str, user_id: str, client_id: str
) -> "Conversation":
"""
Create or reuse conversation for batch transcription fallback.
@@ -544,7 +568,7 @@ async def create_audio_only_conversation(
placeholder_conversation = await Conversation.find_one(
Conversation.client_id == session_id,
Conversation.always_persist == True,
- In(Conversation.processing_status, ["pending_transcription", "transcription_failed"])
+ In(Conversation.processing_status, ["pending_transcription", "transcription_failed"]),
)
if placeholder_conversation:
@@ -582,20 +606,13 @@ async def create_audio_only_conversation(
)
await conversation.insert()
- logger.info(
- f"✅ Created batch transcription conversation {session_id[:12]} for fallback"
- )
+ logger.info(f"✅ Created batch transcription conversation {session_id[:12]} for fallback")
return conversation
@async_job(redis=True, beanie=True)
async def transcription_fallback_check_job(
- session_id: str,
- user_id: str,
- client_id: str,
- timeout_seconds: int = 1800,
- *,
- redis_client=None
+ session_id: str, user_id: str, client_id: str, timeout_seconds: int = 1800, *, redis_client=None
) -> Dict[str, Any]:
"""
Check if streaming transcription succeeded, fallback to batch if needed.
@@ -617,9 +634,7 @@ async def transcription_fallback_check_job(
logger.info(f"🔍 Checking transcription status for session {session_id[:12]}")
# Find conversation by session_id (client_id for streaming sessions)
- conversation = await Conversation.find_one(
- Conversation.client_id == session_id
- )
+ conversation = await Conversation.find_one(Conversation.client_id == session_id)
# Check if transcript exists (streaming succeeded)
if conversation and conversation.active_transcript and conversation.transcript:
@@ -630,7 +645,7 @@ async def transcription_fallback_check_job(
return {
"status": "pass_through",
"transcript_source": "streaming",
- "conversation_id": conversation.conversation_id
+ "conversation_id": conversation.conversation_id,
}
# No transcript → Trigger batch fallback
@@ -674,7 +689,7 @@ async def transcription_fallback_check_job(
"status": "skipped",
"reason": "no_audio",
"message": "No audio was captured for this session",
- "session_id": session_id
+ "session_id": session_id,
}
logger.info(
@@ -718,7 +733,7 @@ async def transcription_fallback_check_job(
"status": "skipped",
"reason": "no_matching_audio",
"message": "No audio matched this session in Redis stream",
- "session_id": session_id
+ "session_id": session_id,
}
# Combine audio chunks in order
@@ -769,7 +784,7 @@ async def transcription_fallback_check_job(
job_timeout=1800,
job_id=f"transcribe_{conversation.conversation_id[:12]}",
description=f"Batch transcription fallback for {session_id[:8]}",
- meta={"session_id": session_id, "client_id": client_id}
+ meta={"session_id": session_id, "client_id": client_id},
)
logger.info(f"🔄 Enqueued batch transcription fallback job {batch_job.id}")
@@ -779,10 +794,11 @@ async def transcription_fallback_check_job(
waited = 0
while waited < max_wait:
batch_job.refresh()
+ # Check is_failed BEFORE is_finished - a failed job is also "finished"
+ if batch_job.is_failed:
+ raise RuntimeError(f"Batch transcription failed: {batch_job.exc_info}")
if batch_job.is_finished:
- if batch_job.is_failed:
- raise Exception(f"Batch transcription failed: {batch_job.exc_info}")
- logger.info(f"✅ Batch transcription completed successfully")
+ logger.info("✅ Batch transcription completed successfully")
break
await asyncio.sleep(2)
waited += 2
@@ -797,7 +813,7 @@ async def transcription_fallback_check_job(
transcript_version_id=version_id,
depends_on_job=None, # Batch already completed (we waited for it)
client_id=client_id,
- end_reason="batch_fallback"
+ end_reason="batch_fallback",
)
logger.info(
@@ -810,7 +826,7 @@ async def transcription_fallback_check_job(
"transcript_source": "batch",
"conversation_id": conversation.conversation_id,
"batch_job_id": batch_job.id,
- "post_job_ids": post_jobs
+ "post_job_ids": post_jobs,
}
@@ -876,7 +892,9 @@ async def stream_speech_detection_job(
# Track when session closes for graceful shutdown
session_closed_at = None
- final_check_grace_period = 15 # Wait up to 15 seconds for final transcription after session closes
+ final_check_grace_period = (
+ 15 # Wait up to 15 seconds for final transcription after session closes
+ )
last_speech_analysis = None # Track last analysis for detailed logging
# Main loop: Listen for speech
@@ -894,7 +912,9 @@ async def stream_speech_detection_job(
if session_closed and session_closed_at is None:
# Session just closed - start grace period for final transcription
session_closed_at = time.time()
- logger.info(f"🛑 Session closed, waiting up to {final_check_grace_period}s for final transcription results...")
+ logger.info(
+ f"🛑 Session closed, waiting up to {final_check_grace_period}s for final transcription results..."
+ )
# Exit if grace period expired without speech
if session_closed_at and (time.time() - session_closed_at) > final_check_grace_period:
@@ -922,7 +942,9 @@ async def stream_speech_detection_job(
grace_elapsed = time.time() - session_closed_at
if grace_elapsed > 5 and not combined.get("chunk_count", 0):
# 5+ seconds with no transcription activity at all - likely API key issue
- logger.error(f"❌ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue")
+ logger.error(
+ f"❌ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue"
+ )
logger.error(f"❌ Session failed - check transcription service configuration")
break
@@ -947,7 +969,9 @@ async def stream_speech_detection_job(
)
if not speech_analysis.get("has_speech", False):
- logger.info(f"⏳ Waiting for more speech - {speech_analysis.get('reason', 'unknown reason')}")
+ logger.info(
+ f"⏳ Waiting for more speech - {speech_analysis.get('reason', 'unknown reason')}"
+ )
await asyncio.sleep(2)
continue
@@ -1123,10 +1147,14 @@ async def stream_speech_detection_job(
}
# Session ended without speech
- reason = last_speech_analysis.get('reason', 'No transcription received') if last_speech_analysis else 'No transcription received'
+ reason = (
+ last_speech_analysis.get("reason", "No transcription received")
+ if last_speech_analysis
+ else "No transcription received"
+ )
# Distinguish between transcription failures (error) vs legitimate no speech (info)
- if reason == 'No transcription received':
+ if reason == "No transcription received":
logger.error(
f"❌ Session failed - transcription service did not respond\n"
f" Reason: {reason}\n"
@@ -1149,11 +1177,13 @@ async def stream_speech_detection_job(
conversation = await Conversation.find_one(
Conversation.client_id == session_id,
Conversation.always_persist == True,
- Conversation.processing_status == "pending_transcription"
+ Conversation.processing_status == "pending_transcription",
)
if conversation:
- logger.info(f"🔴 Found always_persist placeholder conversation {conversation.conversation_id} for failed session {session_id[:12]}")
+ logger.info(
+ f"🔴 Found always_persist placeholder conversation {conversation.conversation_id} for failed session {session_id[:12]}"
+ )
# Update conversation with failure status
conversation.processing_status = "transcription_failed"
@@ -1162,9 +1192,13 @@ async def stream_speech_detection_job(
await conversation.save()
- logger.warning(f"🔴 Marked conversation {conversation.conversation_id} as transcription_failed")
+ logger.warning(
+ f"🔴 Marked conversation {conversation.conversation_id} as transcription_failed"
+ )
else:
- logger.info(f"ℹ️ No always_persist placeholder conversation found for session {session_id[:12]}")
+ logger.info(
+ f"ℹ️ No always_persist placeholder conversation found for session {session_id[:12]}"
+ )
# Enqueue fallback check job for failed streaming sessions
# This will attempt batch transcription as a fallback
@@ -1174,10 +1208,10 @@ async def stream_speech_detection_job(
user_id,
client_id,
timeout_seconds=1800, # 30 minutes for batch transcription
- job_timeout=2400, # 40 minutes job timeout
+ job_timeout=2400, # 40 minutes job timeout
job_id=f"fallback_check_{session_id[:12]}",
description=f"Transcription fallback check for {session_id[:8]} (no speech)",
- meta={"session_id": session_id, "client_id": client_id, "no_speech": True}
+ meta={"session_id": session_id, "client_id": client_id, "no_speech": True},
)
logger.info(
diff --git a/backends/advanced/webui/src/App.tsx b/backends/advanced/webui/src/App.tsx
index d2c31a05..33400f92 100644
--- a/backends/advanced/webui/src/App.tsx
+++ b/backends/advanced/webui/src/App.tsx
@@ -1,6 +1,7 @@
import { BrowserRouter as Router, Routes, Route } from 'react-router-dom'
import { AuthProvider } from './contexts/AuthContext'
import { ThemeProvider } from './contexts/ThemeContext'
+import { RecordingProvider } from './contexts/RecordingContext'
import Layout from './components/layout/Layout'
import LoginPage from './pages/LoginPage'
import Chat from './pages/Chat'
@@ -28,7 +29,8 @@ function App() {
+ {JSON.stringify(conversation.segments, null, 2)}
+
+