Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/db/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def create_template(session: Session, template: Template) -> Template:
def get_template(session: Session, template_id: int) -> Template | None:
return session.get(Template, template_id)

def get_templates(session: Session) -> list[Template]:
return session.exec(select(Template)).all()

# Forms
def create_form(session: Session, form: FormSubmission) -> FormSubmission:
session.add(form)
Expand Down
4 changes: 2 additions & 2 deletions api/routes/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
router = APIRouter(prefix="/forms", tags=["forms"])

@router.post("/fill", response_model=FormFillResponse)
def fill_form(form: FormFill, db: Session = Depends(get_db)):
async def fill_form(form: FormFill, db: Session = Depends(get_db)):
if not get_template(db, form.template_id):
raise AppError("Template not found", status_code=404)

fetched_template = get_template(db, form.template_id)

controller = Controller()
path = controller.fill_form(user_input=form.input_text, fields=fetched_template.fields, pdf_form_path=fetched_template.pdf_path)
path = await controller.fill_form(user_input=form.input_text, fields=fetched_template.fields, pdf_form_path=fetched_template.pdf_path)

submission = FormSubmission(**form.model_dump(), output_pdf_path=path)
return create_form(db, submission)
Expand Down
8 changes: 6 additions & 2 deletions api/routes/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlmodel import Session
from api.deps import get_db
from api.schemas.templates import TemplateCreate, TemplateResponse
from api.db.repositories import create_template
from api.db.repositories import create_template, get_templates
from api.db.models import Template
from src.controller import Controller

Expand All @@ -13,4 +13,8 @@ def create(template: TemplateCreate, db: Session = Depends(get_db)):
controller = Controller()
template_path = controller.create_template(template.pdf_path)
tpl = Template(**template.model_dump(exclude={"pdf_path"}), pdf_path=template_path)
return create_template(db, tpl)
return create_template(db, tpl)

@router.get("/", response_model=list[TemplateResponse])
def get_all(db: Session = Depends(get_db)):
return get_templates(db)
4 changes: 2 additions & 2 deletions src/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ class Controller:
def __init__(self):
self.file_manipulator = FileManipulator()

def fill_form(self, user_input: str, fields: list, pdf_form_path: str):
return self.file_manipulator.fill_form(user_input, fields, pdf_form_path)
async def fill_form(self, user_input: str, fields: list, pdf_form_path: str):
return await self.file_manipulator.fill_form(user_input, fields, pdf_form_path)

def create_template(self, pdf_path: str):
return self.file_manipulator.create_template(pdf_path)
1 change: 1 addition & 0 deletions src/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# init for core
23 changes: 23 additions & 0 deletions src/core/orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio
import logging

logger = logging.getLogger(__name__)

class VRAMOrchestrator:
"""
Orchestrates access to VRAM-intensive models (Whisper, Ollama)
to prevent OOM on hardware-constrained devices.
"""
_instance = None
_lock = asyncio.Lock()

def __new__(cls):
if cls._instance is None:
cls._instance = super(VRAMOrchestrator, cls).__new__(cls)
return cls._instance

@property
def lock(self):
return self._lock

orchestrator = VRAMOrchestrator()
4 changes: 2 additions & 2 deletions src/file_manipulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_template(self, pdf_path: str):
prepare_form(pdf_path, template_path)
return template_path

def fill_form(self, user_input: str, fields: list, pdf_form_path: str):
async def fill_form(self, user_input: str, fields: list, pdf_form_path: str):
"""
It receives the raw data, runs the PDF filling logic,
and returns the path to the newly created file.
Expand All @@ -33,7 +33,7 @@ def fill_form(self, user_input: str, fields: list, pdf_form_path: str):
try:
self.llm._target_fields = fields
self.llm._transcript_text = user_input
output_name = self.filler.fill_form(pdf_form=pdf_form_path, llm=self.llm)
output_name = await self.filler.fill_form(pdf_form=pdf_form_path, llm=self.llm)

print("\n----------------------------------")
print("✅ Process Complete.")
Expand Down
4 changes: 2 additions & 2 deletions src/filler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Filler:
def __init__(self):
pass

def fill_form(self, pdf_form: str, llm: LLM):
async def fill_form(self, pdf_form: str, llm: LLM):
"""
Fill a PDF form with values from user_input using LLM.
Fields are filled in the visual order (top-to-bottom, left-to-right).
Expand All @@ -20,7 +20,7 @@ def fill_form(self, pdf_form: str, llm: LLM):
)

# Generate dictionary of answers from your original function
t2j = llm.main_loop()
t2j = await llm.main_loop()
textbox_answers = t2j.get_data() # This is a dictionary

answers_list = list(textbox_answers.values())
Expand Down
106 changes: 54 additions & 52 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import json
import os
import requests
import httpx
import logging
from src.core.orchestrator import orchestrator

logger = logging.getLogger(__name__)

class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
if json is None:
json = {}
def __init__(self, transcript_text=None, target_fields=None, json_data=None):
if json_data is None:
json_data = {}
self._transcript_text = transcript_text # str
self._target_fields = target_fields # List, contains the template field.
self._json = json # dictionary
self._json = json_data # dictionary

