diff --git a/backend/src/interview_helper/context_manager/session_context_manager.py b/backend/src/interview_helper/context_manager/session_context_manager.py index 23f953e..923e2a2 100644 --- a/backend/src/interview_helper/context_manager/session_context_manager.py +++ b/backend/src/interview_helper/context_manager/session_context_manager.py @@ -363,10 +363,22 @@ async def teardown_session(self, session_id: SessionId) -> None: # Wait for any finishing audio handlers event = None + recording_state_to_clear: tuple[ProjectId, str] | None = None async with self.lock: if session_id in self.active_audio_sessions: event = self.cleanup_waiting_event[session_id] + session_data = self.session_data.get(session_id) + if ( + session_data is not None + and session_data.project in self.recording_state + ): + current_session_id, current_user_name = self.recording_state[ + session_data.project + ] + if current_session_id == session_id: + recording_state_to_clear = (session_data.project, current_user_name) + if event is not None: await event.wait() @@ -401,6 +413,16 @@ async def teardown_session(self, session_id: SessionId) -> None: _ = await task_group.__aexit__(None, None, None) + # Ensure we clear stale recording ownership when a websocket disconnects abruptly. + if recording_state_to_clear is not None: + project_id, user_name = recording_state_to_clear + await self.set_recording_state( + project_id=project_id, + session_id=session_id, + user_name=user_name, + is_recording=False, + ) + async def ingest_audio( self, session_id: SessionId, project_id: ProjectId, audio_chunk: AudioChunk ): diff --git a/backend/src/interview_helper/context_manager/tests/test_session_manager.py b/backend/src/interview_helper/context_manager/tests/test_session_manager.py index 6886630..7da2d62 100644 --- a/backend/src/interview_helper/context_manager/tests/test_session_manager.py +++ b/backend/src/interview_helper/context_manager/tests/test_session_manager.py @@ -99,3 +99,26 @@ async def test_get_settings(): # Ensure that this causes an error so we don't inadvertently use it in tests with pytest.raises(AssertionError): cm.get_settings() + + +async def test_teardown_clears_recording_state_for_disconnected_session(): + context_manager = AppContextManager( + (), ai_processer=FakeAnalyzer, db=PersistentDatabase.new_in_memory() + ) + project_id = ProjectId(ULID()) + ctx = await context_manager.new_session(UserId(ULID()), project_id) + + await context_manager.set_recording_state( + project_id=project_id, + session_id=ctx.session_id, + user_name="Test User", + is_recording=True, + ) + assert await context_manager.get_recording_state(project_id) == ( + ctx.session_id, + "Test User", + ) + + await context_manager.teardown_session(ctx.session_id) + + assert await context_manager.get_recording_state(project_id) is None