diff --git a/commonforms/inference.py b/commonforms/inference.py index 9eed964..f8d5a6b 100644 --- a/commonforms/inference.py +++ b/commonforms/inference.py @@ -4,11 +4,10 @@ from huggingface_hub import hf_hub_download from rfdetr import RFDETRNano, RFDETRBase, RFDETRMedium, RFDETRLarge -from commonforms.utils import BoundingBox, Page, Widget +from commonforms.utils import BoundingBox, Page, TextFragment, Widget from commonforms.form_creator import PyPdfFormCreator from commonforms.exceptions import EncryptedPdfError -import formalpdf import pypdfium2 import logging import PIL @@ -38,7 +37,9 @@ def batch(lst: list, n: int = 8): class FFDetrDetector: def __init__(self, model_or_path: str, device: int | str = "cpu") -> None: self.device = device - self.model = RFDETRMedium(pretrain_weights=self.get_model_path(model_or_path)) + self.model = RFDETRMedium( + pretrain_weights=self.get_model_path(model_or_path), device=device + ) self.id_to_cls = {0: "TextBox", 1: "ChoiceButton", 2: "Signature"} @@ -73,7 +74,9 @@ def extract_widgets( image_size = 1024 results = [] for b in batch([p.image for p in pages], n=batch_size): - predictions = self.model.predict(b, threshold=confidence) + predictions = self.model.predict( + b, threshold=confidence, device=self.device + ) if isinstance(predictions, list): results.extend(predictions) else: @@ -229,16 +232,137 @@ def sort_widgets(widgets: list[Widget]) -> list[Widget]: return [widget for line in lines for widget in line] +def extract_text_fragments(page: pypdfium2.PdfPage) -> list[TextFragment]: + textpage = page.get_textpage() + try: + fragments = [] + for term in textpage.get_text_range().splitlines(): + text = term.strip() + if not text: + continue + + searcher = textpage.search(term, match_case=False, consecutive=True) + try: + match = searcher.get_next() + finally: + searcher.close() + + if match is None: + continue + + index, count = match + rect_count = textpage.count_rects(index, count) + rects = [textpage.get_rect(i) for i in range(rect_count)] + if not rects: + continue + + left = min(rect[0] for rect in rects) + top = max(rect[3] for rect in rects) + fragments.append( + TextFragment( + text=text, + x0=left / page.get_width(), + y0=1 - (top / page.get_height()), + ) + ) + + return fragments + finally: + textpage.close() + + def render_pdf(pdf_path: str) -> list[Page]: pages = [] - doc = formalpdf.open(pdf_path) + doc = pypdfium2.PdfDocument(pdf_path) try: for page in doc: - image = page.render(dpi=144) - pages.append(Page(image=image, width=image.width, height=image.height)) + image = page.render(scale=2).to_pil() + pages.append( + Page( + image=image, + width=image.width, + height=image.height, + text_fragments=extract_text_fragments(page), + ) + ) return pages finally: - doc.document.close() + doc.close() + + +def group_widget_rows( + widgets: list[Widget], y_threshold: float = 0.015 +) -> list[list[Widget]]: + rows: list[list[Widget]] = [] + for widget in sorted(widgets, key=lambda item: item.bounding_box.y0): + if ( + rows + and abs(widget.bounding_box.y0 - rows[-1][0].bounding_box.y0) <= y_threshold + ): + rows[-1].append(widget) + else: + rows.append([widget]) + return rows + + +def promote_signature_widgets( + pages: list[Page], + results: dict[int, list[Widget]], + signature_label_terms: tuple[str, ...] = ("signature",), +) -> dict[int, list[Widget]]: + """Promote likely signature fields by matching signature labels to nearby rows.""" + normalized_terms = tuple(term.lower() for term in signature_label_terms) + + for page_ix, widgets in results.items(): + if any(widget.widget_type == "Signature" for widget in widgets): + continue + + signature_labels = [ + fragment + for fragment in pages[page_ix].text_fragments + if any(term in fragment.text.lower() for term in normalized_terms) + ] + if not signature_labels: + continue + + textbox_rows = group_widget_rows( + [widget for widget in widgets if widget.widget_type == "TextBox"] + ) + if not textbox_rows: + continue + + scored_rows = [] + for row in textbox_rows: + row_left = min(widget.bounding_box.x0 for widget in row) + row_right = max(widget.bounding_box.x1 for widget in row) + row_y = sum(widget.bounding_box.y0 for widget in row) / len(row) + row_width = row_right - row_left + + for label in signature_labels: + horizontal_penalty = 0.0 + if label.x0 < row_left: + horizontal_penalty = row_left - label.x0 + elif label.x0 > row_right: + horizontal_penalty = label.x0 - row_right + + score = ( + horizontal_penalty, + abs(row_y - label.y0), + abs(row_left - label.x0), + -row_width, + -row_y, + ) + scored_rows.append((score, row)) + + if not scored_rows: + continue + + best_row = min(scored_rows, key=lambda item: item[0])[1] + candidate = min(best_row, key=lambda widget: widget.bounding_box.x0) + widget_ix = widgets.index(candidate) + widgets[widget_ix] = candidate.model_copy(update={"widget_type": "Signature"}) + + return results def prepare_form( @@ -254,11 +378,12 @@ def prepare_form( fast: bool = False, multiline: bool = False, batch_size: int = 4, + signature_label_terms: tuple[str, ...] = ("signature",), ): if "FFDNET" in model_or_path.upper(): detector = FFDNetDetector(model_or_path, device=device, fast=fast) else: - detector = FFDetrDetector(model_or_path) + detector = FFDetrDetector(model_or_path, device=device) try: pages = render_pdf(input_path) @@ -274,6 +399,11 @@ def prepare_form( pages, confidence=confidence, image_size=image_size ) + if use_signature_fields: + results = promote_signature_widgets( + pages, results, signature_label_terms=signature_label_terms + ) + writer = PyPdfFormCreator(input_path) if not keep_existing_fields: writer.clear_existing_fields() diff --git a/commonforms/utils.py b/commonforms/utils.py index b85ac92..9d5e0f2 100644 --- a/commonforms/utils.py +++ b/commonforms/utils.py @@ -26,8 +26,15 @@ class Widget(BaseModel): page: int +class TextFragment(BaseModel): + text: str + x0: float + y0: float + + @dataclass class Page: image: Image.Image width: float height: float + text_fragments: list[TextFragment] diff --git a/tests/inference_test.py b/tests/inference_test.py index 5b8693b..f70ec5b 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -3,6 +3,10 @@ import formalpdf import pytest +from PIL import Image + +from commonforms.inference import promote_signature_widgets +from commonforms.utils import BoundingBox, Page, TextFragment, Widget def test_inference(tmp_path): @@ -67,6 +71,72 @@ def test_inference_ffdetr(tmp_path): doc.document.close() +def test_promote_signature_widgets_uses_signature_label_on_test_pdf(): + pages = [ + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[], + ), + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[ + TextFragment( + text="POLICYHOLDER/PATIENT SIGNATURE FAMILY RELATIONSHIP, IF NOT POLICYHOLDER DATE", + x0=0.37, + y0=0.61, + ) + ], + ), + ] + results = { + 1: [ + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.089, y0=0.857, x1=0.384, y1=0.895), + page=1, + ), + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.752, y0=0.859, x1=0.927, y1=0.896), + page=1, + ), + ] + } + + promoted = promote_signature_widgets(pages, results) + + assert promoted[1][0].widget_type == "Signature" + assert promoted[1][1].widget_type == "TextBox" + + +def test_promote_signature_widgets_skips_pages_without_signature_label(): + pages = [ + Page( + image=Image.new("RGB", (1, 1)), + width=1, + height=1, + text_fragments=[TextFragment(text="General contact information", x0=0.1, y0=0.2)], + ) + ] + results = { + 0: [ + Widget( + widget_type="TextBox", + bounding_box=BoundingBox(x0=0.1, y0=0.8, x1=0.3, y1=0.84), + page=0, + ) + ] + } + + promoted = promote_signature_widgets(pages, results) + + assert promoted[0][0].widget_type == "TextBox" + + # TODO(joe): future tests around handling encrypted PDFs # 1. add a --password flag and test that inference doesn't fail # 2. if a password is provided, ensure that the _output_ PDF remains encrpyted