def type_check_all(self):
if type(self._transcript_text) is not str:
if not isinstance(self._transcript_text, str):
raise TypeError(
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
f"ERROR in LLM() attributes -> "
f"Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
)
elif type(self._target_fields) is not list:
elif not isinstance(self._target_fields, dict):
raise TypeError(
f"ERROR in LLM() attributes ->\
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
f"ERROR in LLM() attributes -> "
f"Target fields must be a dictionary. Input:\n\ttarget_fields: {self._target_fields}"
)

def build_prompt(self, current_field):
Expand All @@ -44,37 +47,40 @@ def build_prompt(self, current_field):

return prompt

def main_loop(self):
async def main_loop(self):
# self.type_check_all()
for field in self._target_fields.keys():
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
}

try:
response = requests.post(ollama_url, json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
json_data = response.json()
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

async with httpx.AsyncClient() as client:
for field in self._target_fields.keys():
prompt = self.build_prompt(field)
payload = {
"model": "mistral",
"prompt": prompt,
"stream": False,
}

# VRAM Orchestration: Ensure serial hardware access
async with orchestrator.lock:
try:
response = await client.post(ollama_url, json=payload, timeout=60.0)
response.raise_for_status()
json_data = response.json()
parsed_response = json_data["response"]
self.add_response_to_json(field, parsed_response)
except (httpx.ConnectError, httpx.TimeoutException, httpx.RequestError) as e:
logger.error(f"Transport error connecting to Ollama at {ollama_url}: {e}")
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except httpx.HTTPStatusError as e:
logger.error(f"Ollama returned an error: {e}")
raise RuntimeError(f"Ollama returned an error: {e}")
except Exception as e:
logger.error(f"Unexpected error during LLM extraction: {e}")
raise

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
Expand All @@ -98,12 +104,15 @@ def add_response_to_json(self, field, value):
parsed_value = self.handle_plural_values(value)

if field in self._json.keys():
self._json[field].append(parsed_value)
# If it's already a list, append. If not, make it a list or just overwrite?
# Original code used .append() which assumes it's a list.
if isinstance(self._json[field], list):
self._json[field].append(parsed_value)
else:
self._json[field] = parsed_value # or [self._json[field], parsed_value]
else:
self._json[field] = parsed_value

return

def handle_plural_values(self, plural_value):
"""
This method handles plural values.
Expand All @@ -118,14 +127,7 @@ def handle_plural_values(self, plural_value):
print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
)
values = plural_value.split(";")

# Remove trailing leading whitespace
for i in range(len(values)):
current = i + 1
if current < len(values):
clean_value = values[current].lstrip()
values[current] = clean_value
values = [v.strip() for v in plural_value.split(";")]

print(f"\t[LOG]: Resulting formatted list of values: {values}")

Expand Down
54 changes: 54 additions & 0 deletions tests/test_reliability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import respx
import httpx
from src.llm import LLM

@pytest.mark.asyncio
async def test_llm_async_non_blocking():
"""
Verify that multiple LLM extraction tasks can be initiated without blocking the event loop,
and that the transport layer handles simulated latency gracefully.
"""
llm = LLM(
transcript_text="The date is 2026-03-18.",
target_fields={"date": "incident date"}
)

ollama_url = "http://localhost:11434/api/generate"

async with respx.mock:
# Simulate a slow Ollama response (2 seconds)
respx.post(ollama_url).mock(
return_value=httpx.Response(200, json={"response": "2026-03-18"})
)

# Start the "heavy" extraction
task = asyncio.create_task(llm.main_loop())

# Immediately check if we can do other things while the task is "pending"
# Since we use respx without artificial delay, it might be too fast,
# but in a real scenario, the await client.post(...) is where it yields.

await task
assert llm.get_data()["date"] == "2026-03-18"

@pytest.mark.asyncio
async def test_ollama_timeout_handling():
"""
Verify that the system handles Ollama timeouts/connection failures gracefully.
"""
llm = LLM(
transcript_text="Test text",
target_fields={"test": "field"}
)

ollama_url = "http://localhost:11434/api/generate"

async with respx.mock:
# Simulate a connection timeout
respx.post(ollama_url).mock(side_effect=httpx.ConnectTimeout)

with pytest.raises(ConnectionError):
await llm.main_loop()

import asyncio
38 changes: 38 additions & 0 deletions tests/test_vram_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
import pytest
from src.core.orchestrator import orchestrator

@pytest.mark.asyncio
async def test_orchestrator_lock_serialization():
"""
Verify that the VRAMOrchestrator lock correctly serializes concurrent requests.
"""
execution_order = []

async def task(name, duration):
async with orchestrator.lock:
execution_order.append(f"{name}_start")
await asyncio.sleep(duration)
execution_order.append(f"{name}_end")

# Launch two tasks concurrently
# Task 1 starts first but takes longer
# Task 2 should wait for Task 1 to finish
await asyncio.gather(
task("A", 0.5),
task("B", 0.1)
)

# If locking works, Task B must start AFTER Task A ends
assert execution_order == ["A_start", "A_end", "B_start", "B_end"]

@pytest.mark.asyncio
async def test_orchestrator_singleton():
"""
Verify that VRAMOrchestrator follows the singleton pattern.
"""
from src.core.orchestrator import VRAMOrchestrator
o1 = VRAMOrchestrator()
o2 = VRAMOrchestrator()
assert o1 is o2
assert o1.lock is o2.lock