Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/alembic/versions/869cfd49ebd5_initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def upgrade() -> None:
_ = op.create_table(
"project",
sa.Column("project_id", sa.String(length=26), nullable=False),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("creator_user_id", sa.String(length=26), nullable=False),
sa.Column(
"created_at",
Expand Down
132 changes: 129 additions & 3 deletions backend/src/interview_helper/context_manager/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
import interview_helper.context_manager.models as models
from ulid import ULID
import logging


class PersistentDatabase:
Expand Down Expand Up @@ -366,6 +367,7 @@ class ProjectListing(TypedDict):
id: str
name: str
creator_name: str
creator_user_id: str
created_at: str


Expand All @@ -374,12 +376,13 @@ def get_all_projects(db: PersistentDatabase) -> Sequence[ProjectListing]:
Gets all projects with creator name and creation date, sorted by creation date (descending)
"""
with db.begin() as conn:
rows: Sequence[tuple[str, str, str, DateTime]] = (
rows: Sequence[tuple[str, str, str, str, DateTime]] = (
conn.execute(
sa.select(
models.Project.project_id,
models.Project.name,
models.User.full_name,
models.Project.creator_user_id,
models.Project.created_at,
)
.join(
Expand All @@ -392,12 +395,13 @@ def get_all_projects(db: PersistentDatabase) -> Sequence[ProjectListing]:
)

projects: list[ProjectListing] = []
for project_id, project_name, creator_name, created_at in rows:
for project_id, project_name, creator_name, creator_user_id, created_at in rows:
projects.append(
{
"id": project_id,
"name": project_name,
"creator_name": creator_name,
"creator_user_id": creator_user_id,
"created_at": created_at.isoformat(), # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
}
)
Expand Down Expand Up @@ -433,6 +437,7 @@ def create_new_project(
"id": project_id,
"name": project_name,
"creator_name": user.full_name,
"creator_user_id": str(user.user_id),
"created_at": created_at.isoformat(), # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
}

Expand All @@ -449,6 +454,7 @@ def get_project_by_id(
models.Project.project_id,
models.Project.name,
models.User.full_name,
models.Project.creator_user_id,
models.Project.created_at,
)
.join(models.User, models.Project.creator_user_id == models.User.user_id)
Expand All @@ -458,16 +464,49 @@ def get_project_by_id(
if result is None:
return None

project_id_str, project_name, creator_name, created_at = result.tuple()
project_id_str, project_name, creator_name, creator_user_id, created_at = (
result.tuple()
)

return {
"id": project_id_str,
"name": project_name,
"creator_name": creator_name,
"creator_user_id": creator_user_id,
"created_at": created_at.isoformat(), # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
}


@dataclass
class ProjectCreatorInfo:
creator_user_id: UserId
name: str


def get_project_creator_and_name(
db: PersistentDatabase, project_id: ProjectId
) -> ProjectCreatorInfo | None:
"""
Gets the creator user ID and project name for a project.
Returns None if the project doesn't exist.
"""
with db.begin() as conn:
result = conn.execute(
sa.select(models.Project.creator_user_id, models.Project.name).where(
models.Project.project_id == str(project_id)
)
).one_or_none()

if result is None:
return None

creator_user_id_str, project_name = result.tuple()
return ProjectCreatorInfo(
creator_user_id=UserId.from_str(creator_user_id_str),
name=project_name,
)
Comment thread
StarDylan marked this conversation as resolved.


class AnalysisRow(BaseModel):
analysis_id: str
text: str
Expand Down Expand Up @@ -847,3 +886,90 @@ def get_session_sequence_number(
).scalar_one()

return result # pyright: ignore[reportAny]


def get_project_session_count(db: PersistentDatabase, project_id: ProjectId) -> int:
"""
Gets the number of sessions for a project
"""
with db.begin() as conn:
result = conn.execute(
sa.select(sa.func.count(models.Session.session_id)).where(
models.Session.project_id == str(project_id)
)
).scalar_one()

return int(result)


def delete_project(
db: PersistentDatabase, project_id: ProjectId, audio_recordings_dir: str
) -> None:
"""
Comment thread
StarDylan marked this conversation as resolved.
Deletes a project and all related data including:
- AI analyses
- Transcriptions
- Sessions (and their audio files)
- The project itself

Args:
db: The database instance
project_id: The project ID to delete
audio_recordings_dir: The directory where audio recordings are stored

Note:
This function first commits all database deletes, then deletes audio files.
This ensures transaction safety - if the DB delete fails, files remain intact.
If file deletion fails after DB commit, at least the DB is consistent.
"""
recordings_path = Path(audio_recordings_dir)

# Collect session IDs within transaction, then commit DB deletes before touching filesystem
with db.begin() as conn:
# Get all session IDs for this project to delete audio files later
session_ids_result = conn.execute(
sa.select(models.Session.session_id).where(
models.Session.project_id == str(project_id)
)
).all()

session_ids: list[str] = [str(row[0]) for row in session_ids_result] # pyright: ignore[reportAny]

# Delete AI analyses
_ = conn.execute(
sa.delete(models.AIAnalysis).where(
models.AIAnalysis.project_id == str(project_id)
)
)

# Delete transcriptions
_ = conn.execute(
sa.delete(models.Transcription).where(
models.Transcription.project_id == str(project_id)
)
)

# Delete sessions
_ = conn.execute(
sa.delete(models.Session).where(
models.Session.project_id == str(project_id)
)
)

# Delete the project itself
_ = conn.execute(
sa.delete(models.Project).where(
models.Project.project_id == str(project_id)
)
)
# Transaction commits here when exiting the context manager

# Now that DB deletes are committed, delete audio files from filesystem
for session_id in session_ids:
audio_file = recordings_path / f"recording-{session_id}.wav"
if audio_file.exists():
try:
audio_file.unlink()
except OSError as e:
# Log the error but don't fail - DB is already consistent
logging.warning(f"Failed to delete audio file {audio_file}: {e}")
91 changes: 91 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
get_project_by_id,
get_all_transcripts,
get_all_ai_analyses,
get_project_session_count,
delete_project,
get_project_creator_and_name,
)
from interview_helper.context_manager.types import ProjectId, TranscriptId

Expand Down Expand Up @@ -422,6 +425,25 @@ async def websocket_endpoint(
logger.info(f"Closed session {context.session_id} for user {ticket.user_id}")


@app.get("/user/me")
async def get_current_user(token: Annotated[str, Depends(oidc_scheme)]):
"""
Returns the current user's information
"""
clean_token = token.removeprefix("Bearer ")
user_claims = verify_jwt_token(clean_token, jwks_client, CLIENT_ID, signing_algos)

user_info = await get_user_info_from_oidc_provider(clean_token, userinfo_endpoint)
name = f"{user_info.given_name or ''} {user_info.family_name or ''}".strip()
user = get_or_add_user_by_oidc_id(session_manager.db, user_claims.sub, name)

return {
"user_id": str(user.user_id),
"full_name": user.full_name,
"oidc_id": user.oidc_id,
}


@app.get("/project")
async def list_all_projects(token: Annotated[str, Depends(oidc_scheme)]):
"""
Expand Down Expand Up @@ -456,6 +478,75 @@ async def create_project(
return new_project


@app.delete("/project/{project_id}")
async def delete_project_endpoint(
project_id: str, confirmed_name: str, token: Annotated[str, Depends(oidc_scheme)]
):
"""
Deletes a project and all associated data (sessions, transcriptions, audio files, questions).
Only the project creator can delete the project.
Requires confirmation by providing the exact project name.
"""
clean_token = token.removeprefix("Bearer ")
user_claims = verify_jwt_token(clean_token, jwks_client, CLIENT_ID, signing_algos)

# Get user info
user_info = await get_user_info_from_oidc_provider(clean_token, userinfo_endpoint)
name = f"{user_info.given_name or ''} {user_info.family_name or ''}".strip()
user_id = get_or_add_user_by_oidc_id(
session_manager.db, user_claims.sub, name
).user_id

# Verify project exists and get creator info
project_id_typed = ProjectId.from_str(project_id)

project_info = get_project_creator_and_name(session_manager.db, project_id_typed)
if project_info is None:
raise HTTPException(status_code=404, detail="Project not found")

# Check if user is the creator
if project_info.creator_user_id != user_id:
raise HTTPException(
status_code=403, detail="Only the project creator can delete this project"
)

# Verify the confirmed name matches
if confirmed_name != project_info.name:
raise HTTPException(
status_code=400, detail="Project name confirmation does not match"
)
Comment thread
StarDylan marked this conversation as resolved.

# Delete the project and all related data
delete_project(
session_manager.db,
project_id_typed,
session_manager.get_settings().audio_recordings_dir,
)

return {"status": "success", "message": "Project deleted successfully"}


@app.get("/project/{project_id}/info")
async def get_project_info(
project_id: str, token: Annotated[str, Depends(oidc_scheme)]
):
"""
Gets project information including session count for delete confirmation
"""
clean_token = token.removeprefix("Bearer ")
_user_claims = verify_jwt_token(clean_token, jwks_client, CLIENT_ID, signing_algos)

project_id_typed = ProjectId.from_str(project_id)
project = get_project_by_id(session_manager.db, project_id_typed)

if not project:
raise HTTPException(status_code=404, detail="Project not found")

session_count = get_project_session_count(session_manager.db, project_id_typed)

return {**project, "session_count": session_count}


@app.get("/project/{project_id}/download/transcript")
async def download_transcript(
project_id: str, token: Annotated[str, Depends(oidc_scheme)]
Expand Down
Loading
Loading