Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions backend/app/services/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self):
self.script_connections: Dict[str, Set[WebSocket]] = {}
# User sessions: user_id -> {project_id, script_id, ws}
self.user_sessions: Dict[str, Dict[str, Any]] = {}
# Reverse mappings for performance: project_id -> set of user_ids, script_id -> set of user_ids
self.project_user_ids: Dict[str, Set[str]] = {}
self.script_user_ids: Dict[str, Set[str]] = {}
# Node locks: script_id -> {node_id -> {user_id, lock_time, expires_at}}
self.node_locks: Dict[str, Dict[str, Dict[str, Any]]] = {}
# Lock timeout in minutes
Expand All @@ -33,8 +36,10 @@ async def broadcast_project_active_users(self, project_id: str):
def get_script_active_users(self, script_id: str) -> List[Dict[str, Any]]:
"""Return users currently editing a specific script."""
users = []
for uid, session in self.user_sessions.items():
if session.get("script_id") == script_id:
user_ids = self.script_user_ids.get(script_id, set())
for uid in user_ids:
session = self.user_sessions.get(uid)
if session:
users.append({
"id": uid,
"username": session.get("username", "Unknown"),
Expand All @@ -58,6 +63,16 @@ async def connect_project(self, websocket: WebSocket, project_id: str, user_id:

self.project_connections[project_id].add(websocket)

# Remove from previous project/script mappings if they exist
if user_id in self.user_sessions:
old_session = self.user_sessions[user_id]
old_pid = old_session.get("project_id")
old_sid = old_session.get("script_id")
if old_pid and old_pid in self.project_user_ids:
self.project_user_ids[old_pid].discard(user_id)
if old_sid and old_sid in self.script_user_ids:
self.script_user_ids[old_sid].discard(user_id)

# Update user session
self.user_sessions[user_id] = {
"project_id": project_id,
Expand All @@ -66,6 +81,11 @@ async def connect_project(self, websocket: WebSocket, project_id: str, user_id:
"connected_at": datetime.now().isoformat()
}

# Add to project mapping
if project_id not in self.project_user_ids:
self.project_user_ids[project_id] = set()
self.project_user_ids[project_id].add(user_id)

# Broadcast updated list of active users to everyone
await self.broadcast_project_active_users(project_id)

Expand All @@ -82,6 +102,11 @@ async def connect_script(self, websocket: WebSocket, script_id: str, user_id: st

# Update user session
if user_id in self.user_sessions:
# Remove from previous script mapping if it changed
old_sid = self.user_sessions[user_id].get("script_id")
if old_sid and old_sid != script_id and old_sid in self.script_user_ids:
self.script_user_ids[old_sid].discard(user_id)

self.user_sessions[user_id].update({
"script_id": script_id,
"ws": websocket,
Expand All @@ -95,6 +120,11 @@ async def connect_script(self, websocket: WebSocket, script_id: str, user_id: st
"connected_at": datetime.now().isoformat()
}

# Add to script mapping
if script_id not in self.script_user_ids:
self.script_user_ids[script_id] = set()
self.script_user_ids[script_id].add(user_id)

# Notify others that a user joined the script
await self.broadcast_to_script(
script_id,
Expand Down Expand Up @@ -434,10 +464,12 @@ async def broadcast_to_script(
def get_project_active_users(self, project_id: str) -> List[Dict[str, Any]]:
"""Get list of active users in a project."""
active_users = []
for user_id, session in self.user_sessions.items():
if session.get("project_id") == project_id:
user_ids = self.project_user_ids.get(project_id, set())
for uid in user_ids:
session = self.user_sessions.get(uid)
if session:
active_users.append({
"id": user_id, # Changed to match test expectations
"id": uid,
"username": session.get("username", "Unknown"),
"connected_at": session.get("connected_at"),
"editing_script": session.get("script_id")
Expand Down Expand Up @@ -522,6 +554,12 @@ async def disconnect(self, websocket: WebSocket, user_id: Optional[str] = None):
}
)

# Remove from project/script mappings
if project_id and project_id in self.project_user_ids:
self.project_user_ids[project_id].discard(user_id)
if script_id and script_id in self.script_user_ids:
self.script_user_ids[script_id].discard(user_id)

# Remove user session before broadcasting updated lists
del self.user_sessions[user_id]

Expand Down