diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f83175f..5d5edaa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: .venv key: uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }} restore-keys: | - uv-${{ runner.os }}-${{ matrix.python-version }}- + uv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml', 'uv.lock') }} - name: Install (uv) run: | diff --git a/commonforms/__main__.py b/commonforms/__main__.py index f66df1a..08be6a6 100644 --- a/commonforms/__main__.py +++ b/commonforms/__main__.py @@ -58,6 +58,7 @@ def main(): args = parser.parse_args() + print(f"**{args.confidence=}") prepare_form( args.input, args.output, diff --git a/commonforms/inference.py b/commonforms/inference.py index 7ac5827..0d5ccc8 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -2,6 +2,7 @@ from ultralytics import YOLO from pathlib import Path from huggingface_hub import hf_hub_download +from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge from commonforms.utils import BoundingBox, Page, Widget from commonforms.form_creator import PyPdfFormCreator @@ -9,17 +10,101 @@ import formalpdf import pypdfium2 +import logging +import PIL -# our mapping from (model_name, fast) to (repo_id, filename) for the huggingface hub +logging.basicConfig(level=logging.INFO) + + +# our mapping from (model_name_upper, fast) to (repo_id, filename) for the huggingface hub. +# keeping it simple and declarative like this becuase it's not like we're adding a bunch +# of models. models = { ("FFDNET-S", True): ("jbarrow/FFDNet-S-cpu", "FFDNet-S.onnx"), ("FFDNET-S", False): ("jbarrow/FFDNet-S", "FFDNet-S.pt"), ("FFDNET-L", True): ("jbarrow/FFDNet-L-cpu", "FFDNet-L.onnx"), ("FFDNET-L", False): ("jbarrow/FFDNet-L", "FFDNet-L.pt"), + ("FFDETR", False): ("jbarrow/FFDetr", "FFDetr.pth"), } +def batch(lst: list, n: int = 8): + l = len(lst) + for ndx in range(0, l, n): + yield lst[ndx : min(ndx + n, l)] + + +class FFDetrDetector: + def __init__(self, model_or_path: str, device: int | str = "cpu") -> None: + self.device = device + self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path)) + + self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"} + + def get_model_path(self, model_or_path: str) -> str: + model_upper = model_or_path.upper() + if model_upper in ["FFDETR"]: + # download the model, will just use the cached version if it already exists + repo_id, filename = models[(model_upper, False)] + model_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + model_path = model_or_path + + return model_path + + def resize( + self, + image: PIL.Image.Image, + size: tuple[int, int] | int, + ) -> PIL.Image.Image: + if isinstance(size, int): + size = (size, size) + + return image.resize(size, PIL.Image.Resampling.LANCZOS) + + def extract_widgets( + self, + pages: list[Page], + confidence: float = 0.4, + image_size: int = 1120, + batch_size: int = 3, + ) -> dict[int, list[Widget]]: + image_size = 1024 + results = [] + for b in batch([self.resize(p.image, image_size) for p in pages], n=batch_size): + predictions = self.model.predict(b, threshold=confidence) + if len(pages) == 1 or batch_size == 1: + predictions = [predictions] + results.extend(predictions) + + widgets = {} + + for page_ix, detections in enumerate(results): + logging.info(f" Page {page_ix}: {len(detections)} fields detected") + detections = detections.with_nms(threshold=0.1, class_agnostic=True) + logging.info(f"\t\t{len(detections)} after nms") + widgets[page_ix] = [] + + for class_id, box in zip(detections.class_id, detections.xyxy): + x0, x1 = box[[0, 2]] / pages[page_ix].image.width + y0, y1 = box[[1, 3]] / pages[page_ix].image.height + + widget_type = self.id_to_cls[class_id] + + widgets[page_ix].append( + Widget( + widget_type=widget_type, + bounding_box=BoundingBox(x0=x0, y0=y0, x1=x1, y1=y1), + page=page_ix, + ) + ) + + widgets[page_ix] = sort_widgets(widgets[page_ix]) + + return widgets + + class FFDNetDetector: def __init__( self, model_or_path: str, device: int | str = "cpu", fast: bool = False @@ -43,8 +128,8 @@ def get_model_path( model_upper = model_or_path.upper() if model_upper in ["FFDNET-S", "FFDNET-L"]: # download the model, will just use the cached version if it already exists - repo_id, filename = models[(model_upper, fast)] - model_path = hf_hub_download(repo_id=repo_id, filename=filename) + repo_id, filename = models[(model_upper, fast)] + model_path = hf_hub_download(repo_id=repo_id, filename=filename) else: model_path = model_or_path @@ -148,7 +233,7 @@ def render_pdf(pdf_path: str) -> list[Page]: doc = formalpdf.open(pdf_path) try: for page in doc: - image = page.render() + image = page.render(dpi=144) pages.append(Page(image=image, width=image.width, height=image.height)) return pages finally: @@ -159,16 +244,20 @@ def prepare_form( input_path: str | Path, output_path: str | Path, *, - model_or_path: str = "FFDNet-L", + model_or_path: str = "FFDetr", keep_existing_fields: bool = False, use_signature_fields: bool = False, device: int | str = "cpu", - image_size: int = 1600, - confidence: float = 0.3, + image_size: int = 1024, + confidence: float = 0.4, fast: bool = False, multiline: bool = False, + batch_size: int = 4, ): - detector = FFDNetDetector(model_or_path, device=device, fast=fast) + if "FFDNET" in model_or_path.upper(): + detector = FFDNetDetector(model_or_path, device=device, fast=fast) + else: + detector = FFDetrDetector(model_or_path) try: pages = render_pdf(input_path) @@ -188,7 +277,9 @@ def prepare_form( name = f"{widget.widget_type.lower()}_{widget.page}_{i}" if widget.widget_type == "TextBox": - writer.add_text_box(name, page_ix, widget.bounding_box, multiline=multiline) + writer.add_text_box( + name, page_ix, widget.bounding_box, multiline=multiline + ) elif widget.widget_type == "ChoiceButton": writer.add_checkbox(name, page_ix, widget.bounding_box) elif widget.widget_type == "Signature": diff --git a/pyproject.toml b/pyproject.toml index a613152..bbca4b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ dependencies = [ "pillow>=11.3.0", "pydantic>=2.11.9", "pypdf>=6.1.1", + "rfdetr>=1.3.0", + "transformers>=4.57", "ultralytics>=8.3.204", ] diff --git a/tests/inference_test.py b/tests/inference_test.py index 62b9474..5b8693b 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -8,7 +8,7 @@ def test_inference(tmp_path): # tmp_path is a built-in pythest fixture where we'll write the outputs output_path = tmp_path / "output.pdf" - commonforms.prepare_form("./tests/resources/input.pdf", output_path) + commonforms.prepare_form("./tests/resources/input.pdf", output_path, model_or_path="FFDetr") assert output_path.exists() @@ -20,7 +20,7 @@ def test_inference(tmp_path): def test_inference_fast(tmp_path): output_path = tmp_path / "output.pdf" - commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True) + commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, model_or_path="FFDNet-L") assert output_path.exists() @@ -32,7 +32,9 @@ def test_inference_fast(tmp_path): def test_mutlinline(tmp_path): output_path = tmp_path / "output.pdf" - commonforms.prepare_form("./tests/resources/input.pdf", output_path, fast=True, multiline=True) + commonforms.prepare_form( + "./tests/resources/input.pdf", output_path, fast=True, multiline=True + ) assert output_path.exists() @@ -42,7 +44,6 @@ def test_mutlinline(tmp_path): doc.document.close() - def test_encrypted_failure(tmp_path): # Reminder to future Joe: password for encrypted PDF is "kanbanery" output_path = tmp_path / "output.pdf" @@ -51,7 +52,22 @@ def test_encrypted_failure(tmp_path): commonforms.prepare_form("./tests/resources/encrypted.pdf", output_path) +def test_inference_ffdetr(tmp_path): + # tmp_path is a built-in pythest fixture where we'll write the outputs + output_path = tmp_path / "output.pdf" + commonforms.prepare_form( + "./tests/resources/input.pdf", output_path, model_or_path="FFDetr" + ) + + assert output_path.exists() + + doc = formalpdf.open(output_path) + assert len(doc[0].widgets()) > 0 + + doc.document.close() + + # TODO(joe): future tests around handling encrypted PDFs # 1. add a --password flag and test that inference doesn't fail -# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted -# with the same password +# 2. if a password is provided, ensure that the _output_ PDF remains encrpyted +# with the same password