Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
811e756
adds leave request model
Aakash-Pandit Feb 6, 2026
40fac3b
updates gitignore
Aakash-Pandit Feb 6, 2026
b1d7218
remove appointments
Aakash-Pandit Feb 6, 2026
9171ee2
adds test cases
Aakash-Pandit Feb 6, 2026
699670c
updates client
Aakash-Pandit Feb 6, 2026
6f1131d
adds db clear functions
Aakash-Pandit Feb 6, 2026
56e795e
updates test file config
Aakash-Pandit Feb 6, 2026
17b4c4e
adds choices
Aakash-Pandit Feb 6, 2026
ec267bd
adds leaves request apis
Aakash-Pandit Feb 6, 2026
6541c71
updates tools
Aakash-Pandit Feb 6, 2026
8ac05d1
updates db queries
Aakash-Pandit Feb 6, 2026
2dcb1b7
adds organization api
Aakash-Pandit Feb 6, 2026
c4a1985
updates mapping
Aakash-Pandit Feb 6, 2026
53e19ea
adds user organization
Aakash-Pandit Feb 6, 2026
55a7683
adds utils
Aakash-Pandit Feb 6, 2026
496ee11
updates api
Aakash-Pandit Feb 6, 2026
c6d96d8
updates apis
Aakash-Pandit Feb 7, 2026
9184939
updates api endpoint
Aakash-Pandit Feb 7, 2026
8218dd5
updates compose file
Aakash-Pandit Feb 9, 2026
485dc8a
adds command to remove volumes and orphan container
Aakash-Pandit Feb 9, 2026
2d368c4
adds pgvector
Aakash-Pandit Feb 9, 2026
38503b5
updates client
Aakash-Pandit Feb 9, 2026
18f531f
initializes pgvector db
Aakash-Pandit Feb 9, 2026
8807b66
updates api
Aakash-Pandit Feb 9, 2026
3d7f84e
updates utils
Aakash-Pandit Feb 9, 2026
43553f9
adds policy embedding
Aakash-Pandit Feb 9, 2026
1e4e521
adds client of RAG
Aakash-Pandit Feb 9, 2026
f9e1209
adds libs
Aakash-Pandit Feb 9, 2026
983010f
adds api to get embedding
Aakash-Pandit Feb 9, 2026
73cbcde
updates test cases
Aakash-Pandit Feb 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,5 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/

