diff --git a/.gitignore b/.gitignore index b7faf40..7938694 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,5 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +uploads/* \ No newline at end of file diff --git a/Makefile b/Makefile index 82f26e6..2fc31c7 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,9 @@ start: stop: docker-compose down +remove: + docker-compose down -v --remove-orphans + build: docker-compose build diff --git a/ai/apis.py b/ai/apis.py index cfdc70a..3e5fae4 100644 --- a/ai/apis.py +++ b/ai/apis.py @@ -1,10 +1,13 @@ import uuid -from fastapi import status +from fastapi import Depends, Query, status +from sqlalchemy.orm import Session from ai.agent import SchedulingAgent +from ai.db import PolicyEmbedding from ai.models import QNARequestBody, QNAResponseBody from application.app import app +from database.db import get_db @app.post( @@ -39,3 +42,42 @@ async def ai_assistant(request: QNARequestBody) -> QNAResponseBody: "session_id": result.get("session_id") or session_id, "messages": result.get("messages"), } + + +@app.get("/ai/policy-embeddings") +async def get_policy_embeddings( + policy_id: str | None = None, + organization_id: str | None = None, + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + db: Session = Depends(get_db), +): + query = db.query(PolicyEmbedding) + if policy_id: + query = query.filter(PolicyEmbedding.policy_id == policy_id) + if organization_id: + query = query.filter(PolicyEmbedding.organization_id == organization_id) + rows = ( + query.order_by(PolicyEmbedding.created.desc()).offset(offset).limit(limit).all() + ) + embeddings = [] + for row in rows: + item = { + "id": str(row.id), + "policy_id": str(row.policy_id), + "organization_id": str(row.organization_id), + "policy_name": row.policy_name, + "description": row.description, + "document_name": row.document_name, + "file_path": row.file_path, + "chunk_index": row.chunk_index, + "text": row.text, + "created": row.created, + } + embeddings.append(item) + return { + "items": embeddings, + "total": len(embeddings), + "limit": limit, + "offset": offset, + } diff --git a/ai/clients.py b/ai/clients.py index 7517740..7c9102e 100644 --- a/ai/clients.py +++ b/ai/clients.py @@ -3,22 +3,19 @@ import cohere from ai.prompts import PREAMBLE -from appointments.constants import get_appointment_function_map -from appointments.tools import APPOINTMENT_TOOLS from organizations.constants import get_organization_function_map from organizations.tools import ORGANIZATION_TOOLS class CohereClient: - def __init__(self, message: str | None = None): + def __init__(self, message: str | None = None, model: str | None = None): self.client = cohere.Client(os.getenv("COHERE_API_KEY")) - self.model = os.getenv("COHERE_LLM_MODEL") + self.model = model or os.getenv("COHERE_LLM_MODEL") self.preamble = PREAMBLE self.function_map = { - **get_appointment_function_map(), **get_organization_function_map(), } - self.tools = [*APPOINTMENT_TOOLS, *ORGANIZATION_TOOLS] + self.tools = [*ORGANIZATION_TOOLS] self.message = message or "" def chat( @@ -100,3 +97,11 @@ def ask_llm( return str(e), chat_history or [] return (response.text if response else ""), history + + def embed_texts(self, texts: list[str], input_type: str) -> list[list[float]]: + response = self.client.embed( + texts=texts, + model=self.model, + input_type=input_type, + ) + return response.embeddings or [] diff --git a/ai/db.py b/ai/db.py new file mode 100644 index 0000000..d81d45a --- /dev/null +++ b/ai/db.py @@ -0,0 +1,27 @@ +import os +import uuid + +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, DateTime, Integer, String, Text +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func + +from database.db import Base + +DEFAULT_EMBED_DIM = int(os.getenv("RAG_EMBED_DIM", "1024")) + + +class PolicyEmbedding(Base): + __tablename__ = "policy_embeddings" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + policy_id = Column(UUID(as_uuid=True), nullable=False, index=True) + organization_id = Column(UUID(as_uuid=True), nullable=False, index=True) + policy_name = Column(String(255), nullable=False) + description = Column(Text, nullable=True) + document_name = Column(String(255), nullable=True) + file_path = Column(String(500), nullable=False) + chunk_index = Column(Integer, nullable=False) + text = Column(Text, nullable=False) + embedding = Column(Vector(DEFAULT_EMBED_DIM), nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now()) diff --git a/ai/rag.py b/ai/rag.py new file mode 100644 index 0000000..b6cc7eb --- /dev/null +++ b/ai/rag.py @@ -0,0 +1,185 @@ +import logging +import os +import re +from typing import Iterable + +import httpx +from sqlalchemy import delete, select + +from ai.clients import CohereClient +from ai.db import PolicyEmbedding +from database.db import SessionLocal + + +logger = logging.getLogger(__name__) + + +class RAGClient: + def __init__(self, embed_model: str | None = None): + self.embed_model = embed_model or os.getenv("COHERE_EMBED_MODEL", "embed-english-v3.0") + + def _looks_like_text(self, raw: bytes) -> bool: + if not raw: + return False + if b"\x00" in raw: + return False + sample = raw[:2048] + non_printable = sum(1 for byte in sample if byte < 9 or (13 < byte < 32)) + return non_printable / max(len(sample), 1) < 0.2 + + def _extract_text_from_pdf(self, file_path: str) -> str: + try: + from pypdf import PdfReader # type: ignore + except Exception: + return "" + try: + reader = PdfReader(file_path) + return "\n".join(page.extract_text() or "" for page in reader.pages).strip() + except Exception: + return "" + + def _extract_text_from_docx(self, file_path: str) -> str: + try: + from docx import Document # type: ignore + except Exception: + return "" + try: + doc = Document(file_path) + return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip() + except Exception: + return "" + + def _read_text_from_source(self, file_path: str) -> str: + if file_path.startswith("http://") or file_path.startswith("https://"): + try: + response = httpx.get(file_path, timeout=20.0) + response.raise_for_status() + raw = response.content + except Exception: + return "" + else: + if not os.path.exists(file_path): + return "" + with open(file_path, "rb") as handle: + raw = handle.read() + + ext = os.path.splitext(file_path)[1].lower() + if ext == ".pdf": + return self._extract_text_from_pdf(file_path) + if ext == ".docx": + return self._extract_text_from_docx(file_path) + + if self._looks_like_text(raw): + return raw.decode("utf-8", errors="ignore").strip() + + return "" + + def _clean_text(self, text: str) -> str: + return re.sub(r"\s+", " ", text).strip() + + def _chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 200) -> list[str]: + cleaned = self._clean_text(text) + if not cleaned: + return [] + chunks = [] + start = 0 + while start < len(cleaned): + end = min(len(cleaned), start + max_chars) + chunk = cleaned[start:end].strip() + if chunk: + chunks.append(chunk) + if end >= len(cleaned): + break + start = max(0, end - overlap) + return chunks + + def _embed_texts(self, texts: Iterable[str], input_type: str) -> list[list[float]]: + client = CohereClient(model=self.embed_model) + response = client.embed_texts(list(texts), input_type=input_type) + + logger.info(f"Embedded {len(texts)} texts with model {self.embed_model} & response: {response}") + return response + + def index_policy_document( + self, + policy_id: str, + organization_id: str, + policy_name: str, + description: str | None, + document_name: str | None, + file_path: str, + ) -> dict: + text = self._read_text_from_source(file_path) + if not text: + return {"status": "skipped", "reason": "no_text_extracted"} + + chunks = self._chunk_text(text) + if not chunks: + return {"status": "skipped", "reason": "no_chunks_created"} + + embeddings = self._embed_texts(chunks, input_type="search_document") + if not embeddings: + return {"status": "skipped", "reason": "no_embeddings_created"} + + with SessionLocal() as db: + db.execute( + delete(PolicyEmbedding).where(PolicyEmbedding.policy_id == policy_id) + ) + for idx, (chunk, embedding) in enumerate( + zip(chunks, embeddings, strict=False) + ): + db.add( + PolicyEmbedding( + policy_id=policy_id, + organization_id=organization_id, + policy_name=policy_name, + description=description, + document_name=document_name, + file_path=file_path, + chunk_index=idx, + text=chunk, + embedding=embedding, + ) + ) + db.commit() + return {"status": "indexed", "chunks": len(chunks)} + + def remove_policy_from_index(self, policy_id: str) -> dict: + with SessionLocal() as db: + result = db.execute( + delete(PolicyEmbedding).where(PolicyEmbedding.policy_id == policy_id) + ) + db.commit() + if result.rowcount == 0: + return {"status": "skipped", "reason": "policy_not_found"} + return {"status": "removed", "count": result.rowcount} + + def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]: + query_embedding = self._embed_texts([query], input_type="search_query") + if not query_embedding: + return [] + query_vector = query_embedding[0] + + with SessionLocal() as db: + stmt = ( + select(PolicyEmbedding) + .order_by(PolicyEmbedding.embedding.cosine_distance(query_vector)) + .limit(max(top_k, 1)) + ) + results = db.execute(stmt).scalars().all() + response = [] + for record in results: + response.append( + { + "policy_id": str(record.policy_id), + "organization_id": str(record.organization_id), + "policy_name": record.policy_name, + "description": record.description, + "document_name": record.document_name, + "file_path": record.file_path, + "chunk_index": record.chunk_index, + "text": record.text, + "score": None, + } + ) + return response diff --git a/application/app.py b/application/app.py index 3291965..a8313b7 100644 --- a/application/app.py +++ b/application/app.py @@ -25,7 +25,7 @@ app.add_middleware(AuthenticationMiddleware, backend=JWTAuthBackend()) import ai.apis -import appointments.apis +import ai.db import auth.apis import organizations.apis import users.apis diff --git a/appointments/__init__.py b/appointments/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/appointments/apis.py b/appointments/apis.py deleted file mode 100644 index df41625..0000000 --- a/appointments/apis.py +++ /dev/null @@ -1,96 +0,0 @@ -from datetime import datetime - -from fastapi import Depends, HTTPException -from sqlalchemy.orm import Session - -from application.app import app -from appointments.models import ( - Appointment, - AppointmentItem, - AppointmentRequest, - AppointmentResponse, - AppointmentsListResponse, -) -from database.db import drop_appointments_table, get_db - - -@app.get("/appointments", response_model=AppointmentsListResponse) -async def get_appointments(db: Session = Depends(get_db)): - rows = db.query(Appointment).order_by(Appointment.created_at.desc()).all() - appointments = [ - AppointmentItem( - id=str(row.id), - title=row.title, - description=row.description, - date_and_time=row.date_and_time, - duration=row.duration, - status=row.status, - ) - for row in rows - ] - total = len(appointments) - message = "No appointments found" if total == 0 else "Appointments retrieved" - return AppointmentsListResponse( - appointments=appointments, - total=total, - message=message, - ) - - -@app.get("/appointments/{appointment_id}", response_model=AppointmentItem) -async def get_appointment(appointment_id: str, db: Session = Depends(get_db)): - appointment = db.query(Appointment).filter(Appointment.id == appointment_id).first() - if not appointment: - raise HTTPException(status_code=404, detail="Appointment not found") - - return AppointmentItem( - id=str(appointment.id), - title=appointment.title, - description=appointment.description, - date_and_time=appointment.date_and_time, - duration=appointment.duration, - status=appointment.status, - ) - - -@app.post("/appointments", response_model=AppointmentResponse) -async def create_appointment( - appointment: AppointmentRequest, db: Session = Depends(get_db) -): - new_appointment = Appointment( - title=appointment.title, - description=appointment.description, - date_and_time=appointment.date_and_time, - duration=appointment.duration, - status=appointment.status, - created_at=datetime.now(), - ) - db.add(new_appointment) - db.commit() - db.refresh(new_appointment) - - return AppointmentResponse( - id=str(new_appointment.id), - title=new_appointment.title, - description=new_appointment.description, - date_and_time=new_appointment.date_and_time, - duration=new_appointment.duration, - status=new_appointment.status, - ) - - -@app.delete("/appointments/{appointment_id}") -async def delete_appointment(appointment_id: str, db: Session = Depends(get_db)): - appointment = db.query(Appointment).filter(Appointment.id == appointment_id).first() - if not appointment: - raise HTTPException(status_code=404, detail="Appointment not found") - db.delete(appointment) - db.commit() - - return {"status": "ok", "message": "Appointment deleted"} - - -@app.delete("/admin/drop-appointments-db") -async def drop_appointments_db_table(): - drop_appointments_table() - return {"status": "ok", "message": "Appointments database table dropped"} diff --git a/appointments/choices.py b/appointments/choices.py deleted file mode 100644 index 5d1ce3e..0000000 --- a/appointments/choices.py +++ /dev/null @@ -1,7 +0,0 @@ -import enum - - -class AppointmentStatus(str, enum.Enum): - scheduled = "scheduled" - cancelled = "cancelled" - completed = "completed" diff --git a/appointments/constants.py b/appointments/constants.py deleted file mode 100644 index 5e8e3c2..0000000 --- a/appointments/constants.py +++ /dev/null @@ -1,16 +0,0 @@ -def get_appointment_function_map(): - from appointments.db import ( - cancel_appointment, - check_time_slot_availability, - create_new_appointment, - get_appointment_by_title, - get_appointments_by_date, - ) - - return { - "get_appointments_by_date": get_appointments_by_date, - "get_appointment_by_title": get_appointment_by_title, - "check_time_slot_availability": check_time_slot_availability, - "create_new_appointment": create_new_appointment, - "cancel_appointment": cancel_appointment, - } diff --git a/appointments/db.py b/appointments/db.py deleted file mode 100644 index e9697db..0000000 --- a/appointments/db.py +++ /dev/null @@ -1,143 +0,0 @@ -from datetime import datetime, timedelta - -from sqlalchemy import func - -from appointments.models import Appointment, AppointmentStatus -from database.db import SessionLocal - - -def get_appointments_by_date(date: str): - """Get all appointments for a specific date.""" - with SessionLocal() as db: - target_date = datetime.fromisoformat(date).date() - appointments = ( - db.query(Appointment) - .filter(func.date(Appointment.date_and_time) == target_date) - .order_by(Appointment.date_and_time) - .all() - ) - return { - "date": str(target_date), - "appointments": [ - { - "id": str(appt.id), - "title": appt.title, - "description": appt.description, - "date_and_time": appt.date_and_time.isoformat(), - "duration": appt.duration, - "status": appt.status.value, - } - for appt in appointments - ], - "total": len(appointments), - } - - -def get_appointment_by_title(title: str): - """Get appointment details by title.""" - with SessionLocal() as db: - appointment = ( - db.query(Appointment) - .filter(func.lower(Appointment.title).contains(title.lower())) - .first() - ) - if appointment: - return { - "id": str(appointment.id), - "title": appointment.title, - "description": appointment.description, - "date_and_time": appointment.date_and_time.isoformat(), - "duration": appointment.duration, - "status": appointment.status.value, - } - return {"detail": "Appointment not found", "title": title} - - -def check_time_slot_availability(date_and_time: str, duration: int = 30): - """Check if a time slot is available.""" - with SessionLocal() as db: - requested_start = datetime.fromisoformat(date_and_time) - requested_end = requested_start + timedelta(minutes=duration) - - appointments = ( - db.query(Appointment) - .filter(Appointment.status == AppointmentStatus.scheduled) - .all() - ) - - for appt in appointments: - appt_end = appt.date_and_time + timedelta(minutes=appt.duration) - if appt.date_and_time < requested_end and appt_end > requested_start: - return { - "available": False, - "date_and_time": date_and_time, - "duration": duration, - "conflict_with": appt.title, - } - - return { - "available": True, - "date_and_time": date_and_time, - "duration": duration, - } - - -def create_new_appointment( - title: str, - date_and_time: str, - duration: int = 30, - description: str | None = None, -): - """Create a new appointment.""" - with SessionLocal() as db: - # Check availability first - availability = check_time_slot_availability(date_and_time, duration) - if not availability.get("available"): - return { - "detail": "Time slot not available", - "conflict_with": availability.get("conflict_with"), - } - - appointment = Appointment( - title=title, - description=description, - date_and_time=datetime.fromisoformat(date_and_time), - duration=duration, - status=AppointmentStatus.scheduled, - created_at=datetime.now(), - ) - db.add(appointment) - db.commit() - db.refresh(appointment) - - return { - "detail": "Appointment created", - "appointment": { - "id": str(appointment.id), - "title": appointment.title, - "date_and_time": appointment.date_and_time.isoformat(), - "duration": appointment.duration, - "status": appointment.status.value, - }, - } - - -def cancel_appointment(title: str): - """Cancel an appointment by title.""" - with SessionLocal() as db: - appointment = ( - db.query(Appointment) - .filter(func.lower(Appointment.title).contains(title.lower())) - .first() - ) - if not appointment: - return {"detail": "Appointment not found", "title": title} - - appointment.status = AppointmentStatus.cancelled - db.commit() - - return { - "detail": "Appointment cancelled", - "title": appointment.title, - "status": appointment.status.value, - } diff --git a/appointments/models.py b/appointments/models.py deleted file mode 100644 index b4466a7..0000000 --- a/appointments/models.py +++ /dev/null @@ -1,55 +0,0 @@ -import uuid -from datetime import datetime - -from pydantic import BaseModel -from sqlalchemy import Column, DateTime, Enum, Integer, String -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.sql import func - -from appointments.choices import AppointmentStatus -from database.db import Base - - -class Appointment(Base): - __tablename__ = "appointments" - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - title = Column(String, nullable=False) - description = Column(String, nullable=True) - date_and_time = Column(DateTime, nullable=False) - duration = Column(Integer, nullable=False, default=30) - status = Column(Enum(AppointmentStatus, name="appointment_status"), nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - - -class AppointmentRequest(BaseModel): - title: str - description: str | None = None - date_and_time: datetime - duration: int = 30 - status: AppointmentStatus = AppointmentStatus.scheduled - - -class AppointmentItem(BaseModel): - id: str - title: str - description: str | None = None - date_and_time: datetime - duration: int - status: AppointmentStatus - - -class AppointmentResponse(BaseModel): - id: str - title: str - description: str | None = None - date_and_time: datetime - duration: int - status: AppointmentStatus - - -class AppointmentsListResponse(BaseModel): - appointments: list[AppointmentItem] - total: int - message: str diff --git a/appointments/tools.py b/appointments/tools.py deleted file mode 100644 index 61f0633..0000000 --- a/appointments/tools.py +++ /dev/null @@ -1,77 +0,0 @@ -APPOINTMENT_TOOLS = [ - { - "name": "get_appointments_by_date", - "description": "Returns all appointments for a specific date.", - "parameter_definitions": { - "date": { - "description": "The date to check in ISO format (YYYY-MM-DD).", - "type": "str", - "required": True, - }, - }, - }, - { - "name": "get_appointment_by_title", - "description": "Returns appointment details by searching for the title.", - "parameter_definitions": { - "title": { - "description": "The title or partial title of the appointment.", - "type": "str", - "required": True, - }, - }, - }, - { - "name": "check_time_slot_availability", - "description": "Checks if a specific time slot is available for scheduling.", - "parameter_definitions": { - "date_and_time": { - "description": "The date and time to check in ISO format (YYYY-MM-DDTHH:MM:SS).", - "type": "str", - "required": True, - }, - "duration": { - "description": "Duration in minutes (default 30).", - "type": "int", - "required": False, - }, - }, - }, - { - "name": "create_new_appointment", - "description": "Creates a new appointment at the specified time.", - "parameter_definitions": { - "title": { - "description": "The title of the appointment.", - "type": "str", - "required": True, - }, - "date_and_time": { - "description": "The date and time in ISO format (YYYY-MM-DDTHH:MM:SS).", - "type": "str", - "required": True, - }, - "duration": { - "description": "Duration in minutes (default 30).", - "type": "int", - "required": False, - }, - "description": { - "description": "Optional description for the appointment.", - "type": "str", - "required": False, - }, - }, - }, - { - "name": "cancel_appointment", - "description": "Cancels an appointment by title.", - "parameter_definitions": { - "title": { - "description": "The title of the appointment to cancel.", - "type": "str", - "required": True, - }, - }, - }, -] diff --git a/database/db.py b/database/db.py index 5c9e971..fb6e6ab 100644 --- a/database/db.py +++ b/database/db.py @@ -1,6 +1,6 @@ import os -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text from sqlalchemy.orm import declarative_base, sessionmaker @@ -32,6 +32,12 @@ def get_db(): def init_db(): + try: + with engine.connect() as connection: + connection = connection.execution_options(isolation_level="AUTOCOMMIT") + connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) + except Exception: + pass Base.metadata.create_all(bind=engine) @@ -65,3 +71,17 @@ def drop_policies_table(): Policy.__table__.drop(bind=engine, checkfirst=True) init_db() + + +def drop_leave_requests_table(): + from users.models import LeaveRequest + + LeaveRequest.__table__.drop(bind=engine, checkfirst=True) + init_db() + + +def drop_user_organizations_table(): + from organizations.models import UserOrganization + + UserOrganization.__table__.drop(bind=engine, checkfirst=True) + init_db() diff --git a/docker-compose.yml b/docker-compose.yml index 1acbd97..56bbac6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,7 +23,7 @@ services: start_period: 40s postgres: - image: postgres:16-alpine + image: pgvector/pgvector:pg16 container_name: leave-query-db env_file: - .env diff --git a/organizations/apis.py b/organizations/apis.py index 6089b8d..a7bc956 100644 --- a/organizations/apis.py +++ b/organizations/apis.py @@ -1,9 +1,10 @@ -from datetime import datetime +import logging -from fastapi import Depends, HTTPException +from fastapi import Depends, File, Form, HTTPException, UploadFile from sqlalchemy.orm import Session from application.app import app +from ai.rag import RAGClient from database.db import get_db from organizations.models import ( Organization, @@ -13,16 +14,25 @@ OrganizationsListResponse, Policy, PolicyItem, - PolicyRequest, PolicyResponse, PoliciesListResponse, + UserOrganization, + UserOrganizationItem, + UserOrganizationRequest, + UserOrganizationResponse, + UserOrganizationsListResponse, + UserOrganizationUpdate, ) +from organizations.utils import delete_file_if_exists, save_upload_file +from users.models import User + +logger = logging.getLogger(__name__) # Organization APIs @app.get("/organizations", response_model=OrganizationsListResponse) async def get_organizations(db: Session = Depends(get_db)): - rows = db.query(Organization).order_by(Organization.created_at.desc()).all() + rows = db.query(Organization).order_by(Organization.created.desc()).all() organizations = [ OrganizationItem( id=str(row.id), @@ -32,7 +42,7 @@ async def get_organizations(db: Session = Depends(get_db)): email=row.email, phone=row.phone, is_active=row.is_active, - created_at=row.created_at, + created=row.created, ) for row in rows ] @@ -58,7 +68,7 @@ async def get_organization(organization_id: str, db: Session = Depends(get_db)): email=org.email, phone=org.phone, is_active=org.is_active, - created_at=org.created_at, + created=org.created, ) @@ -137,7 +147,7 @@ async def delete_organization(organization_id: str, db: Session = Depends(get_db # Policy APIs @app.get("/policies", response_model=PoliciesListResponse) async def get_policies(db: Session = Depends(get_db)): - rows = db.query(Policy).order_by(Policy.created_at.desc()).all() + rows = db.query(Policy).order_by(Policy.created.desc()).all() policies = [] for row in rows: org = db.query(Organization).filter(Organization.id == row.organization_id).first() @@ -148,10 +158,10 @@ async def get_policies(db: Session = Depends(get_db)): organization_name=org.name if org else None, name=row.name, description=row.description, - max_leave_days=row.max_leave_days, - carry_forward_days=row.carry_forward_days, + document_name=row.document_name, + file_path=row.file, is_active=row.is_active, - created_at=row.created_at, + created=row.created, ) ) total = len(policies) @@ -175,10 +185,10 @@ async def get_policy(policy_id: str, db: Session = Depends(get_db)): organization_name=org.name if org else None, name=policy.name, description=policy.description, - max_leave_days=policy.max_leave_days, - carry_forward_days=policy.carry_forward_days, + document_name=policy.document_name, + file_path=policy.file, is_active=policy.is_active, - created_at=policy.created_at, + created=policy.created, ) @@ -188,7 +198,7 @@ async def get_organization_policies(organization_id: str, db: Session = Depends( if not org: raise HTTPException(status_code=404, detail="Organization not found") - rows = db.query(Policy).filter(Policy.organization_id == organization_id).order_by(Policy.created_at.desc()).all() + rows = db.query(Policy).filter(Policy.organization_id == organization_id).order_by(Policy.created.desc()).all() policies = [ PolicyItem( id=str(row.id), @@ -196,10 +206,10 @@ async def get_organization_policies(organization_id: str, db: Session = Depends( organization_name=org.name, name=row.name, description=row.description, - max_leave_days=row.max_leave_days, - carry_forward_days=row.carry_forward_days, + document_name=row.document_name, + file_path=row.file, is_active=row.is_active, - created_at=row.created_at, + created=row.created, ) for row in rows ] @@ -213,31 +223,57 @@ async def get_organization_policies(organization_id: str, db: Session = Depends( @app.post("/policies", response_model=PolicyResponse) -async def create_policy(policy: PolicyRequest, db: Session = Depends(get_db)): +async def create_policy( + organization_id: str = Form(...), + name: str = Form(...), + description: str = Form(None), + is_active: bool = Form(True), + file: UploadFile = File(None), + db: Session = Depends(get_db), +): # Verify organization exists - org = db.query(Organization).filter(Organization.id == policy.organization_id).first() + org = db.query(Organization).filter(Organization.id == organization_id).first() if not org: raise HTTPException(status_code=404, detail="Organization not found") + # Handle file upload + file_path = None + document_name = None + if file and file.filename: + document_name = file.filename + file_path = await save_upload_file(file) + new_policy = Policy( - organization_id=policy.organization_id, - name=policy.name, - description=policy.description, - max_leave_days=policy.max_leave_days, - carry_forward_days=policy.carry_forward_days, - is_active=policy.is_active, + organization_id=organization_id, + name=name, + description=description, + document_name=document_name, + file=file_path, + is_active=is_active, ) db.add(new_policy) db.commit() db.refresh(new_policy) + if file_path: + try: + RAGClient().index_policy_document( + policy_id=str(new_policy.id), + organization_id=str(new_policy.organization_id), + policy_name=new_policy.name, + description=new_policy.description, + document_name=new_policy.document_name, + file_path=file_path, + ) + except Exception as exc: + logger.exception("Failed to index policy document", extra={"error": str(exc)}) return PolicyResponse( id=str(new_policy.id), organization_id=str(new_policy.organization_id), name=new_policy.name, description=new_policy.description, - max_leave_days=new_policy.max_leave_days, - carry_forward_days=new_policy.carry_forward_days, + document_name=new_policy.document_name, + file_path=new_policy.file, is_active=new_policy.is_active, ) @@ -245,7 +281,11 @@ async def create_policy(policy: PolicyRequest, db: Session = Depends(get_db)): @app.put("/policies/{policy_id}", response_model=PolicyResponse) async def update_policy( policy_id: str, - policy: PolicyRequest, + organization_id: str = Form(...), + name: str = Form(...), + description: str = Form(None), + is_active: bool = Form(True), + file: UploadFile = File(None), db: Session = Depends(get_db), ): existing_policy = db.query(Policy).filter(Policy.id == policy_id).first() @@ -253,26 +293,43 @@ async def update_policy( raise HTTPException(status_code=404, detail="Policy not found") # Verify organization exists - org = db.query(Organization).filter(Organization.id == policy.organization_id).first() + org = db.query(Organization).filter(Organization.id == organization_id).first() if not org: raise HTTPException(status_code=404, detail="Organization not found") - existing_policy.organization_id = policy.organization_id - existing_policy.name = policy.name - existing_policy.description = policy.description - existing_policy.max_leave_days = policy.max_leave_days - existing_policy.carry_forward_days = policy.carry_forward_days - existing_policy.is_active = policy.is_active + # Handle file upload + if file and file.filename: + # Delete old file if exists + delete_file_if_exists(existing_policy.file) + existing_policy.document_name = file.filename + existing_policy.file = await save_upload_file(file) + + existing_policy.organization_id = organization_id + existing_policy.name = name + existing_policy.description = description + existing_policy.is_active = is_active db.commit() db.refresh(existing_policy) + if file and file.filename and existing_policy.file: + try: + RAGClient().index_policy_document( + policy_id=str(existing_policy.id), + organization_id=str(existing_policy.organization_id), + policy_name=existing_policy.name, + description=existing_policy.description, + document_name=existing_policy.document_name, + file_path=existing_policy.file, + ) + except Exception as exc: + logger.exception("Failed to reindex policy document", extra={"error": str(exc)}) return PolicyResponse( id=str(existing_policy.id), organization_id=str(existing_policy.organization_id), name=existing_policy.name, description=existing_policy.description, - max_leave_days=existing_policy.max_leave_days, - carry_forward_days=existing_policy.carry_forward_days, + document_name=existing_policy.document_name, + file_path=existing_policy.file, is_active=existing_policy.is_active, ) @@ -282,6 +339,192 @@ async def delete_policy(policy_id: str, db: Session = Depends(get_db)): policy = db.query(Policy).filter(Policy.id == policy_id).first() if not policy: raise HTTPException(status_code=404, detail="Policy not found") + + # Delete associated file if exists + delete_file_if_exists(policy.file) + try: + RAGClient().remove_policy_from_index(str(policy.id)) + except Exception as exc: + logger.exception("Failed to remove policy from index", extra={"error": str(exc)}) + db.delete(policy) db.commit() return {"status": "ok", "message": "Policy deleted"} + + +# User Organization (Membership) APIs +@app.get("/user_organizations", response_model=UserOrganizationsListResponse) +async def get_user_organizations(db: Session = Depends(get_db)): + """Get all user-organization memberships.""" + rows = db.query(UserOrganization).order_by(UserOrganization.created.desc()).all() + memberships = [] + for row in rows: + user = db.query(User).filter(User.id == row.user_id).first() + org = db.query(Organization).filter(Organization.id == row.organization_id).first() + memberships.append( + UserOrganizationItem( + id=str(row.id), + user_id=str(row.user_id), + username=user.username if user else None, + organization_id=str(row.organization_id), + organization_name=org.name if org else None, + joined_date=row.joined_date, + left_date=row.left_date, + is_active=row.is_active, + created=row.created, + ) + ) + total = len(memberships) + message = "No memberships found" if total == 0 else "Memberships retrieved" + return UserOrganizationsListResponse( + memberships=memberships, + total=total, + message=message, + ) + + +@app.get("/user_organizations/{membership_id}", response_model=UserOrganizationItem) +async def get_user_organization(membership_id: str, db: Session = Depends(get_db)): + """Get a specific user-organization membership.""" + membership = db.query(UserOrganization).filter(UserOrganization.id == membership_id).first() + if not membership: + raise HTTPException(status_code=404, detail="Membership not found") + + user = db.query(User).filter(User.id == membership.user_id).first() + org = db.query(Organization).filter(Organization.id == membership.organization_id).first() + return UserOrganizationItem( + id=str(membership.id), + user_id=str(membership.user_id), + username=user.username if user else None, + organization_id=str(membership.organization_id), + organization_name=org.name if org else None, + joined_date=membership.joined_date, + left_date=membership.left_date, + is_active=membership.is_active, + created=membership.created, + ) + + +@app.get("/organizations/{organization_id}/members", response_model=UserOrganizationsListResponse) +async def get_members_for_organization(organization_id: str, db: Session = Depends(get_db)): + """Get all members of an organization.""" + org = db.query(Organization).filter(Organization.id == organization_id).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + rows = db.query(UserOrganization).filter(UserOrganization.organization_id == organization_id).order_by(UserOrganization.joined_date.desc()).all() + memberships = [] + for row in rows: + user = db.query(User).filter(User.id == row.user_id).first() + memberships.append( + UserOrganizationItem( + id=str(row.id), + user_id=str(row.user_id), + username=user.username if user else None, + organization_id=str(row.organization_id), + organization_name=org.name, + joined_date=row.joined_date, + left_date=row.left_date, + is_active=row.is_active, + created=row.created, + ) + ) + total = len(memberships) + message = "No members found" if total == 0 else "Members retrieved" + return UserOrganizationsListResponse( + memberships=memberships, + total=total, + message=message, + ) + + +@app.post("/user_organizations", response_model=UserOrganizationResponse) +async def join_organization( + membership: UserOrganizationRequest, + db: Session = Depends(get_db), +): + """Add a user to an organization (join).""" + # Verify user exists + user = db.query(User).filter(User.id == membership.user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Verify organization exists + org = db.query(Organization).filter(Organization.id == membership.organization_id).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # Check if already a member + existing = ( + db.query(UserOrganization) + .filter( + UserOrganization.user_id == membership.user_id, + UserOrganization.organization_id == membership.organization_id, + UserOrganization.is_active.is_(True), + ) + .first() + ) + if existing: + raise HTTPException(status_code=400, detail="User is already a member of this organization") + + new_membership = UserOrganization( + user_id=membership.user_id, + organization_id=membership.organization_id, + joined_date=membership.joined_date.date(), + left_date=membership.left_date.date() if membership.left_date else None, + is_active=membership.is_active, + ) + db.add(new_membership) + db.commit() + db.refresh(new_membership) + + return UserOrganizationResponse( + id=str(new_membership.id), + user_id=str(new_membership.user_id), + organization_id=str(new_membership.organization_id), + joined_date=new_membership.joined_date, + left_date=new_membership.left_date, + is_active=new_membership.is_active, + ) + + +@app.patch("/user_organizations/{membership_id}", response_model=UserOrganizationResponse) +async def update_membership( + membership_id: str, + update: UserOrganizationUpdate, + db: Session = Depends(get_db), +): + """Update a user-organization membership (e.g., set left_date when leaving).""" + membership = db.query(UserOrganization).filter(UserOrganization.id == membership_id).first() + if not membership: + raise HTTPException(status_code=404, detail="Membership not found") + + if update.joined_date is not None: + membership.joined_date = update.joined_date.date() + if update.left_date is not None: + membership.left_date = update.left_date.date() + if update.is_active is not None: + membership.is_active = update.is_active + + db.commit() + db.refresh(membership) + + return UserOrganizationResponse( + id=str(membership.id), + user_id=str(membership.user_id), + organization_id=str(membership.organization_id), + joined_date=membership.joined_date, + left_date=membership.left_date, + is_active=membership.is_active, + ) + + +@app.delete("/user_organizations/{membership_id}") +async def delete_membership(membership_id: str, db: Session = Depends(get_db)): + """Delete a user-organization membership record.""" + membership = db.query(UserOrganization).filter(UserOrganization.id == membership_id).first() + if not membership: + raise HTTPException(status_code=404, detail="Membership not found") + db.delete(membership) + db.commit() + return {"status": "ok", "message": "Membership deleted"} diff --git a/organizations/constants.py b/organizations/constants.py index 6478069..e0fd7ee 100644 --- a/organizations/constants.py +++ b/organizations/constants.py @@ -1,7 +1,6 @@ def get_organization_function_map(): from organizations.db import ( get_all_organizations, - get_leave_allowance, get_organization_by_name, get_policies_for_organization, get_policy_details, @@ -12,5 +11,4 @@ def get_organization_function_map(): "get_all_organizations": get_all_organizations, "get_policies_for_organization": get_policies_for_organization, "get_policy_details": get_policy_details, - "get_leave_allowance": get_leave_allowance, } diff --git a/organizations/db.py b/organizations/db.py index 489aa3f..1b23cb1 100644 --- a/organizations/db.py +++ b/organizations/db.py @@ -71,8 +71,8 @@ def get_policies_for_organization(organization_name: str): "id": str(policy.id), "name": policy.name, "description": policy.description, - "max_leave_days": policy.max_leave_days, - "carry_forward_days": policy.carry_forward_days, + "document_name": policy.document_name, + "file_path": policy.file, "is_active": policy.is_active, } for policy in policies @@ -95,46 +95,9 @@ def get_policy_details(policy_name: str): "id": str(policy.id), "name": policy.name, "description": policy.description, - "max_leave_days": policy.max_leave_days, - "carry_forward_days": policy.carry_forward_days, + "document_name": policy.document_name, + "file_path": policy.file, "is_active": policy.is_active, "organization": org.name if org else None, } return {"detail": "Policy not found", "name": policy_name} - - -def get_leave_allowance(organization_name: str, policy_name: str | None = None): - """Get leave allowance for an organization or specific policy.""" - with SessionLocal() as db: - org = ( - db.query(Organization) - .filter(func.lower(Organization.name).contains(organization_name.lower())) - .first() - ) - if not org: - return {"detail": "Organization not found"} - - query = db.query(Policy).filter( - Policy.organization_id == org.id, - Policy.is_active == True, - ) - - if policy_name: - query = query.filter(func.lower(Policy.name).contains(policy_name.lower())) - - policies = query.all() - - if not policies: - return {"detail": "No policies found", "organization": org.name} - - return { - "organization": org.name, - "leave_policies": [ - { - "policy_name": p.name, - "max_leave_days": p.max_leave_days, - "carry_forward_days": p.carry_forward_days, - } - for p in policies - ], - } diff --git a/organizations/models.py b/organizations/models.py index 5d63d13..28fe91e 100644 --- a/organizations/models.py +++ b/organizations/models.py @@ -2,7 +2,7 @@ from datetime import datetime from pydantic import BaseModel -from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text, Boolean +from sqlalchemy import Boolean, Column, Date, DateTime, ForeignKey, String, Text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from sqlalchemy.sql import func @@ -21,11 +21,12 @@ class Organization(Base): email = Column(String(255), nullable=True) phone = Column(String(50), nullable=True) is_active = Column(Boolean, default=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + created = Column(DateTime(timezone=True), server_default=func.now()) + modified = Column(DateTime(timezone=True), onupdate=func.now()) - # Relationship to policies + # Relationships policies = relationship("Policy", back_populates="organization", cascade="all, delete-orphan") + members = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan") class Policy(Base): @@ -35,16 +36,32 @@ class Policy(Base): organization_id = Column(UUID(as_uuid=True), ForeignKey("organizations.id"), nullable=False) name = Column(String(255), nullable=False) description = Column(Text, nullable=True) - max_leave_days = Column(Integer, nullable=False, default=20) - carry_forward_days = Column(Integer, nullable=False, default=5) + document_name = Column(String(255), nullable=True) + file = Column(String(500), nullable=True) is_active = Column(Boolean, default=True) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + created = Column(DateTime(timezone=True), server_default=func.now()) + modified = Column(DateTime(timezone=True), onupdate=func.now()) # Relationship to organization organization = relationship("Organization", back_populates="policies") +class UserOrganization(Base): + __tablename__ = "user_organizations" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + organization_id = Column(UUID(as_uuid=True), ForeignKey("organizations.id"), nullable=False) + joined_date = Column(Date, nullable=False) + left_date = Column(Date, nullable=True) + is_active = Column(Boolean, default=True) + created = Column(DateTime(timezone=True), server_default=func.now()) + modified = Column(DateTime(timezone=True), onupdate=func.now()) + + # Relationships + organization = relationship("Organization", back_populates="members") + + # Pydantic Models - Organization class OrganizationRequest(BaseModel): name: str @@ -63,7 +80,7 @@ class OrganizationItem(BaseModel): email: str | None = None phone: str | None = None is_active: bool - created_at: datetime | None = None + created: datetime | None = None class OrganizationResponse(BaseModel): @@ -87,8 +104,6 @@ class PolicyRequest(BaseModel): organization_id: str name: str description: str | None = None - max_leave_days: int = 20 - carry_forward_days: int = 5 is_active: bool = True @@ -98,10 +113,10 @@ class PolicyItem(BaseModel): organization_name: str | None = None name: str description: str | None = None - max_leave_days: int - carry_forward_days: int + document_name: str | None = None + file_path: str | None = None is_active: bool - created_at: datetime | None = None + created: datetime | None = None class PolicyResponse(BaseModel): @@ -109,8 +124,8 @@ class PolicyResponse(BaseModel): organization_id: str name: str description: str | None = None - max_leave_days: int - carry_forward_days: int + document_name: str | None = None + file_path: str | None = None is_active: bool @@ -118,3 +133,45 @@ class PoliciesListResponse(BaseModel): policies: list[PolicyItem] total: int message: str + + +# Pydantic Models - UserOrganization +class UserOrganizationRequest(BaseModel): + user_id: str + organization_id: str + joined_date: datetime + left_date: datetime | None = None + is_active: bool = True + + +class UserOrganizationUpdate(BaseModel): + joined_date: datetime | None = None + left_date: datetime | None = None + is_active: bool | None = None + + +class UserOrganizationItem(BaseModel): + id: str + user_id: str + username: str | None = None + organization_id: str + organization_name: str | None = None + joined_date: datetime + left_date: datetime | None = None + is_active: bool + created: datetime | None = None + + +class UserOrganizationResponse(BaseModel): + id: str + user_id: str + organization_id: str + joined_date: datetime + left_date: datetime | None = None + is_active: bool + + +class UserOrganizationsListResponse(BaseModel): + memberships: list[UserOrganizationItem] + total: int + message: str diff --git a/organizations/tools.py b/organizations/tools.py index ed6da1e..fc7cc5b 100644 --- a/organizations/tools.py +++ b/organizations/tools.py @@ -1,39 +1,39 @@ ORGANIZATION_TOOLS = [ { - "name": "get_policies_for_organization", - "description": "Returns all leave policies for a specific organization.", + "name": "get_organization_by_name", + "description": "Returns organization details by searching for the name.", "parameter_definitions": { - "organization_name": { - "description": "The name of the organization.", + "name": { + "description": "The name or partial name of the organization.", "type": "str", "required": True, }, }, }, { - "name": "get_policy_details", - "description": "Returns details of a specific leave policy by name.", - "parameter_definitions": { - "policy_name": { - "description": "The name of the policy.", - "type": "str", - "required": True, - }, - }, + "name": "get_all_organizations", + "description": "Returns a list of all active organizations.", + "parameter_definitions": {}, }, { - "name": "get_leave_allowance", - "description": "Returns the leave allowance (max leave days, carry forward days) for an organization.", + "name": "get_policies_for_organization", + "description": "Returns all policies for a specific organization including document details.", "parameter_definitions": { "organization_name": { "description": "The name of the organization.", "type": "str", "required": True, }, + }, + }, + { + "name": "get_policy_details", + "description": "Returns details of a specific policy by name including document name and file.", + "parameter_definitions": { "policy_name": { - "description": "Optional specific policy name to filter by.", + "description": "The name of the policy.", "type": "str", - "required": False, + "required": True, }, }, }, diff --git a/organizations/utils.py b/organizations/utils.py new file mode 100644 index 0000000..7bfce68 --- /dev/null +++ b/organizations/utils.py @@ -0,0 +1,36 @@ +import os +import uuid + +from fastapi import UploadFile + +# Directory for storing uploaded policy files +UPLOAD_DIR = "uploads/policies" +os.makedirs(UPLOAD_DIR, exist_ok=True) + + +def delete_file_if_exists(file_path: str | None) -> None: + """Delete file path if it exists (handles relative paths).""" + if not file_path: + return + if os.path.isabs(file_path): + candidate_paths = [file_path] + else: + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + candidate_paths = [file_path, os.path.join(base_dir, file_path)] + for candidate in candidate_paths: + if os.path.exists(candidate): + os.remove(candidate) + break + + +async def save_upload_file(upload_file: UploadFile) -> str: + """Save an uploaded file and return the file path.""" + file_extension = os.path.splitext(upload_file.filename)[1] if upload_file.filename else "" + unique_filename = f"{uuid.uuid4()}{file_extension}" + file_path = os.path.join(UPLOAD_DIR, unique_filename) + + content = await upload_file.read() + with open(file_path, "wb") as f: + f.write(content) + + return file_path diff --git a/requirements.txt b/requirements.txt index 8631f1b..edbb28f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,7 @@ cohere==5.20.1 rich==13.9.0 pytest==8.3.4 httpx==0.27.2 +pgvector==0.3.2 +pypdf==6.7.0 +pymupdf==1.26.7 +python-docx==1.2.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f51575d..4445f2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,18 +19,15 @@ os.environ.setdefault("JWT_ALGORITHM", "HS256") os.environ.setdefault("JWT_EXPIRE_MINUTES", "60") -import appointments.db as appointments_db import auth.backend as auth_backend import database.db as db import organizations.db as organizations_db from application.app import app as fastapi_app from auth.jwt import create_access_token from auth.passwords import hash_password -from appointments.choices import AppointmentStatus -from appointments.models import Appointment -from organizations.models import Organization, Policy -from users.choices import UserType -from users.models import User +from organizations.models import Organization, Policy, UserOrganization +from users.choices import LeaveType, UserType +from users.models import LeaveRequest, User _original_uuid_bind_processor = UUID.bind_processor @@ -67,28 +64,11 @@ def db_engine(tmp_path_factory): return create_engine(url, connect_args=connect_args) -@pytest.fixture(scope="session") +@pytest.fixture() def app(db_engine): db.engine = db_engine - db.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=db_engine) - db.Base.metadata.bind = db_engine - - appointments_db.SessionLocal = db.SessionLocal - auth_backend.SessionLocal = db.SessionLocal - organizations_db.SessionLocal = db.SessionLocal - - def override_get_db(): - session = db.SessionLocal() - try: - yield session - finally: - session.close() - - db.Base.metadata.create_all(bind=db_engine) - fastapi_app.dependency_overrides[db.get_db] = override_get_db - yield fastapi_app - fastapi_app.dependency_overrides.clear() - db.Base.metadata.drop_all(bind=db_engine) + db.SessionLocal.configure(bind=db_engine) + return fastapi_app @pytest.fixture(autouse=True) @@ -152,31 +132,6 @@ def _create_user( return _create_user -@pytest.fixture() -def create_appointment(db_session): - def _create_appointment( - *, - title="Test Appointment", - description=None, - date_and_time=None, - duration=30, - status=AppointmentStatus.scheduled, - ): - appointment = Appointment( - title=title, - description=description, - date_and_time=date_and_time or datetime(2026, 1, 30, 10, 0, 0), - duration=duration, - status=status, - ) - db_session.add(appointment) - db_session.commit() - db_session.refresh(appointment) - return appointment - - return _create_appointment - - @pytest.fixture() def create_organization(db_session): def _create_organization( @@ -211,16 +166,16 @@ def _create_policy( organization_id, name="Default Policy", description="Default leave policy", - max_leave_days=20, - carry_forward_days=5, + document_name=None, + file_path=None, is_active=True, ): policy = Policy( organization_id=organization_id, name=name, description=description, - max_leave_days=max_leave_days, - carry_forward_days=carry_forward_days, + document_name=document_name, + file=file_path, is_active=is_active, ) db_session.add(policy) @@ -231,6 +186,58 @@ def _create_policy( return _create_policy +@pytest.fixture() +def create_user_organization(db_session): + def _create_user_organization( + *, + user_id, + organization_id, + joined_date=None, + left_date=None, + is_active=True, + ): + membership = UserOrganization( + user_id=user_id, + organization_id=organization_id, + joined_date=joined_date or datetime(2026, 1, 1).date(), + left_date=left_date, + is_active=is_active, + ) + db_session.add(membership) + db_session.commit() + db_session.refresh(membership) + return membership + + return _create_user_organization + + +@pytest.fixture() +def create_leave_request(db_session): + def _create_leave_request( + *, + user_id, + organization_id, + date=None, + leave_type=LeaveType.SICK_LEAVE, + reason=None, + is_accepted=False, + ): + leave_request = LeaveRequest( + user_id=user_id, + organization_id=organization_id, + date=date or datetime(2026, 3, 15).date(), + leave_type=leave_type, + reason=reason, + is_accepted=is_accepted, + ) + db_session.add(leave_request) + db_session.commit() + db_session.refresh(leave_request) + return leave_request + + return _create_leave_request + + @pytest.fixture() def auth_headers(): def _auth_headers(user): diff --git a/tests/test_appointments_api.py b/tests/test_appointments_api.py deleted file mode 100644 index edbfe18..0000000 --- a/tests/test_appointments_api.py +++ /dev/null @@ -1,42 +0,0 @@ -def test_appointments_crud(client, create_user, auth_headers): - user = create_user(username="appt-user", email="appt-user@example.com") - - create_response = client.post( - "/appointments", - headers=auth_headers(user), - json={ - "title": "Team Meeting", - "description": "Weekly sync", - "date_and_time": "2026-01-30T10:00:00", - "duration": 30, - "status": "scheduled", - }, - ) - assert create_response.status_code == 200 - appointment_id = create_response.json()["id"] - - list_response = client.get("/appointments", headers=auth_headers(user)) - assert list_response.status_code == 200 - assert list_response.json()["total"] == 1 - - get_response = client.get( - f"/appointments/{appointment_id}", headers=auth_headers(user) - ) - assert get_response.status_code == 200 - assert get_response.json()["id"] == appointment_id - assert get_response.json()["title"] == "Team Meeting" - - delete_response = client.delete( - f"/appointments/{appointment_id}", headers=auth_headers(user) - ) - assert delete_response.status_code == 200 - assert delete_response.json()["message"] == "Appointment deleted" - - -def test_appointment_not_found(client, create_user, auth_headers): - user = create_user(username="appt-user", email="appt-user@example.com") - response = client.get( - "/appointments/00000000-0000-0000-0000-000000000000", - headers=auth_headers(user), - ) - assert response.status_code == 404 diff --git a/tests/test_leave_requests_api.py b/tests/test_leave_requests_api.py new file mode 100644 index 0000000..a85ebb8 --- /dev/null +++ b/tests/test_leave_requests_api.py @@ -0,0 +1,344 @@ +from users.choices import LeaveType, UserType + + +def test_apply_leave_request( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="leave-user", email="leave-user@example.com") + org = create_organization(name="Leave Test Org") + create_user_organization(user_id=user.id, organization_id=org.id) + + response = client.post( + "/leave_requests", + headers=auth_headers(user), + json={ + "organization_id": str(org.id), + "date": "2026-03-20T00:00:00", + "leave_type": LeaveType.SICK_LEAVE, + "reason": "Not feeling well", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["leave_type"] == LeaveType.SICK_LEAVE + assert data["is_accepted"] is False + assert data["organization_id"] == str(org.id) + assert data["reason"] == "Not feeling well" + + +def test_apply_leave_request_not_member(client, create_user, create_organization, auth_headers): + user = create_user(username="non-member", email="non-member@example.com") + org = create_organization(name="Another Org") + + response = client.post( + "/leave_requests", + headers=auth_headers(user), + json={ + "organization_id": str(org.id), + "date": "2026-03-20T00:00:00", + "leave_type": LeaveType.SICK_LEAVE, + }, + ) + assert response.status_code == 403 + assert "not an active member" in response.json()["detail"] + + +def test_apply_duplicate_leave_request( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="dup-user", email="dup-user@example.com") + org = create_organization(name="Dup Org") + create_user_organization(user_id=user.id, organization_id=org.id) + + # First request + response1 = client.post( + "/leave_requests", + headers=auth_headers(user), + json={ + "organization_id": str(org.id), + "date": "2026-03-21T00:00:00", + "leave_type": LeaveType.SICK_LEAVE, + }, + ) + assert response1.status_code == 200 + + # Duplicate request for same date + response2 = client.post( + "/leave_requests", + headers=auth_headers(user), + json={ + "organization_id": str(org.id), + "date": "2026-03-21T00:00:00", + "leave_type": LeaveType.PRIVILEGE_LEAVE, + }, + ) + assert response2.status_code == 400 + assert "already have a leave request" in response2.json()["detail"] + + +def test_get_leave_requests_regular_user( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="List Test Org") + user1 = create_user(username="user1", email="user1@example.com") + user2 = create_user(username="user2", email="user2@example.com") + create_user_organization(user_id=user1.id, organization_id=org.id) + create_user_organization(user_id=user2.id, organization_id=org.id) + + # Create leave requests for both users + create_leave_request(user_id=user1.id, organization_id=org.id) + create_leave_request(user_id=user2.id, organization_id=org.id) + + # User1 should only see their own leave request + response = client.get("/leave_requests", headers=auth_headers(user1)) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + +def test_get_leave_requests_admin_user( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Admin Test Org") + admin = create_user( + username="admin", email="admin@example.com", user_type=UserType.ADMIN + ) + user = create_user(username="regular", email="regular@example.com") + create_user_organization(user_id=admin.id, organization_id=org.id) + create_user_organization(user_id=user.id, organization_id=org.id) + + # Create leave requests for both + create_leave_request(user_id=admin.id, organization_id=org.id) + create_leave_request(user_id=user.id, organization_id=org.id) + + # Admin should see all leave requests + response = client.get("/leave_requests", headers=auth_headers(admin)) + assert response.status_code == 200 + assert response.json()["total"] == 2 + + +def test_get_organization_leave_requests( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org1 = create_organization(name="Org One") + org2 = create_organization(name="Org Two") + admin = create_user( + username="admin", email="admin@example.com", user_type=UserType.ADMIN + ) + user = create_user(username="user", email="user@example.com") + create_user_organization(user_id=user.id, organization_id=org1.id) + create_user_organization(user_id=user.id, organization_id=org2.id) + + # Create leave requests for different orgs + create_leave_request(user_id=user.id, organization_id=org1.id) + create_leave_request(user_id=user.id, organization_id=org2.id) + + # Get leave requests for org1 only + response = client.get( + f"/organizations/{org1.id}/leave_requests", headers=auth_headers(admin) + ) + assert response.status_code == 200 + assert response.json()["total"] == 1 + assert response.json()["leave_requests"][0]["organization_id"] == str(org1.id) + + +def test_review_accept_leave_request_admin_only( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Review Test Org") + admin = create_user( + username="admin", email="admin@example.com", user_type=UserType.ADMIN + ) + user = create_user(username="employee", email="employee@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user.id, organization_id=org.id) + + # Admin accepts the leave request + response = client.patch( + f"/leave_requests/{leave_request.id}/review", + headers=auth_headers(admin), + json={"is_accepted": True}, + ) + assert response.status_code == 200 + data = response.json() + assert data["is_accepted"] is True + assert data["reviewed_by"] == str(admin.id) + assert data["reviewed_at"] is not None + + +def test_review_leave_request_non_admin_denied( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Deny Test Org") + user = create_user(username="regular", email="regular@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user.id, organization_id=org.id) + + # Regular user tries to review - should fail + response = client.patch( + f"/leave_requests/{leave_request.id}/review", + headers=auth_headers(user), + json={"is_accepted": True}, + ) + assert response.status_code == 403 + + +def test_review_reject_leave_request_admin_only( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Reject Test Org") + admin = create_user( + username="admin", email="admin@example.com", user_type=UserType.ADMIN + ) + user = create_user(username="employee", email="employee@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request( + user_id=user.id, organization_id=org.id, is_accepted=True + ) + + # Admin rejects the leave request + response = client.patch( + f"/leave_requests/{leave_request.id}/review", + headers=auth_headers(admin), + json={"is_accepted": False}, + ) + assert response.status_code == 200 + assert response.json()["is_accepted"] is False + + +def test_delete_own_leave_request( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Delete Test Org") + user = create_user(username="user", email="user@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user.id, organization_id=org.id) + + response = client.delete( + f"/leave_requests/{leave_request.id}", + headers=auth_headers(user), + ) + assert response.status_code == 200 + assert response.json()["message"] == "Leave request deleted" + + +def test_delete_other_user_leave_request_denied( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Delete Deny Org") + user1 = create_user(username="user1", email="user1@example.com") + user2 = create_user(username="user2", email="user2@example.com") + create_user_organization(user_id=user1.id, organization_id=org.id) + create_user_organization(user_id=user2.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user1.id, organization_id=org.id) + + # User2 tries to delete User1's leave request - should fail + response = client.delete( + f"/leave_requests/{leave_request.id}", + headers=auth_headers(user2), + ) + assert response.status_code == 403 + + +def test_admin_can_delete_any_leave_request( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Admin Delete Org") + admin = create_user( + username="admin", email="admin@example.com", user_type=UserType.ADMIN + ) + user = create_user(username="employee", email="employee@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user.id, organization_id=org.id) + + # Admin can delete any user's leave request + response = client.delete( + f"/leave_requests/{leave_request.id}", + headers=auth_headers(admin), + ) + assert response.status_code == 200 + + +def test_get_single_leave_request( + client, + create_user, + create_organization, + create_user_organization, + create_leave_request, + auth_headers, +): + org = create_organization(name="Single Test Org") + user = create_user(username="user", email="user@example.com") + create_user_organization(user_id=user.id, organization_id=org.id) + + leave_request = create_leave_request(user_id=user.id, organization_id=org.id) + + response = client.get( + f"/leave_requests/{leave_request.id}", + headers=auth_headers(user), + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(leave_request.id) + assert data["organization_id"] == str(org.id) + assert data["organization_name"] == org.name + + +def test_leave_request_not_found(client, create_user, auth_headers): + user = create_user(username="user", email="user@example.com") + + response = client.get( + "/leave_requests/00000000-0000-0000-0000-000000000000", + headers=auth_headers(user), + ) + assert response.status_code == 404 diff --git a/tests/test_organizations_api.py b/tests/test_organizations_api.py index 3c15369..6e796ca 100644 --- a/tests/test_organizations_api.py +++ b/tests/test_organizations_api.py @@ -63,23 +63,20 @@ def test_policies_crud(client, create_user, create_organization, auth_headers): user = create_user(username="policy-user", email="policy-user@example.com") org = create_organization(name="Test Org") - # Create policy + # Create policy (using form data for file upload support) create_response = client.post( "/policies", headers=auth_headers(user), - json={ + data={ "organization_id": str(org.id), "name": "Standard Leave Policy", "description": "Standard policy for all employees", - "max_leave_days": 25, - "carry_forward_days": 5, - "is_active": True, + "is_active": "true", }, ) assert create_response.status_code == 200 policy_id = create_response.json()["id"] assert create_response.json()["name"] == "Standard Leave Policy" - assert create_response.json()["max_leave_days"] == 25 # List policies list_response = client.get("/policies", headers=auth_headers(user)) @@ -98,21 +95,19 @@ def test_policies_crud(client, create_user, create_organization, auth_headers): assert org_policies_response.status_code == 200 assert org_policies_response.json()["total"] == 1 - # Update policy + # Update policy (using form data) update_response = client.put( f"/policies/{policy_id}", headers=auth_headers(user), - json={ + data={ "organization_id": str(org.id), "name": "Updated Leave Policy", "description": "Updated description", - "max_leave_days": 30, - "carry_forward_days": 10, - "is_active": True, + "is_active": "true", }, ) assert update_response.status_code == 200 - assert update_response.json()["max_leave_days"] == 30 + assert update_response.json()["name"] == "Updated Leave Policy" # Delete policy delete_response = client.delete(f"/policies/{policy_id}", headers=auth_headers(user)) @@ -134,11 +129,40 @@ def test_policy_requires_valid_organization(client, create_user, auth_headers): response = client.post( "/policies", headers=auth_headers(user), - json={ + data={ "organization_id": "00000000-0000-0000-0000-000000000000", "name": "Invalid Policy", - "max_leave_days": 20, - "carry_forward_days": 5, }, ) assert response.status_code == 404 + + +def test_policy_with_file_upload(client, create_user, create_organization, auth_headers): + import io + + user = create_user(username="file-user", email="file-user@example.com") + org = create_organization(name="File Test Org") + + # Create a test file + test_file_content = b"This is a test policy document content." + test_file = io.BytesIO(test_file_content) + + # Create policy with file upload + create_response = client.post( + "/policies", + headers=auth_headers(user), + data={ + "organization_id": str(org.id), + "name": "Policy With File", + "description": "Policy with uploaded document", + "is_active": "true", + }, + files={"file": ("test_policy.pdf", test_file, "application/pdf")}, + ) + assert create_response.status_code == 200 + assert create_response.json()["document_name"] == "test_policy.pdf" + assert create_response.json()["file_path"] is not None + + # Clean up - delete the policy + policy_id = create_response.json()["id"] + client.delete(f"/policies/{policy_id}", headers=auth_headers(user)) diff --git a/tests/test_user_organizations_api.py b/tests/test_user_organizations_api.py new file mode 100644 index 0000000..28c5eae --- /dev/null +++ b/tests/test_user_organizations_api.py @@ -0,0 +1,173 @@ +def test_join_organization(client, create_user, create_organization, auth_headers): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + + response = client.post( + "/user_organizations", + headers=auth_headers(user), + json={ + "user_id": str(user.id), + "organization_id": str(org.id), + "joined_date": "2026-01-15T00:00:00", + "is_active": True, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["user_id"] == str(user.id) + assert data["organization_id"] == str(org.id) + assert data["is_active"] is True + + +def test_cannot_join_same_organization_twice( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + create_user_organization(user_id=user.id, organization_id=org.id) + + response = client.post( + "/user_organizations", + headers=auth_headers(user), + json={ + "user_id": str(user.id), + "organization_id": str(org.id), + "joined_date": "2026-02-01T00:00:00", + }, + ) + assert response.status_code == 400 + assert "already a member" in response.json()["detail"] + + +def test_get_user_organizations( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + create_user_organization(user_id=user.id, organization_id=org.id) + + response = client.get("/user_organizations", headers=auth_headers(user)) + assert response.status_code == 200 + assert response.json()["total"] == 1 + + +def test_get_single_membership( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + membership = create_user_organization(user_id=user.id, organization_id=org.id) + + response = client.get( + f"/user_organizations/{membership.id}", headers=auth_headers(user) + ) + assert response.status_code == 200 + assert response.json()["id"] == str(membership.id) + + +def test_get_organizations_for_user( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org1 = create_organization(name="Company A") + org2 = create_organization(name="Company B") + create_user_organization(user_id=user.id, organization_id=org1.id) + create_user_organization(user_id=user.id, organization_id=org2.id) + + response = client.get( + f"/users/{user.id}/organizations", headers=auth_headers(user) + ) + assert response.status_code == 200 + assert response.json()["total"] == 2 + + +def test_get_members_for_organization( + client, create_user, create_organization, create_user_organization, auth_headers +): + user1 = create_user(username="member1", email="member1@example.com") + user2 = create_user(username="member2", email="member2@example.com") + org = create_organization(name="Test Company") + create_user_organization(user_id=user1.id, organization_id=org.id) + create_user_organization(user_id=user2.id, organization_id=org.id) + + response = client.get( + f"/organizations/{org.id}/members", headers=auth_headers(user1) + ) + assert response.status_code == 200 + assert response.json()["total"] == 2 + + +def test_update_membership_left_date( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + membership = create_user_organization(user_id=user.id, organization_id=org.id) + + # User leaves the company + response = client.patch( + f"/user_organizations/{membership.id}", + headers=auth_headers(user), + json={ + "left_date": "2026-06-30T00:00:00", + "is_active": False, + }, + ) + assert response.status_code == 200 + assert response.json()["is_active"] is False + assert response.json()["left_date"] is not None + + +def test_delete_membership( + client, create_user, create_organization, create_user_organization, auth_headers +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + membership = create_user_organization(user_id=user.id, organization_id=org.id) + + response = client.delete( + f"/user_organizations/{membership.id}", headers=auth_headers(user) + ) + assert response.status_code == 200 + assert response.json()["message"] == "Membership deleted" + + +def test_membership_not_found(client, create_user, auth_headers): + user = create_user(username="member", email="member@example.com") + + response = client.get( + "/user_organizations/00000000-0000-0000-0000-000000000000", + headers=auth_headers(user), + ) + assert response.status_code == 404 + + +def test_join_invalid_organization(client, create_user, auth_headers): + user = create_user(username="member", email="member@example.com") + + response = client.post( + "/user_organizations", + headers=auth_headers(user), + json={ + "user_id": str(user.id), + "organization_id": "00000000-0000-0000-0000-000000000000", + "joined_date": "2026-01-15T00:00:00", + }, + ) + assert response.status_code == 404 + + +def test_join_invalid_user(client, create_user, create_organization, auth_headers): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Test Company") + + response = client.post( + "/user_organizations", + headers=auth_headers(user), + json={ + "user_id": "00000000-0000-0000-0000-000000000000", + "organization_id": str(org.id), + "joined_date": "2026-01-15T00:00:00", + }, + ) + assert response.status_code == 404 diff --git a/tests/test_users_api.py b/tests/test_users_api.py index f5f9aed..ea74212 100644 --- a/tests/test_users_api.py +++ b/tests/test_users_api.py @@ -1,7 +1,7 @@ import pytest from fastapi import HTTPException -from users.apis import _coerce_user_type, _require_admin +from users.utils import coerce_user_type, require_admin from users.choices import UserType @@ -70,9 +70,9 @@ def test_delete_user(client, create_user, auth_headers): def test_user_type_helpers(): - assert _coerce_user_type("admin") == UserType.ADMIN - _require_admin("ADMIN") + assert coerce_user_type("admin") == UserType.ADMIN + require_admin("ADMIN") with pytest.raises(HTTPException): - _require_admin("REGULAR") + require_admin("REGULAR") with pytest.raises(HTTPException): - _coerce_user_type("visitor") + coerce_user_type("visitor") diff --git a/users/apis.py b/users/apis.py index e37a1c7..d0f2c5f 100644 --- a/users/apis.py +++ b/users/apis.py @@ -5,28 +5,33 @@ from application.app import app from auth.passwords import hash_password -from database.db import drop_users_table, get_db +from database.db import drop_leave_requests_table, drop_users_table, get_db +from organizations.models import ( + Organization, + UserOrganization, + UserOrganizationItem, + UserOrganizationsListResponse, +) from users.choices import UserType -from users.models import User, UserItem, UserRequest, UserResponse, UsersListResponse - - -def _coerce_user_type(value: str) -> UserType: - normalized = value.strip().upper() - try: - return UserType[normalized] - except KeyError: - raise HTTPException(status_code=403, detail="Admin access required") - - -def _require_admin(user_type: str) -> None: - if _coerce_user_type(user_type) != UserType.ADMIN: - raise HTTPException(status_code=403, detail="Admin access required") - - -def _require_authenticated_user(request: Request): - if not request.user or not request.user.is_authenticated: - raise HTTPException(status_code=401, detail="Authentication required") - return request.user +from users.models import ( + LeaveRequest, + LeaveRequestCreate, + LeaveRequestItem, + LeaveRequestResponse, + LeaveRequestReview, + LeaveRequestsListResponse, + User, + UserItem, + UserRequest, + UserResponse, + UsersListResponse, +) +from users.utils import ( + build_leave_request_item, + coerce_user_type, + require_admin, + require_authenticated_user, +) @app.get("/users", response_model=UsersListResponse) @@ -34,8 +39,8 @@ async def get_users( request: Request, db: Session = Depends(get_db), ): - _require_admin(request.user.user_type) - rows = db.query(User).order_by(User.created_at.desc()).all() + require_admin(request.user.user_type) + rows = db.query(User).order_by(User.created.desc()).all() users = [ UserItem( id=str(row.id), @@ -65,7 +70,7 @@ async def get_user( request: Request, db: Session = Depends(get_db), ): - current_user = _require_authenticated_user(request) + current_user = require_authenticated_user(request) if current_user.user_id != user_id: raise HTTPException(status_code=403, detail="User access restricted") user = db.query(User).filter(User.id == user_id).first() @@ -84,6 +89,44 @@ async def get_user( ) +@app.get("/users/{user_id}/organizations", response_model=UserOrganizationsListResponse) +async def get_organizations_for_user(user_id: str, db: Session = Depends(get_db)): + """Get all organizations a user belongs to.""" + user = db.query(User).filter(User.id == user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + rows = ( + db.query(UserOrganization) + .filter(UserOrganization.user_id == user_id) + .order_by(UserOrganization.joined_date.desc()) + .all() + ) + memberships = [] + for row in rows: + org = db.query(Organization).filter(Organization.id == row.organization_id).first() + memberships.append( + UserOrganizationItem( + id=str(row.id), + user_id=str(row.user_id), + username=user.username, + organization_id=str(row.organization_id), + organization_name=org.name if org else None, + joined_date=row.joined_date, + left_date=row.left_date, + is_active=row.is_active, + created=row.created, + ) + ) + total = len(memberships) + message = "No organizations found for user" if total == 0 else "Organizations retrieved" + return UserOrganizationsListResponse( + memberships=memberships, + total=total, + message=message, + ) + + @app.post("/users", response_model=UserResponse) async def create_user(user: UserRequest, db: Session = Depends(get_db)): password_hash = hash_password(user.password) @@ -96,7 +139,6 @@ async def create_user(user: UserRequest, db: Session = Depends(get_db)): phone=user.phone, gender=user.gender, date_of_birth=user.date_of_birth, - created_at=datetime.now(), user_type=user.user_type, ) db.add(new_user) @@ -130,3 +172,224 @@ async def delete_user(user_id: str, db: Session = Depends(get_db)): async def drop_users_db_table(): drop_users_table() return {"status": "ok", "message": "Users database table dropped"} + + +# Leave Request APIs +@app.get("/leave_requests", response_model=LeaveRequestsListResponse) +async def get_leave_requests( + request: Request, + db: Session = Depends(get_db), +): + """Get all leave requests. Admin sees all, regular users see only their own.""" + current_user = require_authenticated_user(request) + + if coerce_user_type(current_user.user_type) == UserType.ADMIN: + rows = db.query(LeaveRequest).order_by(LeaveRequest.applied_at.desc()).all() + else: + rows = ( + db.query(LeaveRequest) + .filter(LeaveRequest.user_id == current_user.user_id) + .order_by(LeaveRequest.applied_at.desc()) + .all() + ) + + leave_requests = [build_leave_request_item(row, db) for row in rows] + + total = len(leave_requests) + message = "No leave requests found" if total == 0 else "Leave requests retrieved" + return LeaveRequestsListResponse( + leave_requests=leave_requests, + total=total, + message=message, + ) + + +@app.get("/leave_requests/{leave_request_id}", response_model=LeaveRequestItem) +async def get_leave_request( + leave_request_id: str, + request: Request, + db: Session = Depends(get_db), +): + """Get a specific leave request.""" + current_user = require_authenticated_user(request) + + leave_request = db.query(LeaveRequest).filter(LeaveRequest.id == leave_request_id).first() + if not leave_request: + raise HTTPException(status_code=404, detail="Leave request not found") + + # Regular users can only see their own leave requests + if ( + coerce_user_type(current_user.user_type) != UserType.ADMIN + and str(leave_request.user_id) != current_user.user_id + ): + raise HTTPException(status_code=403, detail="Access denied") + + return build_leave_request_item(leave_request, db) + + +@app.get("/organizations/{organization_id}/leave_requests", response_model=LeaveRequestsListResponse) +async def get_organization_leave_requests( + organization_id: str, + request: Request, + db: Session = Depends(get_db), +): + """Get all leave requests for an organization. Admin only.""" + require_admin(request.user.user_type) + + org = db.query(Organization).filter(Organization.id == organization_id).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + rows = ( + db.query(LeaveRequest) + .filter(LeaveRequest.organization_id == organization_id) + .order_by(LeaveRequest.applied_at.desc()) + .all() + ) + + leave_requests = [build_leave_request_item(row, db) for row in rows] + + total = len(leave_requests) + message = "No leave requests found" if total == 0 else "Leave requests retrieved" + return LeaveRequestsListResponse( + leave_requests=leave_requests, + total=total, + message=message, + ) + + +@app.post("/leave_requests", response_model=LeaveRequestResponse) +async def apply_leave_request( + leave_request: LeaveRequestCreate, + request: Request, + db: Session = Depends(get_db), +): + """Apply for a leave request. User must be a member of the organization.""" + current_user = require_authenticated_user(request) + + # Verify organization exists + org = db.query(Organization).filter(Organization.id == leave_request.organization_id).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # Verify user is an active member of the organization + membership = ( + db.query(UserOrganization) + .filter( + UserOrganization.user_id == current_user.user_id, + UserOrganization.organization_id == leave_request.organization_id, + UserOrganization.is_active.is_(True), + ) + .first() + ) + if not membership: + raise HTTPException( + status_code=403, detail="You are not an active member of this organization" + ) + + # Check if user already has a leave request for the same date in this organization + existing = ( + db.query(LeaveRequest) + .filter( + LeaveRequest.user_id == current_user.user_id, + LeaveRequest.organization_id == leave_request.organization_id, + LeaveRequest.date == leave_request.date.date(), + ) + .first() + ) + if existing: + raise HTTPException( + status_code=400, detail="You already have a leave request for this date in this organization" + ) + + new_leave_request = LeaveRequest( + user_id=current_user.user_id, + organization_id=leave_request.organization_id, + date=leave_request.date.date(), + leave_type=leave_request.leave_type, + reason=leave_request.reason, + is_accepted=False, + ) + db.add(new_leave_request) + db.commit() + db.refresh(new_leave_request) + + return LeaveRequestResponse( + id=str(new_leave_request.id), + user_id=str(new_leave_request.user_id), + organization_id=str(new_leave_request.organization_id), + date=new_leave_request.date, + leave_type=new_leave_request.leave_type, + reason=new_leave_request.reason, + is_accepted=new_leave_request.is_accepted, + reviewed_by=None, + reviewed_at=None, + applied_at=new_leave_request.applied_at, + created=new_leave_request.created, + ) + + +@app.patch("/leave_requests/{leave_request_id}/review", response_model=LeaveRequestResponse) +async def review_leave_request( + leave_request_id: str, + review: LeaveRequestReview, + request: Request, + db: Session = Depends(get_db), +): + """Review (accept/reject) a leave request. Admin only.""" + require_admin(request.user.user_type) + + leave_request = db.query(LeaveRequest).filter(LeaveRequest.id == leave_request_id).first() + if not leave_request: + raise HTTPException(status_code=404, detail="Leave request not found") + + leave_request.is_accepted = review.is_accepted + leave_request.reviewed_by = request.user.user_id + leave_request.reviewed_at = datetime.now() + db.commit() + db.refresh(leave_request) + + return LeaveRequestResponse( + id=str(leave_request.id), + user_id=str(leave_request.user_id), + organization_id=str(leave_request.organization_id), + date=leave_request.date, + leave_type=leave_request.leave_type, + reason=leave_request.reason, + is_accepted=leave_request.is_accepted, + reviewed_by=str(leave_request.reviewed_by) if leave_request.reviewed_by else None, + reviewed_at=leave_request.reviewed_at, + applied_at=leave_request.applied_at, + created=leave_request.created, + ) + + +@app.delete("/leave_requests/{leave_request_id}") +async def delete_leave_request( + leave_request_id: str, + request: Request, + db: Session = Depends(get_db), +): + """Delete a leave request. Users can delete their own, admins can delete any.""" + current_user = require_authenticated_user(request) + + leave_request = db.query(LeaveRequest).filter(LeaveRequest.id == leave_request_id).first() + if not leave_request: + raise HTTPException(status_code=404, detail="Leave request not found") + + # Regular users can only delete their own leave requests + if ( + coerce_user_type(current_user.user_type) != UserType.ADMIN + and str(leave_request.user_id) != current_user.user_id + ): + raise HTTPException(status_code=403, detail="Access denied") + + db.delete(leave_request) + db.commit() + return {"status": "ok", "message": "Leave request deleted"} + + +@app.delete("/admin/drop-leave_requests-db") +async def drop_leave_requests_db_table(): + drop_leave_requests_table() + return {"status": "ok", "message": "Leave requests database table dropped"} diff --git a/users/choices.py b/users/choices.py index 00322ef..2753cf9 100644 --- a/users/choices.py +++ b/users/choices.py @@ -4,3 +4,8 @@ class UserType(str, Enum): ADMIN = "ADMIN" REGULAR = "REGULAR" + + +class LeaveType(str, Enum): + SICK_LEAVE = "SICK_LEAVE" + PRIVILEGE_LEAVE = "PRIVILEGE_LEAVE" diff --git a/users/models.py b/users/models.py index 4f761e9..38ae34c 100644 --- a/users/models.py +++ b/users/models.py @@ -2,12 +2,13 @@ from datetime import datetime from pydantic import BaseModel -from sqlalchemy import Column, Date, DateTime, Enum, String +from sqlalchemy import Boolean, Column, Date, DateTime, Enum, ForeignKey, String, Text from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship from sqlalchemy.sql import func from database.db import Base -from users.choices import UserType +from users.choices import LeaveType, UserType class User(Base): @@ -25,7 +26,34 @@ class User(Base): Enum(UserType, name="user_type"), nullable=False, default=UserType.REGULAR ) date_of_birth = Column(Date, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) + created = Column(DateTime(timezone=True), server_default=func.now()) + modified = Column(DateTime(timezone=True), onupdate=func.now()) + + # Relationship to leave requests + leave_requests = relationship("LeaveRequest", foreign_keys="LeaveRequest.user_id", back_populates="user", cascade="all, delete-orphan") + + +class LeaveRequest(Base): + __tablename__ = "leave_requests" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + organization_id = Column(UUID(as_uuid=True), ForeignKey("organizations.id"), nullable=False) + date = Column(Date, nullable=False) + leave_type = Column( + Enum(LeaveType, name="leave_type"), nullable=False + ) + reason = Column(Text, nullable=True) + is_accepted = Column(Boolean, default=False) + reviewed_by = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + reviewed_at = Column(DateTime(timezone=True), nullable=True) + applied_at = Column(DateTime(timezone=True), server_default=func.now()) + created = Column(DateTime(timezone=True), server_default=func.now()) + modified = Column(DateTime(timezone=True), onupdate=func.now()) + + # Relationships + user = relationship("User", foreign_keys=[user_id], back_populates="leave_requests") + reviewer = relationship("User", foreign_keys=[reviewed_by]) class UserRequest(BaseModel): @@ -68,3 +96,52 @@ class UsersListResponse(BaseModel): users: list[UserItem] total: int message: str + + +# Pydantic Models - LeaveRequest +class LeaveRequestCreate(BaseModel): + organization_id: str + date: datetime + leave_type: LeaveType + reason: str | None = None + + +class LeaveRequestItem(BaseModel): + id: str + user_id: str + username: str | None = None + organization_id: str + organization_name: str | None = None + date: datetime + leave_type: LeaveType + reason: str | None = None + is_accepted: bool + reviewed_by: str | None = None + reviewer_name: str | None = None + reviewed_at: datetime | None = None + applied_at: datetime | None = None + created: datetime | None = None + + +class LeaveRequestResponse(BaseModel): + id: str + user_id: str + organization_id: str + date: datetime + leave_type: LeaveType + reason: str | None = None + is_accepted: bool + reviewed_by: str | None = None + reviewed_at: datetime | None = None + applied_at: datetime | None = None + created: datetime | None = None + + +class LeaveRequestsListResponse(BaseModel): + leave_requests: list[LeaveRequestItem] + total: int + message: str + + +class LeaveRequestReview(BaseModel): + is_accepted: bool diff --git a/users/utils.py b/users/utils.py new file mode 100644 index 0000000..dddbe8a --- /dev/null +++ b/users/utils.py @@ -0,0 +1,51 @@ +from fastapi import HTTPException, Request +from sqlalchemy.orm import Session + +from organizations.models import Organization +from users.choices import UserType +from users.models import LeaveRequestItem, User + + +def coerce_user_type(value: str) -> UserType: + """Convert a string value to UserType enum.""" + normalized = value.strip().upper() + try: + return UserType[normalized] + except KeyError: + raise HTTPException(status_code=403, detail="Admin access required") + + +def require_admin(user_type: str) -> None: + """Raise 403 if user is not an admin.""" + if coerce_user_type(user_type) != UserType.ADMIN: + raise HTTPException(status_code=403, detail="Admin access required") + + +def require_authenticated_user(request: Request): + """Raise 401 if user is not authenticated.""" + if not request.user or not request.user.is_authenticated: + raise HTTPException(status_code=401, detail="Authentication required") + return request.user + + +def build_leave_request_item(row, db: Session) -> LeaveRequestItem: + """Build LeaveRequestItem from a LeaveRequest row.""" + user = db.query(User).filter(User.id == row.user_id).first() + org = db.query(Organization).filter(Organization.id == row.organization_id).first() + reviewer = db.query(User).filter(User.id == row.reviewed_by).first() if row.reviewed_by else None + return LeaveRequestItem( + id=str(row.id), + user_id=str(row.user_id), + username=user.username if user else None, + organization_id=str(row.organization_id), + organization_name=org.name if org else None, + date=row.date, + leave_type=row.leave_type, + reason=row.reason, + is_accepted=row.is_accepted, + reviewed_by=str(row.reviewed_by) if row.reviewed_by else None, + reviewer_name=reviewer.username if reviewer else None, + reviewed_at=row.reviewed_at, + applied_at=row.applied_at, + created=row.created, + )