diff --git a/commonforms/inference.py b/commonforms/inference.py index ce8cc60..9eed964 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -74,9 +74,10 @@ def extract_widgets( results = [] for b in batch([p.image for p in pages], n=batch_size): predictions = self.model.predict(b, threshold=confidence) - if len(pages) == 1 or batch_size == 1: - predictions = [predictions] - results.extend(predictions) + if isinstance(predictions, list): + results.extend(predictions) + else: + results.append(predictions) widgets = {} @@ -264,9 +265,14 @@ def prepare_form( except pypdfium2._helpers.misc.PdfiumError: raise EncryptedPdfError - results = detector.extract_widgets( - pages, confidence=confidence, image_size=image_size - ) + if isinstance(detector, FFDetrDetector): + results = detector.extract_widgets( + pages, confidence=confidence, image_size=image_size, batch_size=batch_size + ) + else: + results = detector.extract_widgets( + pages, confidence=confidence, image_size=image_size + ) writer = PyPdfFormCreator(input_path) if not keep_existing_fields: