Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
45cc94d
updates summary
Aakash-Pandit Feb 9, 2026
0f3c32d
updates agent
Aakash-Pandit Feb 9, 2026
2da6e07
Merge branch 'main' of github.com:Aakash-Pandit/LeaveQuery-AI into co…
Aakash-Pandit Feb 9, 2026
72145ba
feat(users): validate duplicate username/email before creating user (…
Aakash-Pandit Feb 14, 2026
1cfdc55
feat(orgs): add get_my_organization_details and get_organization_ids_…
Aakash-Pandit Feb 14, 2026
7f9e55a
feat(orgs): add get_my_organization_details tool for AI agent
Aakash-Pandit Feb 14, 2026
eb7143c
feat(orgs): wire get_my_organization_details with user_id in function…
Aakash-Pandit Feb 14, 2026
f01b1a8
feat(ai): filter RAG query_policy_index by organization_ids
Aakash-Pandit Feb 14, 2026
92d3a2b
feat(ai): add search_my_organization_policies tool (user-scoped polic…
Aakash-Pandit Feb 14, 2026
e537c46
feat(ai): PolicyAgent accepts user_id and passes to CohereClient
Aakash-Pandit Feb 14, 2026
221aba2
feat(ai): CohereClient accepts user_id, filter user-scoped tools when…
Aakash-Pandit Feb 14, 2026
7a709e5
feat(ai): ai_assistant passes current_user id to PolicyAgent
Aakash-Pandit Feb 14, 2026
e2b7b28
docs(ai): preamble prefer search_my_organization_policies for user or…
Aakash-Pandit Feb 14, 2026
286a296
chore: application app updates
Aakash-Pandit Feb 14, 2026
77286ca
test(users): duplicate username/email 409, delete 404
Aakash-Pandit Feb 14, 2026
e2b5254
test(auth): protected endpoint 401 without token and invalid token
Aakash-Pandit Feb 14, 2026
f277e2f
test(app): root includes docs
Aakash-Pandit Feb 14, 2026
74efbf0
test(ai): fix DummyClient/DummyAgent for user_id, add user_id and spy…
Aakash-Pandit Feb 14, 2026
987d529
test(leave_requests): delete leave request 404
Aakash-Pandit Feb 14, 2026
fa2171f
test(user_organizations): get organizations for user 404 when user no…
Aakash-Pandit Feb 14, 2026
9e294a9
test(orgs): unit tests for get_my_organization_details and get_organi…
Aakash-Pandit Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,103 @@
from typing import Any

from ai.clients import CohereClient
from ai.rag import RAGClient
from ai.prompts import POLICY_PROMPT

SESSION_MEMORY: dict[str, list[dict[str, str]]] = {}
MAX_HISTORY = 20
POLICY_KEYWORDS = {
"policy",
"leave",
"pto",
"vacation",
"sick",
"holiday",
"absence",
"bereavement",
"maternity",
"paternity",
"parental",
"time off",
"work from home",
"remote work",
"attendance",
"benefits",
}


class PolicyAgent:
def __init__(self, question: str, session_id: str | None = None):
def __init__(
self,
question: str,
session_id: str | None = None,
user_id: str | None = None,
):
self.question = question
self.session_id = session_id
self.user_id = user_id
self.max_steps = int(os.getenv("AI_AGENT_MAX_STEPS", "8"))
self.client = CohereClient()
self.client = CohereClient(user_id=user_id)

def _trim_history(self, history: list[dict[str, str]]) -> list[dict[str, str]]:
if len(history) <= MAX_HISTORY:
return history
return history[-MAX_HISTORY:]

def run(self) -> dict[str, Any]:
history = []
if self.session_id:
history = list(SESSION_MEMORY.get(self.session_id, []))
def _is_policy_question(self, question: str) -> bool:
lowered = question.lower()
return any(keyword in lowered for keyword in POLICY_KEYWORDS)

def _answer_policy_question(
self,
question: str,
history: list[dict[str, str]],
) -> tuple[str, list] | None:
rag_client = RAGClient()
top_k = int(os.getenv("POLICY_RAG_TOP_K", "5"))
matches = rag_client.query_policy_index(question, top_k=top_k)
if not matches:
return None

excerpts = []
for match in matches:
title = match.get("document_name") or match.get("policy_name") or "Policy"
chunk_index = match.get("chunk_index", "unknown")
text = match.get("text", "")
if text:
excerpts.append(f"[{title} - chunk {chunk_index}] {text}")

if not excerpts:
return None

prompt = POLICY_PROMPT.format(excerpts=excerpts, question=question)
response_text, history = self.client.ask_llm(
message=self.question,
message=prompt,
chat_history=history,
max_steps=self.max_steps,
)
return response_text, history

def run(self) -> dict[str, Any]:
history = []
if self.session_id:
history = list(SESSION_MEMORY.get(self.session_id, []))
if self._is_policy_question(self.question):
policy_result = self._answer_policy_question(self.question, history)
if policy_result:
response_text, history = policy_result
else:
response_text, history = self.client.ask_llm(
message=self.question,
chat_history=history,
max_steps=self.max_steps,
)
else:
response_text, history = self.client.ask_llm(
message=self.question,
chat_history=history,
max_steps=self.max_steps,
)
if self.session_id:
SESSION_MEMORY[self.session_id] = self._trim_history(history)
return {
Expand Down
10 changes: 8 additions & 2 deletions ai/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,32 @@
from ai.db import PolicyEmbedding
from ai.models import QNARequestBody, QNAResponseBody
from application.app import app
from auth.dependencies import require_authenticated_user
from database.db import get_db


@app.post(
"/ai_assistant",
status_code=status.HTTP_200_OK,
summary="Chat with Database",
summary="Chat with Documents",
response_description="Answer from the AI",
)
async def ai_assistant(request: QNARequestBody) -> QNAResponseBody:
async def ai_assistant(
request: QNARequestBody,
current_user=Depends(require_authenticated_user),
) -> QNAResponseBody:
"""
Payload for the endpoint:
{
"question": "give me date of birth of Dr. ruso lamba"
}
"""
session_id = request.session_id or str(uuid.uuid4())
user_id = current_user.user_id if current_user else None
result = PolicyAgent(
question=request.question,
session_id=session_id,
user_id=user_id,
).run()
return {
"question": request.question,
Expand Down
20 changes: 17 additions & 3 deletions ai/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,33 @@
import cohere

from ai.prompts import PREAMBLE
from ai.tools import AI_TOOLS, get_ai_function_map
from organizations.constants import get_organization_function_map
from organizations.tools import ORGANIZATION_TOOLS


class CohereClient:
def __init__(self, message: str | None = None, model: str | None = None):
def __init__(
self,
message: str | None = None,
model: str | None = None,
user_id: str | None = None,
):
self.client = cohere.Client(os.getenv("COHERE_API_KEY"))
self.model = model or os.getenv("COHERE_LLM_MODEL")
self.preamble = PREAMBLE
self.function_map = {
**get_organization_function_map(),
**get_organization_function_map(user_id=user_id),
**get_ai_function_map(user_id=user_id),
}
self.tools = [*ORGANIZATION_TOOLS]
tools = [*ORGANIZATION_TOOLS, *AI_TOOLS]
if user_id is None:
tools = [
t
for t in tools
if t.get("name") not in ("get_my_organization_details", "search_my_organization_policies")
]
self.tools = tools
self.message = message or ""

def chat(
Expand Down
12 changes: 12 additions & 0 deletions ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

## Response Rules
- Use tools to fetch organization and policy data when answering leave questions.
- When the user asks about policy that applies to them or their organization (e.g. leave policy, PTO, benefits),
use the search_my_organization_policies tool to get relevant excerpts from their organization's policies and answer from that.
- For other policy-related questions (e.g. about a named organization), use search_policy_embeddings.
- If a user asks about leave counts or pending requests and no tool data is available,
ask a brief follow-up question or explain the limitation.
- Always reference the organization name and policy details when answering leave questions.
Expand All @@ -17,3 +20,12 @@
Unless the user asks for a different style of answer, you should answer in full sentences,
using proper grammar and spelling.
"""

POLICY_PROMPT = """
"You are a policy assistant. Answer the question using only the policy "
"excerpts below. If the answer is not contained in the excerpts, say you "
"couldn't find it in the policy documents.\n\n"
"Policy excerpts:\n"
f"{os.linesep.join(excerpts)}\n\n"
f"Question: {question}"
"""
27 changes: 21 additions & 6 deletions ai/rag.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
import re
import uuid
from typing import Iterable

import cohere
import httpx
from sqlalchemy import delete, select

from ai.clients import CohereClient
from ai.db import PolicyEmbedding
from database.db import SessionLocal

Expand All @@ -17,6 +17,7 @@
class RAGClient:
def __init__(self, embed_model: str | None = None):
self.embed_model = embed_model or os.getenv("COHERE_EMBED_MODEL", "embed-english-v3.0")
self.client = cohere.Client(os.getenv("COHERE_API_KEY"))

def _looks_like_text(self, raw: bytes) -> bool:
if not raw:
Expand Down Expand Up @@ -94,11 +95,14 @@ def _chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 200) -> l
return chunks

def _embed_texts(self, texts: Iterable[str], input_type: str) -> list[list[float]]:
client = CohereClient(model=self.embed_model)
response = client.embed_texts(list(texts), input_type=input_type)
response = self.client.embed(
texts=list(texts),
model=self.embed_model,
input_type=input_type,
)

logger.info(f"Embedded {len(texts)} texts with model {self.embed_model} & response: {response}")
return response
return response.embeddings or []

def index_policy_document(
self,
Expand Down Expand Up @@ -154,7 +158,12 @@ def remove_policy_from_index(self, policy_id: str) -> dict:
return {"status": "skipped", "reason": "policy_not_found"}
return {"status": "removed", "count": result.rowcount}

def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]:
def query_policy_index(
self,
query: str,
top_k: int = 5,
organization_ids: list[str] | None = None,
) -> list[dict]:
query_embedding = self._embed_texts([query], input_type="search_query")
if not query_embedding:
return []
Expand All @@ -166,6 +175,12 @@ def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]:
.order_by(PolicyEmbedding.embedding.cosine_distance(query_vector))
.limit(max(top_k, 1))
)
if organization_ids:
try:
uuids = [uuid.UUID(oid) for oid in organization_ids]
stmt = stmt.where(PolicyEmbedding.organization_id.in_(uuids))
except (ValueError, TypeError):
pass
results = db.execute(stmt).scalars().all()
response = []
for record in results:
Expand Down
84 changes: 84 additions & 0 deletions ai/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from ai.rag import RAGClient
from organizations.db import get_organization_ids_for_user

AI_TOOLS = [
{
"name": "search_my_organization_policies",
"description": (
"Searches policy documents that belong to the requesting user's organization(s) "
"and returns relevant excerpts to answer their question. Use this when the user "
"asks about leave policy, PTO, sick leave, vacation, benefits, or any policy "
"that applies to their organization. Answers are based only on policies from "
"organizations the user is a member of."
),
"parameter_definitions": {
"query": {
"description": "The user's policy question or what they want to know.",
"type": "str",
"required": True,
},
"top_k": {
"description": "Number of relevant policy chunks to return (default 5).",
"type": "int",
"required": False,
},
},
},
{
"name": "search_policy_embeddings",
"description": "Searches policy document embeddings to answer policy questions.",
"parameter_definitions": {
"query": {
"description": "The policy question or search query.",
"type": "str",
"required": True,
},
"top_k": {
"description": "Number of relevant chunks to return.",
"type": "int",
"required": False,
},
},
},
]


def search_policy_embeddings(query: str, top_k: int = 5):
return RAGClient().query_policy_index(query, top_k=top_k)


def _make_search_my_organization_policies(user_id: str):
"""Return a callable that searches policy content scoped to the user's organizations."""

def _fn(query: str, top_k: int = 5, **kwargs):
org_ids = get_organization_ids_for_user(user_id)
if not org_ids:
return {
"detail": "You are not a member of any organization. No policies to search.",
"matches": [],
}
matches = RAGClient().query_policy_index(
query, top_k=top_k, organization_ids=org_ids
)
if not matches:
return {
"detail": "No matching policy content found in your organization's policies.",
"matches": [],
}
return {
"detail": f"Found {len(matches)} relevant excerpt(s) from your organization's policies.",
"matches": matches,
}

return _fn


def get_ai_function_map(user_id: str | None = None):
mapping = {
"search_policy_embeddings": search_policy_embeddings,
}
if user_id is not None:
mapping["search_my_organization_policies"] = _make_search_my_organization_policies(
user_id
)
return mapping
6 changes: 3 additions & 3 deletions application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from database.db import drop_db, init_db

app = FastAPI(
title="Leave Query API",
description="API for managing leave queries using AI",
title="Policy AI Agent API",
description="API for managing policy queries using AI",
version="1.0.0",
dependencies=[Depends(require_authenticated_user)],
)
Expand Down Expand Up @@ -45,7 +45,7 @@ async def drop_database():
@app.get("/")
async def root():
return {
"message": "Welcome to Leave Query API using AI",
"message": "Welcome to Policy AI Agent API",
"version": "1.0.0",
"docs": "/docs",
}
Expand Down
7 changes: 0 additions & 7 deletions database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ def drop_users_table():
init_db()


def drop_appointments_table():
from appointments.models import Appointment

Appointment.__table__.drop(bind=engine, checkfirst=True)
init_db()


def drop_organizations_table():
from organizations.models import Organization

Expand Down
Loading