diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 6a0ac50db..479721819 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -242,6 +242,36 @@ ingestor = create_ingestor(run_mode="batch") ingestor = ingestor.files([str(INPUT_AUDIO)]).extract_audio() ``` +### Caption extracted images + +Use `.caption()` to generate text descriptions for extracted images using a local VLM. Requires vLLM (see step 3 above). + +```python +ingestor = ( + ingestor.files(documents) + .extract( + extract_text=True, + extract_tables=False, + extract_charts=False, + extract_infographics=False, + extract_images=True, + ) + .caption() + .embed() + .vdb_upload() +) +``` + +By default this uses [Nemotron-Nano-12B-VL](https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16). You can customize the model and prompt: + +```python +.caption( + model_name="nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16", + prompt="Describe this image in detail:", + context_text_max_chars=1024, # include surrounding page text as context +) +``` + ### Explore Different Pipeline Options: You can use the [Nemotron RAG VL Embedder](https://huggingface.co/nvidia/llama-nemotron-embed-vl-1b-v2) diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index 7879814cc..7de00f251 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "soundfile>=0.12.0", "scipy>=1.11.0", "nvidia-ml-py", + "vllm==0.16.0", ] [project.optional-dependencies] @@ -103,6 +104,10 @@ nemotron-table-structure-v1 = { index = "test-pypi" } nemotron-ocr = { index = "test-pypi" } torch = { index = "torch-cuda"} torchvision = { index ="torch-cuda"} +vllm = [ + { url = "https://github.com/vllm-project/vllm/releases/download/v0.16.0/vllm-0.16.0+cu130-cp38-abi3-manylinux_2_35_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/vllm-project/vllm/releases/download/v0.16.0/vllm-0.16.0+cu130-cp38-abi3-manylinux_2_35_aarch64.whl", marker = "platform_machine == 'aarch64'" }, +] [[tool.uv.index]] name = "test-pypi" diff --git a/nemo_retriever/src/nemo_retriever/caption/__init__.py b/nemo_retriever/src/nemo_retriever/caption/__init__.py new file mode 100644 index 000000000..6aa2e3d5b --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/caption/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/nemo_retriever/src/nemo_retriever/caption/caption.py b/nemo_retriever/src/nemo_retriever/caption/caption.py new file mode 100644 index 000000000..b55dea563 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/caption/caption.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +from io import BytesIO +from typing import Any, Dict, List, Tuple + +import pandas as pd +from PIL import Image + +from nemo_retriever.params import CaptionParams + +_DEFAULT_MODEL_NAME = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" +_MAX_CONTEXT_TEXT_CHARS = 4096 +_MIN_IMAGE_DIMENSION = 32 +_cached_local_model = None + + +def _image_meets_min_size(b64: str) -> bool: + """Return True if the base64 image is at least _MIN_IMAGE_DIMENSION on both sides.""" + try: + img = Image.open(BytesIO(base64.b64decode(b64))) + w, h = img.size + return w >= _MIN_IMAGE_DIMENSION and h >= _MIN_IMAGE_DIMENSION + except Exception: + return False + + +def _create_local_model(kwargs: dict) -> "Any": + from nemo_retriever.model.local import NemotronVLMCaptioner + + return NemotronVLMCaptioner( + model_path=kwargs.get("model_name", _DEFAULT_MODEL_NAME), + device=kwargs.get("device"), + hf_cache_dir=kwargs.get("hf_cache_dir"), + tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), + gpu_memory_utilization=kwargs.get("gpu_memory_utilization", 0.5), + ) + + +def _get_cached_local_model(kwargs: dict) -> "Any": + global _cached_local_model + if _cached_local_model is None: + _cached_local_model = _create_local_model(kwargs) + return _cached_local_model + + +class CaptionActor: + """Ray Data actor that holds a local VLM captioner on a single GPU. + + When ``endpoint_url`` is provided, the actor delegates to a remote VLM + endpoint and no local model is loaded. + """ + + def __init__(self, params: CaptionParams) -> None: + self._params = params + self._kwargs = params.model_dump(mode="python") + endpoint = (self._kwargs.get("endpoint_url") or "").strip() + if endpoint: + self._model = None + else: + self._model = _create_local_model(self._kwargs) + + def __call__(self, batch_df: Any) -> Any: + return caption_images(batch_df, model=self._model, **self._kwargs) + + +def _build_prompt_with_context(base_prompt: str, context_text: str) -> str: + """Prepend surrounding page text to the base VLM prompt. + + If *context_text* is empty the *base_prompt* is returned unchanged. + """ + if not context_text: + return base_prompt + return f"Text near this image:\n---\n{context_text}\n---\n\n{base_prompt}" + + +def _create_remote_client(endpoint_url: str, api_key: str | None) -> Any: + """Create a reusable NIM inference client for a remote VLM endpoint.""" + from nv_ingest_api.internal.primitives.nim.model_interface.vlm import VLMModelInterface + from nv_ingest_api.util.nim import create_inference_client + + return create_inference_client( + model_interface=VLMModelInterface(), + endpoints=(None, endpoint_url), + auth_token=api_key, + infer_protocol="http", + ) + + +def _caption_batch_remote( + base64_images: List[str], + *, + nim_client: Any, + model_name: str, + prompt: str, + system_prompt: str | None, + temperature: float, +) -> List[str]: + """Send a batch of images to a remote VLM endpoint and return captions.""" + from nv_ingest_api.util.image_processing.transforms import scale_image_to_encoding_size + + scaled = [scale_image_to_encoding_size(b64)[0] for b64 in base64_images] + + data: Dict[str, Any] = { + "base64_images": scaled, + "prompt": prompt, + } + if system_prompt: + data["system_prompt"] = system_prompt + + return nim_client.infer(data, model_name=model_name, temperature=temperature) + + +def _caption_batch_local( + base64_images: List[str], + *, + model: Any, + prompt: str, + system_prompt: str | None, + temperature: float, +) -> List[str]: + """Generate captions using a local ``NemotronVLMCaptioner`` model.""" + return model.caption_batch( + base64_images, + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + + +def _caption_one( + b64: str, + *, + model: Any, + nim_client: Any | None, + model_name: str, + prompt: str, + system_prompt: str | None, + temperature: float, +) -> str: + """Caption a single image (used when each image gets a unique prompt).""" + if model is not None: + captions = _caption_batch_local( + [b64], + model=model, + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + else: + captions = _caption_batch_remote( + [b64], + nim_client=nim_client, + model_name=model_name, + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + return captions[0] if captions else "" + + +def caption_images( + batch_df: pd.DataFrame, + *, + model: Any = None, + endpoint_url: str | None = None, + model_name: str = _DEFAULT_MODEL_NAME, + api_key: str | None = None, + prompt: str = "Caption the content of this image:", + system_prompt: str | None = "/no_think", + temperature: float = 1.0, + batch_size: int = 8, + context_text_max_chars: int = 0, + **kwargs: Any, +) -> pd.DataFrame: + """Caption images in the ``images`` column using a VLM. + + Supports two modes: + + * **Remote** (``endpoint_url`` is set): sends images to an HTTP VLM + endpoint via ``create_inference_client`` / ``VLMModelInterface``. + * **Local** (``model`` is set): runs inference through a local + ``NemotronVLMCaptioner`` instance loaded from Hugging Face. + + When ``context_text_max_chars`` is greater than zero, the page's ``text`` + column is prepended to the prompt for each image so the VLM can use + surrounding OCR text as context. In this mode images are captioned + one at a time (each gets its own enriched prompt). + + For each row, any item in the ``images`` list whose ``text`` field is + empty will be captioned. The returned caption is written back into + ``images[i]["text"]``. + """ + if not isinstance(batch_df, pd.DataFrame) or batch_df.empty: + return batch_df + if "images" not in batch_df.columns: + return batch_df + + if model is None and not endpoint_url: + model = _get_cached_local_model(kwargs) + + nim_client = _create_remote_client(endpoint_url, api_key) if endpoint_url and model is None else None + + use_context = context_text_max_chars > 0 + effective_max = min(context_text_max_chars, _MAX_CONTEXT_TEXT_CHARS) if use_context else 0 + + pending: List[Tuple[int, int, str]] = [] + for row_idx, row in batch_df.iterrows(): + images = row.get("images") + if not isinstance(images, list): + continue + for item_idx, item in enumerate(images): + if not isinstance(item, dict): + continue + if item.get("text"): + continue # already captioned + b64 = item.get("image_b64") + if b64 and _image_meets_min_size(b64): + pending.append((row_idx, item_idx, b64)) + + if not pending: + return batch_df + + if use_context: + for row_idx, item_idx, b64 in pending: + page_text = batch_df.at[row_idx, "text"] if "text" in batch_df.columns else "" + context = (page_text or "")[:effective_max] + enriched_prompt = _build_prompt_with_context(prompt, context) + caption = _caption_one( + b64, + model=model, + nim_client=nim_client, + model_name=model_name, + prompt=enriched_prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + batch_df.at[row_idx, "images"][item_idx]["text"] = caption + else: + all_b64 = [b64 for _, _, b64 in pending] + + if model is not None: + all_captions = _caption_batch_local( + all_b64, + model=model, + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + else: + all_captions: List[str] = [] + for start in range(0, len(all_b64), batch_size): + captions = _caption_batch_remote( + all_b64[start : start + batch_size], + nim_client=nim_client, + model_name=model_name, + prompt=prompt, + system_prompt=system_prompt, + temperature=temperature, + ) + all_captions.extend(captions) + + for (row_idx, item_idx, _), caption in zip(pending, all_captions): + batch_df.at[row_idx, "images"][item_idx]["text"] = caption + + return batch_df diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 6098f3731..c53ff0db5 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -23,6 +23,7 @@ from nemo_retriever.ingest_modes.batch import BatchIngestor from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema from nemo_retriever.model import resolve_embed_model +from nemo_retriever.params import CaptionParams from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams @@ -511,6 +512,21 @@ def main( "(used when --table-output-format=markdown)." ), ), + extract_text: bool = typer.Option( + True, + "--extract-text/--no-extract-text", + help="Extract text from PDF pages.", + ), + extract_tables: bool = typer.Option( + True, + "--extract-tables/--no-extract-tables", + help="Extract tables from PDF pages.", + ), + extract_charts: bool = typer.Option( + True, + "--extract-charts/--no-extract-charts", + help="Extract charts from PDF pages.", + ), extract_infographics: bool = typer.Option( False, "--extract-infographics/--no-extract-infographics", @@ -521,6 +537,36 @@ def main( "--extract-page-as-image/--no-extract-page-as-image", help="Render and retain full page images for downstream multimodal stages.", ), + caption: bool = typer.Option( + False, + "--caption/--no-caption", + help="Enable image captioning via a local VLM or remote endpoint.", + ), + caption_invoke_url: Optional[str] = typer.Option( + None, + "--caption-invoke-url", + help="Optional VLM endpoint URL for image captioning. Implies --caption.", + ), + caption_model_name: str = typer.Option( + "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16", + "--caption-model-name", + help="VLM model name / HF model ID for image captioning.", + ), + caption_device: Optional[str] = typer.Option( + None, + "--caption-device", + help="GPU device for the local VLM captioner (e.g. 'cuda:1').", + ), + caption_context_text_max_chars: int = typer.Option( + 0, + "--caption-context-text-max-chars", + help="Max characters of surrounding page text to include in the VLM prompt. 0 disables context.", + ), + caption_gpu_memory_utilization: float = typer.Option( + 0.5, + "--caption-gpu-memory-utilization", + help="Fraction of GPU memory vLLM may use for the caption model (0.0–1.0).", + ), text_chunk: bool = typer.Option( False, "--text-chunk", @@ -708,9 +754,9 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams: return ExtractParams( method=method, dpi=int(dpi), - extract_text=True, - extract_tables=True, - extract_charts=True, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, extract_infographics=extract_infographics, extract_page_as_image=extract_page_as_image, api_key=extract_remote_api_key, @@ -747,6 +793,18 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams: if enable_text_chunk: ingestor = ingestor.split(_text_chunk_params) + enable_caption = caption or caption_invoke_url is not None + if enable_caption: + ingestor = ingestor.caption( + CaptionParams( + endpoint_url=caption_invoke_url, + model_name=caption_model_name, + device=caption_device, + context_text_max_chars=caption_context_text_max_chars, + gpu_memory_utilization=caption_gpu_memory_utilization, + ) + ) + ingestor = ingestor.embed(embed_params) logger.info("Running extraction...") diff --git a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py index 386bff850..d8c77ff1f 100644 --- a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py @@ -15,6 +15,7 @@ import typer from nemo_retriever import create_ingestor from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second +from nemo_retriever.params import CaptionParams from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams @@ -92,6 +93,26 @@ def main( "--embed-model-name", help="Embedding model name passed to .embed().", ), + extract_text: bool = typer.Option( + True, + "--extract-text/--no-extract-text", + help="Extract text from PDF pages.", + ), + extract_tables: bool = typer.Option( + True, + "--extract-tables/--no-extract-tables", + help="Extract tables from PDF pages.", + ), + extract_charts: bool = typer.Option( + True, + "--extract-charts/--no-extract-charts", + help="Extract charts from PDF pages.", + ), + extract_infographics: bool = typer.Option( + False, + "--extract-infographics/--no-extract-infographics", + help="Extract infographics from PDF pages.", + ), method: str = typer.Option( "pdfium", "--method", @@ -149,6 +170,38 @@ def main( "--graphic-elements-invoke-url", help="Optional remote endpoint URL for graphic-elements model inference.", ), + caption: bool = typer.Option( + False, + "--caption/--no-caption", + help="Enable image captioning. Uses a local model by default, " + "or a remote endpoint if --caption-invoke-url is set.", + ), + caption_invoke_url: Optional[str] = typer.Option( + None, + "--caption-invoke-url", + help="Optional VLM endpoint URL for image captioning (e.g. http://vlm:8000/v1/chat/completions). " + "Implies --caption. When omitted, a local HF model is loaded instead.", + ), + caption_model_name: str = typer.Option( + "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16", + "--caption-model-name", + help="VLM model name / HF model ID for image captioning.", + ), + caption_device: Optional[str] = typer.Option( + None, + "--caption-device", + help="GPU device for the local VLM captioner (e.g. 'cuda:1'). Defaults to the first --gpu-devices entry.", + ), + caption_context_text_max_chars: int = typer.Option( + 0, + "--caption-context-text-max-chars", + help="Max characters of surrounding page text to include in the VLM prompt. 0 disables context.", + ), + caption_gpu_memory_utilization: float = typer.Option( + 0.5, + "--caption-gpu-memory-utilization", + help="Fraction of GPU memory vLLM may use for the caption model (0.0–1.0).", + ), hybrid: bool = typer.Option( False, "--hybrid/--no-hybrid", @@ -216,10 +269,10 @@ def main( ingestor = ingestor.files(file_patterns).extract_image_files( ExtractParams( method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, use_graphic_elements=use_graphic_elements, graphic_elements_invoke_url=graphic_elements_invoke_url, use_table_structure=use_table_structure, @@ -233,10 +286,10 @@ def main( ingestor = ingestor.files(file_patterns).extract( ExtractParams( method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, use_graphic_elements=use_graphic_elements, graphic_elements_invoke_url=graphic_elements_invoke_url, use_table_structure=use_table_structure, @@ -250,10 +303,10 @@ def main( ingestor = ingestor.files(file_patterns).extract( ExtractParams( method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, use_graphic_elements=use_graphic_elements, graphic_elements_invoke_url=graphic_elements_invoke_url, use_table_structure=use_table_structure, @@ -273,6 +326,18 @@ def main( ) ) + enable_caption = caption or caption_invoke_url is not None + if enable_caption: + ingestor = ingestor.caption( + CaptionParams( + endpoint_url=caption_invoke_url, + model_name=caption_model_name, + device=caption_device, + context_text_max_chars=caption_context_text_max_chars, + gpu_memory_utilization=caption_gpu_memory_utilization, + ) + ) + ingestor = ingestor.embed( EmbedParams( model_name=str(embed_model_name), diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 0a57eebe6..d0cad80fc 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -48,6 +48,7 @@ from ..params import IngestExecuteParams from ..params import PdfSplitParams from ..params import TextChunkParams +from ..params import CaptionParams from ..params import VdbUploadParams logger = logging.getLogger(__name__) @@ -659,7 +660,11 @@ def _append_detection_stages(self, kwargs: dict[str, Any]) -> None: ocr_flags["inference_batch_size"] = self._requested_plan.get_ocr_batch_size() - if ocr_flags: + # Only append OCR stage if at least one content type needs it. + needs_ocr = any( + ocr_flags.get(k) for k in ("extract_text", "extract_tables", "extract_charts", "extract_infographics") + ) + if needs_ocr: self._rd_dataset = self._rd_dataset.map_batches( OCRActor, batch_size=self._requested_plan.get_ocr_batch_size(), @@ -889,10 +894,17 @@ def embed( # We want to create Ray batches that are of the same size as the embed_batch_size. self._rd_dataset = self._rd_dataset.repartition(target_num_rows_per_block=embed_batch_size) + from nemo_retriever.ingest_modes.inprocess import _CONTENT_COLUMNS + + content_columns = ( + (_CONTENT_COLUMNS + ("images",)) if getattr(self, "_caption_enabled", False) else _CONTENT_COLUMNS + ) + if embed_granularity == "page": _row_fn = partial( collapse_content_to_page_rows, modality=embed_modality, + content_columns=content_columns, ) else: text_elements_modality = resolved.text_elements_modality or embed_modality @@ -902,6 +914,7 @@ def embed( modality=embed_modality, text_elements_modality=text_elements_modality, structured_elements_modality=structured_elements_modality, + content_columns=content_columns, ) self._rd_dataset = self._rd_dataset.map_batches( _row_fn, @@ -934,6 +947,36 @@ def embed( return self + def caption(self, params: CaptionParams | None = None, **kwargs: Any) -> "BatchIngestor": + """ + Add an image-captioning stage to the batch pipeline. + + Uses a GPU actor pool with a local VLM (vLLM) or delegates to a + remote VLM endpoint when ``endpoint_url`` is set. + """ + if self._rd_dataset is None: + raise RuntimeError("No Ray Dataset to caption. Run .files(...) / .extract(...) first.") + + resolved = _coerce_params(params, CaptionParams, kwargs) + if resolved.endpoint_url and not resolved.api_key: + resolved = resolved.model_copy(update={"api_key": resolve_remote_api_key()}) + + from nemo_retriever.caption.caption import CaptionActor + + caption_num_gpus = 0.0 if resolved.endpoint_url else resolved.gpu_memory_utilization + + self._rd_dataset = self._rd_dataset.map_batches( + CaptionActor, + batch_size=resolved.batch_size or 8, + batch_format="pandas", + num_gpus=caption_num_gpus, + concurrency=1, + fn_constructor_kwargs={"params": resolved}, + ) + + self._caption_enabled = True + return self + def vdb_upload(self, params: VdbUploadParams | None = None, **kwargs: Any) -> "BatchIngestor": """ Add a streaming LanceDB upload stage to the batch pipeline. diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 1f1d229a2..d47af16ea 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -25,6 +25,8 @@ from collections.abc import Callable, Iterator from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from nemo_retriever.params import CaptionParams + import pandas as pd from nemo_retriever.model.local import NemotronOCRV1, NemotronPageElementsV3, NemotronParseV12 @@ -958,6 +960,7 @@ def __init__(self, documents: Optional[List[str]] = None) -> None: self._pipeline_type: Literal["pdf", "txt", "html", "image"] = "pdf" self._extract_txt_kwargs: Dict[str, Any] = {} self._extract_html_kwargs: Dict[str, Any] = {} + self._caption_enabled: bool = False def files(self, documents: Union[str, List[str]]) -> "InProcessIngestor": """ @@ -1332,6 +1335,46 @@ def extract_audio( self._tasks.append((apply_asr_to_df, {"asr_params": self._extract_audio_asr_kwargs})) return self + def caption(self, params: "CaptionParams | None" = None, **kwargs: Any) -> "InProcessIngestor": + """ + Configure image captioning via a local VLM model or remote endpoint. + + Sends cropped images (from the ``images`` column populated by + ``extract(extract_images=True)``) to a VLM and writes the returned + captions back as ``images[i]["text"]``. + + When ``endpoint_url`` is set, a remote NIM endpoint is used. + Otherwise a local ``NemotronVLMCaptioner`` is loaded from HF. + """ + from nemo_retriever.caption.caption import caption_images + from nemo_retriever.params import CaptionParams + + resolved = _coerce_params(params, CaptionParams, kwargs) + caption_kwargs = resolved.model_dump(mode="python") + + if resolved.endpoint_url: + # Remote mode. + if not resolved.api_key: + caption_kwargs["api_key"] = resolve_remote_api_key() + else: + # Local mode: defer model creation so the VLM is loaded lazily + # on the device specified by CaptionParams.device. + if not resolved.device: + import warnings + + warnings.warn( + "No caption device specified. The VLM will load on cuda:0, which " + "may conflict with other models. Use device='cuda:1' (or " + "--caption-device from the CLI) to place the captioner on a " + "separate GPU.", + stacklevel=2, + ) + caption_kwargs["model"] = None + + self._caption_enabled = True + self._tasks.append((caption_images, caption_kwargs)) + return self + def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessIngestor": """ Configure embedding for in-process execution. @@ -1349,12 +1392,14 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI embed_modality = resolved.embed_modality embed_granularity = resolved.embed_granularity + content_columns = (_CONTENT_COLUMNS + ("images",)) if self._caption_enabled else _CONTENT_COLUMNS + if embed_granularity == "page": # Page-level: one row per page with concatenated text and full page image. self._tasks.append( ( collapse_content_to_page_rows, - {"modality": embed_modality}, + {"modality": embed_modality, "content_columns": content_columns}, ) ) else: @@ -1368,6 +1413,7 @@ def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessI "modality": embed_modality, "text_elements_modality": text_elements_modality, "structured_elements_modality": structured_elements_modality, + "content_columns": content_columns, }, ) ) @@ -1487,12 +1533,21 @@ def ingest(self, params: IngestExecuteParams | None = None, **kwargs: Any) -> li _start = time.perf_counter() - # -- Three-way task classification -------------------------------- + # -- Task classification ------------------------------------------- + from nemo_retriever.caption.caption import caption_images as _caption_images_fn + _post_task_fns = (upload_embeddings_to_lancedb_inprocess, save_dataframe_to_disk_json) _cpu_task_fns = (pdf_extraction,) + # Caption runs on its own device (--caption-device), not in the GPU pool. + _own_device_fns = (_caption_images_fn,) cpu_tasks = [(f, k) for f, k in self._tasks if f in _cpu_task_fns] - gpu_tasks = [(f, k) for f, k in self._tasks if f not in _cpu_task_fns and f not in _post_task_fns] + gpu_tasks = [ + (f, k) + for f, k in self._tasks + if f not in _cpu_task_fns and f not in _post_task_fns and f not in _own_device_fns + ] + own_device_tasks = [(f, k) for f, k in self._tasks if f in _own_device_fns] post_tasks = [(f, k) for f, k in self._tasks if f in _post_task_fns] docs = list(self._documents) @@ -1545,6 +1600,8 @@ def _check_file_done(doc_path: str) -> None: try: result = future.result() if isinstance(result, pd.DataFrame) and not result.empty: + for func, kw in own_device_tasks: + result = func(result, **kw) shard_to_doc[shard_id] = doc gpu_pool.submit(shard_id, result) shard_id += 1 @@ -1639,6 +1696,8 @@ def _on_gpu_done(sid: int) -> None: return results combined = pd.concat(cpu_results, ignore_index=True) + for func, kwargs in own_device_tasks: + combined = func(combined, **kwargs) for func, kwargs in gpu_tasks: combined = func(combined, **kwargs) @@ -1678,6 +1737,8 @@ def _on_gpu_done(sid: int) -> None: else: current = func(current, **kwargs) if isinstance(current, pd.DataFrame) and not current.empty: + for func, kw in own_device_tasks: + current = func(current, **kw) shard_to_doc[shard_id] = doc_path gpu_pool.submit(shard_id, current) shard_id += 1 @@ -1777,7 +1838,7 @@ def _loader(p: str) -> pd.DataFrame: results.append(current) # Run upload/save once on combined results so overwrite=True keeps full corpus. - if post_tasks and results and all(isinstance(r, pd.DataFrame) for r in results): + if results and all(isinstance(r, pd.DataFrame) for r in results): combined = pd.concat(results, ignore_index=True) for func, kwargs in post_tasks: combined = func(combined, **kwargs) diff --git a/nemo_retriever/src/nemo_retriever/ingestor.py b/nemo_retriever/src/nemo_retriever/ingestor.py index 7bbc19486..74b6612e6 100644 --- a/nemo_retriever/src/nemo_retriever/ingestor.py +++ b/nemo_retriever/src/nemo_retriever/ingestor.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from nemo_retriever.application.modes.factory import create_runmode_ingestor +from nemo_retriever.params import CaptionParams from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import TextChunkParams @@ -176,8 +177,9 @@ def save_to_disk( """Record result persistence configuration (execution TBD).""" self._not_implemented("save_to_disk") - def caption(self) -> "ingestor": + def caption(self, params: "CaptionParams | None" = None, **kwargs: Any) -> "ingestor": """Record a caption task configuration.""" + _ = _merge_params(params, kwargs) self._not_implemented("caption") def pdf_split_config(self, pages_per_chunk: int = 32) -> "ingestor": diff --git a/nemo_retriever/src/nemo_retriever/model/local/__init__.py b/nemo_retriever/src/nemo_retriever/model/local/__init__.py index 791df4daa..af068fa7d 100644 --- a/nemo_retriever/src/nemo_retriever/model/local/__init__.py +++ b/nemo_retriever/src/nemo_retriever/model/local/__init__.py @@ -18,6 +18,7 @@ "NemotronGraphicElementsV1", "NemotronParseV12", "NemotronRerankV2", + "NemotronVLMCaptioner", "ParakeetCTC1B1ASR", ] @@ -47,6 +48,10 @@ def __getattr__(name: str): from .nemotron_rerank_v2 import NemotronRerankV2 return NemotronRerankV2 + if name == "NemotronVLMCaptioner": + from .nemotron_vlm_captioner import NemotronVLMCaptioner + + return NemotronVLMCaptioner if name == "ParakeetCTC1B1ASR": from .parakeet_ctc_1_1b_asr import ParakeetCTC1B1ASR diff --git a/nemo_retriever/src/nemo_retriever/model/local/nemotron_vlm_captioner.py b/nemo_retriever/src/nemo_retriever/model/local/nemotron_vlm_captioner.py new file mode 100644 index 000000000..14b814381 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/model/local/nemotron_vlm_captioner.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import base64 +from io import BytesIO +from typing import Any, List, Optional + +from PIL import Image + +from nemo_retriever.utils.hf_cache import configure_global_hf_cache_base +from ..model import BaseModel, RunMode + + +def _b64_to_pil(b64: str) -> Image.Image: + """Decode a base64-encoded image string to a PIL Image.""" + return Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") + + +class NemotronVLMCaptioner(BaseModel): + """ + Local VLM captioner wrapping Nemotron Nano 12B v2 VL variants. + + Supported models: + + * ``nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`` (default, BFloat16) + * ``nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-FP8`` (FP8 quantised) + * ``nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-NVFP4-QAD`` (NVFP4 quantised, + requires GPU compute capability >= 8.9, e.g. Ada Lovelace / Hopper) + + Uses vLLM for inference with batched scheduling. + + Usage:: + + captioner = NemotronVLMCaptioner() + captions = captioner.caption_batch( + ["", ""], + prompt="Caption the content of this image:", + ) + """ + + SUPPORTED_MODELS: dict[str, str] = { + "BF16": "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16", + "FP8": "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-FP8", + "NVFP4-QAD": "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-NVFP4-QAD", + } + + # Pinned HF revision (commit SHA) per model to ensure reproducibility. + _MODEL_REVISIONS: dict[str, str] = { + "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16": "5d250e2e111dc5e1434131bdf3d590c27a878ade", + "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-FP8": "7394488badb786e1decc0e00e308de1cab9560e6", + "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-NVFP4-QAD": "b8d3c170d9ee3a078917ef9bfd508eff988d6de7", + } + + # Map model-name suffixes to vLLM engine kwargs. + # The FP8 HF config ships with quant_method="modelopt" which triggers + # vLLM's ModelOptFp8Config (SM89+). Override to quant_method="fp8" in + # the HF config so vLLM uses its plain FP8 handler (SM80+). + _QUANTIZATION_PROFILES: dict[str, dict[str, Any]] = { + "BF16": {"dtype": "bfloat16"}, + "FP8": { + "dtype": "auto", + "quantization": "fp8", + "hf_overrides": {"quantization_config": {"quant_method": "fp8", "activation_scheme": "static"}}, + }, + "NVFP4-QAD": {"dtype": "auto", "quantization": "modelopt"}, + } + + def __init__( + self, + model_path: str = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16", + device: Optional[str] = None, + hf_cache_dir: Optional[str] = None, + max_new_tokens: int = 1024, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.5, + ) -> None: + super().__init__() + + valid_models = list(self.SUPPORTED_MODELS.values()) + if model_path not in valid_models: + raise ValueError( + f"Unknown caption model: {model_path!r}\n" + f"Supported models:\n" + "\n".join(f" - {m}" for m in valid_models) + ) + + try: + from vllm import LLM, SamplingParams # noqa: F401 + except ImportError as e: + raise ImportError( + 'Local VLM captioning requires vLLM. Install with: pip install "nemo-retriever[vlm-caption]"' + ) from e + + self._model_path = model_path + self._max_new_tokens = max_new_tokens + + if device is not None: + # vLLM uses CUDA_VISIBLE_DEVICES rather than a torch device string. + # Translate e.g. "cuda:1" → "1" so vLLM sees only the requested GPU. + import os + + dev_id = device.split(":")[-1] if ":" in device else device + os.environ["CUDA_VISIBLE_DEVICES"] = dev_id + + configure_global_hf_cache_base(hf_cache_dir) + + revision = self._MODEL_REVISIONS.get(model_path) + + # Pick vLLM engine kwargs based on the model variant. + engine_kwargs: dict[str, Any] = {"dtype": "bfloat16"} # fallback + model_upper = model_path.upper() + for suffix, profile in self._QUANTIZATION_PROFILES.items(): + if model_upper.endswith(suffix): + engine_kwargs = profile + break + + self._llm = LLM( + model=model_path, + revision=revision, + trust_remote_code=True, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + **engine_kwargs, + ) + + def _build_messages( + self, + base64_image: str, + *, + prompt: str, + system_prompt: Optional[str], + ) -> list[dict[str, Any]]: + """Build chat messages in OpenAI format for vLLM.""" + messages: list[dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append( + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}, + {"type": "text", "text": prompt}, + ], + } + ) + return messages + + def caption( + self, + base64_image: str, + *, + prompt: str = "Caption the content of this image:", + system_prompt: Optional[str] = "/no_think", + temperature: float = 1.0, + ) -> str: + """Generate a caption for a single base64-encoded image.""" + return self.caption_batch([base64_image], prompt=prompt, system_prompt=system_prompt, temperature=temperature)[ + 0 + ] + + def caption_batch( + self, + base64_images: List[str], + *, + prompt: str = "Caption the content of this image:", + system_prompt: Optional[str] = "/no_think", + temperature: float = 1.0, + ) -> List[str]: + """Generate captions for a list of base64-encoded images. + + vLLM batches internally and handles scheduling across images. + """ + from vllm import SamplingParams + + conversations = [self._build_messages(b64, prompt=prompt, system_prompt=system_prompt) for b64 in base64_images] + sampling_params = SamplingParams(temperature=temperature, max_tokens=self._max_new_tokens) + outputs = self._llm.chat(conversations, sampling_params=sampling_params) + return [out.outputs[0].text.strip() for out in outputs] + + # ---- BaseModel abstract interface ---- + + @property + def model_name(self) -> str: + return self._model_path + + @property + def model_type(self) -> str: + return "vlm-captioner" + + @property + def model_runmode(self) -> RunMode: + return "local" + + @property + def input(self) -> Any: + return { + "type": "image", + "format": "base64", + "description": "Base64-encoded image for captioning.", + } + + @property + def output(self) -> Any: + return { + "type": "text", + "format": "string", + "description": "Generated caption for the input image.", + } + + @property + def input_batch_size(self) -> int: + return 1 diff --git a/nemo_retriever/src/nemo_retriever/params/__init__.py b/nemo_retriever/src/nemo_retriever/params/__init__.py index 5f4eef723..bfc65b50c 100644 --- a/nemo_retriever/src/nemo_retriever/params/__init__.py +++ b/nemo_retriever/src/nemo_retriever/params/__init__.py @@ -5,6 +5,7 @@ from .models import ASRParams from .models import AudioChunkParams from .models import BatchTuningParams +from .models import CaptionParams from .models import ChartParams from .models import EmbedParams from .models import ExtractParams @@ -30,6 +31,7 @@ "ASRParams", "AudioChunkParams", "BatchTuningParams", + "CaptionParams", "ChartParams", "EmbedParams", "ExtractParams", diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index f08548cbe..8b92975db 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -302,6 +302,21 @@ class ChartParams(_ParamsModel): inference_batch_size: int = 8 +class CaptionParams(_ParamsModel): + endpoint_url: Optional[str] = None + model_name: str = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" + api_key: Optional[str] = None + prompt: str = "Caption the content of this image:" + system_prompt: Optional[str] = "/no_think" + temperature: float = 1.0 + batch_size: int = 8 + device: Optional[str] = None + hf_cache_dir: Optional[str] = None + context_text_max_chars: int = 0 + tensor_parallel_size: int = 1 + gpu_memory_utilization: float = 0.5 + + class InfographicParams(_ParamsModel): remote: RemoteInvokeParams = Field(default_factory=RemoteInvokeParams) remote_retry: RemoteRetryParams = Field(default_factory=RemoteRetryParams) diff --git a/nemo_retriever/src/nemo_retriever/pdf/extract.py b/nemo_retriever/src/nemo_retriever/pdf/extract.py index 992c18ebe..a25502f97 100644 --- a/nemo_retriever/src/nemo_retriever/pdf/extract.py +++ b/nemo_retriever/src/nemo_retriever/pdf/extract.py @@ -15,6 +15,7 @@ from nv_ingest_api.util.pdf.pdfium import ( convert_bitmap_to_corrected_numpy, + extract_image_like_objects_from_pdfium_page, is_scanned_page as _is_scanned_page, ) @@ -296,13 +297,37 @@ def pdf_extraction( render_mode=render_mode, ) + # Extract cropped images from pdfium page objects. + detected_images: List[Dict[str, Any]] = [] + if extract_images: + try: + base64_images = extract_image_like_objects_from_pdfium_page(page) + for img in base64_images: + max_w = float(img.max_width) if img.max_width else 1.0 + max_h = float(img.max_height) if img.max_height else 1.0 + x0, y0, x1, y1 = img.bbox + detected_images.append( + { + "bbox_xyxy_norm": [ + x0 / max_w, + y0 / max_h, + x1 / max_w, + y1 / max_h, + ], + "text": "", + "image_b64": img.image, + } + ) + except Exception: + pass # Image extraction failure should not crash the pipeline. + page_record: Dict[str, Any] = { "path": pdf_path, "page_number": page_number, "source_id": source_id, "text": text if extract_text else "", "page_image": None, - "images": [], + "images": detected_images, "tables": [], "charts": [], "infographics": [], diff --git a/nemo_retriever/tests/test_caption.py b/nemo_retriever/tests/test_caption.py new file mode 100644 index 000000000..76900dd75 --- /dev/null +++ b/nemo_retriever/tests/test_caption.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for image captioning pipeline stage.""" + +import base64 +import io +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +PIL = pytest.importorskip("PIL") +from PIL import Image # noqa: E402 + + +def _make_test_png_b64(size: tuple[int, int] = (64, 64)) -> str: + img = Image.new("RGB", size, color=(255, 0, 0)) + buf = io.BytesIO() + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +def _make_page_df(num_images=2, captioned=False): + b64 = _make_test_png_b64() + images = [ + {"bbox_xyxy_norm": [0.1, 0.2, 0.5, 0.8], "text": "done" if captioned else "", "image_b64": b64} + for _ in range(num_images) + ] + return pd.DataFrame([{"text": "page", "images": images, "tables": [], "charts": [], "infographics": []}]) + + +def test_caption_images_writes_back(): + from nemo_retriever.caption.caption import caption_images + + mock_model = MagicMock() + mock_model.caption_batch.return_value = ["cap1", "cap2"] + result = caption_images(_make_page_df(), model=mock_model) + assert result.iloc[0]["images"][0]["text"] == "cap1" + assert result.iloc[0]["images"][1]["text"] == "cap2" + + +def test_caption_images_skips_already_captioned(): + from nemo_retriever.caption.caption import caption_images + + mock_model = MagicMock() + result = caption_images(_make_page_df(captioned=True), model=mock_model) + mock_model.caption_batch.assert_not_called() + assert result.iloc[0]["images"][0]["text"] == "done" + + +@patch("nemo_retriever.pdf.extract.extract_image_like_objects_from_pdfium_page") +def test_pdf_extraction_populates_images(mock_extract): + _ext = pytest.importorskip("nemo_retriever.pdf.extract") + pdfium = pytest.importorskip("pypdfium2") + + mock_img = MagicMock(image=_make_test_png_b64(), bbox=(10, 20, 100, 200), max_width=612, max_height=792) + mock_extract.return_value = [mock_img] + + doc = pdfium.PdfDocument.new() + doc.new_page(612, 792) + buf = io.BytesIO() + doc.save(buf) + doc.close() + + result = _ext.pdf_extraction( + pd.DataFrame([{"bytes": buf.getvalue(), "path": "t.pdf", "page_number": 1}]), extract_images=True + ) + images = result.iloc[0]["images"] + assert len(images) == 1 + assert images[0]["text"] == "" + assert abs(images[0]["bbox_xyxy_norm"][0] - 10 / 612) < 1e-6 + + +def test_explode_includes_captioned_images(): + from nemo_retriever.ingest_modes.inprocess import explode_content_to_rows + + b64 = _make_test_png_b64() + df = pd.DataFrame( + [ + { + "text": "page", + "page_image": {"image_b64": b64}, + "images": [{"text": "a dog", "bbox_xyxy_norm": [0.1, 0.2, 0.5, 0.8], "image_b64": b64}], + "tables": [], + "charts": [], + "infographics": [], + } + ] + ) + result = explode_content_to_rows(df, content_columns=("table", "chart", "infographic", "images")) + assert len(result) == 2 # page text + image caption + + # Default columns exclude images + result2 = explode_content_to_rows(df) + assert len(result2) == 1 + + +def test_context_text_prepended_to_prompt(): + from nemo_retriever.caption.caption import caption_images + + mock_model = MagicMock() + mock_model.caption_batch.return_value = ["captioned with context"] + + df = _make_page_df(num_images=1) + df.at[0, "text"] = "The quick brown fox jumps over the lazy dog." + + result = caption_images(df, model=mock_model, context_text_max_chars=100) + + assert result.iloc[0]["images"][0]["text"] == "captioned with context" + # The prompt passed to caption_batch should contain the page text. + call_kwargs = mock_model.caption_batch.call_args[1] + assert "quick brown fox" in call_kwargs["prompt"] + assert "Text near this image:" in call_kwargs["prompt"] + + +def test_caption_images_skips_small_images(): + from nemo_retriever.caption.caption import caption_images + + tiny_b64 = _make_test_png_b64(size=(1, 1)) + images = [{"bbox_xyxy_norm": [0.1, 0.2, 0.5, 0.8], "text": "", "image_b64": tiny_b64}] + df = pd.DataFrame([{"text": "page", "images": images, "tables": [], "charts": [], "infographics": []}]) + + mock_model = MagicMock() + result = caption_images(df, model=mock_model) + mock_model.caption_batch.assert_not_called() + assert result.iloc[0]["images"][0]["text"] == ""