From 45cc94ddc5dd6ff6646401b24a038705ca8f16b1 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Mon, 9 Feb 2026 19:56:47 +0530 Subject: [PATCH 01/20] updates summary --- ai/apis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai/apis.py b/ai/apis.py index 3e5fae4..77aad13 100644 --- a/ai/apis.py +++ b/ai/apis.py @@ -13,7 +13,7 @@ @app.post( "/ai_assistant", status_code=status.HTTP_200_OK, - summary="Chat with Database", + summary="Chat with Documents", response_description="Answer from the AI", ) async def ai_assistant(request: QNARequestBody) -> QNAResponseBody: From 0f3c32d12a18317318c28b042cacf076e4f3e13c Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Mon, 9 Feb 2026 19:57:01 +0530 Subject: [PATCH 02/20] updates agent --- ai/agent.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 76 insertions(+), 5 deletions(-) diff --git a/ai/agent.py b/ai/agent.py index 1825e4e..3a9420f 100644 --- a/ai/agent.py +++ b/ai/agent.py @@ -2,9 +2,28 @@ from typing import Any from ai.clients import CohereClient +from ai.rag import RAGClient SESSION_MEMORY: dict[str, list[dict[str, str]]] = {} MAX_HISTORY = 20 +POLICY_KEYWORDS = { + "policy", + "leave", + "pto", + "vacation", + "sick", + "holiday", + "absence", + "bereavement", + "maternity", + "paternity", + "parental", + "time off", + "work from home", + "remote work", + "attendance", + "benefits", +} class SchedulingAgent: @@ -19,15 +38,67 @@ def _trim_history(self, history: list[dict[str, str]]) -> list[dict[str, str]]: return history return history[-MAX_HISTORY:] - def run(self) -> dict[str, Any]: - history = [] - if self.session_id: - history = list(SESSION_MEMORY.get(self.session_id, [])) + def _is_policy_question(self, question: str) -> bool: + lowered = question.lower() + return any(keyword in lowered for keyword in POLICY_KEYWORDS) + + def _answer_policy_question( + self, + question: str, + history: list[dict[str, str]], + ) -> tuple[str, list] | None: + rag_client = RAGClient() + top_k = int(os.getenv("POLICY_RAG_TOP_K", "5")) + matches = rag_client.query_policy_index(question, top_k=top_k) + if not matches: + return None + + excerpts = [] + for match in matches: + title = match.get("document_name") or match.get("policy_name") or "Policy" + chunk_index = match.get("chunk_index", "unknown") + text = match.get("text", "") + if text: + excerpts.append(f"[{title} - chunk {chunk_index}] {text}") + + if not excerpts: + return None + + prompt = ( + "You are a policy assistant. Answer the question using only the policy " + "excerpts below. If the answer is not contained in the excerpts, say you " + "couldn't find it in the policy documents.\n\n" + "Policy excerpts:\n" + f"{os.linesep.join(excerpts)}\n\n" + f"Question: {question}" + ) response_text, history = self.client.ask_llm( - message=self.question, + message=prompt, chat_history=history, max_steps=self.max_steps, ) + return response_text, history + + def run(self) -> dict[str, Any]: + history = [] + if self.session_id: + history = list(SESSION_MEMORY.get(self.session_id, [])) + if self._is_policy_question(self.question): + policy_result = self._answer_policy_question(self.question, history) + if policy_result: + response_text, history = policy_result + else: + response_text, history = self.client.ask_llm( + message=self.question, + chat_history=history, + max_steps=self.max_steps, + ) + else: + response_text, history = self.client.ask_llm( + message=self.question, + chat_history=history, + max_steps=self.max_steps, + ) if self.session_id: SESSION_MEMORY[self.session_id] = self._trim_history(history) return { From 72145ba1954f84591666f88fe774a449f84eb57c Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:21 +0530 Subject: [PATCH 03/20] feat(users): validate duplicate username/email before creating user (409) Co-authored-by: Cursor --- users/apis.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/users/apis.py b/users/apis.py index d0f2c5f..1b0fadf 100644 --- a/users/apis.py +++ b/users/apis.py @@ -129,13 +129,29 @@ async def get_organizations_for_user(user_id: str, db: Session = Depends(get_db) @app.post("/users", response_model=UserResponse) async def create_user(user: UserRequest, db: Session = Depends(get_db)): + username_lower = user.username.lower() + email_lower = (user.email or "").strip().lower() + + existing_by_username = db.query(User).filter(User.username == username_lower).first() + if existing_by_username: + raise HTTPException( + status_code=409, + detail="A user with this username already exists", + ) + + if email_lower and db.query(User).filter(User.email == email_lower).first(): + raise HTTPException( + status_code=409, + detail="A user with this email already exists", + ) + password_hash = hash_password(user.password) new_user = User( first_name=user.first_name.lower(), last_name=user.last_name.lower(), - username=user.username.lower(), + username=username_lower, password_hash=password_hash, - email=user.email, + email=email_lower or user.email, phone=user.phone, gender=user.gender, date_of_birth=user.date_of_birth, From 1cfdc55e58fcb69ec3a1e8b72cdda7ba6044053f Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:24 +0530 Subject: [PATCH 04/20] feat(orgs): add get_my_organization_details and get_organization_ids_for_user Co-authored-by: Cursor --- organizations/db.py | 58 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/organizations/db.py b/organizations/db.py index 0bce1a7..e2d3042 100644 --- a/organizations/db.py +++ b/organizations/db.py @@ -1,7 +1,7 @@ from sqlalchemy import func from database.db import SessionLocal -from organizations.models import Organization, Policy +from organizations.models import Organization, Policy, UserOrganization def get_organization_details(organization_name: str): @@ -25,6 +25,62 @@ def get_organization_details(organization_name: str): return {"detail": "Organization not found", "name": organization_name} +def get_my_organization_details(user_id: str): + """ + Get organization details for the given user by looking up their memberships + in UserOrganization and returning the organizations they belong to. + """ + with SessionLocal() as db: + memberships = ( + db.query(UserOrganization) + .filter( + UserOrganization.user_id == user_id, + UserOrganization.is_active.is_(True), + ) + .order_by(UserOrganization.joined_date.desc()) + .all() + ) + if not memberships: + return { + "detail": "You are not a member of any organization.", + "organizations": [], + "total": 0, + } + organizations = [] + for m in memberships: + org = db.query(Organization).filter(Organization.id == m.organization_id).first() + if org: + organizations.append({ + "id": str(org.id), + "name": org.name, + "description": org.description, + "address": org.address, + "email": org.email, + "phone": org.phone, + "is_active": org.is_active, + "membership_joined_date": str(m.joined_date) if m.joined_date else None, + }) + return { + "organizations": organizations, + "total": len(organizations), + "message": f"Found {len(organizations)} organization(s) you belong to.", + } + + +def get_organization_ids_for_user(user_id: str) -> list[str]: + """Return list of organization IDs (as strings) the user belongs to (active memberships).""" + with SessionLocal() as db: + rows = ( + db.query(UserOrganization.organization_id) + .filter( + UserOrganization.user_id == user_id, + UserOrganization.is_active.is_(True), + ) + .all() + ) + return [str(r[0]) for r in rows] + + def get_policies_for_organization(organization_name: str): """Get all policies for an organization.""" with SessionLocal() as db: From 7f9e55a600eea7748bb475ccfebf60161cbd6034 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:24 +0530 Subject: [PATCH 05/20] feat(orgs): add get_my_organization_details tool for AI agent Co-authored-by: Cursor --- organizations/tools.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/organizations/tools.py b/organizations/tools.py index 0f84e4d..9c9d2b2 100644 --- a/organizations/tools.py +++ b/organizations/tools.py @@ -1,4 +1,13 @@ ORGANIZATION_TOOLS = [ + { + "name": "get_my_organization_details", + "description": ( + "Returns the organization(s) that the requesting user belongs to. " + "Use this when the user asks about 'my organization', 'details of my organization', " + "'tell me about my organization', or similar. Looks up the user's memberships and returns org details." + ), + "parameter_definitions": {}, + }, { "name": "get_organization_details", "description": "Returns organization details by searching for the name.", From eb7143c8cd6f79662742f9a7038f528ebf868dd0 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:25 +0530 Subject: [PATCH 06/20] feat(orgs): wire get_my_organization_details with user_id in function map Co-authored-by: Cursor --- organizations/constants.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/organizations/constants.py b/organizations/constants.py index d3d4ad1..47bb404 100644 --- a/organizations/constants.py +++ b/organizations/constants.py @@ -1,12 +1,26 @@ from organizations.db import ( + get_my_organization_details, get_organization_details, get_policies_for_organization, get_policy_details, ) -def get_organization_function_map(): - return { + +def _make_get_my_organization_details(user_id: str): + """Return a no-arg callable that fetches org details for the given user_id.""" + + def _fn(**kwargs): + return get_my_organization_details(user_id) + + return _fn + + +def get_organization_function_map(user_id: str | None = None): + mapping = { "get_organization_details": get_organization_details, "get_policies_for_organization": get_policies_for_organization, "get_policy_details": get_policy_details, } + if user_id is not None: + mapping["get_my_organization_details"] = _make_get_my_organization_details(user_id) + return mapping From f01b1a8b162745902ff511d21a6025ab6a409295 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:25 +0530 Subject: [PATCH 07/20] feat(ai): filter RAG query_policy_index by organization_ids Co-authored-by: Cursor --- ai/rag.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/ai/rag.py b/ai/rag.py index b6cc7eb..21e1f60 100644 --- a/ai/rag.py +++ b/ai/rag.py @@ -1,12 +1,12 @@ import logging import os import re +import uuid from typing import Iterable +import cohere import httpx from sqlalchemy import delete, select - -from ai.clients import CohereClient from ai.db import PolicyEmbedding from database.db import SessionLocal @@ -17,6 +17,7 @@ class RAGClient: def __init__(self, embed_model: str | None = None): self.embed_model = embed_model or os.getenv("COHERE_EMBED_MODEL", "embed-english-v3.0") + self.client = cohere.Client(os.getenv("COHERE_API_KEY")) def _looks_like_text(self, raw: bytes) -> bool: if not raw: @@ -94,11 +95,14 @@ def _chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 200) -> l return chunks def _embed_texts(self, texts: Iterable[str], input_type: str) -> list[list[float]]: - client = CohereClient(model=self.embed_model) - response = client.embed_texts(list(texts), input_type=input_type) + response = self.client.embed( + texts=list(texts), + model=self.embed_model, + input_type=input_type, + ) logger.info(f"Embedded {len(texts)} texts with model {self.embed_model} & response: {response}") - return response + return response.embeddings or [] def index_policy_document( self, @@ -154,7 +158,12 @@ def remove_policy_from_index(self, policy_id: str) -> dict: return {"status": "skipped", "reason": "policy_not_found"} return {"status": "removed", "count": result.rowcount} - def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]: + def query_policy_index( + self, + query: str, + top_k: int = 5, + organization_ids: list[str] | None = None, + ) -> list[dict]: query_embedding = self._embed_texts([query], input_type="search_query") if not query_embedding: return [] @@ -166,6 +175,12 @@ def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]: .order_by(PolicyEmbedding.embedding.cosine_distance(query_vector)) .limit(max(top_k, 1)) ) + if organization_ids: + try: + uuids = [uuid.UUID(oid) for oid in organization_ids] + stmt = stmt.where(PolicyEmbedding.organization_id.in_(uuids)) + except (ValueError, TypeError): + pass results = db.execute(stmt).scalars().all() response = [] for record in results: From 92d3a2b7438867a2d7195a9fbe017d58911e55d3 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:26 +0530 Subject: [PATCH 08/20] feat(ai): add search_my_organization_policies tool (user-scoped policy search) Co-authored-by: Cursor --- ai/tools.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 ai/tools.py diff --git a/ai/tools.py b/ai/tools.py new file mode 100644 index 0000000..07d07d6 --- /dev/null +++ b/ai/tools.py @@ -0,0 +1,84 @@ +from ai.rag import RAGClient +from organizations.db import get_organization_ids_for_user + +AI_TOOLS = [ + { + "name": "search_my_organization_policies", + "description": ( + "Searches policy documents that belong to the requesting user's organization(s) " + "and returns relevant excerpts to answer their question. Use this when the user " + "asks about leave policy, PTO, sick leave, vacation, benefits, or any policy " + "that applies to their organization. Answers are based only on policies from " + "organizations the user is a member of." + ), + "parameter_definitions": { + "query": { + "description": "The user's policy question or what they want to know.", + "type": "str", + "required": True, + }, + "top_k": { + "description": "Number of relevant policy chunks to return (default 5).", + "type": "int", + "required": False, + }, + }, + }, + { + "name": "search_policy_embeddings", + "description": "Searches policy document embeddings to answer policy questions.", + "parameter_definitions": { + "query": { + "description": "The policy question or search query.", + "type": "str", + "required": True, + }, + "top_k": { + "description": "Number of relevant chunks to return.", + "type": "int", + "required": False, + }, + }, + }, +] + + +def search_policy_embeddings(query: str, top_k: int = 5): + return RAGClient().query_policy_index(query, top_k=top_k) + + +def _make_search_my_organization_policies(user_id: str): + """Return a callable that searches policy content scoped to the user's organizations.""" + + def _fn(query: str, top_k: int = 5, **kwargs): + org_ids = get_organization_ids_for_user(user_id) + if not org_ids: + return { + "detail": "You are not a member of any organization. No policies to search.", + "matches": [], + } + matches = RAGClient().query_policy_index( + query, top_k=top_k, organization_ids=org_ids + ) + if not matches: + return { + "detail": "No matching policy content found in your organization's policies.", + "matches": [], + } + return { + "detail": f"Found {len(matches)} relevant excerpt(s) from your organization's policies.", + "matches": matches, + } + + return _fn + + +def get_ai_function_map(user_id: str | None = None): + mapping = { + "search_policy_embeddings": search_policy_embeddings, + } + if user_id is not None: + mapping["search_my_organization_policies"] = _make_search_my_organization_policies( + user_id + ) + return mapping From e537c46e93e19d471a9b622d6773d4effdb66b56 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:26 +0530 Subject: [PATCH 09/20] feat(ai): PolicyAgent accepts user_id and passes to CohereClient Co-authored-by: Cursor --- ai/agent.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ai/agent.py b/ai/agent.py index 8965477..6b73c19 100644 --- a/ai/agent.py +++ b/ai/agent.py @@ -3,6 +3,7 @@ from ai.clients import CohereClient from ai.rag import RAGClient +from ai.prompts import POLICY_PROMPT SESSION_MEMORY: dict[str, list[dict[str, str]]] = {} MAX_HISTORY = 20 @@ -27,11 +28,17 @@ class PolicyAgent: - def __init__(self, question: str, session_id: str | None = None): + def __init__( + self, + question: str, + session_id: str | None = None, + user_id: str | None = None, + ): self.question = question self.session_id = session_id + self.user_id = user_id self.max_steps = int(os.getenv("AI_AGENT_MAX_STEPS", "8")) - self.client = CohereClient() + self.client = CohereClient(user_id=user_id) def _trim_history(self, history: list[dict[str, str]]) -> list[dict[str, str]]: if len(history) <= MAX_HISTORY: @@ -64,14 +71,7 @@ def _answer_policy_question( if not excerpts: return None - prompt = ( - "You are a policy assistant. Answer the question using only the policy " - "excerpts below. If the answer is not contained in the excerpts, say you " - "couldn't find it in the policy documents.\n\n" - "Policy excerpts:\n" - f"{os.linesep.join(excerpts)}\n\n" - f"Question: {question}" - ) + prompt = POLICY_PROMPT.format(excerpts=excerpts, question=question) response_text, history = self.client.ask_llm( message=prompt, chat_history=history, From 221aba2bf8302c308118c134653cf001078c6488 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:29 +0530 Subject: [PATCH 10/20] feat(ai): CohereClient accepts user_id, filter user-scoped tools when no user Co-authored-by: Cursor --- ai/clients.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ai/clients.py b/ai/clients.py index 7c9102e..53ad3d8 100644 --- a/ai/clients.py +++ b/ai/clients.py @@ -3,19 +3,33 @@ import cohere from ai.prompts import PREAMBLE +from ai.tools import AI_TOOLS, get_ai_function_map from organizations.constants import get_organization_function_map from organizations.tools import ORGANIZATION_TOOLS class CohereClient: - def __init__(self, message: str | None = None, model: str | None = None): + def __init__( + self, + message: str | None = None, + model: str | None = None, + user_id: str | None = None, + ): self.client = cohere.Client(os.getenv("COHERE_API_KEY")) self.model = model or os.getenv("COHERE_LLM_MODEL") self.preamble = PREAMBLE self.function_map = { - **get_organization_function_map(), + **get_organization_function_map(user_id=user_id), + **get_ai_function_map(user_id=user_id), } - self.tools = [*ORGANIZATION_TOOLS] + tools = [*ORGANIZATION_TOOLS, *AI_TOOLS] + if user_id is None: + tools = [ + t + for t in tools + if t.get("name") not in ("get_my_organization_details", "search_my_organization_policies") + ] + self.tools = tools self.message = message or "" def chat( From 7a709e5733d273307556f23038dd1d6c35e5bc7a Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:29 +0530 Subject: [PATCH 11/20] feat(ai): ai_assistant passes current_user id to PolicyAgent Co-authored-by: Cursor --- ai/apis.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ai/apis.py b/ai/apis.py index 7c3b8f1..73bf53d 100644 --- a/ai/apis.py +++ b/ai/apis.py @@ -7,6 +7,7 @@ from ai.db import PolicyEmbedding from ai.models import QNARequestBody, QNAResponseBody from application.app import app +from auth.dependencies import require_authenticated_user from database.db import get_db @@ -16,7 +17,10 @@ summary="Chat with Documents", response_description="Answer from the AI", ) -async def ai_assistant(request: QNARequestBody) -> QNAResponseBody: +async def ai_assistant( + request: QNARequestBody, + current_user=Depends(require_authenticated_user), +) -> QNAResponseBody: """ Payload for the endpoint: { @@ -24,9 +28,11 @@ async def ai_assistant(request: QNARequestBody) -> QNAResponseBody: } """ session_id = request.session_id or str(uuid.uuid4()) + user_id = current_user.user_id if current_user else None result = PolicyAgent( question=request.question, session_id=session_id, + user_id=user_id, ).run() return { "question": request.question, From e2b7b28549f637ce6ddd10861b0e3d1fbbaf178e Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:30 +0530 Subject: [PATCH 12/20] docs(ai): preamble prefer search_my_organization_policies for user org policy questions Co-authored-by: Cursor --- ai/prompts.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ai/prompts.py b/ai/prompts.py index 3c44b64..6fc74f3 100644 --- a/ai/prompts.py +++ b/ai/prompts.py @@ -9,6 +9,9 @@ ## Response Rules - Use tools to fetch organization and policy data when answering leave questions. +- When the user asks about policy that applies to them or their organization (e.g. leave policy, PTO, benefits), + use the search_my_organization_policies tool to get relevant excerpts from their organization's policies and answer from that. +- For other policy-related questions (e.g. about a named organization), use search_policy_embeddings. - If a user asks about leave counts or pending requests and no tool data is available, ask a brief follow-up question or explain the limitation. - Always reference the organization name and policy details when answering leave questions. @@ -17,3 +20,12 @@ Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling. """ + +POLICY_PROMPT = """ +"You are a policy assistant. Answer the question using only the policy " +"excerpts below. If the answer is not contained in the excerpts, say you " +"couldn't find it in the policy documents.\n\n" +"Policy excerpts:\n" +f"{os.linesep.join(excerpts)}\n\n" +f"Question: {question}" +""" \ No newline at end of file From 286a29634e36c405cd45d3438e67bc5e5a374705 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:33 +0530 Subject: [PATCH 13/20] chore: application app updates Co-authored-by: Cursor --- application/app.py | 6 +++--- database/db.py | 7 ------- docker-compose.yml | 6 +++--- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/application/app.py b/application/app.py index a8313b7..1ab1a00 100644 --- a/application/app.py +++ b/application/app.py @@ -9,8 +9,8 @@ from database.db import drop_db, init_db app = FastAPI( - title="Leave Query API", - description="API for managing leave queries using AI", + title="Policy AI Agent API", + description="API for managing policy queries using AI", version="1.0.0", dependencies=[Depends(require_authenticated_user)], ) @@ -45,7 +45,7 @@ async def drop_database(): @app.get("/") async def root(): return { - "message": "Welcome to Leave Query API using AI", + "message": "Welcome to Policy AI Agent API", "version": "1.0.0", "docs": "/docs", } diff --git a/database/db.py b/database/db.py index fb6e6ab..960d477 100644 --- a/database/db.py +++ b/database/db.py @@ -52,13 +52,6 @@ def drop_users_table(): init_db() -def drop_appointments_table(): - from appointments.models import Appointment - - Appointment.__table__.drop(bind=engine, checkfirst=True) - init_db() - - def drop_organizations_table(): from organizations.models import Organization diff --git a/docker-compose.yml b/docker-compose.yml index 56bbac6..23c689b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,8 +3,8 @@ services: build: context: . dockerfile: compose/Dockerfile - image: leave-query-app - container_name: leave-query-app + image: policy-agent-app + container_name: policy-agent-app ports: - "8000:8000" env_file: @@ -24,7 +24,7 @@ services: postgres: image: pgvector/pgvector:pg16 - container_name: leave-query-db + container_name: policy-agent-db env_file: - .env volumes: From 77286ca77baf8bff028664ce5ccbf8ad551eaf7f Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:35 +0530 Subject: [PATCH 14/20] test(users): duplicate username/email 409, delete 404 Co-authored-by: Cursor --- tests/test_users_api.py | 54 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_users_api.py b/tests/test_users_api.py index ea74212..2271fb9 100644 --- a/tests/test_users_api.py +++ b/tests/test_users_api.py @@ -76,3 +76,57 @@ def test_user_type_helpers(): require_admin("REGULAR") with pytest.raises(HTTPException): coerce_user_type("visitor") + + +def test_create_user_duplicate_username_returns_409(client, create_user): + create_user(username="existing", email="first@example.com") + response = client.post( + "/users", + json={ + "first_name": "Other", + "last_name": "User", + "username": "existing", + "password": "secret123", + "email": "second@example.com", + "phone": "5555552222", + "gender": "male", + "user_type": "REGULAR", + "date_of_birth": "1992-01-01T00:00:00", + }, + ) + assert response.status_code == 409 + assert "username" in response.json()["detail"].lower() + + +def test_create_user_duplicate_email_returns_409(client, create_user): + create_user(username="user1", email="same@example.com") + response = client.post( + "/users", + json={ + "first_name": "Other", + "last_name": "User", + "username": "user2", + "password": "secret123", + "email": "same@example.com", + "phone": "5555553333", + "gender": "female", + "user_type": "REGULAR", + "date_of_birth": "1993-01-01T00:00:00", + }, + ) + assert response.status_code == 409 + assert "email" in response.json()["detail"].lower() + + +def test_delete_user_not_found_returns_404(client, create_user, auth_headers): + admin = create_user( + username="admin", + email="admin@example.com", + user_type=UserType.ADMIN, + ) + response = client.delete( + "/users/00000000-0000-0000-0000-000000000000", + headers=auth_headers(admin), + ) + assert response.status_code == 404 + assert response.json()["detail"] == "User not found" From e2b525417769b06eb6b4d83ed2024f1be3b1c744 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:35 +0530 Subject: [PATCH 15/20] test(auth): protected endpoint 401 without token and invalid token Co-authored-by: Cursor --- tests/test_auth.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_auth.py b/tests/test_auth.py index ac2c896..b7e55ae 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -40,3 +40,17 @@ def test_login_invalid_credentials(client): json={"username": "missing", "password": "nope"}, ) assert response.status_code == 401 + + +def test_protected_endpoint_without_token_returns_401(client): + response = client.get("/users") + assert response.status_code == 401 + assert "detail" in response.json() + + +def test_protected_endpoint_with_invalid_token_returns_401(client): + response = client.get( + "/users", + headers={"Authorization": "Bearer invalid-token-here"}, + ) + assert response.status_code == 401 From f277e2f6f365e40b50a94d31e5bd4e7d8f4dd019 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:36 +0530 Subject: [PATCH 16/20] test(app): root includes docs Co-authored-by: Cursor --- tests/test_app.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_app.py b/tests/test_app.py index 09b42d7..75a60e9 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -2,7 +2,7 @@ def test_root(client): response = client.get("/") assert response.status_code == 200 payload = response.json() - assert payload["message"] == "Welcome to Leave Query API using AI" + assert payload["message"] == "Welcome to Policy AI Agent API" assert payload["version"] == "1.0.0" @@ -12,3 +12,11 @@ def test_health(client): payload = response.json() assert payload["status"] == "healthy" assert "timestamp" in payload + + +def test_root_includes_docs(client): + response = client.get("/") + assert response.status_code == 200 + payload = response.json() + assert "docs" in payload + assert payload["docs"] == "/docs" From 74efbf02b6d2b78e49a7488142d6cea738f7bb58 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:42 +0530 Subject: [PATCH 17/20] test(ai): fix DummyClient/DummyAgent for user_id, add user_id and spy tests Co-authored-by: Cursor --- tests/test_ai.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/test_ai.py b/tests/test_ai.py index 593f42a..bc4b5d1 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -4,6 +4,9 @@ def test_agent_preserves_sensitive_text(monkeypatch): class DummyClient: + def __init__(self, message=None, model=None, user_id=None, **kwargs): + pass + def ask_llm(self, message=None, chat_history=None, max_steps=8): history = [ { @@ -24,6 +27,9 @@ def ask_llm(self, message=None, chat_history=None, max_steps=8): def test_agent_trims_history(monkeypatch): class DummyClient: + def __init__(self, message=None, model=None, user_id=None, **kwargs): + pass + def ask_llm(self, message=None, chat_history=None, max_steps=8): history = [ {"role": "USER", "message": f"msg {idx}"} @@ -44,7 +50,7 @@ def test_ai_assistant_endpoint_uses_agent( client, monkeypatch, create_user, auth_headers ): class DummyAgent: - def __init__(self, question, session_id=None): + def __init__(self, question, session_id=None, user_id=None): self.question = question self.session_id = session_id @@ -66,3 +72,66 @@ def run(self): payload = response.json() assert payload["response"] == "ok" assert payload["messages"][0]["message"] == "ok" + + +def test_policy_agent_passes_user_id_to_cohere_client(monkeypatch): + captured = {} + + class FakeCohereClient: + def __init__(self, message=None, model=None, user_id=None): + captured["user_id"] = user_id + + def ask_llm(self, message=None, chat_history=None, max_steps=8): + return "ok", [] + + monkeypatch.setattr(agent_module, "CohereClient", FakeCohereClient) + agent_module.SESSION_MEMORY.clear() + + agent_module.PolicyAgent("hello", session_id="s1", user_id="user-123").run() + assert captured.get("user_id") == "user-123" + + +def test_policy_agent_accepts_none_user_id(monkeypatch): + captured = {} + + class FakeCohereClient: + def __init__(self, message=None, model=None, user_id=None): + captured["user_id"] = user_id + + def ask_llm(self, message=None, chat_history=None, max_steps=8): + return "ok", [] + + monkeypatch.setattr(agent_module, "CohereClient", FakeCohereClient) + agent_module.SESSION_MEMORY.clear() + + agent_module.PolicyAgent("hello", session_id="s2").run() + assert captured.get("user_id") is None + + +def test_ai_assistant_passes_user_id_to_agent( + client, monkeypatch, create_user, auth_headers +): + captured = {} + + class SpyAgent: + def __init__(self, question, session_id=None, user_id=None): + captured["user_id"] = user_id + captured["question"] = question + + def run(self): + return { + "response": "ok", + "session_id": None, + "messages": [], + } + + monkeypatch.setattr(ai_apis, "PolicyAgent", SpyAgent) + user = create_user(username="spy-user", email="spy@example.com") + response = client.post( + "/ai_assistant", + json={"question": "What is my organization?"}, + headers=auth_headers(user), + ) + assert response.status_code == 200 + assert captured.get("user_id") == str(user.id) + assert captured.get("question") == "What is my organization?" From 987d529abd121ba00ab7489779f02e0551f239a8 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:42 +0530 Subject: [PATCH 18/20] test(leave_requests): delete leave request 404 Co-authored-by: Cursor --- tests/test_leave_requests_api.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_leave_requests_api.py b/tests/test_leave_requests_api.py index a85ebb8..dfeb4ac 100644 --- a/tests/test_leave_requests_api.py +++ b/tests/test_leave_requests_api.py @@ -342,3 +342,14 @@ def test_leave_request_not_found(client, create_user, auth_headers): headers=auth_headers(user), ) assert response.status_code == 404 + + +def test_delete_leave_request_not_found(client, create_user, auth_headers): + user = create_user(username="user", email="user@example.com") + + response = client.delete( + "/leave_requests/00000000-0000-0000-0000-000000000000", + headers=auth_headers(user), + ) + assert response.status_code == 404 + assert response.json()["detail"] == "Leave request not found" From fa2171f259ae71403701c227bf362aec76347b76 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:43 +0530 Subject: [PATCH 19/20] test(user_organizations): get organizations for user 404 when user not found Co-authored-by: Cursor --- tests/test_user_organizations_api.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_user_organizations_api.py b/tests/test_user_organizations_api.py index 28c5eae..cdc5e78 100644 --- a/tests/test_user_organizations_api.py +++ b/tests/test_user_organizations_api.py @@ -171,3 +171,14 @@ def test_join_invalid_user(client, create_user, create_organization, auth_header }, ) assert response.status_code == 404 + + +def test_get_organizations_for_user_not_found(client, create_user, auth_headers): + user = create_user(username="member", email="member@example.com") + + response = client.get( + "/users/00000000-0000-0000-0000-000000000000/organizations", + headers=auth_headers(user), + ) + assert response.status_code == 404 + assert response.json()["detail"] == "User not found" From 9e294a9d18047932850b454e3c7b3b3303fcfcb9 Mon Sep 17 00:00:00 2001 From: Aakash-Pandit Date: Sat, 14 Feb 2026 16:56:43 +0530 Subject: [PATCH 20/20] test(orgs): unit tests for get_my_organization_details and get_organization_ids_for_user Co-authored-by: Cursor --- tests/test_organizations_db.py | 100 +++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/test_organizations_db.py diff --git a/tests/test_organizations_db.py b/tests/test_organizations_db.py new file mode 100644 index 0000000..662477d --- /dev/null +++ b/tests/test_organizations_db.py @@ -0,0 +1,100 @@ +"""Unit tests for organizations.db helpers (get_my_organization_details, get_organization_ids_for_user).""" + +import organizations.db as org_db + + +def test_get_my_organization_details_returns_orgs_for_user( + app, create_user, create_organization, create_user_organization +): + user = create_user(username="member", email="member@example.com") + org = create_organization(name="Acme Corp", description="Test org") + create_user_organization(user_id=user.id, organization_id=org.id) + + result = org_db.get_my_organization_details(str(user.id)) + + assert result["total"] == 1 + assert len(result["organizations"]) == 1 + assert result["organizations"][0]["name"] == "Acme Corp" + assert result["organizations"][0]["id"] == str(org.id) + assert "membership_joined_date" in result["organizations"][0] + assert "message" in result + + +def test_get_my_organization_details_returns_empty_when_user_has_no_orgs(app, create_user): + user = create_user(username="loner", email="loner@example.com") + + result = org_db.get_my_organization_details(str(user.id)) + + assert result["total"] == 0 + assert result["organizations"] == [] + assert "not a member" in result["detail"].lower() + + +def test_get_my_organization_details_ignores_inactive_membership( + app, create_user, create_organization, create_user_organization +): + user = create_user(username="inactive", email="inactive@example.com") + org = create_organization(name="Left Org") + create_user_organization( + user_id=user.id, + organization_id=org.id, + is_active=False, + ) + + result = org_db.get_my_organization_details(str(user.id)) + + assert result["total"] == 0 + assert result["organizations"] == [] + + +def test_get_my_organization_details_multiple_orgs( + app, create_user, create_organization, create_user_organization +): + user = create_user(username="multi", email="multi@example.com") + org1 = create_organization(name="Org One") + org2 = create_organization(name="Org Two") + create_user_organization(user_id=user.id, organization_id=org1.id) + create_user_organization(user_id=user.id, organization_id=org2.id) + + result = org_db.get_my_organization_details(str(user.id)) + + assert result["total"] == 2 + names = {o["name"] for o in result["organizations"]} + assert names == {"Org One", "Org Two"} + + +def test_get_organization_ids_for_user_returns_ids( + app, create_user, create_organization, create_user_organization +): + user = create_user(username="u", email="u@example.com") + org = create_organization(name="My Org") + create_user_organization(user_id=user.id, organization_id=org.id) + + ids = org_db.get_organization_ids_for_user(str(user.id)) + + assert len(ids) == 1 + assert ids[0] == str(org.id) + + +def test_get_organization_ids_for_user_empty_when_no_membership(app, create_user): + user = create_user(username="nobody", email="nobody@example.com") + + ids = org_db.get_organization_ids_for_user(str(user.id)) + + assert ids == [] + + +def test_get_organization_ids_for_user_excludes_inactive( + app, create_user, create_organization, create_user_organization +): + user = create_user(username="ex", email="ex@example.com") + org = create_organization(name="Ex Org") + create_user_organization( + user_id=user.id, + organization_id=org.id, + is_active=False, + ) + + ids = org_db.get_organization_ids_for_user(str(user.id)) + + assert ids == []