Skip to content
3 changes: 2 additions & 1 deletion ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _answer_policy_question(
if not excerpts:
return None

prompt = POLICY_PROMPT.format(excerpts=excerpts, question=question)
excerpts_text = os.linesep.join(excerpts)
prompt = POLICY_PROMPT.format(excerpts_text=excerpts_text, question=question)
response_text, history = self.client.ask_llm(
message=prompt,
chat_history=history,
Expand Down
7 changes: 6 additions & 1 deletion ai/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def __init__(
tools = [
t
for t in tools
if t.get("name") not in ("get_my_organization_details", "search_my_organization_policies")
if t.get("name")
not in (
"get_my_organization_details",
"search_my_organization_policies",
"get_my_pending_leaves",
)
]
self.tools = tools
self.message = message or ""
Expand Down
13 changes: 6 additions & 7 deletions ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
using proper grammar and spelling.
"""

POLICY_PROMPT = """
"You are a policy assistant. Answer the question using only the policy "
"excerpts below. If the answer is not contained in the excerpts, say you "
"couldn't find it in the policy documents.\n\n"
"Policy excerpts:\n"
f"{os.linesep.join(excerpts)}\n\n"
f"Question: {question}"
POLICY_PROMPT = """You are a policy assistant. Answer the question using only the policy excerpts below. If the answer is not contained in the excerpts, say you couldn't find it in the policy documents.

Policy excerpts:
{excerpts_text}

Question: {question}
"""
44 changes: 43 additions & 1 deletion ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ai.rag import RAGClient
from organizations.db import get_organization_ids_for_user
from organizations.db import get_my_approved_leaves_summary, get_organization_ids_for_user

AI_TOOLS = [
{
Expand Down Expand Up @@ -40,6 +40,17 @@
},
},
},
{
"name": "get_my_pending_leaves",
"description": (
"Returns the requesting user's approved leaves summary and leave policy excerpts "
"to determine how many leaves are pending/remaining. Use when the user asks "
"'how many leaves are pending of mine?', 'leaves remaining', 'my leave balance', "
"'how many days do I have left?', or similar. Combines approved leave count from "
"the database with leave policy from the organization to compute pending leaves."
),
"parameter_definitions": {},
},
]


Expand Down Expand Up @@ -73,6 +84,36 @@ def _fn(query: str, top_k: int = 5, **kwargs):
return _fn


def _make_get_my_pending_leaves(user_id: str):
"""Return a callable that fetches approved leaves + leave policy for the user."""

def _fn(**kwargs):
approved = get_my_approved_leaves_summary(user_id)
org_ids = get_organization_ids_for_user(user_id)
if not org_ids:
return {
"detail": "You are not a member of any organization. No leave data available.",
"approved_leaves": [],
"policy_excerpts": [],
}
policy_matches = RAGClient().query_policy_index(
"leave policy days allowance sick leave privilege leave PTO annual vacation",
top_k=5,
organization_ids=org_ids,
)
return {
"detail": "Your approved leaves and relevant leave policy. Use policy excerpts to determine total allowance and compute pending = allowance - approved.",
"approved_leaves": approved.get("approved_leaves", []),
"total_approved_days": approved.get("total_approved_days", 0),
"policy_excerpts": [
{"text": m.get("text"), "policy_name": m.get("policy_name")}
for m in policy_matches
],
}

return _fn


def get_ai_function_map(user_id: str | None = None):
mapping = {
"search_policy_embeddings": search_policy_embeddings,
Expand All @@ -81,4 +122,5 @@ def get_ai_function_map(user_id: str | None = None):
mapping["search_my_organization_policies"] = _make_search_my_organization_policies(
user_id
)
mapping["get_my_pending_leaves"] = _make_get_my_pending_leaves(user_id)
return mapping
47 changes: 47 additions & 0 deletions organizations/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from database.db import SessionLocal
from organizations.models import Organization, Policy, UserOrganization
from users.models import LeaveRequest


def get_organization_details(organization_name: str):
Expand Down Expand Up @@ -134,3 +135,49 @@ def get_policy_details(policy_name: str, organization_name: str):
"organization": organization_name,
}
return {"detail": "Policy not found", "name": policy_name}


def get_my_approved_leaves_summary(user_id: str) -> dict:
"""
Get the count of approved leaves for the given user, grouped by organization and leave type.
Used to compute pending/remaining leaves when combined with leave policy.
"""
with SessionLocal() as db:
rows = (
db.query(
LeaveRequest.organization_id,
LeaveRequest.leave_type,
func.count(LeaveRequest.id).label("count"),
)
.filter(
LeaveRequest.user_id == user_id,
LeaveRequest.is_accepted.is_(True),
)
.group_by(LeaveRequest.organization_id, LeaveRequest.leave_type)
.all()
)
orgs = {}
if rows:
unique_org_ids = list({r[0] for r in rows})
org_list = (
db.query(Organization)
.filter(Organization.id.in_(unique_org_ids))
.all()
)
orgs = {str(o.id): o.name for o in org_list}

by_org = {}
total_approved = 0
for org_id, leave_type, count in rows:
org_name = orgs.get(str(org_id), "Unknown")
if org_id not in by_org:
by_org[str(org_id)] = {"organization_name": org_name, "by_type": {}, "total": 0}
by_org[str(org_id)]["by_type"][str(leave_type.value)] = count
by_org[str(org_id)]["total"] += count
total_approved += count

return {
"approved_leaves": list(by_org.values()),
"total_approved_days": total_approved,
"organizations": list(by_org.keys()),
}
64 changes: 64 additions & 0 deletions tests/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,67 @@ def run(self):
assert response.status_code == 200
assert captured.get("user_id") == str(user.id)
assert captured.get("question") == "What is my organization?"


def test_get_my_pending_leaves_tool_registered_for_authenticated_user():
from ai.tools import AI_TOOLS, get_ai_function_map

tool_names = [t["name"] for t in AI_TOOLS]
assert "get_my_pending_leaves" in tool_names

fn_map = get_ai_function_map(user_id="test-uuid")
assert "get_my_pending_leaves" in fn_map


def test_get_my_pending_leaves_tool_excluded_when_user_id_none():
from ai.clients import CohereClient

client = CohereClient(user_id=None)
tool_names = [t["name"] for t in client.tools]
assert "get_my_pending_leaves" not in tool_names


def test_get_my_pending_leaves_returns_approved_and_policy(
app,
create_user,
create_organization,
create_user_organization,
create_leave_request,
monkeypatch,
):
from unittest.mock import MagicMock

from users.choices import LeaveType

from ai import tools as ai_tools

# Mock RAGClient so we don't hit Cohere/pgvector in tests
mock_rag = MagicMock()
mock_rag.query_policy_index.return_value = []
monkeypatch.setattr(ai_tools, "RAGClient", lambda: mock_rag)

user = create_user(username="pending-user", email="pending@example.com")
org = create_organization(name="Pending Org")
create_user_organization(user_id=user.id, organization_id=org.id)
create_leave_request(
user_id=user.id, organization_id=org.id, is_accepted=True, leave_type=LeaveType.SICK_LEAVE
)

fn_map = ai_tools.get_ai_function_map(user_id=str(user.id))
result = fn_map["get_my_pending_leaves"]()

assert "approved_leaves" in result
assert "total_approved_days" in result
assert "policy_excerpts" in result
assert result["total_approved_days"] == 1
assert len(result["approved_leaves"]) == 1


def test_policy_prompt_format_no_key_error():
"""Ensure POLICY_PROMPT.format() does not raise KeyError (e.g. 'os')."""
from ai.prompts import POLICY_PROMPT

# Verify template uses safe placeholders
prompt = POLICY_PROMPT.format(excerpts_text="Sample excerpt", question="How many days?")
assert "Sample excerpt" in prompt
assert "How many days?" in prompt
85 changes: 85 additions & 0 deletions tests/test_organizations_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,88 @@ def test_get_organization_ids_for_user_excludes_inactive(
ids = org_db.get_organization_ids_for_user(str(user.id))

assert ids == []


def test_get_my_approved_leaves_summary_empty_when_no_approved_leaves(
app, create_user, create_organization, create_user_organization
):
user = create_user(username="no-leaves", email="no-leaves@example.com")
org = create_organization(name="Leave Org")
create_user_organization(user_id=user.id, organization_id=org.id)

result = org_db.get_my_approved_leaves_summary(str(user.id))

assert result["total_approved_days"] == 0
assert result["approved_leaves"] == []
assert result["organizations"] == []


def test_get_my_approved_leaves_summary_counts_approved_only(
app,
create_user,
create_organization,
create_user_organization,
create_leave_request,
):
from users.choices import LeaveType

user = create_user(username="leave-user", email="leave@example.com")
org = create_organization(name="Leave Org")
create_user_organization(user_id=user.id, organization_id=org.id)

create_leave_request(
user_id=user.id, organization_id=org.id, is_accepted=True, leave_type=LeaveType.SICK_LEAVE
)
create_leave_request(
user_id=user.id, organization_id=org.id, is_accepted=True, leave_type=LeaveType.SICK_LEAVE
)
create_leave_request(
user_id=user.id,
organization_id=org.id,
is_accepted=False,
leave_type=LeaveType.PRIVILEGE_LEAVE,
)

result = org_db.get_my_approved_leaves_summary(str(user.id))

assert result["total_approved_days"] == 2
assert len(result["approved_leaves"]) == 1
assert result["approved_leaves"][0]["organization_name"] == "Leave Org"
assert result["approved_leaves"][0]["by_type"]["SICK_LEAVE"] == 2
assert result["approved_leaves"][0]["total"] == 2


def test_get_my_approved_leaves_summary_grouped_by_org_and_type(
app,
create_user,
create_organization,
create_user_organization,
create_leave_request,
):
from users.choices import LeaveType

user = create_user(username="multi-leave", email="multi@example.com")
org1 = create_organization(name="Org A")
org2 = create_organization(name="Org B")
create_user_organization(user_id=user.id, organization_id=org1.id)
create_user_organization(user_id=user.id, organization_id=org2.id)

create_leave_request(
user_id=user.id, organization_id=org1.id, is_accepted=True, leave_type=LeaveType.SICK_LEAVE
)
create_leave_request(
user_id=user.id,
organization_id=org1.id,
is_accepted=True,
leave_type=LeaveType.PRIVILEGE_LEAVE,
)
create_leave_request(
user_id=user.id, organization_id=org2.id, is_accepted=True, leave_type=LeaveType.SICK_LEAVE
)

result = org_db.get_my_approved_leaves_summary(str(user.id))

assert result["total_approved_days"] == 3
assert len(result["approved_leaves"]) == 2
org_names = {a["organization_name"] for a in result["approved_leaves"]}
assert org_names == {"Org A", "Org B"}