Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,22 @@ async def teardown_session(self, session_id: SessionId) -> None:

# Wait for any finishing audio handlers
event = None
recording_state_to_clear: tuple[ProjectId, str] | None = None
async with self.lock:
if session_id in self.active_audio_sessions:
event = self.cleanup_waiting_event[session_id]

session_data = self.session_data.get(session_id)
if (
session_data is not None
and session_data.project in self.recording_state
):
current_session_id, current_user_name = self.recording_state[
session_data.project
]
if current_session_id == session_id:
recording_state_to_clear = (session_data.project, current_user_name)

if event is not None:
await event.wait()

Expand Down Expand Up @@ -401,6 +413,16 @@ async def teardown_session(self, session_id: SessionId) -> None:

_ = await task_group.__aexit__(None, None, None)

# Ensure we clear stale recording ownership when a websocket disconnects abruptly.
if recording_state_to_clear is not None:
project_id, user_name = recording_state_to_clear
await self.set_recording_state(
project_id=project_id,
session_id=session_id,
user_name=user_name,
is_recording=False,
)

async def ingest_audio(
self, session_id: SessionId, project_id: ProjectId, audio_chunk: AudioChunk
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,26 @@ async def test_get_settings():
# Ensure that this causes an error so we don't inadvertently use it in tests
with pytest.raises(AssertionError):
cm.get_settings()


async def test_teardown_clears_recording_state_for_disconnected_session():
context_manager = AppContextManager(
(), ai_processer=FakeAnalyzer, db=PersistentDatabase.new_in_memory()
)
project_id = ProjectId(ULID())
ctx = await context_manager.new_session(UserId(ULID()), project_id)

await context_manager.set_recording_state(
project_id=project_id,
session_id=ctx.session_id,
user_name="Test User",
is_recording=True,
)
assert await context_manager.get_recording_state(project_id) == (
ctx.session_id,
"Test User",
)

await context_manager.teardown_session(ctx.session_id)

assert await context_manager.get_recording_state(project_id) is None
Loading