Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MAX_HISTORY = 20


class SchedulingAgent:
class PolicyAgent:
def __init__(self, question: str, session_id: str | None = None):
self.question = question
self.session_id = session_id
Expand Down
12 changes: 2 additions & 10 deletions ai/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import Depends, Query, status
from sqlalchemy.orm import Session

from ai.agent import SchedulingAgent
from ai.agent import PolicyAgent
from ai.db import PolicyEmbedding
from ai.models import QNARequestBody, QNAResponseBody
from application.app import app
Expand All @@ -21,18 +21,10 @@ async def ai_assistant(request: QNARequestBody) -> QNAResponseBody:
Payload for the endpoint:
{
"question": "give me date of birth of Dr. ruso lamba"
"question": "give me date of birth of Dr. fila delphia"
"question": "give me date of birth of patient jake funro"
"question": "what is the speciality of Dr. ruso lamba"
"question": "Which organ is andre russo associated with?"
"question": "give me contact details of andre russo"
"question": "get list of doctors for patient jack kallis"
"question": "Is Dr. mark ruffello available on 2026-01-30 02:00 for 45 minutes?"
"question": "create appointment for patient jack kallis on 2026-01-30 at 01:30"
}
"""
session_id = request.session_id or str(uuid.uuid4())
result = SchedulingAgent(
result = PolicyAgent(
question=request.question,
session_id=session_id,
).run()
Expand Down
15 changes: 9 additions & 6 deletions ai/prompts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
PREAMBLE = """
## Task & Context
You help people manage leave queries, appointments, and understand organization policies.
Use the provided tools to look up appointments, check availability, query organization
details, and understand leave policies.
You help users manage leave queries and understand organization policies.
Assume the user belongs to an organization and may ask questions like:
- "How many leaves are pending?"
- "What is the leave policy for this company?"
- "Show my leave requests for this month."
Use the provided tools to look up organization details and policies.

## Response Rules
- Use tools to check availability before creating appointments.
- Provide clear confirmation when appointments are created or cancelled.
- When asked about leave policies, use organization tools to fetch accurate data.
- Use tools to fetch organization and policy data when answering leave questions.
- If a user asks about leave counts or pending requests and no tool data is available,
ask a brief follow-up question or explain the limitation.
- Always reference the organization name and policy details when answering leave questions.

## Style Guide
Expand Down
16 changes: 7 additions & 9 deletions organizations/constants.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
def get_organization_function_map():
from organizations.db import (
get_all_organizations,
get_organization_by_name,
get_policies_for_organization,
get_policy_details,
)
from organizations.db import (
get_organization_details,
get_policies_for_organization,
get_policy_details,
)

def get_organization_function_map():
return {
"get_organization_by_name": get_organization_by_name,
"get_all_organizations": get_all_organizations,
"get_organization_details": get_organization_details,
"get_policies_for_organization": get_policies_for_organization,
"get_policy_details": get_policy_details,
}
37 changes: 7 additions & 30 deletions organizations/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from organizations.models import Organization, Policy


def get_organization_by_name(name: str):
def get_organization_details(organization_name: str):
"""Get organization details by name."""
with SessionLocal() as db:
org = (
db.query(Organization)
.filter(func.lower(Organization.name).contains(name.lower()))
.filter(func.lower(Organization.name).contains(organization_name.lower()))
.first()
)
if org:
Expand All @@ -22,30 +22,7 @@ def get_organization_by_name(name: str):
"phone": org.phone,
"is_active": org.is_active,
}
return {"detail": "Organization not found", "name": name}


def get_all_organizations():
"""Get all active organizations."""
with SessionLocal() as db:
orgs = (
db.query(Organization)
.filter(Organization.is_active == True)
.order_by(Organization.name)
.all()
)
return {
"organizations": [
{
"id": str(org.id),
"name": org.name,
"description": org.description,
"is_active": org.is_active,
}
for org in orgs
],
"total": len(orgs),
}
return {"detail": "Organization not found", "name": organization_name}


def get_policies_for_organization(organization_name: str):
Expand All @@ -61,7 +38,7 @@ def get_policies_for_organization(organization_name: str):

policies = (
db.query(Policy)
.filter(Policy.organization_id == org.id, Policy.is_active == True)
.filter(Policy.organization_id == org.id, Policy.is_active.is_(True))
.all()
)
return {
Expand All @@ -81,23 +58,23 @@ def get_policies_for_organization(organization_name: str):
}


def get_policy_details(policy_name: str):
def get_policy_details(policy_name: str, organization_name: str):
"""Get policy details by name."""
with SessionLocal() as db:
policy = (
db.query(Policy)
.filter(func.lower(Policy.name).contains(policy_name.lower()))
.filter(func.lower(Organization.name).contains(organization_name.lower()))
.first()
)
if policy:
org = db.query(Organization).filter(Organization.id == policy.organization_id).first()
return {
"id": str(policy.id),
"name": policy.name,
"description": policy.description,
"document_name": policy.document_name,
"file_path": policy.file,
"is_active": policy.is_active,
"organization": org.name if org else None,
"organization": organization_name,
}
return {"detail": "Policy not found", "name": policy_name}
11 changes: 3 additions & 8 deletions organizations/tools.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
ORGANIZATION_TOOLS = [
{
"name": "get_organization_by_name",
"name": "get_organization_details",
"description": "Returns organization details by searching for the name.",
"parameter_definitions": {
"name": {
"description": "The name or partial name of the organization.",
"organization_name": {
"description": "The name of the organization.",
"type": "str",
"required": True,
},
},
},
{
"name": "get_all_organizations",
"description": "Returns a list of all active organizations.",
"parameter_definitions": {},
},
{
"name": "get_policies_for_organization",
"description": "Returns all policies for a specific organization including document details.",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ask_llm(self, message=None, chat_history=None, max_steps=8):
monkeypatch.setattr(agent_module, "CohereClient", DummyClient)
agent_module.SESSION_MEMORY.clear()

agent = agent_module.SchedulingAgent("hello", session_id="s1")
agent = agent_module.PolicyAgent("hello", session_id="s1")
result = agent.run()
assert "test@example.com" in result["response"]
assert "555-555-1234" in result["messages"][0]["message"]
Expand All @@ -34,7 +34,7 @@ def ask_llm(self, message=None, chat_history=None, max_steps=8):
monkeypatch.setattr(agent_module, "CohereClient", DummyClient)
agent_module.SESSION_MEMORY.clear()

agent = agent_module.SchedulingAgent("hello", session_id="s2")
agent = agent_module.PolicyAgent("hello", session_id="s2")
result = agent.run()
assert len(result["messages"]) == agent_module.MAX_HISTORY + 5
assert len(agent_module.SESSION_MEMORY["s2"]) == agent_module.MAX_HISTORY
Expand All @@ -55,7 +55,7 @@ def run(self):
"messages": [{"role": "CHATBOT", "message": "ok"}],
}

monkeypatch.setattr(ai_apis, "SchedulingAgent", DummyAgent)
monkeypatch.setattr(ai_apis, "PolicyAgent", DummyAgent)
user = create_user(username="ai-user", email="ai-user@example.com")
response = client.post(
"/ai_assistant",
Expand Down