uploads/*
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ start:
stop:
docker-compose down

remove:
docker-compose down -v --remove-orphans

build:
docker-compose build

Expand Down
44 changes: 43 additions & 1 deletion ai/apis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import uuid

from fastapi import status
from fastapi import Depends, Query, status
from sqlalchemy.orm import Session

from ai.agent import SchedulingAgent
from ai.db import PolicyEmbedding
from ai.models import QNARequestBody, QNAResponseBody
from application.app import app
from database.db import get_db


@app.post(
Expand Down Expand Up @@ -39,3 +42,42 @@ async def ai_assistant(request: QNARequestBody) -> QNAResponseBody:
"session_id": result.get("session_id") or session_id,
"messages": result.get("messages"),
}


@app.get("/ai/policy-embeddings")
async def get_policy_embeddings(
policy_id: str | None = None,
organization_id: str | None = None,
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
query = db.query(PolicyEmbedding)
if policy_id:
query = query.filter(PolicyEmbedding.policy_id == policy_id)
if organization_id:
query = query.filter(PolicyEmbedding.organization_id == organization_id)
rows = (
query.order_by(PolicyEmbedding.created.desc()).offset(offset).limit(limit).all()
)
embeddings = []
for row in rows:
item = {
"id": str(row.id),
"policy_id": str(row.policy_id),
"organization_id": str(row.organization_id),
"policy_name": row.policy_name,
"description": row.description,
"document_name": row.document_name,
"file_path": row.file_path,
"chunk_index": row.chunk_index,
"text": row.text,
"created": row.created,
}
embeddings.append(item)
return {
"items": embeddings,
"total": len(embeddings),
"limit": limit,
"offset": offset,
}
17 changes: 11 additions & 6 deletions ai/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
import cohere

from ai.prompts import PREAMBLE
from appointments.constants import get_appointment_function_map
from appointments.tools import APPOINTMENT_TOOLS
from organizations.constants import get_organization_function_map
from organizations.tools import ORGANIZATION_TOOLS


class CohereClient:
def __init__(self, message: str | None = None):
def __init__(self, message: str | None = None, model: str | None = None):
self.client = cohere.Client(os.getenv("COHERE_API_KEY"))
self.model = os.getenv("COHERE_LLM_MODEL")
self.model = model or os.getenv("COHERE_LLM_MODEL")
self.preamble = PREAMBLE
self.function_map = {
**get_appointment_function_map(),
**get_organization_function_map(),
}
self.tools = [*APPOINTMENT_TOOLS, *ORGANIZATION_TOOLS]
self.tools = [*ORGANIZATION_TOOLS]
self.message = message or ""

def chat(
Expand Down Expand Up @@ -100,3 +97,11 @@ def ask_llm(
return str(e), chat_history or []

return (response.text if response else ""), history

def embed_texts(self, texts: list[str], input_type: str) -> list[list[float]]:
response = self.client.embed(
texts=texts,
model=self.model,
input_type=input_type,
)
return response.embeddings or []
27 changes: 27 additions & 0 deletions ai/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import uuid

from pgvector.sqlalchemy import Vector
from sqlalchemy import Column, DateTime, Integer, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func

from database.db import Base

DEFAULT_EMBED_DIM = int(os.getenv("RAG_EMBED_DIM", "1024"))


class PolicyEmbedding(Base):
__tablename__ = "policy_embeddings"

id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
policy_id = Column(UUID(as_uuid=True), nullable=False, index=True)
organization_id = Column(UUID(as_uuid=True), nullable=False, index=True)
policy_name = Column(String(255), nullable=False)
description = Column(Text, nullable=True)
document_name = Column(String(255), nullable=True)
file_path = Column(String(500), nullable=False)
chunk_index = Column(Integer, nullable=False)
text = Column(Text, nullable=False)
embedding = Column(Vector(DEFAULT_EMBED_DIM), nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now())
185 changes: 185 additions & 0 deletions ai/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging
import os
import re
from typing import Iterable

import httpx
from sqlalchemy import delete, select

from ai.clients import CohereClient
from ai.db import PolicyEmbedding
from database.db import SessionLocal


logger = logging.getLogger(__name__)


class RAGClient:
def __init__(self, embed_model: str | None = None):
self.embed_model = embed_model or os.getenv("COHERE_EMBED_MODEL", "embed-english-v3.0")

def _looks_like_text(self, raw: bytes) -> bool:
if not raw:
return False
if b"\x00" in raw:
return False
sample = raw[:2048]
non_printable = sum(1 for byte in sample if byte < 9 or (13 < byte < 32))
return non_printable / max(len(sample), 1) < 0.2

def _extract_text_from_pdf(self, file_path: str) -> str:
try:
from pypdf import PdfReader # type: ignore
except Exception:
return ""
try:
reader = PdfReader(file_path)
return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
except Exception:
return ""

def _extract_text_from_docx(self, file_path: str) -> str:
try:
from docx import Document # type: ignore
except Exception:
return ""
try:
doc = Document(file_path)
return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
except Exception:
return ""

def _read_text_from_source(self, file_path: str) -> str:
if file_path.startswith("http://") or file_path.startswith("https://"):
try:
response = httpx.get(file_path, timeout=20.0)
response.raise_for_status()
raw = response.content
except Exception:
return ""
else:
if not os.path.exists(file_path):
return ""
with open(file_path, "rb") as handle:
raw = handle.read()

ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
return self._extract_text_from_pdf(file_path)
if ext == ".docx":
return self._extract_text_from_docx(file_path)

if self._looks_like_text(raw):
return raw.decode("utf-8", errors="ignore").strip()

return ""

def _clean_text(self, text: str) -> str:
return re.sub(r"\s+", " ", text).strip()

def _chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 200) -> list[str]:
cleaned = self._clean_text(text)
if not cleaned:
return []
chunks = []
start = 0
while start < len(cleaned):
end = min(len(cleaned), start + max_chars)
chunk = cleaned[start:end].strip()
if chunk:
chunks.append(chunk)
if end >= len(cleaned):
break
start = max(0, end - overlap)
return chunks

def _embed_texts(self, texts: Iterable[str], input_type: str) -> list[list[float]]:
client = CohereClient(model=self.embed_model)
response = client.embed_texts(list(texts), input_type=input_type)

logger.info(f"Embedded {len(texts)} texts with model {self.embed_model} & response: {response}")
return response

def index_policy_document(
self,
policy_id: str,
organization_id: str,
policy_name: str,
description: str | None,
document_name: str | None,
file_path: str,
) -> dict:
text = self._read_text_from_source(file_path)
if not text:
return {"status": "skipped", "reason": "no_text_extracted"}

chunks = self._chunk_text(text)
if not chunks:
return {"status": "skipped", "reason": "no_chunks_created"}

embeddings = self._embed_texts(chunks, input_type="search_document")
if not embeddings:
return {"status": "skipped", "reason": "no_embeddings_created"}

with SessionLocal() as db:
db.execute(
delete(PolicyEmbedding).where(PolicyEmbedding.policy_id == policy_id)
)
for idx, (chunk, embedding) in enumerate(
zip(chunks, embeddings, strict=False)
):
db.add(
PolicyEmbedding(
policy_id=policy_id,
organization_id=organization_id,
policy_name=policy_name,
description=description,
document_name=document_name,
file_path=file_path,
chunk_index=idx,
text=chunk,
embedding=embedding,
)
)
db.commit()
return {"status": "indexed", "chunks": len(chunks)}

def remove_policy_from_index(self, policy_id: str) -> dict:
with SessionLocal() as db:
result = db.execute(
delete(PolicyEmbedding).where(PolicyEmbedding.policy_id == policy_id)
)
db.commit()
if result.rowcount == 0:
return {"status": "skipped", "reason": "policy_not_found"}
return {"status": "removed", "count": result.rowcount}

def query_policy_index(self, query: str, top_k: int = 5) -> list[dict]:
query_embedding = self._embed_texts([query], input_type="search_query")
if not query_embedding:
return []
query_vector = query_embedding[0]

with SessionLocal() as db:
stmt = (
select(PolicyEmbedding)
.order_by(PolicyEmbedding.embedding.cosine_distance(query_vector))
.limit(max(top_k, 1))
)
results = db.execute(stmt).scalars().all()
response = []
for record in results:
response.append(
{
"policy_id": str(record.policy_id),
"organization_id": str(record.organization_id),
"policy_name": record.policy_name,
"description": record.description,
"document_name": record.document_name,
"file_path": record.file_path,
"chunk_index": record.chunk_index,
"text": record.text,
"score": None,
}
)
return response
2 changes: 1 addition & 1 deletion application/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
app.add_middleware(AuthenticationMiddleware, backend=JWTAuthBackend())

import ai.apis
import appointments.apis
import ai.db
import auth.apis
import organizations.apis
import users.apis
Expand Down
Empty file removed appointments/__init__.py
Empty file.
Loading