Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
.venv
key: uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }}
restore-keys: |
uv-${{ runner.os }}-${{ matrix.python-version }}-
uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }}

- name: Install (uv)
run: |
Expand Down
1 change: 1 addition & 0 deletions commonforms/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def main():

args = parser.parse_args()

print(f"**{args.confidence=}")
prepare_form(
args.input,
args.output,
Expand Down
109 changes: 100 additions & 9 deletions commonforms/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,109 @@
from ultralytics import YOLO
from pathlib import Path
from huggingface_hub import hf_hub_download
from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge

from commonforms.utils import BoundingBox, Page, Widget
from commonforms.form_creator import PyPdfFormCreator
from commonforms.exceptions import EncryptedPdfError

import formalpdf
import pypdfium2
import logging
import PIL


# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub
logging.basicConfig(level=logging.INFO)


# our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub.
# keeping it simple and declarative like this becuase it's not like we're adding a bunch
# of models.
models = {
("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"),
("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"),
("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"),
("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"),
("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"),
}


def batch(lst: list, n: int = 8):
l = len(lst)
for ndx in range(0, l, n):
yield lst[ndx : min(ndx + n, l)]


class FFDetrDetector:
def __init__(self, model_or_path: str, device: int | str = "cpu") -> None:
self.device = device
self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path))

self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"}

def get_model_path(self, model_or_path: str) -> str:
model_upper = model_or_path.upper()
if model_upper in ["FFDETR"]:
# download the model, will just use the cached version if it already exists
repo_id, filename = models[(model_upper, False)]
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
model_path = model_or_path

return model_path

def resize(
self,
image: PIL.Image.Image,
size: tuple[int, int] | int,
) -> PIL.Image.Image:
if isinstance(size, int):
size = (size, size)

return image.resize(size, PIL.Image.Resampling.LANCZOS)

def extract_widgets(
self,
pages: list[Page],
confidence: float = 0.4,
image_size: int = 1120,
batch_size: int = 3,
) -> dict[int, list[Widget]]:
image_size = 1024
results = []
for b in batch([self.resize(p.image, image_size) for p in pages], n=batch_size):
predictions = self.model.predict(b, threshold=confidence)
if len(pages) == 1 or batch_size == 1:
predictions = [predictions]
results.extend(predictions)

widgets = {}

for page_ix, detections in enumerate(results):
logging.info(f" Page {page_ix}: {len(detections)} fields detected")
detections = detections.with_nms(threshold=0.1, class_agnostic=True)
logging.info(f"\t\t{len(detections)} after nms")
widgets[page_ix] = []

for class_id, box in zip(detections.class_id, detections.xyxy):
x0, x1 = box[[0, 2]] / pages[page_ix].image.width
y0, y1 = box[[1, 3]] / pages[page_ix].image.height

widget_type = self.id_to_cls[class_id]

widgets[page_ix].append(
Widget(
widget_type=widget_type,
bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1),
page=page_ix,
)
)

widgets[page_ix] = sort_widgets(widgets[page_ix])

return widgets


class FFDNetDetector:
def __init__(
self, model_or_path: str, device: int | str = "cpu", fast: bool = False
Expand All @@ -43,8 +128,8 @@ def get_model_path(
model_upper = model_or_path.upper()
if model_upper in ["FFDNET-S", "FFDNET-L"]:
# download the model, will just use the cached version if it already exists
repo_id, filename = models[(model_upper, fast)]
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
repo_id, filename = models[(model_upper, fast)]
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
model_path = model_or_path

Expand Down Expand Up @@ -148,7 +233,7 @@ def render_pdf(pdf_path: str) -> list[Page]:
doc = formalpdf.open(pdf_path)
try:
for page in doc:
image = page.render()
image = page.render(dpi=144)
pages.append(Page(image=image, width=image.width, height=image.height))
return pages
finally:
Expand All @@ -159,16 +244,20 @@ def prepare_form(
input_path: str | Path,
output_path: str | Path,
*,
model_or_path: str = "FFDNet-L",
model_or_path: str = "FFDetr",
keep_existing_fields: bool = False,
use_signature_fields: bool = False,
device: int | str = "cpu",
image_size: int = 1600,
confidence: float = 0.3,
image_size: int = 1024,
confidence: float = 0.4,
fast: bool = False,
multiline: bool = False,
batch_size: int = 4,
):
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
if "FFDNET" in model_or_path.upper():
detector = FFDNetDetector(model_or_path, device=device, fast=fast)
else:
detector = FFDetrDetector(model_or_path)

try:
pages = render_pdf(input_path)
Expand All @@ -188,7 +277,9 @@ def prepare_form(
name = f"{widget.widget_type.lower()}_{widget.page}_{i}"

if widget.widget_type == "TextBox":
writer.add_text_box(name, page_ix, widget.bounding_box, multiline=multiline)
writer.add_text_box(
name, page_ix, widget.bounding_box, multiline=multiline
)
elif widget.widget_type == "ChoiceButton":
writer.add_checkbox(name, page_ix, widget.bounding_box)
elif widget.widget_type == "Signature":
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ dependencies = [
"pillow>=11.3.0",
"pydantic>=2.11.9",
"pypdf>=6.1.1",
"rfdetr>=1.3.0",
"transformers>=4.57",
"ultralytics>=8.3.204",
]

Expand Down
28 changes: 22 additions & 6 deletions tests/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def test_inference(tmp_path):
# tmp_path is a built-in pythest fixture where we'll write the outputs
output_path = tmp_path / "output.pdf"
commonforms.prepare_form("./tests/resources/input.pdf", output_path)
commonforms.prepare_form("./tests/resources/input.pdf", output_path, model_or_path="FFDetr")

assert output_path.exists()

Expand All @@ -20,7 +20,7 @@ def test_inference(tmp_path):

def test_inference_fast(tmp_path):
output_path = tmp_path / "output.pdf"
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True)
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, model_or_path="FFDNet-L")

assert output_path.exists()

Expand All @@ -32,7 +32,9 @@ def test_inference_fast(tmp_path):

def test_mutlinline(tmp_path):
output_path = tmp_path / "output.pdf"
commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, multiline=True)
commonforms.prepare_form(
"./tests/resources/input.pdf", output_path, fast=True, multiline=True
)

assert output_path.exists()

Expand All @@ -42,7 +44,6 @@ def test_mutlinline(tmp_path):
doc.document.close()



def test_encrypted_failure(tmp_path):
# Reminder to future Joe: password for encrypted PDF is "kanbanery"
output_path = tmp_path / "output.pdf"
Expand All @@ -51,7 +52,22 @@ def test_encrypted_failure(tmp_path):
commonforms.prepare_form("./tests/resources/encrypted.pdf", output_path)


def test_inference_ffdetr(tmp_path):
# tmp_path is a built-in pythest fixture where we'll write the outputs
output_path = tmp_path / "output.pdf"
commonforms.prepare_form(
"./tests/resources/input.pdf", output_path, model_or_path="FFDetr"
)

assert output_path.exists()

doc = formalpdf.open(output_path)
assert len(doc[0].widgets()) > 0

doc.document.close()


# TODO(joe): future tests around handling encrypted PDFs
# 1. add a --password flag and test that inference doesn't fail
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
# with the same password
# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted
# with the same password