diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 6a0ac50db..d0f9711bc 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -395,3 +395,9 @@ retriever-harness sweep --runs-config harness/vidore_sweep.yaml ``` The same commands also work under the main CLI as `retriever harness ...` if you prefer a single top-level command namespace. + +The harness now supports multiple execution modes through a shared structured metrics contract: + +- `run_mode: batch` is the default path. +- `run_mode: inprocess` and `run_mode: fused` use the same `results.json` / `session_summary.json` schema. +- Per-run structured metrics are written under `runtime_metrics/.run_report.json`, and the harness derives its compact summary views from that report rather than scraping console output. diff --git a/nemo_retriever/harness/HANDOFF.md b/nemo_retriever/harness/HANDOFF.md index 281234e12..6d8069054 100644 --- a/nemo_retriever/harness/HANDOFF.md +++ b/nemo_retriever/harness/HANDOFF.md @@ -9,7 +9,8 @@ It captures what exists now, what was intentionally chosen, and what to iterate ## Current Scope and Intent - Harness is standalone under `nemo_retriever` (not based on `tools/harness`). -- It wraps `nemo_retriever.examples.batch_pipeline`. +- It now executes shared run-mode runners directly instead of scraping CLI output. +- `batch` remains the default run mode, with `inprocess` and `fused` supported through the same harness config surface. - Primary use case is benchmark orchestration for local/cluster runs without Docker orchestration. - Vector DB is LanceDB only. - Recall gating is supported and enforced by config (`recall_required`). @@ -17,17 +18,21 @@ It captures what exists now, what was intentionally chosen, and what to iterate ## Key Files - `nemo_retriever/src/nemo_retriever/harness/run.py` - - CLI run/sweep/nightly orchestration, subprocess execution, metrics extraction, artifact writes. + - CLI run/sweep/nightly orchestration, run-mode dispatch, artifact writes. - `nemo_retriever/src/nemo_retriever/harness/config.py` - - YAML + CLI/env merge logic and `HarnessConfig`. -- `nemo_retriever/src/nemo_retriever/harness/parsers.py` - - Stream parsing for ingest/throughput/recall metrics. + - YAML + CLI/env merge logic and `HarnessConfig`, including `run_mode`. - `nemo_retriever/src/nemo_retriever/harness/artifacts.py` - Artifact/session directory creation and session summary writing. - `nemo_retriever/src/nemo_retriever/harness/recall_adapters.py` - Dataset-specific query normalization adapters for recall inputs. +- `nemo_retriever/src/nemo_retriever/application/modes/reports.py` + - Shared `RunReport` / `RunMetrics` schema and artifact persistence helpers. +- `nemo_retriever/src/nemo_retriever/application/modes/run_batch.py` +- `nemo_retriever/src/nemo_retriever/application/modes/run_inprocess.py` +- `nemo_retriever/src/nemo_retriever/application/modes/run_fused.py` + - Shared mode runners that return structured reports consumed by the harness. - `nemo_retriever/harness/test_configs.yaml` - - Active defaults, presets, dataset presets. + - Active defaults, presets, dataset presets, and default `run_mode`. - `nemo_retriever/harness/nightly_config.yaml` - Ordered run list for sweep/nightly. @@ -48,7 +53,6 @@ It captures what exists now, what was intentionally chosen, and what to iterate From repo root: ```bash -source ~/setup_env.sh source .retriever/bin/activate uv pip install -e ./nemo_retriever ``` @@ -92,6 +96,8 @@ Per run: - `command.txt` - `runtime_metrics/` - `lancedb/` +- `runtime_metrics/.run_report.json` +- `runtime_metrics/.runtime.summary.json` Session-level: @@ -113,8 +119,9 @@ Notes: - Kept `session_summary.json`. - Removed `sweep_results.json` generation. -3. **TTY-backed subprocess retained** - - Harness runs batch pipeline through a PTY so Ray progress remains rich/pretty by default. +3. **Structured run reports are authoritative** + - Harness metrics are populated from `RunReport` objects returned by the mode runners. + - Console output is now presentation-only and is no longer scraped for harness metrics. ## Known Behavior to Remember @@ -159,62 +166,41 @@ Harness-focused tests pass: ```bash pytest -q nemo_retriever/tests/test_batch_ingestor.py \ - nemo_retriever/tests/test_batch_pipeline.py \ - nemo_retriever/tests/test_harness_parsers.py \ nemo_retriever/tests/test_harness_config.py \ nemo_retriever/tests/test_harness_run.py \ nemo_retriever/tests/test_harness_reporting.py \ + nemo_retriever/tests/test_harness_nightly.py \ nemo_retriever/tests/test_harness_recall_adapters.py \ nemo_retriever/tests/test_recall_core.py ``` -## Upstream Batch Compatibility (Mar 2026) - -The upstream `nemo_retriever.examples.batch_pipeline` CLI and log output changed -after the initial harness work landed. The harness now carries a compatibility -shim for that newer upstream behavior. - -### CLI compatibility - -- Harness config field names remain unchanged for now. -- `harness.run._build_command()` maps those fields to the newer public batch CLI - flags, including: - - `pdf_extract_workers` -> `--pdf-extract-tasks` - - `pdf_extract_num_cpus` -> `--pdf-extract-cpus-per-task` - - `page_elements_workers` -> `--page-elements-actors` - - `ocr_workers` -> `--ocr-actors` - - `embed_workers` -> `--embed-actors` - - `gpu_page_elements` -> `--page-elements-gpus-per-actor` - - `gpu_ocr` -> `--ocr-gpus-per-actor` - - `gpu_embed` -> `--embed-gpus-per-actor` - -### Artifact / parser semantics - -- Current upstream batch mode no longer emits the old plain `[done]` / `Pages/sec` - lines on the main ingest path. -- Harness parsers now accept: - - the legacy plain-text format when present - - the newer logged line: - - `Ingestion complete. rows procesed in seconds. PPS` - - logger-prefixed recall lines such as: - - `2026-... INFO ... recall@5: 0.9043` -- `results.json` keeps the legacy page fields for backward compatibility: - - `metrics.pages` - - `metrics.pages_per_sec_ingest` -- For current upstream batch runs, those legacy page fields may be `null`. -- The authoritative ingest counters for the current upstream path are: - - `metrics.rows_processed` - - `metrics.rows_per_sec_ingest` -- `metrics.ingest_secs` is still populated from whichever upstream ingest summary - line is available. +## Structured Metrics Contract (Mar 2026) + +The harness no longer relies on stdout or stderr to derive ingest or evaluation +metrics. Instead, each supported run mode produces a shared structured +`RunReport`, and the harness projects that report into `results.json` and +`session_summary.json`. + +### Authoritative metrics sources + +- `results.json["run_report"]` is the canonical per-run payload. +- `results.json["metrics"]` is a flattened compatibility view derived from the report. +- `results.json["summary_metrics"]` is the compact downstream view used by nightly/reporting. +- `runtime_metrics/.run_report.json` mirrors the same report for direct inspection. + +### Legacy compatibility + +- `results.json` still keeps compatibility fields such as `metrics.pages`, + `metrics.pages_per_sec_ingest`, `summary_metrics.recall_5`, and + `summary_metrics.ndcg_10`. +- `command.txt` is retained for reproducibility, but it is no longer a metrics source. ### Remaining caveat -- This compatibility follow-up restores harness operability and recall gating. -- It does **not** solve the larger semantic question of authoritative physical PDF - page counts versus uploaded unique pages versus post-explode rows. -- `runtime_summary` and `detection_summary` remain best-effort side artifacts and - may still be `null` until upstream batch mode writes them consistently again. +- Metric semantics still need care when comparing `input_pages`, + `processed_pages`, and `rows_processed` across modes. +- The harness now preserves all three counters explicitly rather than inferring + one from CLI text. ## Recommended Next Iterations diff --git a/nemo_retriever/harness/test_configs.yaml b/nemo_retriever/harness/test_configs.yaml index 2be2c23df..0920ab4b0 100644 --- a/nemo_retriever/harness/test_configs.yaml +++ b/nemo_retriever/harness/test_configs.yaml @@ -3,6 +3,7 @@ active: dataset: jp20 preset: single_gpu + run_mode: batch query_csv: data/jp20_query_gt.csv input_type: pdf recall_required: true diff --git a/nemo_retriever/src/nemo_retriever/application/modes/__init__.py b/nemo_retriever/src/nemo_retriever/application/modes/__init__.py index a4ca36132..7a822e0f5 100644 --- a/nemo_retriever/src/nemo_retriever/application/modes/__init__.py +++ b/nemo_retriever/src/nemo_retriever/application/modes/__init__.py @@ -2,12 +2,10 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from .executor import run_mode_ingest from .factory import RunMode, create_runmode_ingestor -from .run_batch import run_batch -from .run_fused import run_fused -from .run_inprocess import run_inprocess -from .run_online import run_online __all__ = [ "RunMode", @@ -18,3 +16,23 @@ "run_inprocess", "run_online", ] + + +def __getattr__(name: str): + if name == "run_batch": + from .run_batch import run_batch + + return run_batch + if name == "run_fused": + from .run_fused import run_fused + + return run_fused + if name == "run_inprocess": + from .run_inprocess import run_inprocess + + return run_inprocess + if name == "run_online": + from .run_online import run_online + + return run_online + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/nemo_retriever/src/nemo_retriever/application/modes/reports.py b/nemo_retriever/src/nemo_retriever/application/modes/reports.py new file mode 100644 index 000000000..86cd7b839 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/application/modes/reports.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from nemo_retriever.params import RunMode + + +class _ReportModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class RunArtifacts(_ReportModel): + runtime_metrics_dir: str | None = None + report_file: str | None = None + runtime_summary_file: str | None = None + detection_summary_file: str | None = None + log_file: str | None = None + lancedb_uri: str | None = None + lancedb_table: str | None = None + + +class RunMetrics(_ReportModel): + input_files: int | None = None + input_pages: int | None = None + processed_pages: int | None = None + rows_processed: int | None = None + ingest_secs: float | None = None + materialize_secs: float | None = None + vdb_write_secs: float | None = None + evaluation_secs: float | None = None + total_secs: float | None = None + pages_per_sec_ingest: float | None = None + rows_per_sec_ingest: float | None = None + + +class EvaluationSummary(_ReportModel): + label: str = "Recall" + query_count: int | None = None + metrics: dict[str, float] = Field(default_factory=dict) + + +class RunEvaluationConfig(_ReportModel): + evaluation_mode: str = "recall" + query_csv: str | None = None + recall_match_mode: str = "pdf_page" + beir_loader: str | None = None + beir_dataset_name: str | None = None + beir_split: str = "test" + beir_query_language: str | None = None + beir_doc_id_field: str = "pdf_basename" + beir_ks: tuple[int, ...] = (1, 3, 5, 10) + reranker: bool = False + reranker_model_name: str = "nvidia/llama-nemotron-rerank-1b-v2" + + +class RunArtifactConfig(_ReportModel): + lancedb_uri: str = "lancedb" + lancedb_table: str = "nv-ingest" + detection_summary_file: str | None = None + log_file: str | None = None + + +class RunReport(_ReportModel): + run_mode: RunMode + input_path: str + input_type: str + evaluation_mode: str + metrics: RunMetrics = Field(default_factory=RunMetrics) + evaluation: EvaluationSummary = Field(default_factory=EvaluationSummary) + detection_summary: dict[str, Any] | None = None + runtime_summary: dict[str, Any] = Field(default_factory=dict) + artifacts: RunArtifacts = Field(default_factory=RunArtifacts) + extras: dict[str, Any] = Field(default_factory=dict) + + +def normalize_metric_key(key: str) -> str: + metric = str(key).strip().lower() + return metric.replace("@", "_").replace("-", "_") + + +def _safe_ratio(numerator: int | float | None, denominator: int | float | None) -> float | None: + if numerator is None or denominator in {None, 0, 0.0}: + return None + try: + value = float(numerator) / float(denominator) + except (TypeError, ValueError, ZeroDivisionError): + return None + return round(value, 2) + + +def canonical_pages(report: RunReport) -> int | None: + if report.metrics.processed_pages is not None: + return report.metrics.processed_pages + return report.metrics.input_pages + + +def flatten_report_metrics(report: RunReport) -> dict[str, Any]: + flat: dict[str, Any] = { + "files": report.metrics.input_files, + "pages": canonical_pages(report), + "input_files": report.metrics.input_files, + "input_pages": report.metrics.input_pages, + "processed_pages": report.metrics.processed_pages, + "rows_processed": report.metrics.rows_processed, + "ingest_secs": report.metrics.ingest_secs, + "materialize_secs": report.metrics.materialize_secs, + "vdb_write_secs": report.metrics.vdb_write_secs, + "evaluation_secs": report.metrics.evaluation_secs, + "total_secs": report.metrics.total_secs, + "pages_per_sec_ingest": report.metrics.pages_per_sec_ingest, + "rows_per_sec_ingest": report.metrics.rows_per_sec_ingest, + "evaluation_query_count": report.evaluation.query_count, + } + for key, value in report.evaluation.metrics.items(): + flat[normalize_metric_key(key)] = value + return flat + + +def project_summary_metrics(report: RunReport) -> dict[str, Any]: + flat = flatten_report_metrics(report) + return { + "pages": flat.get("pages"), + "ingest_secs": flat.get("ingest_secs"), + "pages_per_sec_ingest": flat.get("pages_per_sec_ingest"), + "recall_5": flat.get("recall_5"), + "ndcg_10": flat.get("ndcg_10"), + } + + +def build_runtime_summary(report: RunReport) -> dict[str, Any]: + summary = dict(report.runtime_summary) + summary.update( + { + "run_mode": report.run_mode, + "input_type": report.input_type, + "input_files": report.metrics.input_files, + "input_pages": report.metrics.input_pages, + "processed_pages": report.metrics.processed_pages, + "rows_processed": report.metrics.rows_processed, + "ingest_secs": report.metrics.ingest_secs, + "materialize_secs": report.metrics.materialize_secs, + "vdb_write_secs": report.metrics.vdb_write_secs, + "evaluation_secs": report.metrics.evaluation_secs, + "elapsed_secs": report.metrics.total_secs, + "pages_per_sec_ingest": report.metrics.pages_per_sec_ingest, + "rows_per_sec_ingest": report.metrics.rows_per_sec_ingest, + } + ) + return summary + + +def update_metric_derivatives(report: RunReport) -> RunReport: + updated_metrics = report.metrics.model_copy( + update={ + "pages_per_sec_ingest": ( + report.metrics.pages_per_sec_ingest + if report.metrics.pages_per_sec_ingest is not None + else _safe_ratio(canonical_pages(report), report.metrics.ingest_secs) + ), + "rows_per_sec_ingest": ( + report.metrics.rows_per_sec_ingest + if report.metrics.rows_per_sec_ingest is not None + else _safe_ratio(report.metrics.rows_processed, report.metrics.ingest_secs) + ), + } + ) + updated_report = report.model_copy(update={"metrics": updated_metrics}) + return updated_report.model_copy(update={"runtime_summary": build_runtime_summary(updated_report)}) + + +def persist_run_report_artifacts( + report: RunReport, *, runtime_metrics_dir: str | None, prefix: str | None +) -> RunReport: + if runtime_metrics_dir is None: + return update_metric_derivatives(report) + + root = Path(runtime_metrics_dir).expanduser().resolve() + root.mkdir(parents=True, exist_ok=True) + run_prefix = str(prefix or "run") + report_path = root / f"{run_prefix}.run_report.json" + runtime_summary_path = root / f"{run_prefix}.runtime.summary.json" + + updated_report = update_metric_derivatives( + report.model_copy( + update={ + "artifacts": report.artifacts.model_copy( + update={ + "runtime_metrics_dir": str(root), + "report_file": str(report_path), + "runtime_summary_file": str(runtime_summary_path), + } + ) + } + ) + ) + + report_path.write_text( + json.dumps(updated_report.model_dump(mode="python"), indent=2, sort_keys=False) + "\n", + encoding="utf-8", + ) + runtime_summary_path.write_text( + json.dumps(updated_report.runtime_summary, indent=2, sort_keys=False) + "\n", + encoding="utf-8", + ) + return updated_report diff --git a/nemo_retriever/src/nemo_retriever/application/modes/run_batch.py b/nemo_retriever/src/nemo_retriever/application/modes/run_batch.py index affa968a8..05f975ec0 100644 --- a/nemo_retriever/src/nemo_retriever/application/modes/run_batch.py +++ b/nemo_retriever/src/nemo_retriever/application/modes/run_batch.py @@ -4,10 +4,225 @@ from __future__ import annotations +import json +from pathlib import Path +import time + +from pydantic import BaseModel, ConfigDict, Field + +from nemo_retriever.ingest_modes.batch import BatchIngestor +from nemo_retriever.params import EmbedParams +from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams +from nemo_retriever.params import TextChunkParams +from nemo_retriever.utils.detection_summary import print_detection_summary, print_run_summary, write_detection_summary +from nemo_retriever.utils.input_files import resolve_input_files, resolve_input_patterns +from nemo_retriever.vector_store.lancedb_store import handle_lancedb from .executor import run_mode_ingest +from .reports import ( + RunArtifactConfig, + RunArtifacts, + RunEvaluationConfig, + RunMetrics, + RunReport, + persist_run_report_artifacts, +) +from .shared import ( + DEFAULT_LANCEDB_TABLE, + DEFAULT_LANCEDB_URI, + ensure_lancedb_table, + estimate_processed_pages, + evaluate_lancedb_metrics, + resolve_input_pages, +) + + +class _RunnerConfigModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class BatchPipelineConfig(_RunnerConfigModel): + input_path: str + input_type: str = "pdf" + file_patterns: list[str] = Field(default_factory=list) + create_params: IngestorCreateParams = Field(default_factory=IngestorCreateParams) + execute_params: IngestExecuteParams = Field(default_factory=IngestExecuteParams) + extract_params: ExtractParams | None = None + embed_params: EmbedParams = Field(default_factory=EmbedParams) + text_chunk_params: TextChunkParams | None = None + enable_text_chunk: bool = False + evaluation: RunEvaluationConfig = Field(default_factory=RunEvaluationConfig) + artifacts: RunArtifactConfig = Field(default_factory=RunArtifactConfig) + hybrid: bool = False + + +def _resolve_file_patterns(cfg: BatchPipelineConfig) -> list[str]: + if cfg.file_patterns: + return list(cfg.file_patterns) + return resolve_input_patterns(Path(cfg.input_path), cfg.input_type) + + +def _build_ingestor(cfg: BatchPipelineConfig): + from nemo_retriever.ingestor import create_ingestor + + file_patterns = _resolve_file_patterns(cfg) + ingestor = create_ingestor(run_mode="batch", params=cfg.create_params) + chunk_params = cfg.text_chunk_params or TextChunkParams() + + if cfg.input_type == "txt": + ingestor = ingestor.files(file_patterns).extract_txt(chunk_params) + elif cfg.input_type == "html": + ingestor = ingestor.files(file_patterns).extract_html(chunk_params) + elif cfg.input_type == "image": + ingestor = ingestor.files(file_patterns).extract_image_files(cfg.extract_params or ExtractParams()) + else: + ingestor = ingestor.files(file_patterns).extract(cfg.extract_params or ExtractParams()) + + if cfg.enable_text_chunk: + ingestor = ingestor.split(chunk_params) + + ingestor = ingestor.embed(cfg.embed_params) + return ingestor, file_patterns + + +def _write_error_rows(error_rows, output_dir: str | None) -> None: + error_count = int(error_rows.count()) + if error_count <= 0: + return + + target_dir = Path(output_dir).expanduser().resolve() if output_dir else Path.cwd() + target_dir.mkdir(parents=True, exist_ok=True) + error_file = target_dir / "ingest_errors.json" + error_rows_to_write = error_rows.take(min(5, error_count)) + with error_file.open("w", encoding="utf-8") as fh: + json.dump(error_rows_to_write, fh, indent=2, default=str) + fh.write("\n") + raise RuntimeError( + "Detected " + f"{error_count} error row(s) in ingest results. " + f"Wrote first {len(error_rows_to_write)} row(s) to {error_file}." + ) + + +def run_batch_pipeline(cfg: BatchPipelineConfig) -> RunReport: + try: + input_path = Path(cfg.input_path).expanduser().resolve() + input_files = resolve_input_files(input_path, cfg.input_type) + resolved_input_pages = resolve_input_pages(cfg.input_type, input_files) + lancedb_uri = str(Path(cfg.artifacts.lancedb_uri or DEFAULT_LANCEDB_URI).expanduser().resolve()) + lancedb_table = str(cfg.artifacts.lancedb_table or DEFAULT_LANCEDB_TABLE) + + ensure_lancedb_table(lancedb_uri, lancedb_table) + ingestor, file_patterns = _build_ingestor(cfg) + + ingest_start = time.perf_counter() + ingest_results = ingestor.ingest(params=cfg.execute_params).get_dataset().materialize() + ingest_secs = time.perf_counter() - ingest_start + + materialize_start = time.perf_counter() + ingest_local_results = ingest_results.take_all() + materialize_secs = time.perf_counter() - materialize_start + + vdb_write_start = time.perf_counter() + handle_lancedb(ingest_local_results, lancedb_uri, lancedb_table, hybrid=cfg.hybrid, mode="overwrite") + vdb_write_secs = time.perf_counter() - vdb_write_start + + if isinstance(ingestor, BatchIngestor): + error_rows = ingestor.get_error_rows(dataset=ingest_results).materialize() + _write_error_rows(error_rows, cfg.execute_params.runtime_metrics_dir) + + detection_summary = estimate_processed_pages(lancedb_uri, lancedb_table) + evaluation_summary, evaluation_secs = evaluate_lancedb_metrics( + cfg.evaluation, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + embed_model_name=str(cfg.embed_params.model_name or "nvidia/llama-nemotron-embed-1b-v2"), + embed_invoke_url=cfg.embed_params.embed_invoke_url, + embed_api_key=cfg.embed_params.api_key, + hybrid=cfg.hybrid, + ) + total_secs = time.perf_counter() - ingest_start + + from nemo_retriever.utils.detection_summary import collect_detection_summary_from_lancedb + + detection_payload = collect_detection_summary_from_lancedb(lancedb_uri, lancedb_table) + if cfg.artifacts.detection_summary_file is not None: + detection_path = Path(cfg.artifacts.detection_summary_file).expanduser().resolve() + write_detection_summary(detection_path, detection_payload) + detection_summary_file = str(detection_path) + else: + detection_summary_file = None + + processed_pages = ( + detection_payload.get("pages_seen") + if isinstance(detection_payload, dict) and detection_payload.get("pages_seen") is not None + else detection_summary + ) + + report = RunReport( + run_mode="batch", + input_path=str(input_path), + input_type=cfg.input_type, + evaluation_mode=cfg.evaluation.evaluation_mode, + metrics=RunMetrics( + input_files=len(input_files) or None, + input_pages=resolved_input_pages, + processed_pages=processed_pages, + rows_processed=len(ingest_local_results), + ingest_secs=ingest_secs, + materialize_secs=materialize_secs, + vdb_write_secs=vdb_write_secs, + evaluation_secs=evaluation_secs, + total_secs=total_secs, + ), + evaluation=evaluation_summary, + detection_summary=detection_payload, + artifacts=RunArtifacts( + detection_summary_file=detection_summary_file, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ), + extras={"file_patterns": file_patterns}, + ) + return persist_run_report_artifacts( + report, + runtime_metrics_dir=cfg.execute_params.runtime_metrics_dir, + prefix=cfg.execute_params.runtime_metrics_prefix, + ) + finally: + try: + import ray + + ray.shutdown() + except Exception: + pass + + +def render_batch_run_report(report: RunReport, *, hybrid: bool) -> None: + if report.detection_summary is not None: + print_detection_summary(report.detection_summary) + + processed_pages = report.metrics.processed_pages or report.metrics.input_pages + if processed_pages is None or report.metrics.ingest_secs is None or report.metrics.total_secs is None: + return + + print_run_summary( + processed_pages, + Path(report.input_path), + hybrid, + str(report.artifacts.lancedb_uri or DEFAULT_LANCEDB_URI), + str(report.artifacts.lancedb_table or DEFAULT_LANCEDB_TABLE), + report.metrics.total_secs, + report.metrics.ingest_secs, + float(report.metrics.materialize_secs or 0.0), + float(report.metrics.vdb_write_secs or 0.0), + float(report.metrics.evaluation_secs or 0.0), + report.evaluation.metrics, + evaluation_label=report.evaluation.label, + evaluation_count=report.evaluation.query_count, + ) def run_batch( diff --git a/nemo_retriever/src/nemo_retriever/application/modes/run_fused.py b/nemo_retriever/src/nemo_retriever/application/modes/run_fused.py index a5496bc44..de63f15fc 100644 --- a/nemo_retriever/src/nemo_retriever/application/modes/run_fused.py +++ b/nemo_retriever/src/nemo_retriever/application/modes/run_fused.py @@ -4,10 +4,169 @@ from __future__ import annotations +from pathlib import Path +import time + +from pydantic import BaseModel, ConfigDict, Field + +from nemo_retriever.params import EmbedParams +from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams +from nemo_retriever.params import VdbUploadParams +from nemo_retriever.utils.detection_summary import ( + print_detection_summary, + print_pages_per_second, + write_detection_summary, +) +from nemo_retriever.utils.input_files import resolve_input_files, resolve_input_patterns from .executor import run_mode_ingest +from .reports import ( + RunArtifactConfig, + RunArtifacts, + RunEvaluationConfig, + RunMetrics, + RunReport, + persist_run_report_artifacts, +) +from .shared import ( + DEFAULT_LANCEDB_TABLE, + DEFAULT_LANCEDB_URI, + count_lancedb_rows, + ensure_lancedb_table, + estimate_processed_pages, + evaluate_lancedb_metrics, + resolve_input_pages, +) + + +class _RunnerConfigModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class FusedPipelineConfig(_RunnerConfigModel): + input_path: str + input_type: str = "pdf" + file_patterns: list[str] = Field(default_factory=list) + create_params: IngestorCreateParams = Field(default_factory=IngestorCreateParams) + execute_params: IngestExecuteParams = Field(default_factory=IngestExecuteParams) + extract_params: ExtractParams = Field(default_factory=ExtractParams) + embed_params: EmbedParams = Field(default_factory=EmbedParams) + vdb_upload_params: VdbUploadParams = Field(default_factory=VdbUploadParams) + evaluation: RunEvaluationConfig = Field(default_factory=RunEvaluationConfig) + artifacts: RunArtifactConfig = Field(default_factory=RunArtifactConfig) + + +def _resolve_file_patterns(cfg: FusedPipelineConfig) -> list[str]: + if cfg.file_patterns: + return list(cfg.file_patterns) + return resolve_input_patterns(Path(cfg.input_path), cfg.input_type) + + +def run_fused_pipeline(cfg: FusedPipelineConfig) -> RunReport: + try: + from nemo_retriever.ingestor import create_ingestor + + if cfg.input_type != "pdf": + raise ValueError("Fused mode currently supports only pdf input_type.") + + input_path = Path(cfg.input_path).expanduser().resolve() + input_files = resolve_input_files(input_path, cfg.input_type) + file_patterns = _resolve_file_patterns(cfg) + lancedb_uri = str( + Path(cfg.vdb_upload_params.lancedb.lancedb_uri or cfg.artifacts.lancedb_uri or DEFAULT_LANCEDB_URI) + .expanduser() + .resolve() + ) + lancedb_table = str( + cfg.vdb_upload_params.lancedb.table_name or cfg.artifacts.lancedb_table or DEFAULT_LANCEDB_TABLE + ) + + ensure_lancedb_table(lancedb_uri, lancedb_table) + ingestor = create_ingestor(run_mode="fused", params=cfg.create_params) + ingestor = ( + ingestor.files(file_patterns) + .extract(cfg.extract_params) + .embed(cfg.embed_params) + .vdb_upload(cfg.vdb_upload_params) + ) + + ingest_start = time.perf_counter() + ingestor.ingest(params=cfg.execute_params) + ingest_secs = time.perf_counter() - ingest_start + + from nemo_retriever.utils.detection_summary import collect_detection_summary_from_lancedb + + detection_payload = collect_detection_summary_from_lancedb(lancedb_uri, lancedb_table) + if cfg.artifacts.detection_summary_file is not None: + detection_path = Path(cfg.artifacts.detection_summary_file).expanduser().resolve() + write_detection_summary(detection_path, detection_payload) + detection_summary_file = str(detection_path) + else: + detection_summary_file = None + + evaluation_summary, evaluation_secs = evaluate_lancedb_metrics( + cfg.evaluation, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + embed_model_name=str(cfg.embed_params.model_name or "nemo_retriever_v1"), + embed_invoke_url=cfg.embed_params.embed_invoke_url, + embed_api_key=cfg.embed_params.api_key, + hybrid=bool(cfg.vdb_upload_params.lancedb.hybrid), + ) + total_secs = ingest_secs + evaluation_secs + + report = RunReport( + run_mode="fused", + input_path=str(input_path), + input_type=cfg.input_type, + evaluation_mode=cfg.evaluation.evaluation_mode, + metrics=RunMetrics( + input_files=len(input_files) or None, + input_pages=resolve_input_pages(cfg.input_type, input_files), + processed_pages=( + detection_payload.get("pages_seen") + if isinstance(detection_payload, dict) and detection_payload.get("pages_seen") is not None + else estimate_processed_pages(lancedb_uri, lancedb_table) + ), + rows_processed=count_lancedb_rows(lancedb_uri, lancedb_table), + ingest_secs=ingest_secs, + evaluation_secs=evaluation_secs, + total_secs=total_secs, + ), + evaluation=evaluation_summary, + detection_summary=detection_payload, + artifacts=RunArtifacts( + detection_summary_file=detection_summary_file, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ), + extras={"file_patterns": file_patterns}, + ) + return persist_run_report_artifacts( + report, + runtime_metrics_dir=cfg.execute_params.runtime_metrics_dir, + prefix=cfg.execute_params.runtime_metrics_prefix, + ) + finally: + try: + import ray + + ray.shutdown() + except Exception: + pass + + +def render_fused_run_report(report: RunReport) -> None: + if report.detection_summary is not None: + print_detection_summary(report.detection_summary) + if report.metrics.ingest_secs is not None: + print_pages_per_second(report.metrics.processed_pages, report.metrics.ingest_secs) + if report.evaluation.metrics: + print(f"\n{report.evaluation.label} metrics:") + for key, value in sorted(report.evaluation.metrics.items()): + print(f" {key}: {value:.4f}") def run_fused( diff --git a/nemo_retriever/src/nemo_retriever/application/modes/run_inprocess.py b/nemo_retriever/src/nemo_retriever/application/modes/run_inprocess.py index b9ed663cf..f74b3706b 100644 --- a/nemo_retriever/src/nemo_retriever/application/modes/run_inprocess.py +++ b/nemo_retriever/src/nemo_retriever/application/modes/run_inprocess.py @@ -4,10 +4,173 @@ from __future__ import annotations +from pathlib import Path +import time + +import pandas as pd +from pydantic import BaseModel, ConfigDict, Field + +from nemo_retriever.params import EmbedParams +from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams +from nemo_retriever.params import TextChunkParams +from nemo_retriever.params import VdbUploadParams +from nemo_retriever.utils.detection_summary import ( + collect_detection_summary_from_df, + print_detection_summary, + print_pages_per_second, + write_detection_summary, +) +from nemo_retriever.utils.input_files import resolve_input_files, resolve_input_patterns from .executor import run_mode_ingest +from .reports import ( + RunArtifactConfig, + RunArtifacts, + RunEvaluationConfig, + RunMetrics, + RunReport, + persist_run_report_artifacts, +) +from .shared import ( + DEFAULT_LANCEDB_TABLE, + DEFAULT_LANCEDB_URI, + ensure_lancedb_table, + evaluate_lancedb_metrics, + resolve_input_pages, +) + + +class _RunnerConfigModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class InProcessPipelineConfig(_RunnerConfigModel): + input_path: str + input_type: str = "pdf" + file_patterns: list[str] = Field(default_factory=list) + execute_params: IngestExecuteParams = Field(default_factory=IngestExecuteParams) + extract_params: ExtractParams | None = None + embed_params: EmbedParams = Field(default_factory=EmbedParams) + text_chunk_params: TextChunkParams | None = None + enable_text_chunk: bool = False + vdb_upload_params: VdbUploadParams = Field(default_factory=VdbUploadParams) + evaluation: RunEvaluationConfig = Field(default_factory=RunEvaluationConfig) + artifacts: RunArtifactConfig = Field(default_factory=RunArtifactConfig) + + +def _resolve_file_patterns(cfg: InProcessPipelineConfig) -> list[str]: + if cfg.file_patterns: + return list(cfg.file_patterns) + return resolve_input_patterns(Path(cfg.input_path), cfg.input_type) + + +def _build_ingestor(cfg: InProcessPipelineConfig): + from nemo_retriever.ingestor import create_ingestor + + file_patterns = _resolve_file_patterns(cfg) + ingestor = create_ingestor(run_mode="inprocess") + chunk_params = cfg.text_chunk_params or TextChunkParams() + + if cfg.input_type == "txt": + ingestor = ingestor.files(file_patterns).extract_txt(chunk_params) + elif cfg.input_type == "html": + ingestor = ingestor.files(file_patterns).extract_html(chunk_params) + elif cfg.input_type == "image": + ingestor = ingestor.files(file_patterns).extract_image_files(cfg.extract_params or ExtractParams()) + else: + ingestor = ingestor.files(file_patterns).extract(cfg.extract_params or ExtractParams()) + + if cfg.enable_text_chunk: + ingestor = ingestor.split(chunk_params) + + ingestor = ingestor.embed(cfg.embed_params).vdb_upload(cfg.vdb_upload_params) + return ingestor, file_patterns + + +def run_inprocess_pipeline(cfg: InProcessPipelineConfig) -> RunReport: + input_path = Path(cfg.input_path).expanduser().resolve() + input_files = resolve_input_files(input_path, cfg.input_type) + ingestor, file_patterns = _build_ingestor(cfg) + + lancedb_uri = str( + Path(cfg.vdb_upload_params.lancedb.lancedb_uri or cfg.artifacts.lancedb_uri or DEFAULT_LANCEDB_URI) + .expanduser() + .resolve() + ) + lancedb_table = str( + cfg.vdb_upload_params.lancedb.table_name or cfg.artifacts.lancedb_table or DEFAULT_LANCEDB_TABLE + ) + ensure_lancedb_table(lancedb_uri, lancedb_table) + + ingest_start = time.perf_counter() + results = ingestor.ingest(params=cfg.execute_params) + ingest_secs = time.perf_counter() - ingest_start + + dataframes = [item for item in (results or []) if isinstance(item, pd.DataFrame) and not item.empty] + combined = pd.concat(dataframes, ignore_index=True) if dataframes else None + detection_payload = collect_detection_summary_from_df(combined) if combined is not None else None + processed_pages = detection_payload.get("pages_seen") if isinstance(detection_payload, dict) else None + rows_processed = int(combined.shape[0]) if combined is not None else None + + if cfg.artifacts.detection_summary_file is not None: + detection_path = Path(cfg.artifacts.detection_summary_file).expanduser().resolve() + write_detection_summary(detection_path, detection_payload) + detection_summary_file = str(detection_path) + else: + detection_summary_file = None + + evaluation_summary, evaluation_secs = evaluate_lancedb_metrics( + cfg.evaluation, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + embed_model_name=str(cfg.embed_params.model_name or "nvidia/llama-nemotron-embed-1b-v2"), + embed_invoke_url=cfg.embed_params.embed_invoke_url, + embed_api_key=cfg.embed_params.api_key, + hybrid=bool(cfg.vdb_upload_params.lancedb.hybrid), + ) + total_secs = ingest_secs + evaluation_secs + + report = RunReport( + run_mode="inprocess", + input_path=str(input_path), + input_type=cfg.input_type, + evaluation_mode=cfg.evaluation.evaluation_mode, + metrics=RunMetrics( + input_files=len(input_files) or None, + input_pages=resolve_input_pages(cfg.input_type, input_files), + processed_pages=processed_pages, + rows_processed=rows_processed, + ingest_secs=ingest_secs, + evaluation_secs=evaluation_secs, + total_secs=total_secs, + ), + evaluation=evaluation_summary, + detection_summary=detection_payload, + artifacts=RunArtifacts( + detection_summary_file=detection_summary_file, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ), + extras={"file_patterns": file_patterns}, + ) + return persist_run_report_artifacts( + report, + runtime_metrics_dir=cfg.execute_params.runtime_metrics_dir, + prefix=cfg.execute_params.runtime_metrics_prefix, + ) + + +def render_inprocess_run_report(report: RunReport, *, include_ingest_summary: bool = True) -> None: + if include_ingest_summary and report.detection_summary is not None: + print_detection_summary(report.detection_summary) + if include_ingest_summary and report.metrics.ingest_secs is not None: + print_pages_per_second(report.metrics.processed_pages, report.metrics.ingest_secs) + if report.evaluation.metrics: + print(f"\n{report.evaluation.label} metrics:") + for key, value in sorted(report.evaluation.metrics.items()): + print(f" {key}: {value:.4f}") def run_inprocess( diff --git a/nemo_retriever/src/nemo_retriever/application/modes/shared.py b/nemo_retriever/src/nemo_retriever/application/modes/shared.py new file mode 100644 index 000000000..60c24cfc2 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/application/modes/shared.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from importlib import import_module +import logging +from pathlib import Path +import sys +import time +from typing import Any, Optional, TextIO + +from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema +from nemo_retriever.model import resolve_embed_model +from nemo_retriever.recall.beir import BeirConfig, evaluate_lancedb_beir +from nemo_retriever.recall.core import RecallConfig, retrieve_and_score + +DEFAULT_LANCEDB_URI = "lancedb" +DEFAULT_LANCEDB_TABLE = "nv-ingest" + +logger = logging.getLogger(__name__) + + +class _TeeStream: + """Write stream output to the terminal and an optional file.""" + + def __init__(self, primary: TextIO, mirror: TextIO) -> None: + self._primary = primary + self._mirror = mirror + + def write(self, data: str) -> int: + self._primary.write(data) + self._mirror.write(data) + return len(data) + + def flush(self) -> None: + self._primary.flush() + self._mirror.flush() + + def isatty(self) -> bool: + return bool(getattr(self._primary, "isatty", lambda: False)()) + + def fileno(self) -> int: + return int(getattr(self._primary, "fileno")()) + + def writable(self) -> bool: + return bool(getattr(self._primary, "writable", lambda: True)()) + + @property + def encoding(self) -> str: + return str(getattr(self._primary, "encoding", "utf-8")) + + +def configure_cli_logging(log_file: Optional[Path], *, debug: bool = False) -> tuple[Optional[TextIO], TextIO, TextIO]: + """Configure root logging and optionally tee stdout/stderr to a file.""" + + original_stdout = sys.stdout + original_stderr = sys.stderr + log_level = logging.DEBUG if debug else logging.INFO + if log_file is None: + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + force=True, + ) + return None, original_stdout, original_stderr + + target = Path(log_file).expanduser().resolve() + target.parent.mkdir(parents=True, exist_ok=True) + handle = open(target, "a", encoding="utf-8", buffering=1) + + sys.stdout = _TeeStream(sys.__stdout__, handle) + sys.stderr = _TeeStream(sys.__stderr__, handle) + logging.basicConfig( + level=log_level, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + force=True, + ) + logging.getLogger(__name__).info("Writing combined pipeline logs to %s", str(target)) + return handle, original_stdout, original_stderr + + +def restore_cli_logging( + handle: Optional[TextIO], + original_stdout: TextIO, + original_stderr: TextIO, +) -> None: + sys.stdout = original_stdout + sys.stderr = original_stderr + if handle is not None: + try: + handle.flush() + finally: + handle.close() + + +def lancedb_module() -> Any: + return import_module("lancedb") + + +def ensure_lancedb_table(uri: str, table_name: str) -> None: + """Ensure the LanceDB URI exists and the target table can be opened.""" + + Path(uri).mkdir(parents=True, exist_ok=True) + db = lancedb_module().connect(uri) + try: + db.open_table(table_name) + return + except Exception: + pass + + import pyarrow as pa # type: ignore + + schema = lancedb_schema() + empty = pa.table({field.name: [] for field in schema}, schema=schema) + db.create_table(table_name, data=empty, schema=schema, mode="create") + + +def open_lancedb_table_with_retry( + uri: str, + table_name: str, + *, + retries: int = 3, + sleep_seconds: float = 2.0, +) -> Any: + db = lancedb_module().connect(uri) + open_err: Exception | None = None + for _ in range(max(1, retries)): + try: + return db.open_table(table_name) + except Exception as exc: + open_err = exc + ensure_lancedb_table(uri, table_name) + time.sleep(sleep_seconds) + raise RuntimeError(f"Could not open LanceDB table {table_name!r} at {uri!r}") from open_err + + +def count_lancedb_rows(uri: str, table_name: str) -> int | None: + try: + table = open_lancedb_table_with_retry(uri, table_name, retries=1, sleep_seconds=0.0) + return int(table.count_rows()) + except Exception: + return None + + +def estimate_processed_pages(uri: str, table_name: str) -> int | None: + """Estimate processed pages from unique `(source_id, page_number)` pairs.""" + + try: + table = open_lancedb_table_with_retry(uri, table_name, retries=1, sleep_seconds=0.0) + df = table.to_pandas()[["source_id", "page_number"]] + return int(df.dropna(subset=["source_id", "page_number"]).drop_duplicates().shape[0]) + except Exception: + return count_lancedb_rows(uri, table_name) + + +def safe_pdf_page_count(path: Path) -> int | None: + try: + import pypdfium2 as pdfium # type: ignore + + doc = pdfium.PdfDocument(str(path)) + try: + try: + count = int(len(doc)) + except Exception: + count = int(doc.get_page_count()) # type: ignore[attr-defined] + finally: + try: + doc.close() + except Exception: + pass + return max(count, 0) + except Exception: + return None + + +def resolve_input_pages(input_type: str, input_files: list[Path]) -> int | None: + if input_type == "image": + return len(input_files) + if input_type != "pdf": + return None + + total_pages = 0 + counted_any = False + for path in input_files: + page_count = safe_pdf_page_count(path) + if page_count is None: + continue + counted_any = True + total_pages += page_count + if counted_any: + return total_pages + return None + + +def evaluate_lancedb_metrics( + evaluation_cfg, + *, + lancedb_uri: str, + lancedb_table: str, + embed_model_name: str, + embed_invoke_url: str | None, + embed_api_key: str | None, + hybrid: bool, +): + from .reports import EvaluationSummary + + if evaluation_cfg.evaluation_mode not in {"recall", "beir"}: + raise ValueError(f"Unsupported evaluation mode: {evaluation_cfg.evaluation_mode}") + + table = open_lancedb_table_with_retry(lancedb_uri, lancedb_table) + try: + if int(table.count_rows()) == 0: + logger.warning( + "LanceDB table %r exists but is empty; skipping %s evaluation.", + lancedb_table, + evaluation_cfg.evaluation_mode, + ) + return EvaluationSummary(label="BEIR" if evaluation_cfg.evaluation_mode == "beir" else "Recall"), 0.0 + except Exception: + pass + + resolved_model = resolve_embed_model(embed_model_name) + evaluation_label = "BEIR" if evaluation_cfg.evaluation_mode == "beir" else "Recall" + + if evaluation_cfg.evaluation_mode == "beir": + if not evaluation_cfg.beir_loader: + raise ValueError("--beir-loader is required when --evaluation-mode=beir") + if not evaluation_cfg.beir_dataset_name: + raise ValueError("--beir-dataset-name is required when --evaluation-mode=beir") + + beir_cfg = BeirConfig( + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + embedding_model=resolved_model, + loader=str(evaluation_cfg.beir_loader), + dataset_name=str(evaluation_cfg.beir_dataset_name), + split=str(evaluation_cfg.beir_split), + query_language=evaluation_cfg.beir_query_language, + doc_id_field=str(evaluation_cfg.beir_doc_id_field), + ks=tuple(evaluation_cfg.beir_ks) if evaluation_cfg.beir_ks else (1, 3, 5, 10), + embedding_http_endpoint=embed_invoke_url, + embedding_api_key=(embed_api_key or "").strip(), + hybrid=hybrid, + reranker=bool(evaluation_cfg.reranker), + reranker_model_name=str(evaluation_cfg.reranker_model_name), + ) + evaluation_start = time.perf_counter() + beir_dataset, _raw_hits, _run, metrics = evaluate_lancedb_beir(beir_cfg) + return ( + EvaluationSummary(label=evaluation_label, query_count=len(beir_dataset.query_ids), metrics=metrics), + time.perf_counter() - evaluation_start, + ) + + query_csv = Path(str(evaluation_cfg.query_csv)).expanduser() if evaluation_cfg.query_csv else None + if query_csv is None or not query_csv.exists(): + logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv) + return EvaluationSummary(label=evaluation_label), 0.0 + + recall_cfg = RecallConfig( + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + embedding_model=resolved_model, + embedding_http_endpoint=embed_invoke_url, + embedding_api_key=(embed_api_key or "").strip(), + top_k=10, + ks=(1, 5, 10), + hybrid=hybrid, + match_mode=str(evaluation_cfg.recall_match_mode), + reranker=str(evaluation_cfg.reranker_model_name) if evaluation_cfg.reranker else None, + ) + evaluation_start = time.perf_counter() + df_query, _gold, _raw_hits, _retrieved_keys, metrics = retrieve_and_score(query_csv=query_csv, cfg=recall_cfg) + return ( + EvaluationSummary(label=evaluation_label, query_count=len(df_query.index), metrics=metrics), + time.perf_counter() - evaluation_start, + ) diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index 6098f3731..61f81db71 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -11,27 +11,24 @@ import logging import os import sys -import time from importlib import import_module from pathlib import Path from typing import Any, Optional, TextIO -from nemo_retriever.utils.detection_summary import print_run_summary -import ray import typer -from nemo_retriever import create_ingestor -from nemo_retriever.ingest_modes.batch import BatchIngestor from nemo_retriever.ingest_modes.lancedb_utils import lancedb_schema -from nemo_retriever.model import resolve_embed_model from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import TextChunkParams -from nemo_retriever.recall.beir import BeirConfig, evaluate_lancedb_beir -from nemo_retriever.recall.core import RecallConfig, retrieve_and_score +from nemo_retriever.application.modes.reports import RunArtifactConfig, RunEvaluationConfig +from nemo_retriever.application.modes.run_batch import ( + BatchPipelineConfig, + render_batch_run_report, + run_batch_pipeline, +) from nemo_retriever.utils.remote_auth import resolve_remote_api_key -from nemo_retriever.vector_store.lancedb_store import handle_lancedb logger = logging.getLogger(__name__) @@ -650,13 +647,6 @@ def main( else: raise typer.BadParameter(f"Path does not exist: {input_path}") - ingestor = create_ingestor( - run_mode="batch", - params=IngestorCreateParams( - ray_address=ray_address, ray_log_to_driver=ray_log_to_driver, debug=bool(debug) - ), - ) - # -- Shared params used by multiple input-type branches ---------------- embed_params = EmbedParams( model_name=str(embed_model_name), @@ -730,192 +720,55 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams: overlap_tokens=text_chunk_overlap_tokens if text_chunk_overlap_tokens is not None else 150, ) - if input_type == "txt": - ingestor = ingestor.files(file_patterns).extract_txt(_text_chunk_params) - elif input_type == "html": - ingestor = ingestor.files(file_patterns).extract_html(_text_chunk_params) + enable_text_chunk = text_chunk or text_chunk_max_tokens is not None or text_chunk_overlap_tokens is not None + if input_type == "txt" or input_type == "html": + resolved_extract_params = None elif input_type == "image": - ingestor = ingestor.files(file_patterns).extract_image_files(_extract_params(_detection_batch_tuning)) + resolved_extract_params = _extract_params(_detection_batch_tuning) elif input_type == "doc": - ingestor = ingestor.files(file_patterns).extract(_extract_params(_pdf_batch_tuning)) + resolved_extract_params = _extract_params(_pdf_batch_tuning) else: - ingestor = ingestor.files(file_patterns).extract( - _extract_params(_pdf_batch_tuning, inference_batch_size=page_elements_batch_size) - ) - - enable_text_chunk = text_chunk or text_chunk_max_tokens is not None or text_chunk_overlap_tokens is not None - if enable_text_chunk: - ingestor = ingestor.split(_text_chunk_params) - - ingestor = ingestor.embed(embed_params) - - logger.info("Running extraction...") - ingest_start = time.perf_counter() - - ingest_results = ( - ingestor.ingest( - params=IngestExecuteParams( + resolved_extract_params = _extract_params(_pdf_batch_tuning, inference_batch_size=page_elements_batch_size) + + report = run_batch_pipeline( + BatchPipelineConfig( + input_path=str(input_path), + input_type=input_type, + file_patterns=file_patterns, + create_params=IngestorCreateParams( + ray_address=ray_address, ray_log_to_driver=ray_log_to_driver, debug=bool(debug) + ), + execute_params=IngestExecuteParams( runtime_metrics_dir=str(runtime_metrics_dir) if runtime_metrics_dir is not None else None, runtime_metrics_prefix=runtime_metrics_prefix, - ) - ) - .get_dataset() - .materialize() - ) - - ingestion_only_total_time = time.perf_counter() - ingest_start - - # Capture the time it takes to download the Ray dataset to the local machine for reporting. - ray_dataset_download_start = time.perf_counter() - ingest_local_results = ingest_results.take_all() - ray_dataset_download_time = time.perf_counter() - ray_dataset_download_start - - # Write to lancedb and capture the time it takes. - lancedb_write_start = time.perf_counter() - handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") - lancedb_write_time = time.perf_counter() - lancedb_write_start - - if isinstance(ingestor, BatchIngestor): - error_rows = ingestor.get_error_rows(dataset=ingest_results).materialize() - error_count = int(error_rows.count()) - - # Error out, stop processing, and write top 5 errors rows to a local file for analysis. - if error_count > 0: - error_file = Path("ingest_errors.json").resolve() - max_error_rows_to_write = 5 - error_rows_to_write = error_rows.take(min(max_error_rows_to_write, error_count)) - with error_file.open("w", encoding="utf-8") as fh: - json.dump(error_rows_to_write, fh, indent=2, default=str) - fh.write("\n") - logger.error( - "Detected %d error row(s) in ingest results. Wrote first %d row(s) " - "to %s. Showing top 5 extracted errors and exiting before recall." - " Writing top(%d) error rows to %s", - error_count, - len(error_rows_to_write), - str(error_file), - int(max_error_rows_to_write), - str(error_file), - ) - - ray.shutdown() - logger.error(f"Exiting with code 1 due to {error_count} error rows in ingest results.") - raise typer.Exit(code=1) - - # --------------------------------------------------------------------------- - # Evaluation calculation - # --------------------------------------------------------------------------- - evaluation_label = "Recall" - evaluation_total_time = 0.0 - evaluation_metrics: dict[str, float] = {} - evaluation_query_count: Optional[int] = None - - if evaluation_mode == "recall": - query_csv = Path(query_csv) - if not query_csv.exists(): - logger.warning(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - return - - db = _lancedb().connect(lancedb_uri) - table = None - open_err: Optional[Exception] = None - for _ in range(3): - try: - table = db.open_table(LANCEDB_TABLE) - open_err = None - break - except Exception as e: - open_err = e - # Create table if missing, then retry open. - _ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE) - time.sleep(2) - if table is None: - raise RuntimeError( - f"Recall stage requires LanceDB table {LANCEDB_TABLE!r} at {lancedb_uri!r}, " f"but it was not found." - ) from open_err - try: - if int(table.count_rows()) == 0: - logger.warning( - f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping {evaluation_mode} evaluation." - ) - return - except Exception: - pass - - _recall_model = resolve_embed_model(str(embed_model_name)) - if evaluation_mode == "beir": - if not beir_loader: - raise ValueError("--beir-loader is required when --evaluation-mode=beir") - if not beir_dataset_name: - raise ValueError("--beir-dataset-name is required when --evaluation-mode=beir") - - beir_cfg = BeirConfig( - lancedb_uri=str(lancedb_uri), - lancedb_table=str(LANCEDB_TABLE), - embedding_model=_recall_model, - loader=str(beir_loader), - dataset_name=str(beir_dataset_name), - split=str(beir_split), - query_language=beir_query_language, - doc_id_field=str(beir_doc_id_field), - ks=tuple(beir_k) if beir_k else (1, 3, 5, 10), - embedding_http_endpoint=embed_invoke_url, - embedding_api_key=embed_remote_api_key or "", + ), + extract_params=resolved_extract_params, + embed_params=embed_params, + text_chunk_params=_text_chunk_params, + enable_text_chunk=enable_text_chunk, + evaluation=RunEvaluationConfig( + evaluation_mode=evaluation_mode, + query_csv=str(query_csv), + recall_match_mode=recall_match_mode, + beir_loader=beir_loader, + beir_dataset_name=beir_dataset_name, + beir_split=beir_split, + beir_query_language=beir_query_language, + beir_doc_id_field=beir_doc_id_field, + beir_ks=tuple(beir_k) if beir_k else (1, 3, 5, 10), + reranker=bool(reranker), + reranker_model_name=str(reranker_model_name), + ), + artifacts=RunArtifactConfig( + lancedb_uri=lancedb_uri, + lancedb_table=LANCEDB_TABLE, + detection_summary_file=str(detection_summary_file) if detection_summary_file is not None else None, + log_file=str(log_file) if log_file is not None else None, + ), hybrid=hybrid, - reranker=bool(reranker), - reranker_model_name=str(reranker_model_name), - ) - evaluation_start = time.perf_counter() - beir_dataset, _raw_hits, _run, evaluation_metrics = evaluate_lancedb_beir(beir_cfg) - evaluation_total_time = time.perf_counter() - evaluation_start - evaluation_label = "BEIR" - evaluation_query_count = len(beir_dataset.query_ids) - else: - recall_cfg = RecallConfig( - lancedb_uri=str(lancedb_uri), - lancedb_table=str(LANCEDB_TABLE), - embedding_model=_recall_model, - embedding_http_endpoint=embed_invoke_url, - embedding_api_key=embed_remote_api_key or "", - top_k=10, - ks=(1, 5, 10), - hybrid=hybrid, - match_mode=recall_match_mode, - reranker=reranker_model_name if reranker else None, - ) - - evaluation_start = time.perf_counter() - _df_query, _gold, _raw_hits, _retrieved_keys, evaluation_metrics = retrieve_and_score( - query_csv=query_csv, - cfg=recall_cfg, ) - evaluation_total_time = time.perf_counter() - evaluation_start - evaluation_query_count = len(_df_query.index) - - total_time = time.perf_counter() - ingest_start - - # This processing has nothing to do with processing or performance so we exclude - # it from the runtimes. Just getting row counts for metrics ... - num_rows = ingest_results.groupby("source_id").count().count() - - ray.shutdown() - - # Print runtimes for easy user viewing at end - print_run_summary( - num_rows, - input_path, - hybrid, - lancedb_uri, - LANCEDB_TABLE, - total_time, - ingestion_only_total_time, - ray_dataset_download_time, - lancedb_write_time, - evaluation_total_time, - evaluation_metrics, - evaluation_label=evaluation_label, - evaluation_count=evaluation_query_count, ) + render_batch_run_report(report, hybrid=hybrid) finally: # Restore real stdio before closing the mirror file so exception hooks diff --git a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py index 55c4b8706..6fe01f1fb 100644 --- a/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/fused_pipeline.py @@ -11,38 +11,27 @@ import os import subprocess -import time from pathlib import Path from typing import Optional -import lancedb -import ray import typer -from nemo_retriever import create_ingestor +from nemo_retriever.application.modes.reports import RunArtifactConfig, RunEvaluationConfig +from nemo_retriever.application.modes.run_fused import ( + FusedPipelineConfig, + render_fused_run_report, + run_fused_pipeline, +) +from nemo_retriever.application.modes.shared import ( + DEFAULT_LANCEDB_TABLE as LANCEDB_TABLE, + DEFAULT_LANCEDB_URI as LANCEDB_URI, + configure_cli_logging, + restore_cli_logging, +) from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.examples.batch_pipeline import ( - LANCEDB_TABLE, - LANCEDB_URI, - _configure_logging, - _ensure_lancedb_table, -) -from nemo_retriever.utils.detection_summary import ( - collect_detection_summary_from_lancedb, - print_detection_summary, - write_detection_summary, -) -from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second -from nemo_retriever.recall.core import ( - RecallConfig, - gold_to_doc_page, - hit_key_and_distance, - is_hit_at_k, - retrieve_and_score, -) app = typer.Typer() @@ -175,11 +164,10 @@ def main( help="Embedding granularity: 'element' (one row per table/chart/text) or 'page' (one row per page).", ), ) -> None: - log_handle, original_stdout, original_stderr = _configure_logging(log_file) + log_handle, original_stdout, original_stderr = configure_cli_logging(log_file) try: os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0" lancedb_uri = str(Path(lancedb_uri).expanduser().resolve()) - _ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE) if start_ray: subprocess.run(["ray", "start", "--head"], check=True, env=os.environ) @@ -193,14 +181,17 @@ def main( else: raise typer.BadParameter(f"Path does not exist: {input_path}") - ingestor = create_ingestor( - run_mode="fused", - params=IngestorCreateParams(ray_address=ray_address, ray_log_to_driver=ray_log_to_driver), - ) - ingestor = ( - ingestor.files(file_patterns) - .extract( - ExtractParams( + report = run_fused_pipeline( + FusedPipelineConfig( + input_path=str(input_path), + input_type="pdf", + file_patterns=file_patterns, + create_params=IngestorCreateParams(ray_address=ray_address, ray_log_to_driver=ray_log_to_driver), + execute_params=IngestExecuteParams( + runtime_metrics_dir=str(runtime_metrics_dir) if runtime_metrics_dir is not None else None, + runtime_metrics_prefix=runtime_metrics_prefix, + ), + extract_params=ExtractParams( extract_text=True, extract_tables=True, extract_charts=True, @@ -211,10 +202,8 @@ def main( "pdf_split_batch_size": int(pdf_split_batch_size), "pdf_extract_batch_size": int(pdf_extract_batch_size), }, - ) - ) - .embed( - EmbedParams( + ), + embed_params=EmbedParams( model_name="nemo_retriever_v1", embed_granularity=embed_granularity, fused_tuning={ @@ -223,145 +212,30 @@ def main( "fused_cpus_per_actor": float(fused_cpus_per_actor), "fused_gpus_per_actor": float(fused_gpus_per_actor), }, - ) - ) - .vdb_upload( - VdbUploadParams( + ), + vdb_upload_params=VdbUploadParams( lancedb={ "lancedb_uri": lancedb_uri, "table_name": LANCEDB_TABLE, "overwrite": True, "create_index": True, } - ) - ) - ) - - print("Running extraction...") - ingest_start = time.perf_counter() - ingestor.ingest( - params=IngestExecuteParams( - runtime_metrics_dir=str(runtime_metrics_dir) if runtime_metrics_dir is not None else None, - runtime_metrics_prefix=runtime_metrics_prefix, + ), + evaluation=RunEvaluationConfig( + evaluation_mode="recall", + query_csv=str(query_csv), + ), + artifacts=RunArtifactConfig( + lancedb_uri=lancedb_uri, + lancedb_table=LANCEDB_TABLE, + detection_summary_file=str(detection_summary_file) if detection_summary_file is not None else None, + log_file=str(log_file) if log_file is not None else None, + ), ) ) - ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = estimate_processed_pages(lancedb_uri, LANCEDB_TABLE) - detection_summary = collect_detection_summary_from_lancedb(lancedb_uri, LANCEDB_TABLE) - print("Extraction complete.") - print_detection_summary(detection_summary) - if detection_summary_file is not None: - write_detection_summary(detection_summary_file, detection_summary) - print(f"Wrote detection summary JSON to {Path(detection_summary_file).expanduser().resolve()}") - - ray.shutdown() - - query_csv = Path(query_csv) - if not query_csv.exists(): - print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - print_pages_per_second(processed_pages, ingest_elapsed_s) - return - - db = lancedb.connect(lancedb_uri) - table = None - open_err: Optional[Exception] = None - for _ in range(3): - try: - table = db.open_table(LANCEDB_TABLE) - open_err = None - break - except Exception as e: - open_err = e - _ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE) - time.sleep(2) - if table is None: - raise RuntimeError( - f"Recall stage requires LanceDB table {LANCEDB_TABLE!r} at {lancedb_uri!r}, " f"but it was not found." - ) from open_err - try: - if int(table.count_rows()) == 0: - print(f"LanceDB table {LANCEDB_TABLE!r} exists but is empty; skipping recall evaluation.") - print_pages_per_second(processed_pages, ingest_elapsed_s) - return - except Exception: - pass - - unique_basenames = table.to_pandas()["pdf_basename"].unique() - print(f"Unique basenames: {unique_basenames}") - - cfg = RecallConfig( - lancedb_uri=str(lancedb_uri), - lancedb_table=str(LANCEDB_TABLE), - embedding_model="nvidia/llama-nemotron-embed-1b-v2", - top_k=10, - ks=(1, 5, 10), - ) - - _df_query, _gold, _raw_hits, _retrieved_keys, metrics = retrieve_and_score(query_csv=query_csv, cfg=cfg) - - if not no_recall_details: - print("\nPer-query retrieval details:") - missed_gold: list[tuple[str, str]] = [] - for i, (q, g, hits) in enumerate( - zip( - _df_query["query"].astype(str).tolist(), - _gold, - _raw_hits, - ) - ): - doc, page = gold_to_doc_page(g) - - scored_hits: list[tuple[str, float | None]] = [] - for h in hits: - key, dist = hit_key_and_distance(h) - if key: - scored_hits.append((key, dist)) - - top_keys = [k for (k, _d) in scored_hits] - hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") - - if not no_recall_details: - print(f"\nQuery {i}: {q}") - print(f" Gold: {g} (file: {doc}.pdf, page: {page})") - print(f" Hit@{cfg.top_k}: {hit}") - print(" Top hits:") - if not scored_hits: - print(" (no hits)") - else: - for rank, (key, dist) in enumerate(scored_hits[: int(cfg.top_k)], start=1): - if dist is None: - print(f" {rank:02d}. {key}") - else: - print(f" {rank:02d}. {key} distance={dist:.6f}") - - if not hit: - missed_gold.append((f"{doc}.pdf", str(page))) - - missed_unique = sorted(set(missed_gold), key=lambda x: (x[0], x[1])) - print("\nMissed gold (unique pdf/page):") - if not missed_unique: - print(" (none)") - else: - for pdf, page in missed_unique: - print(f" {pdf} page {page}") - print(f"\nTotal missed: {len(missed_unique)} / {len(_gold)}") - - print("\nRecall metrics (matching nemo_retriever.recall.core):") - for k, v in metrics.items(): - print(f" {k}: {v:.4f}") - print_pages_per_second(processed_pages, ingest_elapsed_s) + render_fused_run_report(report) finally: - # Restore real stdio before closing the mirror file so exception hooks - # and late flushes never write to a closed stream wrapper. - import sys - - sys.stdout = original_stdout - sys.stderr = original_stderr - if log_handle is not None: - try: - log_handle.flush() - finally: - log_handle.close() + restore_cli_logging(log_handle, original_stdout, original_stderr) if __name__ == "__main__": diff --git a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py index 386bff850..5d94093c3 100644 --- a/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/inprocess_pipeline.py @@ -7,32 +7,28 @@ Run with: uv run python -m nemo_retriever.examples.inprocess_pipeline """ -import time from pathlib import Path from typing import Optional -import lancedb import typer -from nemo_retriever import create_ingestor -from nemo_retriever.examples.common import estimate_processed_pages, print_pages_per_second +from nemo_retriever.application.modes.reports import RunArtifactConfig, RunEvaluationConfig +from nemo_retriever.application.modes.run_inprocess import ( + InProcessPipelineConfig, + render_inprocess_run_report, + run_inprocess_pipeline, +) +from nemo_retriever.application.modes.shared import ( + DEFAULT_LANCEDB_TABLE as LANCEDB_TABLE, + DEFAULT_LANCEDB_URI as LANCEDB_URI, +) from nemo_retriever.params import EmbedParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params import VdbUploadParams -from nemo_retriever.recall.core import ( - RecallConfig, - gold_to_doc_page, - hit_key_and_distance, - is_hit_at_k, - retrieve_and_score, -) app = typer.Typer() -LANCEDB_URI = "lancedb" -LANCEDB_TABLE = "nv-ingest" - @app.command() def main( @@ -197,210 +193,68 @@ def main( else: raise typer.BadParameter(f"Path does not exist: {input_path}") - ingestor = create_ingestor(run_mode="inprocess") - if input_type == "txt": - ingestor = ingestor.files(file_patterns).extract_txt( - TextChunkParams( - max_tokens=text_chunk_max_tokens or 1024, - overlap_tokens=text_chunk_overlap_tokens if text_chunk_overlap_tokens is not None else 150, - ) - ) - elif input_type == "html": - ingestor = ingestor.files(file_patterns).extract_html( - TextChunkParams( - max_tokens=text_chunk_max_tokens or 1024, - overlap_tokens=text_chunk_overlap_tokens if text_chunk_overlap_tokens is not None else 150, - ) - ) - elif input_type == "image": - ingestor = ingestor.files(file_patterns).extract_image_files( - ExtractParams( - method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, - use_graphic_elements=use_graphic_elements, - graphic_elements_invoke_url=graphic_elements_invoke_url, - use_table_structure=use_table_structure, - table_output_format=table_output_format, - table_structure_invoke_url=table_structure_invoke_url, - page_elements_invoke_url=page_elements_invoke_url, - ocr_invoke_url=ocr_invoke_url, - ) - ) - elif input_type == "doc": - ingestor = ingestor.files(file_patterns).extract( - ExtractParams( - method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, - use_graphic_elements=use_graphic_elements, - graphic_elements_invoke_url=graphic_elements_invoke_url, - use_table_structure=use_table_structure, - table_output_format=table_output_format, - table_structure_invoke_url=table_structure_invoke_url, - page_elements_invoke_url=page_elements_invoke_url, - ocr_invoke_url=ocr_invoke_url, - ) - ) - else: - ingestor = ingestor.files(file_patterns).extract( - ExtractParams( - method=method, - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=False, - use_graphic_elements=use_graphic_elements, - graphic_elements_invoke_url=graphic_elements_invoke_url, - use_table_structure=use_table_structure, - table_output_format=table_output_format, - table_structure_invoke_url=table_structure_invoke_url, - page_elements_invoke_url=page_elements_invoke_url, - ocr_invoke_url=ocr_invoke_url, - ) - ) - enable_text_chunk = text_chunk or text_chunk_max_tokens is not None or text_chunk_overlap_tokens is not None - if enable_text_chunk: - ingestor = ingestor.split( - TextChunkParams( - max_tokens=text_chunk_max_tokens or 1024, - overlap_tokens=text_chunk_overlap_tokens if text_chunk_overlap_tokens is not None else 150, - ) - ) - - ingestor = ingestor.embed( - EmbedParams( - model_name=str(embed_model_name), - embed_invoke_url=embed_invoke_url, - embed_modality=embed_modality, - text_elements_modality=text_elements_modality, - structured_elements_modality=structured_elements_modality, - embed_granularity=embed_granularity, - ) - ).vdb_upload( - VdbUploadParams( - lancedb={ - "lancedb_uri": LANCEDB_URI, - "table_name": LANCEDB_TABLE, - "overwrite": True, - "create_index": True, - "hybrid": hybrid, - } - ) - ) - - print("Running extraction...") - ingest_start = time.perf_counter() - ingestor.ingest( - params=IngestExecuteParams( - parallel=True, - max_workers=max_workers, - gpu_devices=gpu_device_list, - show_progress=True, - ) + extract_params = ExtractParams( + method=method, + extract_text=True, + extract_tables=True, + extract_charts=True, + extract_infographics=False, + use_graphic_elements=use_graphic_elements, + graphic_elements_invoke_url=graphic_elements_invoke_url, + use_table_structure=use_table_structure, + table_output_format=table_output_format, + table_structure_invoke_url=table_structure_invoke_url, + page_elements_invoke_url=page_elements_invoke_url, + ocr_invoke_url=ocr_invoke_url, ) - ingest_elapsed_s = time.perf_counter() - ingest_start - processed_pages = estimate_processed_pages(LANCEDB_URI, LANCEDB_TABLE) - print("Extraction complete.") - - # --------------------------------------------------------------------------- - # Recall calculation (optional) - # --------------------------------------------------------------------------- - query_csv = Path(query_csv) - if not query_csv.exists(): - print(f"Query CSV not found at {query_csv}; skipping recall evaluation.") - print_pages_per_second(processed_pages, ingest_elapsed_s) - return - - db = lancedb.connect(f"./{LANCEDB_URI}") - table = db.open_table(LANCEDB_TABLE) - unique_basenames = table.to_pandas()["pdf_basename"].unique() - print(f"Unique basenames: {unique_basenames}") - - # Resolve the HF model ID for recall query embedding so aliases - # (e.g. "nemo_retriever_v1") map to the correct model. - from nemo_retriever.model import resolve_embed_model - - _recall_model = resolve_embed_model(str(embed_model_name)) - - cfg = RecallConfig( - lancedb_uri=str(LANCEDB_URI), - lancedb_table=str(LANCEDB_TABLE), - embedding_model=_recall_model, - embedding_http_endpoint=embed_invoke_url, - top_k=10, - ks=(1, 5, 10), - hybrid=hybrid, + chunk_params = TextChunkParams( + max_tokens=text_chunk_max_tokens or 1024, + overlap_tokens=text_chunk_overlap_tokens if text_chunk_overlap_tokens is not None else 150, ) - _df_query, _gold, _raw_hits, _retrieved_keys, metrics = retrieve_and_score(query_csv=query_csv, cfg=cfg) - - if not no_recall_details: - print("\nPer-query retrieval details:") - missed_gold: list[tuple[str, str]] = [] - for i, (q, g, hits) in enumerate( - zip( - _df_query["query"].astype(str).tolist(), - _gold, - _raw_hits, + report = run_inprocess_pipeline( + InProcessPipelineConfig( + input_path=str(input_path), + input_type=input_type, + file_patterns=file_patterns, + execute_params=IngestExecuteParams( + parallel=True, + max_workers=max_workers, + gpu_devices=gpu_device_list, + show_progress=True, + ), + extract_params=extract_params, + embed_params=EmbedParams( + model_name=str(embed_model_name), + embed_invoke_url=embed_invoke_url, + embed_modality=embed_modality, + text_elements_modality=text_elements_modality, + structured_elements_modality=structured_elements_modality, + embed_granularity=embed_granularity, + ), + text_chunk_params=chunk_params, + enable_text_chunk=enable_text_chunk, + vdb_upload_params=VdbUploadParams( + lancedb={ + "lancedb_uri": LANCEDB_URI, + "table_name": LANCEDB_TABLE, + "overwrite": True, + "create_index": True, + "hybrid": hybrid, + } + ), + evaluation=RunEvaluationConfig( + evaluation_mode="recall", + query_csv=str(query_csv), + ), + artifacts=RunArtifactConfig( + lancedb_uri=LANCEDB_URI, + lancedb_table=LANCEDB_TABLE, + ), ) - ): - doc, page = gold_to_doc_page(g) - - scored_hits: list[tuple[str, float | None]] = [] - for h in hits: - key, dist = hit_key_and_distance(h) - if key: - scored_hits.append((key, dist)) - - top_keys = [k for (k, _d) in scored_hits] - hit = is_hit_at_k(g, top_keys, cfg.top_k, match_mode="pdf_page") - - if not no_recall_details: - ext = ( - ".txt" - if input_type == "txt" - else (".html" if input_type == "html" else (".docx" if input_type == "doc" else ".pdf")) - ) - print(f"\nQuery {i}: {q}") - print(f" Gold: {g} (file: {doc}{ext}, page: {page})") - print(f" Hit@{cfg.top_k}: {hit}") - print(" Top hits:") - if not scored_hits: - print(" (no hits)") - else: - for rank, (key, dist) in enumerate(scored_hits[: int(cfg.top_k)], start=1): - if dist is None: - print(f" {rank:02d}. {key}") - else: - print(f" {rank:02d}. {key} distance={dist:.6f}") - - if not hit: - ext = ( - ".txt" - if input_type == "txt" - else (".html" if input_type == "html" else (".docx" if input_type == "doc" else ".pdf")) - ) - missed_gold.append((f"{doc}{ext}", str(page))) - - missed_unique = sorted(set(missed_gold), key=lambda x: (x[0], x[1])) - print("\nMissed gold (unique doc/page):") - if not missed_unique: - print(" (none)") - else: - for doc_page, page in missed_unique: - print(f" {doc_page} page {page}") - print(f"\nTotal missed: {len(missed_unique)} / {len(_gold)}") - - print("\nRecall metrics (matching nemo_retriever.recall.core):") - for k, v in metrics.items(): - print(f" {k}: {v:.4f}") - print_pages_per_second(processed_pages, ingest_elapsed_s) + ) + render_inprocess_run_report(report, include_ingest_summary=False) if __name__ == "__main__": diff --git a/nemo_retriever/src/nemo_retriever/harness/config.py b/nemo_retriever/src/nemo_retriever/harness/config.py index 8370e9423..f4ce52144 100644 --- a/nemo_retriever/src/nemo_retriever/harness/config.py +++ b/nemo_retriever/src/nemo_retriever/harness/config.py @@ -15,6 +15,7 @@ REPO_ROOT = NEMO_RETRIEVER_ROOT.parent DEFAULT_TEST_CONFIG_PATH = NEMO_RETRIEVER_ROOT / "harness" / "test_configs.yaml" DEFAULT_NIGHTLY_CONFIG_PATH = NEMO_RETRIEVER_ROOT / "harness" / "nightly_config.yaml" +VALID_RUN_MODES = {"batch", "inprocess", "fused"} VALID_EVALUATION_MODES = {"recall", "beir"} VALID_RECALL_ADAPTERS = {"none", "page_plus_one", "financebench_json"} VALID_BEIR_LOADERS = {"vidore_hf"} @@ -29,7 +30,7 @@ "recall_5", ] -TUNING_FIELDS = { +BATCH_TUNING_FIELDS = { "pdf_extract_workers", "pdf_extract_num_cpus", "pdf_extract_batch_size", @@ -47,6 +48,29 @@ "gpu_ocr", "gpu_embed", } +INPROCESS_TUNING_FIELDS = { + "max_workers", + "gpu_devices", +} +FUSED_TUNING_FIELDS = { + "pdf_extract_workers", + "pdf_extract_num_cpus", + "pdf_extract_batch_size", + "pdf_split_batch_size", + "fused_workers", + "fused_batch_size", + "fused_cpus_per_actor", + "fused_gpus_per_actor", +} +TUNING_FIELDS = BATCH_TUNING_FIELDS + + +def tuning_fields_for_run_mode(run_mode: str) -> set[str]: + if run_mode == "inprocess": + return set(INPROCESS_TUNING_FIELDS) + if run_mode == "fused": + return set(FUSED_TUNING_FIELDS) + return set(BATCH_TUNING_FIELDS) @dataclass @@ -54,6 +78,7 @@ class HarnessConfig: dataset_dir: str dataset_label: str preset: str + run_mode: str = "batch" query_csv: str | None = None input_type: str = "pdf" @@ -72,6 +97,8 @@ class HarnessConfig: ray_address: str | None = None lancedb_uri: str = "lancedb" hybrid: bool = False + max_workers: int = 16 + gpu_devices: str | None = None embed_model_name: str = "nvidia/llama-nemotron-embed-1b-v2" embed_modality: str = "text" embed_granularity: str = "element" @@ -95,6 +122,10 @@ class HarnessConfig: gpu_page_elements: float = 0.1 gpu_ocr: float = 0.1 gpu_embed: float = 0.25 + fused_workers: int = 1 + fused_batch_size: int = 64 + fused_cpus_per_actor: float = 1.0 + fused_gpus_per_actor: float = 1.0 def validate(self) -> list[str]: errors: list[str] = [] @@ -106,14 +137,20 @@ def validate(self) -> list[str]: if self.query_csv is not None and not Path(self.query_csv).exists(): errors.append(f"query_csv does not exist: {self.query_csv}") + if self.run_mode not in VALID_RUN_MODES: + errors.append(f"run_mode must be one of {sorted(VALID_RUN_MODES)}") + if self.evaluation_mode not in VALID_EVALUATION_MODES: errors.append(f"evaluation_mode must be one of {sorted(VALID_EVALUATION_MODES)}") if self.evaluation_mode == "recall" and self.recall_required and not self.query_csv: errors.append("recall_required=true requires query_csv") - if self.input_type not in {"pdf", "txt", "html", "doc"}: - errors.append(f"input_type must be one of pdf/txt/html/doc, got '{self.input_type}'") + if self.input_type not in {"pdf", "txt", "html", "doc", "image"}: + errors.append(f"input_type must be one of pdf/txt/html/doc/image, got '{self.input_type}'") + + if self.run_mode == "fused" and self.input_type != "pdf": + errors.append("fused run_mode currently supports only input_type=pdf") if self.evaluation_mode == "recall": if self.recall_match_mode not in {"pdf_page", "pdf_only"}: @@ -148,12 +185,16 @@ def validate(self) -> list[str]: if self.embed_granularity not in VALID_EMBED_GRANULARITIES: errors.append(f"embed_granularity must be one of {sorted(VALID_EMBED_GRANULARITIES)}") - for name in TUNING_FIELDS: + for name in tuning_fields_for_run_mode(self.run_mode): val = getattr(self, name) + if name == "gpu_devices": + continue if name.startswith("gpu_") and float(val) < 0.0: errors.append(f"{name} must be >= 0.0") elif name.endswith("_workers") and int(val) < 1: errors.append(f"{name} must be >= 1") + elif name in {"max_workers", "fused_batch_size"} and int(val) < 1: + errors.append(f"{name} must be >= 1") return errors @@ -246,6 +287,7 @@ def _apply_env_overrides(config_dict: dict[str, Any]) -> None: "HARNESS_DATASET": ("dataset", str), "HARNESS_DATASET_DIR": ("dataset_dir", str), "HARNESS_PRESET": ("preset", str), + "HARNESS_RUN_MODE": ("run_mode", str), "HARNESS_QUERY_CSV": ("query_csv", str), "HARNESS_INPUT_TYPE": ("input_type", str), "HARNESS_RECALL_REQUIRED": ("recall_required", _parse_bool), @@ -261,15 +303,21 @@ def _apply_env_overrides(config_dict: dict[str, Any]) -> None: "HARNESS_RAY_ADDRESS": ("ray_address", str), "HARNESS_LANCEDB_URI": ("lancedb_uri", str), "HARNESS_HYBRID": ("hybrid", _parse_bool), + "HARNESS_MAX_WORKERS": ("max_workers", _parse_number), + "HARNESS_GPU_DEVICES": ("gpu_devices", str), "HARNESS_EMBED_MODEL_NAME": ("embed_model_name", str), "HARNESS_EMBED_MODALITY": ("embed_modality", str), "HARNESS_EMBED_GRANULARITY": ("embed_granularity", str), "HARNESS_EXTRACT_PAGE_AS_IMAGE": ("extract_page_as_image", _parse_bool), "HARNESS_EXTRACT_INFOGRAPHICS": ("extract_infographics", _parse_bool), "HARNESS_WRITE_DETECTION_FILE": ("write_detection_file", _parse_bool), + "HARNESS_FUSED_WORKERS": ("fused_workers", _parse_number), + "HARNESS_FUSED_BATCH_SIZE": ("fused_batch_size", _parse_number), + "HARNESS_FUSED_CPUS_PER_ACTOR": ("fused_cpus_per_actor", _parse_number), + "HARNESS_FUSED_GPUS_PER_ACTOR": ("fused_gpus_per_actor", _parse_number), } - for key in TUNING_FIELDS: + for key in BATCH_TUNING_FIELDS: env_map[f"HARNESS_{key.upper()}"] = (key, _parse_number) for env_key, (cfg_key, parser) in env_map.items(): diff --git a/nemo_retriever/src/nemo_retriever/harness/parsers.py b/nemo_retriever/src/nemo_retriever/harness/parsers.py deleted file mode 100644 index c7638b313..000000000 --- a/nemo_retriever/src/nemo_retriever/harness/parsers.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import re -from dataclasses import dataclass, field - -DONE_RE = re.compile(r"\[done\]\s+(?P\d+)\s+files,\s+(?P\d+)\s+pages\s+in\s+(?P[0-9.]+)s") -INGEST_ROWS_RE = re.compile( - r"Ingestion complete\.\s+(?P\d+)\s+rows\s+proces+ed\s+in\s+(?P[0-9.]+)\s+seconds\.\s+" - r"(?P[0-9.]+)\s+PPS" -) -PAGES_PER_SEC_RE = re.compile(r"Pages/sec \(ingest only; excludes Ray startup and recall\):\s*(?P[0-9.]+)") -METRIC_RE = re.compile(r"(?P[A-Za-z_]+@\d+):\s*(?P[0-9.]+)\s*$") - - -@dataclass -class StreamMetrics: - files: int | None = None - pages: int | None = None - ingest_secs: float | None = None - pages_per_sec_ingest: float | None = None - rows_processed: int | None = None - rows_per_sec_ingest: float | None = None - recall_metrics: dict[str, float] = field(default_factory=dict) - evaluation_metrics: dict[str, float] = field(default_factory=dict) - _metric_block: str | None = None - - def consume(self, line: str) -> None: - done_match = DONE_RE.search(line) - if done_match: - self.files = int(done_match.group("files")) - self.pages = int(done_match.group("pages")) - self.ingest_secs = float(done_match.group("secs")) - - ingest_rows_match = INGEST_ROWS_RE.search(line) - if ingest_rows_match: - self.rows_processed = int(ingest_rows_match.group("rows")) - self.ingest_secs = float(ingest_rows_match.group("secs")) - self.rows_per_sec_ingest = float(ingest_rows_match.group("pps")) - - pps_match = PAGES_PER_SEC_RE.search(line) - if pps_match: - self.pages_per_sec_ingest = float(pps_match.group("val")) - - normalized_line = line.strip().lower() - if "recall metrics" in normalized_line: - self._metric_block = "recall" - return - - if "beir metrics" in normalized_line: - self._metric_block = "beir" - return - - if self._metric_block is not None: - metric_match = METRIC_RE.search(line) - if metric_match: - metric = metric_match.group("metric").lower() - value = float(metric_match.group("val")) - if self._metric_block == "recall": - self.recall_metrics[metric] = value - else: - self.evaluation_metrics[metric] = value - return - - if line.strip() and not line.startswith(" "): - self._metric_block = None - - -def parse_stream_text(stdout_text: str) -> StreamMetrics: - metrics = StreamMetrics() - for line in stdout_text.splitlines(): - metrics.consume(line) - return metrics diff --git a/nemo_retriever/src/nemo_retriever/harness/run.py b/nemo_retriever/src/nemo_retriever/harness/run.py index 76e11f8d9..eb3c7e0bf 100644 --- a/nemo_retriever/src/nemo_retriever/harness/run.py +++ b/nemo_retriever/src/nemo_retriever/harness/run.py @@ -4,13 +4,7 @@ from __future__ import annotations -import errno from importlib import metadata -import json -import os -import pty -import re -import select import shlex import socket import subprocess @@ -20,6 +14,16 @@ import typer +from nemo_retriever.application.modes.reports import ( + RunArtifactConfig, + RunEvaluationConfig, + flatten_report_metrics, + normalize_metric_key, +) +from nemo_retriever.application.modes.run_batch import BatchPipelineConfig, run_batch_pipeline +from nemo_retriever.application.modes.run_fused import FusedPipelineConfig, run_fused_pipeline +from nemo_retriever.application.modes.run_inprocess import InProcessPipelineConfig, run_inprocess_pipeline +from nemo_retriever.application.modes.shared import DEFAULT_LANCEDB_TABLE from nemo_retriever.harness.artifacts import ( create_run_artifact_dir, create_session_dir, @@ -31,16 +35,19 @@ from nemo_retriever.harness.config import ( DEFAULT_NIGHTLY_CONFIG_PATH, HarnessConfig, - TUNING_FIELDS, load_harness_config, load_nightly_config, + tuning_fields_for_run_mode, ) -from nemo_retriever.harness.parsers import StreamMetrics from nemo_retriever.harness.recall_adapters import prepare_recall_query_file +from nemo_retriever.params import EmbedParams +from nemo_retriever.params import ExtractParams +from nemo_retriever.params import IngestExecuteParams +from nemo_retriever.params import IngestorCreateParams +from nemo_retriever.params import TextChunkParams +from nemo_retriever.params import VdbUploadParams from nemo_retriever.utils.input_files import resolve_input_files -ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") - def _collect_gpu_metadata() -> tuple[int | None, str | None]: try: @@ -94,20 +101,27 @@ def _collect_run_metadata() -> dict[str, Any]: def _normalize_tags(tags: list[str] | None) -> list[str]: normalized: list[str] = [] seen: set[str] = set() - for raw in tags or []: tag = str(raw).strip() if not tag or tag in seen: continue seen.add(tag) normalized.append(tag) - return normalized def _normalize_recall_metric_key(key: str) -> str: - metric = str(key).strip().lower() - return metric.replace("@", "_").replace("-", "_") + return normalize_metric_key(key) + + +def _empty_summary_metrics() -> dict[str, Any]: + return { + "pages": None, + "ingest_secs": None, + "pages_per_sec_ingest": None, + "recall_5": None, + "ndcg_10": None, + } def _safe_pdf_page_count(path: Path) -> int | None: @@ -135,7 +149,7 @@ def _resolve_summary_metrics( metrics_payload: dict[str, Any], runtime_summary: dict[str, Any] | None, ) -> dict[str, Any]: - summary_metrics: dict[str, Any] = { + summary_metrics = { "pages": metrics_payload.get("pages"), "ingest_secs": metrics_payload.get("ingest_secs"), "pages_per_sec_ingest": metrics_payload.get("pages_per_sec_ingest"), @@ -144,7 +158,7 @@ def _resolve_summary_metrics( } if summary_metrics["pages"] is None and isinstance(runtime_summary, dict): - runtime_pages = runtime_summary.get("num_pages") + runtime_pages = runtime_summary.get("processed_pages") if runtime_pages is None: runtime_pages = runtime_summary.get("input_pages") if runtime_pages is not None: @@ -187,204 +201,373 @@ def _resolve_lancedb_uri(cfg: HarnessConfig, artifact_dir: Path) -> str: return str(p) -def _build_command(cfg: HarnessConfig, artifact_dir: Path, run_id: str) -> tuple[list[str], Path, Path, Path | None]: +def _parse_gpu_devices(raw: str | None) -> list[str]: + if raw is None: + return [] + return [token.strip() for token in str(raw).split(",") if token.strip()] + + +def _prepare_run_paths( + cfg: HarnessConfig, + artifact_dir: Path, +) -> tuple[Path, Path | None, Path | None]: runtime_dir = artifact_dir / "runtime_metrics" runtime_dir.mkdir(parents=True, exist_ok=True) - if cfg.write_detection_file: - detection_summary_file = artifact_dir / "detection_summary.json" - else: - # Keep detection summary out of top-level artifacts unless explicitly requested. - detection_summary_file = runtime_dir / ".detection_summary.json" + detection_summary_file = ( + artifact_dir / "detection_summary.json" if cfg.write_detection_file else runtime_dir / ".detection_summary.json" + ) effective_query_csv: Path | None = None - - cmd = [ - sys.executable, - "-m", - "nemo_retriever.examples.batch_pipeline", - str(Path(cfg.dataset_dir).resolve()), - "--input-type", - cfg.input_type, - "--evaluation-mode", - cfg.evaluation_mode, - "--pdf-extract-tasks", - str(cfg.pdf_extract_workers), - "--pdf-extract-cpus-per-task", - str(cfg.pdf_extract_num_cpus), - "--pdf-extract-batch-size", - str(cfg.pdf_extract_batch_size), - "--pdf-split-batch-size", - str(cfg.pdf_split_batch_size), - "--page-elements-batch-size", - str(cfg.page_elements_batch_size), - "--page-elements-actors", - str(cfg.page_elements_workers), - "--ocr-actors", - str(cfg.ocr_workers), - "--ocr-batch-size", - str(cfg.ocr_batch_size), - "--embed-actors", - str(cfg.embed_workers), - "--embed-batch-size", - str(cfg.embed_batch_size), - "--page-elements-cpus-per-actor", - str(cfg.page_elements_cpus_per_actor), - "--ocr-cpus-per-actor", - str(cfg.ocr_cpus_per_actor), - "--embed-cpus-per-actor", - str(cfg.embed_cpus_per_actor), - "--page-elements-gpus-per-actor", - str(cfg.gpu_page_elements), - "--ocr-gpus-per-actor", - str(cfg.gpu_ocr), - "--embed-gpus-per-actor", - str(cfg.gpu_embed), - "--embed-model-name", - cfg.embed_model_name, - "--embed-modality", - cfg.embed_modality, - "--embed-granularity", - cfg.embed_granularity, - "--runtime-metrics-dir", - str(runtime_dir), - "--runtime-metrics-prefix", - run_id, - "--detection-summary-file", - str(detection_summary_file), - "--lancedb-uri", - _resolve_lancedb_uri(cfg, artifact_dir), - ] - - if cfg.evaluation_mode == "beir": - cmd += [ - "--beir-loader", - str(cfg.beir_loader), - "--beir-dataset-name", - str(cfg.beir_dataset_name or cfg.dataset_label), - "--beir-split", - cfg.beir_split, - "--beir-doc-id-field", - cfg.beir_doc_id_field, - ] - if cfg.beir_query_language: - cmd += ["--beir-query-language", cfg.beir_query_language] - for k in cfg.beir_ks: - cmd += ["--beir-k", str(int(k))] - else: + if cfg.evaluation_mode != "beir": effective_query_csv = prepare_recall_query_file( query_csv=Path(cfg.query_csv) if cfg.query_csv else None, recall_adapter=cfg.recall_adapter, output_dir=runtime_dir, ) + return runtime_dir, detection_summary_file, effective_query_csv + + +def _build_command( + cfg: HarnessConfig, + artifact_dir: Path, + run_id: str, +) -> tuple[list[str], Path, Path | None, Path | None]: + runtime_dir, detection_summary_file, effective_query_csv = _prepare_run_paths(cfg, artifact_dir) + lancedb_uri = _resolve_lancedb_uri(cfg, artifact_dir) + module_name = { + "batch": "nemo_retriever.examples.batch_pipeline", + "inprocess": "nemo_retriever.examples.inprocess_pipeline", + "fused": "nemo_retriever.examples.fused_pipeline", + }[cfg.run_mode] + cmd = [sys.executable, "-m", module_name, str(Path(cfg.dataset_dir).resolve())] + + if cfg.input_type != "pdf": + cmd += ["--input-type", cfg.input_type] + + if cfg.run_mode == "batch": cmd += [ - "--query-csv", - str(effective_query_csv), - "--recall-match-mode", - cfg.recall_match_mode, - "--no-recall-details", + "--evaluation-mode", + cfg.evaluation_mode, + "--pdf-extract-tasks", + str(cfg.pdf_extract_workers), + "--pdf-extract-cpus-per-task", + str(cfg.pdf_extract_num_cpus), + "--pdf-extract-batch-size", + str(cfg.pdf_extract_batch_size), + "--pdf-split-batch-size", + str(cfg.pdf_split_batch_size), + "--page-elements-actors", + str(cfg.page_elements_workers), + "--page-elements-batch-size", + str(cfg.page_elements_batch_size), + "--ocr-actors", + str(cfg.ocr_workers), + "--ocr-batch-size", + str(cfg.ocr_batch_size), + "--embed-actors", + str(cfg.embed_workers), + "--embed-batch-size", + str(cfg.embed_batch_size), + "--page-elements-cpus-per-actor", + str(cfg.page_elements_cpus_per_actor), + "--ocr-cpus-per-actor", + str(cfg.ocr_cpus_per_actor), + "--embed-cpus-per-actor", + str(cfg.embed_cpus_per_actor), + "--page-elements-gpus-per-actor", + str(cfg.gpu_page_elements), + "--ocr-gpus-per-actor", + str(cfg.gpu_ocr), + "--embed-gpus-per-actor", + str(cfg.gpu_embed), + "--embed-model-name", + cfg.embed_model_name, + "--embed-modality", + cfg.embed_modality, + "--embed-granularity", + cfg.embed_granularity, + "--runtime-metrics-dir", + str(runtime_dir), + "--runtime-metrics-prefix", + run_id, + "--lancedb-uri", + lancedb_uri, ] - - cmd += ["--extract-page-as-image" if cfg.extract_page_as_image else "--no-extract-page-as-image"] - if cfg.extract_infographics: - cmd += ["--extract-infographics"] - if cfg.embed_modality: - cmd += ["--structured-elements-modality", cfg.embed_modality] - if cfg.ray_address: - cmd += ["--ray-address", cfg.ray_address] - if cfg.hybrid: - cmd += ["--hybrid"] + cmd += ["--detection-summary-file", str(detection_summary_file)] + if cfg.evaluation_mode == "beir": + cmd += [ + "--beir-loader", + str(cfg.beir_loader), + "--beir-dataset-name", + str(cfg.beir_dataset_name or cfg.dataset_label), + "--beir-split", + cfg.beir_split, + "--beir-doc-id-field", + cfg.beir_doc_id_field, + ] + if cfg.beir_query_language: + cmd += ["--beir-query-language", cfg.beir_query_language] + for k in cfg.beir_ks: + cmd += ["--beir-k", str(int(k))] + elif effective_query_csv is not None: + cmd += ["--query-csv", str(effective_query_csv), "--recall-match-mode", cfg.recall_match_mode] + cmd += ["--extract-page-as-image" if cfg.extract_page_as_image else "--no-extract-page-as-image"] + if cfg.extract_infographics: + cmd += ["--extract-infographics"] + if cfg.embed_modality: + cmd += ["--structured-elements-modality", cfg.embed_modality] + if cfg.hybrid: + cmd += ["--hybrid"] + if cfg.ray_address: + cmd += ["--ray-address", cfg.ray_address] + elif cfg.run_mode == "inprocess": + cmd += ["--max-workers", str(cfg.max_workers)] + if cfg.gpu_devices: + cmd += ["--gpu-devices", cfg.gpu_devices] + if effective_query_csv is not None: + cmd += ["--query-csv", str(effective_query_csv)] + if cfg.hybrid: + cmd += ["--hybrid"] + cmd += ["--embed-modality", cfg.embed_modality, "--embed-granularity", cfg.embed_granularity] + else: + cmd += [ + "--pdf-extract-tasks", + str(cfg.pdf_extract_workers), + "--pdf-extract-cpus-per-task", + str(cfg.pdf_extract_num_cpus), + "--pdf-extract-batch-size", + str(cfg.pdf_extract_batch_size), + "--pdf-split-batch-size", + str(cfg.pdf_split_batch_size), + "--fused-workers", + str(cfg.fused_workers), + "--fused-batch-size", + str(cfg.fused_batch_size), + "--fused-cpus-per-actor", + str(cfg.fused_cpus_per_actor), + "--fused-gpus-per-actor", + str(cfg.fused_gpus_per_actor), + "--lancedb-uri", + lancedb_uri, + "--runtime-metrics-dir", + str(runtime_dir), + "--runtime-metrics-prefix", + run_id, + ] + if effective_query_csv is not None: + cmd += ["--query-csv", str(effective_query_csv)] + if detection_summary_file is not None: + cmd += ["--detection-summary-file", str(detection_summary_file)] + if cfg.ray_address: + cmd += ["--ray-address", cfg.ray_address] return cmd, runtime_dir, detection_summary_file, effective_query_csv -def _evaluate_run_outcome( - process_rc: int, - evaluation_mode: str, - recall_required: bool, - recall_metrics: dict[str, float], - evaluation_metrics: dict[str, float] | None = None, -) -> tuple[int, str, bool]: - if process_rc != 0: - reason = f"subprocess_exit_{process_rc}" - return process_rc, reason, False - if evaluation_mode == "beir" and not (evaluation_metrics or {}): - return 97, "missing_beir_metrics", False - if evaluation_mode == "recall" and recall_required and not recall_metrics: - return 98, "missing_recall_metrics", False - return 0, "", True +def _common_evaluation_config(cfg: HarnessConfig, effective_query_csv: Path | None) -> RunEvaluationConfig: + return RunEvaluationConfig( + evaluation_mode=cfg.evaluation_mode, + query_csv=str(effective_query_csv) if effective_query_csv is not None else None, + recall_match_mode=cfg.recall_match_mode, + beir_loader=cfg.beir_loader, + beir_dataset_name=cfg.beir_dataset_name or cfg.dataset_label, + beir_split=cfg.beir_split, + beir_query_language=cfg.beir_query_language, + beir_doc_id_field=cfg.beir_doc_id_field, + beir_ks=tuple(cfg.beir_ks), + ) -def _read_json_if_exists(path: Path) -> dict[str, Any] | None: - if not path.exists(): - return None - try: - data = json.loads(path.read_text(encoding="utf-8")) - except (json.JSONDecodeError, OSError): - return None - if not isinstance(data, dict): - return None - return data - +def _common_extract_params(cfg: HarnessConfig, *, batch_tuning: dict[str, Any] | None = None) -> ExtractParams: + kwargs: dict[str, Any] = { + "extract_text": True, + "extract_tables": True, + "extract_charts": True, + "extract_infographics": cfg.extract_infographics, + "extract_page_as_image": cfg.extract_page_as_image, + } + if batch_tuning is not None: + kwargs["batch_tuning"] = batch_tuning + return ExtractParams(**kwargs) -def _consume_parseable_output(metrics: StreamMetrics, parse_buffer: str) -> str: - while "\n" in parse_buffer: - line, parse_buffer = parse_buffer.split("\n", 1) - cleaned = ANSI_ESCAPE_RE.sub("", line) - metrics.consume(cleaned + "\n") - return parse_buffer +def _common_embed_params( + cfg: HarnessConfig, + *, + batch_tuning: dict[str, Any] | None = None, + fused_tuning: dict[str, Any] | None = None, + model_name: str | None = None, +) -> EmbedParams: + kwargs: dict[str, Any] = { + "model_name": model_name or cfg.embed_model_name, + "embed_modality": cfg.embed_modality, + "embed_granularity": cfg.embed_granularity, + } + if batch_tuning is not None: + kwargs["batch_tuning"] = batch_tuning + if fused_tuning is not None: + kwargs["fused_tuning"] = fused_tuning + return EmbedParams(**kwargs) -def _run_subprocess_with_tty(cmd: list[str], metrics: StreamMetrics) -> int: - """ - Run command in a pseudo-terminal so Ray renders rich progress. - We still parse lines from the PTY stream to extract benchmark metrics. - """ - master_fd, slave_fd = pty.openpty() - parse_buffer = "" - try: - proc = subprocess.Popen( - cmd, - stdin=None, - stdout=slave_fd, - stderr=slave_fd, - close_fds=True, +def _build_runner_config( + cfg: HarnessConfig, + artifact_dir: Path, + run_id: str, + runtime_dir: Path, + detection_summary_file: Path | None, + effective_query_csv: Path | None, +): + lancedb_uri = _resolve_lancedb_uri(cfg, artifact_dir) + artifacts = RunArtifactConfig( + lancedb_uri=lancedb_uri, + lancedb_table=DEFAULT_LANCEDB_TABLE, + detection_summary_file=str(detection_summary_file) if detection_summary_file is not None else None, + ) + evaluation = _common_evaluation_config(cfg, effective_query_csv) + runtime_metrics_dir = str(runtime_dir) + text_chunk_params = TextChunkParams() + + if cfg.run_mode == "batch": + batch_tuning = { + "pdf_extract_workers": cfg.pdf_extract_workers, + "pdf_extract_num_cpus": cfg.pdf_extract_num_cpus, + "pdf_extract_batch_size": cfg.pdf_extract_batch_size, + "pdf_split_batch_size": cfg.pdf_split_batch_size, + "page_elements_batch_size": cfg.page_elements_batch_size, + "page_elements_workers": cfg.page_elements_workers, + "detect_workers": cfg.ocr_workers, + "ocr_inference_batch_size": cfg.ocr_batch_size, + "detect_batch_size": cfg.ocr_batch_size, + "embed_workers": cfg.embed_workers, + "embed_batch_size": cfg.embed_batch_size, + "page_elements_cpus_per_actor": cfg.page_elements_cpus_per_actor, + "ocr_cpus_per_actor": cfg.ocr_cpus_per_actor, + "embed_cpus_per_actor": cfg.embed_cpus_per_actor, + "gpu_page_elements": cfg.gpu_page_elements, + "gpu_ocr": cfg.gpu_ocr, + "gpu_embed": cfg.gpu_embed, + } + extract_params = ( + None if cfg.input_type in {"txt", "html"} else _common_extract_params(cfg, batch_tuning=batch_tuning) + ) + return BatchPipelineConfig( + input_path=cfg.dataset_dir, + input_type=cfg.input_type, + create_params=IngestorCreateParams(ray_address=cfg.ray_address, ray_log_to_driver=True, debug=False), + execute_params=IngestExecuteParams( + runtime_metrics_dir=runtime_metrics_dir, + runtime_metrics_prefix=run_id, + ), + extract_params=extract_params, + embed_params=_common_embed_params(cfg, batch_tuning=batch_tuning), + text_chunk_params=text_chunk_params, + enable_text_chunk=False, + evaluation=evaluation, + artifacts=artifacts, + hybrid=cfg.hybrid, ) - finally: - os.close(slave_fd) - try: - while True: - read_fds, _, _ = select.select([master_fd], [], [], 0.1) - if master_fd not in read_fds: - if proc.poll() is not None: - break - continue + if cfg.run_mode == "inprocess": + return InProcessPipelineConfig( + input_path=cfg.dataset_dir, + input_type=cfg.input_type, + execute_params=IngestExecuteParams( + parallel=True, + max_workers=cfg.max_workers, + gpu_devices=_parse_gpu_devices(cfg.gpu_devices), + show_progress=False, + runtime_metrics_dir=runtime_metrics_dir, + runtime_metrics_prefix=run_id, + ), + extract_params=_common_extract_params(cfg), + embed_params=_common_embed_params(cfg), + text_chunk_params=text_chunk_params, + enable_text_chunk=False, + vdb_upload_params=VdbUploadParams( + lancedb={ + "lancedb_uri": lancedb_uri, + "table_name": DEFAULT_LANCEDB_TABLE, + "overwrite": True, + "create_index": True, + "hybrid": cfg.hybrid, + } + ), + evaluation=evaluation, + artifacts=artifacts, + ) - try: - chunk = os.read(master_fd, 4096) - except OSError as exc: - # PTY EOF on Linux often appears as EIO. - if exc.errno == errno.EIO: - break - raise + fused_model_name = cfg.embed_model_name + if fused_model_name == "nvidia/llama-nemotron-embed-1b-v2": + fused_model_name = "nemo_retriever_v1" + return FusedPipelineConfig( + input_path=cfg.dataset_dir, + input_type=cfg.input_type, + create_params=IngestorCreateParams(ray_address=cfg.ray_address, ray_log_to_driver=True), + execute_params=IngestExecuteParams( + runtime_metrics_dir=runtime_metrics_dir, + runtime_metrics_prefix=run_id, + ), + extract_params=_common_extract_params( + cfg, + batch_tuning={ + "pdf_extract_workers": cfg.pdf_extract_workers, + "pdf_extract_num_cpus": cfg.pdf_extract_num_cpus, + "pdf_extract_batch_size": cfg.pdf_extract_batch_size, + "pdf_split_batch_size": cfg.pdf_split_batch_size, + }, + ), + embed_params=_common_embed_params( + cfg, + fused_tuning={ + "fused_workers": cfg.fused_workers, + "fused_batch_size": cfg.fused_batch_size, + "fused_cpus_per_actor": cfg.fused_cpus_per_actor, + "fused_gpus_per_actor": cfg.fused_gpus_per_actor, + }, + model_name=fused_model_name, + ), + vdb_upload_params=VdbUploadParams( + lancedb={ + "lancedb_uri": lancedb_uri, + "table_name": DEFAULT_LANCEDB_TABLE, + "overwrite": True, + "create_index": True, + "hybrid": cfg.hybrid, + } + ), + evaluation=evaluation, + artifacts=artifacts, + ) - if not chunk: - break - text = chunk.decode("utf-8", errors="replace") - sys.stdout.write(text) - sys.stdout.flush() +def _execute_runner(cfg: HarnessConfig, runner_cfg): + if cfg.run_mode == "batch": + return run_batch_pipeline(runner_cfg) + if cfg.run_mode == "inprocess": + return run_inprocess_pipeline(runner_cfg) + return run_fused_pipeline(runner_cfg) - parse_buffer += text.replace("\r", "\n") - parse_buffer = _consume_parseable_output(metrics, parse_buffer) - if parse_buffer: - cleaned_tail = ANSI_ESCAPE_RE.sub("", parse_buffer) - metrics.consume(cleaned_tail) +def _evaluate_run_outcome( + process_rc: int | None = None, + evaluation_mode: str | None = None, + recall_required: bool | None = None, + recall_metrics: dict[str, float] | None = None, + evaluation_metrics: dict[str, float] | None = None, + *, + runner_error: Exception | None = None, +) -> tuple[int, str, bool]: + if runner_error is not None: + return 1, f"runner_exception_{type(runner_error).__name__}", False + if process_rc not in {None, 0}: + return int(process_rc), f"subprocess_exit_{int(process_rc)}", False - return proc.wait() - finally: - os.close(master_fd) + metrics = evaluation_metrics or recall_metrics or {} + if evaluation_mode == "beir" and not metrics: + return 97, "missing_beir_metrics", False + if evaluation_mode == "recall" and recall_required and not metrics: + return 98, "missing_recall_metrics", False + return 0, "", True def _run_single(cfg: HarnessConfig, artifact_dir: Path, run_id: str, tags: list[str] | None = None) -> dict[str, Any]: @@ -395,40 +578,36 @@ def _run_single(cfg: HarnessConfig, artifact_dir: Path, run_id: str, tags: list[ typer.echo(f"\n=== Running {run_id} ===") typer.echo(command_text) - metrics = StreamMetrics() - process_rc = _run_subprocess_with_tty(cmd, metrics) run_metadata = _collect_run_metadata() - runtime_summary_path = runtime_dir / f"{run_id}.runtime.summary.json" - runtime_summary = _read_json_if_exists(runtime_summary_path) - detection_summary = _read_json_if_exists(detection_summary_file) - if not cfg.write_detection_file and detection_summary_file.exists(): - detection_summary_file.unlink() - - recall_metrics_normalized: dict[str, float] = {} - for key, val in metrics.recall_metrics.items(): - recall_metrics_normalized[_normalize_recall_metric_key(key)] = val - evaluation_metrics_normalized: dict[str, float] = {} - for key, val in metrics.evaluation_metrics.items(): - evaluation_metrics_normalized[_normalize_recall_metric_key(key)] = val - + runner_error: Exception | None = None + report = None + try: + runner_cfg = _build_runner_config( + cfg, artifact_dir, run_id, runtime_dir, detection_summary_file, effective_query_csv + ) + report = _execute_runner(cfg, runner_cfg) + except Exception as exc: + runner_error = exc + typer.echo(f"Run failed: {type(exc).__name__}: {exc}", err=True) + + metrics_payload = flatten_report_metrics(report) if report is not None else {} + runtime_summary = report.runtime_summary if report is not None else None + summary_metrics = ( + _resolve_summary_metrics(cfg, metrics_payload, runtime_summary) + if report is not None + else _empty_summary_metrics() + ) + detection_summary = report.detection_summary if report is not None else None + evaluation_metrics = dict(report.evaluation.metrics) if report is not None else {} effective_rc, failure_reason, success = _evaluate_run_outcome( - process_rc=process_rc, + process_rc=0, evaluation_mode=cfg.evaluation_mode, recall_required=bool(cfg.recall_required), - recall_metrics=metrics.recall_metrics, - evaluation_metrics=metrics.evaluation_metrics, + recall_metrics=evaluation_metrics, + evaluation_metrics=evaluation_metrics, + runner_error=runner_error, ) - metrics_payload = { - "files": metrics.files, - "pages": metrics.pages, - "ingest_secs": metrics.ingest_secs, - "pages_per_sec_ingest": metrics.pages_per_sec_ingest, - **recall_metrics_normalized, - **evaluation_metrics_normalized, - } - summary_metrics = _resolve_summary_metrics(cfg, metrics_payload, runtime_summary) - result_payload: dict[str, Any] = { "timestamp": now_timestr(), "latest_commit": last_commit(), @@ -439,6 +618,7 @@ def _run_single(cfg: HarnessConfig, artifact_dir: Path, run_id: str, tags: list[ "dataset_label": cfg.dataset_label, "dataset_dir": cfg.dataset_dir, "preset": cfg.preset, + "run_mode": cfg.run_mode, "query_csv": cfg.query_csv, "effective_query_csv": str(effective_query_csv) if effective_query_csv is not None else None, "input_type": cfg.input_type, @@ -461,18 +641,9 @@ def _run_single(cfg: HarnessConfig, artifact_dir: Path, run_id: str, tags: list[ "extract_infographics": cfg.extract_infographics, "write_detection_file": cfg.write_detection_file, "lancedb_uri": _resolve_lancedb_uri(cfg, artifact_dir), - "tuning": {field: getattr(cfg, field) for field in sorted(TUNING_FIELDS)}, - }, - "metrics": { - "files": metrics.files, - "pages": metrics.pages, - "ingest_secs": metrics.ingest_secs, - "pages_per_sec_ingest": metrics.pages_per_sec_ingest, - "rows_processed": metrics.rows_processed, - "rows_per_sec_ingest": metrics.rows_per_sec_ingest, - **recall_metrics_normalized, - **evaluation_metrics_normalized, + "tuning": {field: getattr(cfg, field) for field in sorted(tuning_fields_for_run_mode(cfg.run_mode))}, }, + "metrics": metrics_payload, "summary_metrics": summary_metrics, "run_metadata": run_metadata, "runtime_summary": runtime_summary, @@ -482,8 +653,16 @@ def _run_single(cfg: HarnessConfig, artifact_dir: Path, run_id: str, tags: list[ "runtime_metrics_dir": str(runtime_dir.resolve()), }, } - if cfg.write_detection_file: - result_payload["artifacts"]["detection_summary_file"] = str(detection_summary_file.resolve()) + if report is not None: + result_payload["run_report"] = report.model_dump(mode="python") + if report.artifacts.report_file: + result_payload["artifacts"]["mode_run_report_file"] = report.artifacts.report_file + if report.artifacts.runtime_summary_file: + result_payload["artifacts"]["runtime_summary_file"] = report.artifacts.runtime_summary_file + if cfg.write_detection_file and report.artifacts.detection_summary_file: + result_payload["artifacts"]["detection_summary_file"] = report.artifacts.detection_summary_file + if runner_error is not None: + result_payload["error"] = {"type": type(runner_error).__name__, "message": str(runner_error)} if tags: result_payload["tags"] = list(tags) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index e00037285..8e73f2209 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -80,6 +80,9 @@ def _runtime_env_vars() -> dict[str, str]: env_vars = { "NEMO_RETRIEVER_HF_CACHE_DIR": resolve_hf_cache_dir(), "LOG_LEVEL": "INFO", + # Allow per-run Ray tuning via the caller environment without baking + # a shared repo default into the batch runtime. + "RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION": os.environ.get("RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION"), } return {key: value for key, value in env_vars.items() if isinstance(value, str)} diff --git a/nemo_retriever/src/nemo_retriever/utils/detection_summary.py b/nemo_retriever/src/nemo_retriever/utils/detection_summary.py index 6faee1e5e..a17e567c9 100644 --- a/nemo_retriever/src/nemo_retriever/utils/detection_summary.py +++ b/nemo_retriever/src/nemo_retriever/utils/detection_summary.py @@ -13,6 +13,7 @@ from __future__ import annotations +import ast from datetime import datetime import json from collections import defaultdict @@ -119,7 +120,12 @@ def iter_lancedb_rows(uri: str, table_name: str): if isinstance(parsed, dict): meta = parsed except Exception: - pass + try: + parsed = ast.literal_eval(raw_metadata) + if isinstance(parsed, dict): + meta = parsed + except Exception: + pass yield (source_id, page_number), meta, {} diff --git a/nemo_retriever/src/nemo_retriever/utils/input_files.py b/nemo_retriever/src/nemo_retriever/utils/input_files.py index 2660c7e0c..c5267426a 100644 --- a/nemo_retriever/src/nemo_retriever/utils/input_files.py +++ b/nemo_retriever/src/nemo_retriever/utils/input_files.py @@ -7,6 +7,7 @@ "txt": ("*.txt",), "html": ("*.html",), "doc": ("*.docx", "*.pptx"), + "image": ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff", "*.tif", "*.svg"), } diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index 2b46ecbb5..f269e3cc3 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -90,6 +90,38 @@ def _safe_str(x: Any) -> str: return "" if x is None else str(x) +def _build_detection_metadata_from_row(row: Dict[str, Any]) -> Dict[str, Any]: + """Extract per-page detection counters into the stored metadata payload.""" + + out: Dict[str, Any] = {} + + pe_num = row.get("page_elements_v3_num_detections") + if pe_num is not None: + try: + out["page_elements_v3_num_detections"] = int(pe_num) + except Exception: + pass + + pe_counts = row.get("page_elements_v3_counts_by_label") + if isinstance(pe_counts, dict): + counts_by_label: Dict[str, int] = {} + for key, value in pe_counts.items(): + if not isinstance(key, str) or value is None: + continue + try: + counts_by_label[str(key)] = int(value) + except Exception: + continue + out["page_elements_v3_counts_by_label"] = counts_by_label + + for ocr_col in ("table", "chart", "infographic"): + entries = row.get(ocr_col) + if isinstance(entries, list): + out[f"ocr_{ocr_col}_detections"] = int(len(entries)) + + return out + + def _extract_source_path_and_id(meta: Dict[str, Any]) -> Tuple[str, str]: """ Extract a stable source path/id from metadata. @@ -138,6 +170,7 @@ def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, An meta = row.get("metadata") if not isinstance(meta, dict): continue + meta = dict(meta) embedding = meta.get("embedding") if embedding is None: @@ -150,6 +183,7 @@ def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, An except Exception: continue meta.pop("embedding", None) # Remove embedding from metadata to save space in LanceDB. + meta.update(_build_detection_metadata_from_row(row)) # path, source_id = _extract_source_path_and_id(meta) path = row.get("path", "") source_id = meta.get("source_path", path) @@ -174,7 +208,7 @@ def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, An "source_id": source_id, "path": path, "text": row.get("text", ""), - "metadata": str(meta), + "metadata": json.dumps(meta, ensure_ascii=False), } ) diff --git a/nemo_retriever/tests/test_batch_ingestor.py b/nemo_retriever/tests/test_batch_ingestor.py index 21b9a7dde..5a27046b9 100644 --- a/nemo_retriever/tests/test_batch_ingestor.py +++ b/nemo_retriever/tests/test_batch_ingestor.py @@ -56,3 +56,34 @@ def test_batch_ingestor_filters_none_runtime_env_vars(monkeypatch) -> None: } assert dummy_ctx.enable_rich_progress_bars is True assert dummy_ctx.use_ray_tqdm is False + + +def test_batch_ingestor_passes_through_object_store_env_var(monkeypatch) -> None: + captured: dict[str, object] = {} + dummy_ctx = SimpleNamespace(enable_rich_progress_bars=False, use_ray_tqdm=True) + + monkeypatch.setenv("RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION", "0.85") + monkeypatch.setattr( + "nemo_retriever.ingest_modes.batch.resolve_hf_cache_dir", + lambda: "/tmp/hf-cache", + ) + monkeypatch.setattr( + "nemo_retriever.ingest_modes.batch.ray.init", + lambda **kwargs: captured.update(kwargs), + ) + monkeypatch.setattr( + "nemo_retriever.ingest_modes.batch.rd.DataContext.get_current", + lambda: dummy_ctx, + ) + monkeypatch.setattr( + "nemo_retriever.ingest_modes.batch.gather_cluster_resources", + lambda _ray: _DummyClusterResources(), + ) + monkeypatch.setattr( + "nemo_retriever.ingest_modes.batch.resolve_requested_plan", + lambda cluster_resources: {"plan": "dummy"}, + ) + + BatchIngestor(documents=[]) + + assert captured["runtime_env"]["env_vars"]["RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION"] == "0.85" diff --git a/nemo_retriever/tests/test_batch_pipeline.py b/nemo_retriever/tests/test_batch_pipeline.py deleted file mode 100644 index 9dacac7d2..000000000 --- a/nemo_retriever/tests/test_batch_pipeline.py +++ /dev/null @@ -1,182 +0,0 @@ -from typer.testing import CliRunner - -import nemo_retriever.examples.batch_pipeline as batch_pipeline -from nemo_retriever.utils.input_files import resolve_input_patterns - -RUNNER = CliRunner() - - -class _FakeDataset: - def materialize(self): - return self - - def take_all(self): - return [] - - def groupby(self, _key): - class _FakeGrouped: - @staticmethod - def count(): - class _FakeCounted: - @staticmethod - def count(): - return 1 - - return _FakeCounted() - - return _FakeGrouped() - - -class _FakeIngestResult: - def get_dataset(self): - return _FakeDataset() - - -class _FakeErrorRows: - def materialize(self): - return self - - def count(self) -> int: - return 0 - - -class _FakeIngestor: - def __init__(self) -> None: - self.extract_params = None - self.embed_params = None - self.file_patterns = None - - def files(self, file_patterns): - self.file_patterns = file_patterns - return self - - def extract(self, params): - self.extract_params = params - return self - - def extract_image_files(self, params): - self.extract_params = params - return self - - def extract_txt(self, params): - return self - - def extract_html(self, params): - return self - - def split(self, params): - return self - - def embed(self, params): - self.embed_params = params - return self - - def ingest(self, params=None): - return _FakeIngestResult() - - def get_error_rows(self, dataset=None): - return _FakeErrorRows() - - -def test_resolve_input_file_patterns_recurses_for_directory_inputs(tmp_path) -> None: - dataset_dir = tmp_path / "earnings_consulting" - dataset_dir.mkdir() - - pdf_patterns = resolve_input_patterns(dataset_dir, "pdf") - txt_patterns = resolve_input_patterns(dataset_dir, "txt") - doc_patterns = resolve_input_patterns(dataset_dir, "doc") - - assert pdf_patterns == [str(dataset_dir / "**" / "*.pdf")] - assert txt_patterns == [str(dataset_dir / "**" / "*.txt")] - assert doc_patterns == [str(dataset_dir / "**" / "*.docx"), str(dataset_dir / "**" / "*.pptx")] - - -def test_batch_pipeline_accepts_multimodal_embed_and_page_image_flags(tmp_path, monkeypatch) -> None: - dataset_dir = tmp_path / "dataset" - dataset_dir.mkdir() - (dataset_dir / "sample.pdf").write_text("placeholder", encoding="utf-8") - missing_query_csv = tmp_path / "missing.csv" - - fake_ingestor = _FakeIngestor() - monkeypatch.setattr(batch_pipeline, "create_ingestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) - - result = RUNNER.invoke( - batch_pipeline.app, - [ - str(dataset_dir), - "--query-csv", - str(missing_query_csv), - "--embed-modality", - "text_image", - "--embed-granularity", - "page", - "--extract-infographics", - "--no-extract-page-as-image", - ], - ) - - assert result.exit_code == 0 - assert isinstance(fake_ingestor.file_patterns, list) - assert fake_ingestor.extract_params.extract_infographics is True - assert fake_ingestor.extract_params.extract_page_as_image is False - assert fake_ingestor.embed_params.embed_modality == "text_image" - assert fake_ingestor.embed_params.embed_granularity == "page" - - -def test_batch_pipeline_routes_beir_mode_to_evaluator(tmp_path, monkeypatch) -> None: - dataset_dir = tmp_path / "dataset" - dataset_dir.mkdir() - (dataset_dir / "sample.pdf").write_text("placeholder", encoding="utf-8") - - fake_ingestor = _FakeIngestor() - monkeypatch.setattr(batch_pipeline, "create_ingestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "print_run_summary", lambda *args, **kwargs: None) - - class _FakeTable: - def count_rows(self) -> int: - return 1 - - class _FakeDb: - def open_table(self, _name): - return _FakeTable() - - class _FakeLanceModule: - @staticmethod - def connect(_uri): - return _FakeDb() - - monkeypatch.setattr(batch_pipeline, "_lancedb", lambda: _FakeLanceModule()) - - captured = {} - - def _fake_evaluate(cfg): - captured["cfg"] = cfg - return type("Dataset", (), {"query_ids": ["1", "2"]})(), [], {}, {"ndcg@10": 0.75, "recall@5": 0.6} - - monkeypatch.setattr(batch_pipeline, "evaluate_lancedb_beir", _fake_evaluate) - - result = RUNNER.invoke( - batch_pipeline.app, - [ - str(dataset_dir), - "--evaluation-mode", - "beir", - "--beir-loader", - "vidore_hf", - "--beir-dataset-name", - "vidore_v3_computer_science", - "--beir-k", - "5", - "--beir-k", - "10", - ], - ) - - assert result.exit_code == 0 - assert captured["cfg"].loader == "vidore_hf" - assert captured["cfg"].dataset_name == "vidore_v3_computer_science" - assert tuple(captured["cfg"].ks) == (5, 10) diff --git a/nemo_retriever/tests/test_harness_config.py b/nemo_retriever/tests/test_harness_config.py index 40f0a1c4e..803464a4f 100644 --- a/nemo_retriever/tests/test_harness_config.py +++ b/nemo_retriever/tests/test_harness_config.py @@ -184,6 +184,65 @@ def test_load_harness_config_supports_recall_adapter_and_match_mode(tmp_path: Pa assert cfg.recall_match_mode == "pdf_page" +def test_load_harness_config_supports_inprocess_run_mode(tmp_path: Path) -> None: + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + query_csv = tmp_path / "query.csv" + query_csv.write_text("query,pdf_page\nq,doc_1\n", encoding="utf-8") + cfg_path = tmp_path / "test_configs.yaml" + cfg_path.write_text( + "\n".join( + [ + "active:", + " dataset: tiny", + " preset: base", + "presets:", + " base: {}", + "datasets:", + " tiny:", + f" path: {dataset_dir}", + f" query_csv: {query_csv}", + " recall_required: false", + " run_mode: inprocess", + " max_workers: 4", + " gpu_devices: 0,1", + ] + ), + encoding="utf-8", + ) + + cfg = load_harness_config(config_file=str(cfg_path)) + assert cfg.run_mode == "inprocess" + assert cfg.max_workers == 4 + assert cfg.gpu_devices == "0,1" + + +def test_load_harness_config_rejects_invalid_run_mode(tmp_path: Path) -> None: + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + cfg_path = tmp_path / "test_configs.yaml" + cfg_path.write_text( + "\n".join( + [ + "active:", + " dataset: tiny", + " preset: base", + "presets:", + " base: {}", + "datasets:", + " tiny:", + f" path: {dataset_dir}", + " recall_required: false", + " run_mode: nope", + ] + ), + encoding="utf-8", + ) + + with pytest.raises(ValueError, match="run_mode must be one of"): + load_harness_config(config_file=str(cfg_path)) + + def test_load_harness_config_supports_multimodal_embedding_options(tmp_path: Path) -> None: dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() diff --git a/nemo_retriever/tests/test_harness_parsers.py b/nemo_retriever/tests/test_harness_parsers.py deleted file mode 100644 index 8f2ea49dc..000000000 --- a/nemo_retriever/tests/test_harness_parsers.py +++ /dev/null @@ -1,71 +0,0 @@ -from nemo_retriever.harness.parsers import StreamMetrics, parse_stream_text - - -def test_parse_stream_text_extracts_done_recall_and_pages_per_second() -> None: - stdout = """ -[done] 20 files, 3181 pages in 122.3s -Recall metrics (matching nemo_retriever.recall.core): - recall@1: 0.6087 - recall@5: 0.9043 - recall@10: 0.9565 -Pages/sec (ingest only; excludes Ray startup and recall): 15.77 -""" - metrics = parse_stream_text(stdout) - assert metrics.files == 20 - assert metrics.pages == 3181 - assert metrics.ingest_secs == 122.3 - assert metrics.pages_per_sec_ingest == 15.77 - assert metrics.recall_metrics == { - "recall@1": 0.6087, - "recall@5": 0.9043, - "recall@10": 0.9565, - } - assert metrics.rows_processed is None - assert metrics.rows_per_sec_ingest is None - - -def test_parse_stream_text_extracts_rows_and_logger_prefixed_recall_lines() -> None: - stdout = """ -2026-03-06 20:12:42,120 INFO nemo_retriever.examples.batch_pipeline: \ -Ingestion complete. 3181 rows procesed in 151.40 seconds. 21.01 PPS -2026-03-06 20:12:42,121 INFO nemo_retriever.examples.batch_pipeline: \ -Recall metrics (matching nemo_retriever.recall.core): -2026-03-06 20:12:42,122 INFO nemo_retriever.examples.batch_pipeline: recall@1: 0.6087 -2026-03-06 20:12:42,123 INFO nemo_retriever.examples.batch_pipeline: recall@5: 0.9043 -2026-03-06 20:12:42,124 INFO nemo_retriever.examples.batch_pipeline: recall@10: 0.9565 -""" - metrics = parse_stream_text(stdout) - assert metrics.files is None - assert metrics.pages is None - assert metrics.ingest_secs == 151.40 - assert metrics.pages_per_sec_ingest is None - assert metrics.rows_processed == 3181 - assert metrics.rows_per_sec_ingest == 21.01 - assert metrics.recall_metrics == { - "recall@1": 0.6087, - "recall@5": 0.9043, - "recall@10": 0.9565, - } - - -def test_stream_metrics_handles_non_recall_lines_after_recall_block() -> None: - metrics = StreamMetrics() - metrics.consume("Recall metrics (matching nemo_retriever.recall.core):\n") - metrics.consume(" recall@5: 0.9043\n") - metrics.consume("Pages processed: 1933\n") - metrics.consume(" recall@10: 0.9565\n") - assert metrics.recall_metrics == {"recall@5": 0.9043} - - -def test_parse_stream_text_extracts_beir_metrics_block() -> None: - stdout = """ -BEIR metrics: - ndcg@10: 0.7421 - recall@5: 0.6234 -""" - metrics = parse_stream_text(stdout) - assert metrics.recall_metrics == {} - assert metrics.evaluation_metrics == { - "ndcg@10": 0.7421, - "recall@5": 0.6234, - } diff --git a/nemo_retriever/tests/test_harness_run.py b/nemo_retriever/tests/test_harness_run.py index 62b227580..15ce0c841 100644 --- a/nemo_retriever/tests/test_harness_run.py +++ b/nemo_retriever/tests/test_harness_run.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from typer.testing import CliRunner @@ -191,55 +190,67 @@ def test_build_command_applies_page_plus_one_adapter(tmp_path: Path) -> None: assert "q,doc_name_1" in csv_contents -def test_normalize_recall_metric_key_removes_duplicate_prefix() -> None: - assert _normalize_recall_metric_key("recall@1") == "recall_1" - assert _normalize_recall_metric_key("recall@10") == "recall_10" - - -def test_run_single_writes_tags_to_results_json(monkeypatch, tmp_path: Path) -> None: +def test_build_command_supports_inprocess_run_mode(tmp_path: Path) -> None: dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() query_csv = tmp_path / "query.csv" query_csv.write_text("query,pdf_page\nq,doc_1\n", encoding="utf-8") - runtime_dir = tmp_path / "runtime_metrics" - runtime_dir.mkdir() cfg = HarnessConfig( dataset_dir=str(dataset_dir), dataset_label="jp20", preset="single_gpu", + run_mode="inprocess", query_csv=str(query_csv), + max_workers=8, + gpu_devices="0,1", ) + cmd, _runtime_dir, _detection_file, effective_query_csv = _build_command(cfg, tmp_path, run_id="r1") - monkeypatch.setattr( - harness_run, - "_build_command", - lambda *_args, **_kwargs: (["python", "-V"], runtime_dir, runtime_dir / ".detection_summary.json", query_csv), - ) + assert "nemo_retriever.examples.inprocess_pipeline" in cmd + assert "--max-workers" in cmd + assert cmd[cmd.index("--max-workers") + 1] == "8" + assert "--gpu-devices" in cmd + assert cmd[cmd.index("--gpu-devices") + 1] == "0,1" + assert "--query-csv" in cmd + assert str(effective_query_csv) in cmd - def _fake_run_subprocess(_cmd: list[str], metrics) -> int: - metrics.files = 20 - metrics.pages = 100 - metrics.ingest_secs = 10.0 - metrics.pages_per_sec_ingest = 10.0 - metrics.recall_metrics = {"recall@1": 0.5, "recall@5": 0.8} - return 0 - monkeypatch.setattr(harness_run, "_run_subprocess_with_tty", _fake_run_subprocess) - monkeypatch.setattr(harness_run, "last_commit", lambda: "abc123") - monkeypatch.setattr(harness_run, "now_timestr", lambda: "20260305_000000_UTC") +def test_build_command_supports_fused_run_mode(tmp_path: Path) -> None: + dataset_dir = tmp_path / "dataset" + dataset_dir.mkdir() + query_csv = tmp_path / "query.csv" + query_csv.write_text("query,pdf_page\nq,doc_1\n", encoding="utf-8") - captured: dict[str, dict] = {} + cfg = HarnessConfig( + dataset_dir=str(dataset_dir), + dataset_label="jp20", + preset="single_gpu", + run_mode="fused", + query_csv=str(query_csv), + fused_workers=2, + fused_batch_size=32, + fused_cpus_per_actor=2.0, + fused_gpus_per_actor=1.0, + ) + cmd, runtime_dir, detection_file, effective_query_csv = _build_command(cfg, tmp_path, run_id="r1") - def _fake_write_json(_path: Path, payload: dict) -> None: - captured["payload"] = payload + assert "nemo_retriever.examples.fused_pipeline" in cmd + assert "--fused-workers" in cmd + assert cmd[cmd.index("--fused-workers") + 1] == "2" + assert "--fused-batch-size" in cmd + assert cmd[cmd.index("--fused-batch-size") + 1] == "32" + assert "--runtime-metrics-dir" in cmd + assert str(runtime_dir) in cmd + assert "--detection-summary-file" in cmd + assert str(detection_file) in cmd + assert "--query-csv" in cmd + assert str(effective_query_csv) in cmd - monkeypatch.setattr(harness_run, "write_json", _fake_write_json) - harness_run._run_single(cfg, tmp_path, run_id="r1", tags=["nightly", "candidate"]) - assert captured["payload"]["tags"] == ["nightly", "candidate"] - assert captured["payload"]["metrics"]["recall_1"] == 0.5 - assert captured["payload"]["metrics"]["recall_5"] == 0.8 +def test_normalize_recall_metric_key_removes_duplicate_prefix() -> None: + assert _normalize_recall_metric_key("recall@1") == "recall_1" + assert _normalize_recall_metric_key("recall@10") == "recall_10" def test_run_entry_session_artifact_dir_uses_run_name(monkeypatch, tmp_path: Path) -> None: @@ -427,199 +438,6 @@ def _raise_package_not_found(_name: str) -> str: } -def test_run_single_writes_results_with_run_metadata(monkeypatch, tmp_path: Path) -> None: - artifact_dir = tmp_path / "run_artifacts" - artifact_dir.mkdir() - dataset_dir = tmp_path / "dataset" - dataset_dir.mkdir() - query_csv = tmp_path / "query.csv" - query_csv.write_text("q,s,p\nx,y,1\n", encoding="utf-8") - - runtime_dir = artifact_dir / "runtime_metrics" - runtime_dir.mkdir() - detection_file = artifact_dir / "detection_summary.json" - detection_file.write_text(json.dumps({"total_detections": 7}), encoding="utf-8") - runtime_summary_file = runtime_dir / "jp20_single.runtime.summary.json" - runtime_summary_file.write_text(json.dumps({"elapsed_secs": 12.5}), encoding="utf-8") - - cfg = HarnessConfig( - dataset_dir=str(dataset_dir), - dataset_label="jp20", - preset="single_gpu", - query_csv=str(query_csv), - write_detection_file=True, - ) - - monkeypatch.setattr( - harness_run, - "_build_command", - lambda _cfg, _artifact_dir, _run_id: ( - ["python", "-m", "nemo_retriever.examples.batch_pipeline", str(dataset_dir)], - runtime_dir, - detection_file, - query_csv, - ), - ) - - def _fake_run_subprocess(_cmd: list[str], metrics) -> int: - metrics.files = None - metrics.pages = None - metrics.ingest_secs = 12.5 - metrics.pages_per_sec_ingest = None - metrics.rows_processed = 3181 - metrics.rows_per_sec_ingest = 254.48 - metrics.recall_metrics = {"recall@5": 0.9} - return 0 - - monkeypatch.setattr(harness_run, "_run_subprocess_with_tty", _fake_run_subprocess) - monkeypatch.setattr(harness_run, "now_timestr", lambda: "20260305_120000_UTC") - monkeypatch.setattr(harness_run, "last_commit", lambda: "abc1234") - monkeypatch.setattr( - harness_run, - "_collect_run_metadata", - lambda: { - "host": "builder-01", - "gpu_count": 2, - "cuda_driver": "550.54.15", - "ray_version": "2.49.0", - "python_version": "3.12.4", - }, - ) - - result = harness_run._run_single(cfg, artifact_dir, run_id="jp20_single") - payload = json.loads((artifact_dir / "results.json").read_text(encoding="utf-8")) - - expected = { - "timestamp": "20260305_120000_UTC", - "latest_commit": "abc1234", - "success": True, - "return_code": 0, - "failure_reason": None, - "test_config": { - "dataset_label": "jp20", - "dataset_dir": str(dataset_dir), - "preset": "single_gpu", - "query_csv": str(query_csv), - "effective_query_csv": str(query_csv), - "input_type": cfg.input_type, - "recall_required": cfg.recall_required, - "recall_match_mode": cfg.recall_match_mode, - "recall_adapter": cfg.recall_adapter, - "evaluation_mode": cfg.evaluation_mode, - "beir_loader": cfg.beir_loader, - "beir_dataset_name": cfg.beir_dataset_name, - "beir_split": cfg.beir_split, - "beir_query_language": cfg.beir_query_language, - "beir_doc_id_field": cfg.beir_doc_id_field, - "beir_ks": list(cfg.beir_ks), - "ray_address": cfg.ray_address, - "hybrid": cfg.hybrid, - "embed_model_name": cfg.embed_model_name, - "embed_modality": cfg.embed_modality, - "embed_granularity": cfg.embed_granularity, - "extract_page_as_image": cfg.extract_page_as_image, - "extract_infographics": cfg.extract_infographics, - "write_detection_file": True, - "lancedb_uri": str((artifact_dir / "lancedb").resolve()), - "tuning": {field: getattr(cfg, field) for field in sorted(harness_run.TUNING_FIELDS)}, - }, - "metrics": { - "files": None, - "pages": None, - "ingest_secs": 12.5, - "pages_per_sec_ingest": None, - "rows_processed": 3181, - "rows_per_sec_ingest": 254.48, - "recall_5": 0.9, - }, - "summary_metrics": { - "pages": None, - "ingest_secs": 12.5, - "pages_per_sec_ingest": None, - "recall_5": 0.9, - "ndcg_10": None, - }, - "run_metadata": { - "host": "builder-01", - "gpu_count": 2, - "cuda_driver": "550.54.15", - "ray_version": "2.49.0", - "python_version": "3.12.4", - }, - "runtime_summary": {"elapsed_secs": 12.5}, - "detection_summary": {"total_detections": 7}, - "artifacts": { - "command_file": str((artifact_dir / "command.txt").resolve()), - "runtime_metrics_dir": str(runtime_dir.resolve()), - "detection_summary_file": str(detection_file.resolve()), - }, - } - - assert result == expected - assert payload == expected - - -def test_run_single_allows_missing_optional_summary_files(monkeypatch, tmp_path: Path) -> None: - artifact_dir = tmp_path / "run_artifacts" - artifact_dir.mkdir() - dataset_dir = tmp_path / "dataset" - dataset_dir.mkdir() - query_csv = tmp_path / "query.csv" - query_csv.write_text("q,s,p\nx,y,1\n", encoding="utf-8") - - runtime_dir = artifact_dir / "runtime_metrics" - runtime_dir.mkdir() - detection_file = runtime_dir / ".detection_summary.json" - - cfg = HarnessConfig( - dataset_dir=str(dataset_dir), - dataset_label="jp20", - preset="single_gpu", - query_csv=str(query_csv), - write_detection_file=False, - recall_required=False, - ) - - monkeypatch.setattr( - harness_run, - "_build_command", - lambda _cfg, _artifact_dir, _run_id: ( - ["python", "-m", "nemo_retriever.examples.batch_pipeline", str(dataset_dir)], - runtime_dir, - detection_file, - query_csv, - ), - ) - - def _fake_run_subprocess(_cmd: list[str], metrics) -> int: - metrics.rows_processed = 42 - metrics.rows_per_sec_ingest = 3.5 - metrics.ingest_secs = 12.0 - return 0 - - monkeypatch.setattr(harness_run, "_run_subprocess_with_tty", _fake_run_subprocess) - monkeypatch.setattr(harness_run, "now_timestr", lambda: "20260306_210000_UTC") - monkeypatch.setattr(harness_run, "last_commit", lambda: "abc1234") - monkeypatch.setattr(harness_run, "_collect_run_metadata", lambda: {"host": "builder-01"}) - - result = harness_run._run_single(cfg, artifact_dir, run_id="jp20_single") - - assert result["success"] is True - assert result["runtime_summary"] is None - assert result["detection_summary"] is None - assert result["metrics"]["rows_processed"] == 42 - assert result["metrics"]["rows_per_sec_ingest"] == 3.5 - assert result["metrics"]["pages"] is None - assert result["summary_metrics"] == { - "pages": None, - "ingest_secs": 12.0, - "pages_per_sec_ingest": None, - "recall_5": None, - "ndcg_10": None, - } - assert "detection_summary_file" not in result["artifacts"] - - def test_resolve_summary_metrics_falls_back_to_dataset_page_count(monkeypatch, tmp_path: Path) -> None: dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() diff --git a/nemo_retriever/tests/test_lancedb_store.py b/nemo_retriever/tests/test_lancedb_store.py new file mode 100644 index 000000000..55298073f --- /dev/null +++ b/nemo_retriever/tests/test_lancedb_store.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json + +from nemo_retriever.vector_store.lancedb_store import _build_lancedb_rows_from_df + + +def test_build_lancedb_rows_from_df_serializes_metadata_as_json_with_detection_counts() -> None: + rows = [ + { + "metadata": { + "embedding": [0.1, 0.2], + "has_text": True, + "source_path": "/docs/sample.pdf", + }, + "path": "/docs/sample.pdf", + "page_number": 3, + "text": "sample text", + "page_elements_v3_num_detections": 5, + "page_elements_v3_counts_by_label": {"text": 3, "table": 2}, + "table": [{}, {}], + "chart": [{}], + "infographic": [], + } + ] + + out = _build_lancedb_rows_from_df(rows) + assert len(out) == 1 + + payload = out[0] + metadata = json.loads(payload["metadata"]) + assert payload["source"] == "/docs/sample.pdf" + assert metadata["has_text"] is True + assert metadata["source_path"] == "/docs/sample.pdf" + assert metadata["page_elements_v3_num_detections"] == 5 + assert metadata["page_elements_v3_counts_by_label"] == {"text": 3, "table": 2} + assert metadata["ocr_table_detections"] == 2 + assert metadata["ocr_chart_detections"] == 1 + assert metadata["ocr_infographic_detections"] == 0