From 2fd72ea20bca670a835a8908636d63c8f108ae1e Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Fri, 20 Mar 2026 16:16:18 +0000 Subject: [PATCH 1/2] Breakdown local OCR pipeline to enable batched inference --- .../model/local/nemotron_ocr_v1.py | 306 +++++++++++++++++- 1 file changed, 290 insertions(+), 16 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v1.py b/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v1.py index 81d68cb5a..0a8ebb88f 100644 --- a/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v1.py +++ b/nemo_retriever/src/nemo_retriever/model/local/nemotron_ocr_v1.py @@ -2,7 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Tuple, Union # noqa: F401 +from typing import Any, Dict, List, Optional, Tuple, Union import base64 import io @@ -11,11 +11,17 @@ import numpy as np import torch +import torch.nn.functional as F from nemo_retriever.utils.hf_cache import configure_global_hf_cache_base +from torch import amp +from torchvision.transforms.functional import convert_image_dtype from ..model import BaseModel, RunMode from PIL import Image +# Max images per GPU batch (matches typical nemotron-ocr inference limits). +_OCR_MAX_GPU_BATCH = 32 + class NemotronOCRV1(BaseModel): """ @@ -138,6 +144,247 @@ def _tensor_to_png_b64(img: torch.Tensor) -> str: pil.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") + @staticmethod + def _torch_chw_to_float16_cpu(img: torch.Tensor) -> torch.Tensor: + """RGB CHW tensor -> float16 CHW on CPU in [0, 1], as in ``NemotronOCR._load_image_to_tensor``.""" + if not isinstance(img, torch.Tensor) or img.ndim != 3: + raise ValueError(f"Expected CHW torch.Tensor, got shape {getattr(img, 'shape', None)}") + x = img.detach().cpu() + if x.dtype.is_floating_point: + maxv = float(x.max().item()) if x.numel() else 1.0 + if maxv <= 1.5: + x = x * 255.0 + x = x.clamp(0, 255).to(dtype=torch.uint8) + else: + x = x.clamp(0, 255).to(dtype=torch.uint8) + c = int(x.shape[0]) + if c == 1: + x = x.repeat(3, 1, 1) + elif c != 3: + raise ValueError(f"Expected 1 or 3 channels, got {c}") + return convert_image_dtype(x, dtype=torch.float16) + + @staticmethod + def _numpy_hwc_to_chw_f16(image: np.ndarray) -> torch.Tensor: + """HWC ndarray -> float16 CHW on CPU, matching ``NemotronOCR._load_image_to_tensor``.""" + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + if image.shape[2] == 4: + image = image[..., :3] + img_tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1) + return convert_image_dtype(img_tensor, dtype=torch.float16) + + def _batch_process_chw( + self, + images_chw: List[torch.Tensor], + merge_level: str, + ) -> List[List[Dict[str, Any]]]: + """ + Run detector → NMS → recognizer → relational for a batch of RGB CHW tensors. + + Uses the same building blocks as ``nemotron_ocr.inference.pipeline.NemotronOCR._process_tensor``, + with a shared square side ``M = max_i max(H_i, W_i)`` so all images stack for one detector call. + Quad scaling uses ``M / INFER_LENGTH`` so coordinates match the resized ``M×M`` canvas. + """ + from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square + from nemotron_ocr.inference.pipeline import ( + DETECTOR_DOWNSAMPLE, + INFER_LENGTH, + MERGE_LEVELS, + NMS_IOU_THRESHOLD, + NMS_MAX_REGIONS, + NMS_PROB_THRESHOLD, + PAD_COLOR, + ) + from nemotron_ocr.inference.post_processing.data.text_region import TextBlock + from nemotron_ocr.inference.post_processing.research_ops import ( + parse_relational_results, + reorder_boxes, + ) + from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads + + mdl = self._model + if mdl is None: + raise RuntimeError("Local OCR model was not initialized.") + + bsz = len(images_chw) + if bsz == 0: + return [] + + if merge_level not in MERGE_LEVELS: + raise ValueError(f"Invalid merge level: {merge_level}. Must be one of {MERGE_LEVELS}.") + + original_shapes: List[Tuple[int, int]] = [] + for t in images_chw: + _, h, w = t.shape + original_shapes.append((int(h), int(w))) + + m_side = max(max(h, w) for h, w in original_shapes) + + square_rows: List[torch.Tensor] = [] + for t in images_chw: + sq = pad_to_square(t, m_side, how="bottom_right").unsqueeze(0) + square_rows.append(sq) + batch_square = torch.cat(square_rows, dim=0) + + pad_color = PAD_COLOR.to(device=batch_square.device, dtype=batch_square.dtype) + padded_image = interpolate_and_pad(batch_square, pad_color, INFER_LENGTH) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + det_conf, _, det_rboxes, det_feature_3 = mdl.detector(padded_image.cuda()) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + e2e_det_conf = torch.sigmoid(det_conf) + e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE) + + quads, confidence, region_counts = quad_non_maximal_suppression( + e2e_det_coords, + e2e_det_conf, + prob_threshold=NMS_PROB_THRESHOLD, + iou_threshold=NMS_IOU_THRESHOLD, + kernel_height=2, + kernel_width=3, + max_regions=NMS_MAX_REGIONS, + verbose=False, + )[:3] + + region_counts = region_counts.reshape(-1).to(dtype=torch.int64) + predictions_per_image: List[List[Dict[str, Any]]] = [[] for _ in range(bsz)] + + if quads.shape[0] == 0 or int(region_counts.sum().item()) == 0: + return predictions_per_image + + rec_rectified_quads = mdl.recognizer_quad_rectifier( + quads.detach(), padded_image.shape[2], padded_image.shape[3] + ) + rel_rectified_quads = mdl.relational_quad_rectifier( + quads.cuda().detach(), padded_image.shape[2], padded_image.shape[3] + ) + + input_indices = region_counts_to_indices(region_counts, quads.shape[0]) + + rec_rectified_quads = mdl.grid_sampler(det_feature_3.float(), rec_rectified_quads.float(), input_indices) + rel_rectified_quads = mdl.grid_sampler( + det_feature_3.float().cuda(), + rel_rectified_quads, + input_indices.cuda(), + ) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + rec_output, rec_features = mdl.recognizer(rec_rectified_quads.cuda()) + + rel_output = mdl.relational( + rel_rectified_quads.cuda(), + quads.cuda(), + region_counts.cpu(), + rec_features.cuda(), + ) + words, lines, line_var = ( + rel_output["words"], + rel_output["lines"], + rel_output["line_log_var_unc"], + ) + + with amp.autocast("cuda", enabled=True), torch.no_grad(): + words = [F.softmax(r, dim=1, dtype=torch.float32)[:, 1:] for r in words] + + output: Dict[str, Any] = { + "sequences": F.softmax(rec_output, dim=2, dtype=torch.float32), + "region_counts": region_counts, + "quads": quads, + "raw_detector_confidence": e2e_det_conf, + "confidence": confidence, + "relations": words, + "line_relations": lines, + "line_rel_var": line_var, + "fg_colors": None, + "fonts": None, + "tt_log_var_uncertainty": None, + "e2e_recog_features": rec_features, + } + + quads_scaled = output["quads"] + qscale = float(m_side) / float(INFER_LENGTH) + lengths_tensor = torch.full( + (quads_scaled.shape[0], 1, 1), + qscale, + dtype=torch.float32, + device=quads_scaled.device, + ) + quads_scaled = quads_scaled * lengths_tensor + output["quads"] = quads_scaled + + rec_batch = mdl.recog_encoder.convert_targets_to_labels(output, image_size=None, is_gt=False) + relation_batch = mdl.relation_encoder.convert_targets_to_labels(output, image_size=None, is_gt=False) + + for example, rel_example in zip(rec_batch, relation_batch): + example.relation_graph = rel_example.relation_graph + example.prune_invalid_relations() + + for example in rec_batch: + if example.relation_graph is None: + continue + for paragraph in example.relation_graph: + block: List[Any] = [] + for line in paragraph: + for relational_idx in line: + block.append(example[relational_idx]) + if block: + example.blocks.append(TextBlock(block)) + + for example in rec_batch: + for text_region in example: + text_region.region = text_region.region.vertices + + for ex_idx, example in enumerate(rec_batch): + boxes, texts, scores = parse_relational_results(example, level=merge_level) + boxes, texts, scores = reorder_boxes(boxes, texts, scores, mode="top_left", dbscan_eps=10) + + orig_h, orig_w = original_shapes[ex_idx] + + if len(boxes) == 0: + boxes = ["nan"] + texts = ["nan"] + scores = ["nan"] + else: + boxes_array = np.array(boxes).reshape(-1, 4, 2) + boxes_array[:, :, 0] = boxes_array[:, :, 0] / orig_w + boxes_array[:, :, 1] = boxes_array[:, :, 1] / orig_h + boxes = boxes_array.astype(np.float16).tolist() + + for box, text, conf in zip(boxes, texts, scores): + if box == "nan": + break + predictions_per_image[ex_idx].append( + { + "text": text, + "confidence": conf, + "left": min(p[0] for p in box), + "upper": max(p[1] for p in box), + "right": max(p[0] for p in box), + "lower": min(p[1] for p in box), + } + ) + + return predictions_per_image + + def _invoke_sequential( + self, + inputs: List[Union[torch.Tensor, np.ndarray]], + merge_level: str, + *, + as_numpy: bool, + ) -> List[Any]: + """One pipeline call per item (used when TensorRT compilation fixes detector batch to 1).""" + out: List[Any] = [] + for item in inputs: + if as_numpy: + out.append(self._model(item, merge_level=merge_level)) # type: ignore[arg-type] + else: + b64 = self._tensor_to_png_b64(item) # type: ignore[arg-type] + out.append(self._model(b64.encode("utf-8"), merge_level=merge_level)) + return out + @staticmethod def _extract_text(obj: Any) -> str: if obj is None: @@ -161,7 +408,7 @@ def _extract_text(obj: Any) -> str: def invoke( self, - input_data: Union[torch.Tensor, str, bytes, np.ndarray, io.BytesIO], + input_data: Union[torch.Tensor, str, bytes, np.ndarray, io.BytesIO, List[np.ndarray]], merge_level: str = "paragraph", ) -> Any: """ @@ -171,31 +418,58 @@ def invoke( - file path (str) **only if it exists** - base64 (str/bytes) (str is treated as base64 unless it is an existing file path) - NumPy array (HWC) + - list of NumPy arrays (HWC): batched GPU inference up to ``_OCR_MAX_GPU_BATCH`` per forward - io.BytesIO - - torch.Tensor (CHW/BCHW): converted to base64 PNG internally for compatibility + - torch.Tensor CHW: single image + - torch.Tensor BCHW: batched inference; returns ``list[list[dict]]`` (one inner list per image) """ if self._model is None: raise RuntimeError("Local OCR model was not initialized.") - # Convert torch tensors to base64 bytes (NemotronOCR expects file path/base64/ndarray/BytesIO). + # Batched RGB crops (as produced by page-element OCR in ``ocr.py``). + if isinstance(input_data, list): + if not input_data: + return [] + if not all(isinstance(x, np.ndarray) for x in input_data): + raise TypeError( + "Batched invoke expects each list element to be a numpy.ndarray (HWC RGB). " + f"Got types: {[type(x).__name__ for x in input_data[:8]]}" + (" ..." if len(input_data) > 8 else "") + ) + arrays: List[np.ndarray] = input_data # type: ignore[assignment] + if self._enable_trt: + return self._invoke_sequential(arrays, merge_level, as_numpy=True) + merged: List[List[Dict[str, Any]]] = [] + for start in range(0, len(arrays), _OCR_MAX_GPU_BATCH): + chunk = arrays[start : start + _OCR_MAX_GPU_BATCH] + chw = [self._numpy_hwc_to_chw_f16(a) for a in chunk] + merged.extend(self._batch_process_chw(chw, merge_level)) + return merged + if isinstance(input_data, torch.Tensor): if input_data.ndim == 4: - out: List[Any] = [] - for i in range(int(input_data.shape[0])): - b64 = self._tensor_to_png_b64(input_data[i]) - out.extend(self._model(b64.encode("utf-8"), merge_level=merge_level)) - return out + n = int(input_data.shape[0]) + if self._enable_trt: + return self._invoke_sequential( + [input_data[i] for i in range(n)], + merge_level, + as_numpy=False, + ) + merged_t: List[List[Dict[str, Any]]] = [] + for start in range(0, n, _OCR_MAX_GPU_BATCH): + sl = input_data[start : start + _OCR_MAX_GPU_BATCH] + chw = [self._torch_chw_to_float16_cpu(sl[i]) for i in range(int(sl.shape[0]))] + merged_t.extend(self._batch_process_chw(chw, merge_level)) + return merged_t if input_data.ndim == 3: - b64 = self._tensor_to_png_b64(input_data) - return self._model(b64.encode("utf-8"), merge_level=merge_level) + if self._enable_trt: + b64 = self._tensor_to_png_b64(input_data) + return self._model(b64.encode("utf-8"), merge_level=merge_level) + single = self._batch_process_chw([self._torch_chw_to_float16_cpu(input_data)], merge_level) + return single[0] raise ValueError(f"Unsupported torch tensor shape for OCR: {tuple(input_data.shape)}") # Disambiguate str: existing file path vs base64 string. if isinstance(input_data, str): - # s = input_data.strip() - # breakpoint() - # if s and Path(s).is_file(): - # return self._model(s, merge_level=merge_level) # Treat as base64 string (nemotron_ocr expects bytes for base64). return self._model(input_data.encode("utf-8"), merge_level=merge_level) @@ -265,4 +539,4 @@ def output(self) -> Any: @property def input_batch_size(self) -> int: """Maximum or default input batch size.""" - return 8 + return _OCR_MAX_GPU_BATCH From 74bc23fdde4b26919ba71fb70fd787d4c6a0cddc Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Fri, 20 Mar 2026 17:28:56 +0000 Subject: [PATCH 2/2] Ensure that upstream stages effectively saturate OCR --- .../src/nemo_retriever/ingest_modes/batch.py | 31 ++ nemo_retriever/src/nemo_retriever/ocr/ocr.py | 366 ++++++++++-------- 2 files changed, 235 insertions(+), 162 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index e00037285..4785e331d 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -401,6 +401,7 @@ def _endpoint_count(raw: Any) -> int: ) self._apply_nemotron_parse_overrides(kwargs) + self._apply_user_batch_overrides_to_requested_plan(kwargs) self._append_detection_stages(kwargs) @@ -431,6 +432,35 @@ def _apply_nemotron_parse_overrides(self, kwargs: dict[str, Any]) -> None: if overrides: self._requested_plan = self._requested_plan.model_copy(update=overrides) + def _apply_user_batch_overrides_to_requested_plan(self, kwargs: dict[str, Any]) -> None: + """Apply ``batch_tuning`` sizes from ``ExtractParams`` onto ``_requested_plan``. + + ``_append_detection_stages`` uses ``RequestedPlan.get_ocr_batch_size()`` and + ``get_page_elements_batch_size()`` for Ray ``map_batches(..., batch_size=)`` + and for ``OCRActor`` / ``PageElementDetectionActor`` constructor kwargs + (including ``ocr_page_elements`` streaming crop batching via + ``inference_batch_size``). Without this merge, CLI values such as + ``ocr_inference_batch_size`` from :mod:`nemo_retriever.examples.batch_pipeline` + would be ignored in favour of heuristic defaults from + ``resolve_requested_plan``. + """ + updates: dict[str, Any] = {} + + ocr_bs = kwargs.get("ocr_inference_batch_size") + if not (isinstance(ocr_bs, (int, float)) and int(ocr_bs) > 0): + dbs = kwargs.get("detect_batch_size") + if isinstance(dbs, (int, float)) and int(dbs) > 0: + ocr_bs = int(dbs) + if isinstance(ocr_bs, (int, float)) and int(ocr_bs) > 0: + updates["ocr_batch_size"] = int(ocr_bs) + + pe_bs = kwargs.get("page_elements_batch_size") + if isinstance(pe_bs, (int, float)) and int(pe_bs) > 0: + updates["page_elements_batch_size"] = int(pe_bs) + + if updates: + self._requested_plan = self._requested_plan.model_copy(update=updates) + def _append_detection_stages(self, kwargs: dict[str, Any]) -> None: """Append downstream GPU detection stages (page elements, OCR, table/chart/infographic). @@ -700,6 +730,7 @@ def extract_image_files(self, params: ExtractParams | None = None, **kwargs: Any # Downstream detection stages (page elements, OCR, table/chart/infographic). self._apply_nemotron_parse_overrides(kwargs) + self._apply_user_batch_overrides_to_requested_plan(kwargs) self._append_detection_stages(kwargs) return self diff --git a/nemo_retriever/src/nemo_retriever/ocr/ocr.py b/nemo_retriever/src/nemo_retriever/ocr/ocr.py index 34ae7258b..bc0d5ff3b 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/ocr.py +++ b/nemo_retriever/src/nemo_retriever/ocr/ocr.py @@ -460,18 +460,18 @@ def ocr_page_elements( remote_retry: RemoteRetryParams | None = None, **kwargs: Any, ) -> Any: - retry = remote_retry or RemoteRetryParams( - remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), - remote_max_retries=int(kwargs.get("remote_max_retries", 10)), - remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), - ) """ Run Nemotron OCR v1 on cropped regions detected by PageElements v3. + Crops are accumulated in streaming buffers (per merge level for local + inference, one FIFO buffer for remote NIM), flushed at + ``inference_batch_size`` so GPU batches stay full without holding every + crop for ``batch_df`` in memory at once. + For each row (page) in ``batch_df``: 1. Read ``page_elements_v3`` detections and ``page_image["image_b64"]``. 2. For each detection whose ``label_name`` is a requested type, crop the - page image, invoke OCR, parse the result, and collect text. + page image and enqueue for OCR (buffers flush when full). 3. Write per-type content lists and timing metadata to output columns. Parameters @@ -489,6 +489,11 @@ def ocr_page_elements( Original columns plus ``table``, ``chart``, ``infographic``, and ``ocr_v1``. """ + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) if not isinstance(batch_df, pd.DataFrame): raise NotImplementedError("ocr_page_elements currently only supports pandas.DataFrame input.") @@ -507,24 +512,187 @@ def ocr_page_elements( if extract_infographics: wanted_labels.add("infographic") - # Per-row accumulators. - all_table: List[List[Dict[str, Any]]] = [] - all_chart: List[List[Dict[str, Any]]] = [] - all_infographic: List[List[Dict[str, Any]]] = [] - all_text: List[str] = [] - all_ocr_meta: List[Dict[str, Any]] = [] + n_rows = len(batch_df) + all_table: List[List[Dict[str, Any]]] = [[] for _ in range(n_rows)] + all_chart: List[List[Dict[str, Any]]] = [[] for _ in range(n_rows)] + all_infographic: List[List[Dict[str, Any]]] = [[] for _ in range(n_rows)] + all_text: List[Optional[str]] = [None] * n_rows + all_ocr_meta: List[Dict[str, Any]] = [{"timing": None, "error": None} for _ in range(n_rows)] + all_ocr_blocks: Optional[List[List[Dict[str, Any]]]] = [[] for _ in range(n_rows)] if extract_text else None + + remote_batch_size = max(1, int(kwargs.get("inference_batch_size", 8))) + buf_remote: List[Tuple[int, str, List[float], str]] = [] + + if not use_remote: + if inference_batch_size is None or inference_batch_size < 1: + raise ValueError(f"inference_batch_size must be set and greater than 0. Value: {inference_batch_size}") + local_batch_size = max(1, int(inference_batch_size)) + else: + local_batch_size = 0 # unused + + buf_word: List[Tuple[int, str, List[float], np.ndarray]] = [] + buf_paragraph: List[Tuple[int, str, List[float], np.ndarray]] = [] + + def _row_at(row_idx: int) -> Any: + return batch_df.iloc[row_idx] + + def _append_local_result( + row_idx: int, + label_name: str, + bbox: List[float], + preds: Any, + crop_hw: Tuple[int, int], + ) -> None: + row = _row_at(row_idx) + if label_name == "chart" and use_graphic_elements: + ge_dets = _find_ge_detections_for_bbox(row, bbox) + if ge_dets: + text = join_graphic_elements_and_ocr_output(ge_dets, preds, crop_hw) + if text: + all_chart[row_idx].append({"bbox_xyxy_norm": bbox, "text": text}) + return + blocks = _parse_ocr_result(preds) + if label_name == "table": + text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw) + if not text: + text = _blocks_to_text(blocks) + else: + text = _blocks_to_text(blocks) + entry = {"bbox_xyxy_norm": bbox, "text": text} + if label_name == "table": + all_table[row_idx].append(entry) + elif label_name == "chart": + all_chart[row_idx].append(entry) + elif label_name == "infographic": + all_infographic[row_idx].append(entry) + elif label_name in _TEXT_LABELS and all_ocr_blocks is not None: + all_ocr_blocks[row_idx].extend(blocks) + + def _local_process_chunk(ml: str, chunk: List[Tuple[int, str, List[float], np.ndarray]]) -> None: + if not chunk: + return + batch_crops = [c for _, _, _, c in chunk] + try: + batch_preds = model.invoke(batch_crops, merge_level=ml) + except Exception: + batch_preds = None + if isinstance(batch_preds, list) and len(batch_preds) == len(chunk): + for (row_idx, label_name, bbox, crop_array), preds in zip(chunk, batch_preds): + _append_local_result( + row_idx, + label_name, + bbox, + preds, + crop_hw=(int(crop_array.shape[0]), int(crop_array.shape[1])), + ) + else: + for row_idx, label_name, bbox, crop_array in chunk: + preds = model.invoke(crop_array, merge_level=ml) + _append_local_result( + row_idx, + label_name, + bbox, + preds, + crop_hw=(int(crop_array.shape[0]), int(crop_array.shape[1])), + ) + + def _local_flush_ready(ml: str, buf: List[Tuple[int, str, List[float], np.ndarray]]) -> None: + while len(buf) >= local_batch_size: + chunk = buf[:local_batch_size] + _local_process_chunk(ml, chunk) + del buf[:local_batch_size] + + def _local_flush_remainder(ml: str, buf: List[Tuple[int, str, List[float], np.ndarray]]) -> None: + while buf: + take = min(len(buf), local_batch_size) + chunk = buf[:take] + _local_process_chunk(ml, chunk) + del buf[:take] + + def _append_remote_prediction( + row_idx: int, + label_name: str, + bbox: List[float], + preds: Any, + crop_b64: str, + ) -> None: + row = _row_at(row_idx) + if label_name == "chart" and use_graphic_elements: + ge_dets = _find_ge_detections_for_bbox(row, bbox) + if ge_dets: + crop_hw = (0, 0) + try: + _raw = base64.b64decode(crop_b64) + with Image.open(io.BytesIO(_raw)) as _cim: + _cw, _ch = _cim.size + crop_hw = (_ch, _cw) + except Exception: + pass + text = join_graphic_elements_and_ocr_output(ge_dets, preds, crop_hw) + if text: + all_chart[row_idx].append({"bbox_xyxy_norm": bbox, "text": text}) + return + blocks = _parse_ocr_result(preds) + if label_name == "table": + crop_hw_table = (0, 0) + try: + _raw = base64.b64decode(crop_b64) + with Image.open(io.BytesIO(_raw)) as _cim: + _cw, _ch = _cim.size + crop_hw_table = (_ch, _cw) + except Exception: + pass + text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw_table) or _blocks_to_text(blocks) + else: + text = _blocks_to_text(blocks) + entry = {"bbox_xyxy_norm": bbox, "text": text} + if label_name == "table": + all_table[row_idx].append(entry) + elif label_name == "chart": + all_chart[row_idx].append(entry) + elif label_name == "infographic": + all_infographic[row_idx].append(entry) + elif label_name in _TEXT_LABELS and all_ocr_blocks is not None: + all_ocr_blocks[row_idx].extend(blocks) + + def _remote_process_chunk(chunk: List[Tuple[int, str, List[float], str]]) -> None: + if not chunk: + return + crop_b64s = [entry[3] for entry in chunk] + response_items = invoke_image_inference_batches( + invoke_url=invoke_url, + image_b64_list=crop_b64s, + api_key=api_key, + timeout_s=float(request_timeout_s), + max_batch_size=remote_batch_size, + max_pool_workers=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(response_items) != len(chunk): + raise RuntimeError(f"Expected {len(chunk)} OCR responses, got {len(response_items)}") + for i, (row_idx, label_name, bbox, b64) in enumerate(chunk): + preds = _extract_remote_ocr_item(response_items[i]) + _append_remote_prediction(row_idx, label_name, bbox, preds, b64) + + def _remote_flush_ready() -> None: + while len(buf_remote) >= remote_batch_size: + chunk = buf_remote[:remote_batch_size] + _remote_process_chunk(chunk) + del buf_remote[:remote_batch_size] + + def _remote_flush_remainder() -> None: + while buf_remote: + take = min(len(buf_remote), remote_batch_size) + chunk = buf_remote[:take] + _remote_process_chunk(chunk) + del buf_remote[:take] t0_total = time.perf_counter() - for row in batch_df.itertuples(index=False): - table_items: List[Dict[str, Any]] = [] - chart_items: List[Dict[str, Any]] = [] - infographic_items: List[Dict[str, Any]] = [] - row_ocr_text_blocks: List[Dict[str, Any]] = [] + for row_idx, row in enumerate(batch_df.itertuples(index=False)): row_error: Any = None - try: - # --- get page elements detections --- pe = getattr(row, "page_elements_v3", None) dets: List[Dict[str, Any]] = [] if isinstance(pe, dict): @@ -532,20 +700,12 @@ def ocr_page_elements( if not isinstance(dets, list): dets = [] - # --- get page image --- page_image = getattr(row, "page_image", None) or {} page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None if not isinstance(page_image_b64, str) or not page_image_b64: - # No image available — nothing to crop/OCR. - all_table.append(table_items) - all_chart.append(chart_items) - all_infographic.append(infographic_items) - all_text.append(None) - all_ocr_meta.append({"timing": None, "error": None}) continue - # --- determine per-row labels (text/title only for pages needing OCR) --- row_wanted = wanted_labels if extract_text: meta = getattr(row, "metadata", None) or {} @@ -553,137 +713,18 @@ def ocr_page_elements( if needs_ocr: row_wanted = wanted_labels | _TEXT_LABELS - # --- decode page image once, crop all matching detections --- if use_remote: crops = _crop_all_from_page(page_image_b64, dets, row_wanted, as_b64=True) - crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops] - crop_meta: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops] - - if crop_b64s: - response_items = invoke_image_inference_batches( - invoke_url=invoke_url, - image_b64_list=crop_b64s, - api_key=api_key, - timeout_s=float(request_timeout_s), - max_batch_size=int(kwargs.get("inference_batch_size", 8)), - max_pool_workers=int(retry.remote_max_pool_workers), - max_retries=int(retry.remote_max_retries), - max_429_retries=int(retry.remote_max_429_retries), - ) - if len(response_items) != len(crop_meta): - raise RuntimeError(f"Expected {len(crop_meta)} OCR responses, got {len(response_items)}") - - for i, (label_name, bbox) in enumerate(crop_meta): - preds = _extract_remote_ocr_item(response_items[i]) - - if label_name == "chart" and use_graphic_elements: - ge_dets = _find_ge_detections_for_bbox(row, bbox) - if ge_dets: - # Decode crop dimensions from the b64 PNG for graphic element joining. - crop_hw = (0, 0) - try: - _raw = base64.b64decode(crop_b64s[i]) - with Image.open(io.BytesIO(_raw)) as _cim: - _cw, _ch = _cim.size - crop_hw = (_ch, _cw) - except Exception: - pass - text = join_graphic_elements_and_ocr_output(ge_dets, preds, crop_hw) - if text: - chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) - continue - - blocks = _parse_ocr_result(preds) - if label_name == "table": - crop_hw_table: Tuple[int, int] = (0, 0) - try: - _raw = base64.b64decode(crop_b64s[i]) - with Image.open(io.BytesIO(_raw)) as _cim: - _cw, _ch = _cim.size - crop_hw_table = (_ch, _cw) - except Exception: - pass - text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw_table) or _blocks_to_text(blocks) - else: - text = _blocks_to_text(blocks) - entry = {"bbox_xyxy_norm": bbox, "text": text} - if label_name == "table": - table_items.append(entry) - elif label_name == "chart": - chart_items.append(entry) - elif label_name == "infographic": - infographic_items.append(entry) - elif label_name in _TEXT_LABELS: - row_ocr_text_blocks.extend(blocks) + for label_name, bbox, b64 in crops: + buf_remote.append((row_idx, label_name, bbox, b64)) + _remote_flush_ready() else: crops = _crop_all_from_page(page_image_b64, dets, row_wanted) - - if inference_batch_size is None or inference_batch_size < 1: - raise ValueError( - f"inference_batch_size must be set and greater than 0. Value: {inference_batch_size}" - ) - - local_batch_size = max(1, int(inference_batch_size)) - - # Tables require word-level merging; charts/infographics use paragraph-level. - # Group by merge level so each batched invoke uses one consistent setting. - local_jobs: Dict[str, List[Tuple[str, List[float], np.ndarray]]] = {"word": [], "paragraph": []} for label_name, bbox, crop_array in crops: ml = "word" if label_name == "table" else "paragraph" - local_jobs[ml].append((label_name, bbox, crop_array)) - - def _append_local_result( - label_name: str, bbox: List[float], preds: Any, crop_hw: Tuple[int, int] = (0, 0) - ) -> None: - if label_name == "chart" and use_graphic_elements: - ge_dets = _find_ge_detections_for_bbox(row, bbox) - if ge_dets: - text = join_graphic_elements_and_ocr_output(ge_dets, preds, crop_hw) - if text: - chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) - return - blocks = _parse_ocr_result(preds) - if label_name == "table": - text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw) - if not text: - text = _blocks_to_text(blocks) - else: - text = _blocks_to_text(blocks) - entry = {"bbox_xyxy_norm": bbox, "text": text} - if label_name == "table": - table_items.append(entry) - elif label_name == "chart": - chart_items.append(entry) - elif label_name == "infographic": - infographic_items.append(entry) - elif label_name in _TEXT_LABELS: - row_ocr_text_blocks.extend(blocks) - - for ml, jobs in local_jobs.items(): - if not jobs: - continue - for start in range(0, len(jobs), local_batch_size): - batch_jobs = jobs[start : start + local_batch_size] - batch_crops = [crop_array for _, _, crop_array in batch_jobs] - - # Try batched invoke first; if backend does not return one response - # per input, fall back to per-item to preserve correctness. - try: - batch_preds = model.invoke(batch_crops, merge_level=ml) - except Exception: - batch_preds = None - - if isinstance(batch_preds, list) and len(batch_preds) == len(batch_jobs): - for (label_name, bbox, crop_array), preds in zip(batch_jobs, batch_preds): - _append_local_result( - label_name, bbox, preds, crop_hw=(crop_array.shape[0], crop_array.shape[1]) - ) - else: - for label_name, bbox, crop_array in batch_jobs: - preds = model.invoke(crop_array, merge_level=ml) - _append_local_result( - label_name, bbox, preds, crop_hw=(crop_array.shape[0], crop_array.shape[1]) - ) + buf = buf_word if ml == "word" else buf_paragraph + buf.append((row_idx, label_name, bbox, crop_array)) + _local_flush_ready(ml, buf) except BaseException as e: print(f"Warning: OCR failed: {type(e).__name__}: {e}") @@ -694,17 +735,18 @@ def _append_local_result( "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), } - # Assemble OCR'd text from text/title detections for this row. - # Use None as sentinel for "keep existing native text". - if extract_text and row_ocr_text_blocks: - all_text.append(_blocks_to_text(row_ocr_text_blocks)) - else: - all_text.append(None) + all_ocr_meta[row_idx] = {"timing": None, "error": row_error} - all_table.append(table_items) - all_chart.append(chart_items) - all_infographic.append(infographic_items) - all_ocr_meta.append({"timing": None, "error": row_error}) + if use_remote: + _remote_flush_remainder() + else: + _local_flush_remainder("word", buf_word) + _local_flush_remainder("paragraph", buf_paragraph) + + if extract_text and all_ocr_blocks is not None: + for i, blocks in enumerate(all_ocr_blocks): + if blocks: + all_text[i] = _blocks_to_text(blocks) elapsed = time.perf_counter() - t0_total