From 3525a2baa3a8dc861df701d38d9a9d851243de1b Mon Sep 17 00:00:00 2001 From: eurekayuan Date: Tue, 2 Jun 2026 11:29:49 -0700 Subject: [PATCH 1/4] Add long-context windowing for detection, replace, and rewrite --- docs/concepts/detection.md | 3 + docs/concepts/long-context.md | 212 ++++++++++ docs/concepts/replace.md | 3 + docs/concepts/rewrite.md | 3 + mkdocs.yml | 1 + src/anonymizer/config/anonymizer_config.py | 38 ++ .../engine/detection/chunked_augmentation.py | 371 +++++++++++++++++ .../engine/detection/chunked_detection.py | 261 ++++++++++++ .../engine/detection/chunked_latent.py | 272 +++++++++++++ .../engine/detection/detection_workflow.py | 91 ++++- .../engine/replace/chunked_replace.py | 377 ++++++++++++++++++ .../engine/replace/llm_replace_workflow.py | 33 +- .../engine/replace/replace_runner.py | 4 + .../engine/rewrite/chunked_rewrite.py | 251 ++++++++++++ .../engine/rewrite/chunked_steps.py | 137 +++++++ .../engine/rewrite/domain_classification.py | 33 +- .../engine/rewrite/qa_generation.py | 149 ++++++- .../engine/rewrite/rewrite_generation.py | 33 +- .../engine/rewrite/rewrite_workflow.py | 6 + .../engine/rewrite/sensitivity_disposition.py | 70 +++- src/anonymizer/engine/windowing.py | 55 +++ src/anonymizer/interface/anonymizer.py | 7 + tests/engine/test_chunked_augmentation.py | 179 +++++++++ tests/engine/test_chunked_detection.py | 195 +++++++++ tests/engine/test_chunked_latent.py | 116 ++++++ tests/engine/test_chunked_replace.py | 145 +++++++ tests/engine/test_chunked_rewrite.py | 82 ++++ tests/engine/test_chunked_steps.py | 208 ++++++++++ tests/engine/test_detection_workflow.py | 5 +- tests/engine/test_domain_classification.py | 9 +- tests/engine/test_qa_generation.py | 13 +- tests/engine/test_rewrite_generation.py | 22 +- tests/engine/test_sensitivity_disposition.py | 8 +- tests/engine/test_windowing.py | 50 +++ 34 files changed, 3357 insertions(+), 85 deletions(-) create mode 100644 docs/concepts/long-context.md create mode 100644 src/anonymizer/engine/detection/chunked_augmentation.py create mode 100644 src/anonymizer/engine/detection/chunked_detection.py create mode 100644 src/anonymizer/engine/detection/chunked_latent.py create mode 100644 src/anonymizer/engine/replace/chunked_replace.py create mode 100644 src/anonymizer/engine/rewrite/chunked_rewrite.py create mode 100644 src/anonymizer/engine/rewrite/chunked_steps.py create mode 100644 src/anonymizer/engine/windowing.py create mode 100644 tests/engine/test_chunked_augmentation.py create mode 100644 tests/engine/test_chunked_detection.py create mode 100644 tests/engine/test_chunked_latent.py create mode 100644 tests/engine/test_chunked_replace.py create mode 100644 tests/engine/test_chunked_rewrite.py create mode 100644 tests/engine/test_chunked_steps.py create mode 100644 tests/engine/test_windowing.py diff --git a/docs/concepts/detection.md b/docs/concepts/detection.md index 83fbd68d..24b1d195 100644 --- a/docs/concepts/detection.md +++ b/docs/concepts/detection.md @@ -52,6 +52,9 @@ config = AnonymizerConfig( ## Chunked validation +!!! info "Part of a broader long-context story" + Chunked validation is one of four strategies Anonymizer uses to keep large records within an LLM call's limits. For augmentation and latent windowing, and how they relate to validation, see [Long-context handling](long-context.md). + When a row yields many entity candidates, validating them in a single LLM call can often exceed the model's context window or the provider's rate limits (tokens-per-minute or requests-per-minute quotas that many hosted models enforce). Anonymizer automatically splits validation for such rows: candidates are grouped in position order into chunks of at most `validation_max_entities_per_call`, and each chunk is validated independently with its own bounded text excerpt (`validation_excerpt_window_chars` before and after the chunk's span). Decisions are merged back into a single per-row set. The chunked path is always on; if a row has fewer candidates than the limit, it runs as a single call and is exactly equivalent to the unchunked behavior. Tuning guidance: diff --git a/docs/concepts/long-context.md b/docs/concepts/long-context.md new file mode 100644 index 00000000..586e489e --- /dev/null +++ b/docs/concepts/long-context.md @@ -0,0 +1,212 @@ + + + +# Long-context handling + +A single record can be larger than what one LLM call can take — either it exceeds the model's context window, or it exceeds the renderer's prompt-size cap, or it just makes a single call slow and rate-limit-prone. Anonymizer never silently truncates such inputs. Instead, every LLM-backed stage splits a large record into pieces, processes the pieces, and reassembles the result. + +Different stages have different bottlenecks, so they split in different ways. This page collects all of them in one place. Each stage has a **fast path**: if a record's rendered prompt already fits under the cap, it runs as a single call that is exactly equivalent to the un-split behavior, and none of the machinery below applies. + +--- + +## The shared primitive: boundary-aware windowing + +The [Substitute](replace.md) map and [Rewrite](rewrite.md) paths tile a document into **sequential, non-overlapping** windows. Rather than cut at an arbitrary character offset, each window backs off to the last newline inside it, so a chunk boundary lands on a natural break instead of mid-line or mid-word. If a window contains no newline at all, it hard-cuts at the size limit so progress is always made. + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart LR + Doc([Long document]):::io --> P["Take up to max_chars
from the current position"] + P --> Q{"Newline inside
the window?"} + Q -->|yes| BO["Back off to the
last newline"] + Q -->|no| HC["Hard cut at max_chars
(always makes progress)"] + BO --> EM["Emit window,
advance to its end"] + HC --> EM + EM --> M{"More text left?"} + M -->|yes| P + M -->|no| Done(["Windows tile the doc:
no gaps, no overlap"]):::io + classDef io fill:#76B900,stroke:#3d6b00,color:#ffffff,font-weight:bold; +``` + +The result is a sequence of abutting windows that reconstruct the document exactly — no character dropped, none duplicated: + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart LR + subgraph Document + direction LR + W1["Window 1
ends on a newline"] --- W2["Window 2
ends on a newline"] --- W3["Window 3
to end of doc"] + end +``` + +Implementation: `next_window_end()` / `iter_boundary_windows()` in `src/anonymizer/engine/windowing.py`. + +A floor of **4,000 characters** (`_MIN_WINDOW_CHARS`) prevents a pathological document from shrinking windows to nothing. + +--- + +## The render cap + +All stages size their windows against one ceiling: the maximum number of characters a single rendered prompt may contain. This defaults to DataDesigner's `MAX_RENDERED_LEN` (**512,000 characters**). A **safety margin** (default **8,000 characters**) is subtracted from the cap when sizing a window, leaving headroom for prompt scaffolding, entity tags, and seed JSON that get added on top of the raw text: + +``` +initial_window = max(4,000, cap − safety_margin) = max(4,000, 512,000 − 8,000) = 504,000 +``` + +--- + +## Four strategies, by stage + +The right way to split depends on what the stage is doing. Stages that *detect* things statelessly can process windows independently and just merge the results; stages that *transform* text must carry state across windows so the output stays consistent. + +| Stage | Mode(s) | Splits by | Windows | Carries state across windows? | +|-------|---------|-----------|---------|-------------------------------| +| [Validation](#1-validation-split-by-entity-count) | both | entity **count** | n/a (splits a list) | no | +| [Augmentation / Latent](#2-augmentation-latent-overlapping-windows) | both / rewrite-only | character **windows** | **overlapping** | no — overlap + dedupe | +| [Substitute map](#3-substitute-map-sequential-windows-rolling-state) | replace | character **windows** | **abutting** | yes — running map + summary | +| [Rewrite generation](#4-rewrite-generation-sequential-windows-continuity) | rewrite | character **windows** | **abutting** | yes — continuity summary | + +The mental model: **stateless detection → overlapping windows, just dedupe the seam; stateful transformation → abutting windows, thread a rolling summary through the seam; validation → split a list, not a window.** + +--- + +### 1. Validation: split by entity count + +The validator's bottleneck is *how many candidate entities* it must judge, not document size — so this stage splits a **list of candidates**, not the text. The full document is never sent; each entity travels with a bounded excerpt of surrounding context. + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart LR + Cand([250 candidate entities]):::io --> S["Split in position order,
≤ 100 per chunk"] + S --> A["Chunk A · e1–e100
each with a ±500-char excerpt"] + S --> B["Chunk B · e101–e200"] + S --> C["Chunk C · e201–e250"] + A --> Pool{{"Validator pool
round-robin + failover"}} + B --> Pool + C --> Pool + Pool --> Merge(["Merge keep / drop / reclassify
into one per-row set"]):::io + classDef io fill:#76B900,stroke:#3d6b00,color:#ffffff,font-weight:bold; +``` + +A chunk carries only its entities plus a bounded excerpt around each — never the whole document. + +- **`validation_max_entities_per_call`** (default `100`) — candidates per chunk. +- **`validation_excerpt_window_chars`** (default `500`) — characters of context included before and after each chunk's spans. + +Both are fields on [`Detect`](detection.md#detect-fields). Chunks are dispatched across a [validator pool](models.md#validator-pools) for load-spreading and failover. See [Chunked validation](detection.md#chunked-validation) for the full treatment, including what happens when a row can't be validated. + +--- + +### 2. Augmentation / Latent: overlapping windows + +[Augmentation](detection.md) (finding entities the NER model missed) and **latent-entity detection** (rewrite mode only) both scan the text for things to flag. Each window is processed independently — the only thing connecting adjacent windows is **overlap**, which guarantees an entity sitting on a boundary lands fully inside at least one window. Results are then deduplicated. + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart TB + subgraph Doc ["Document, tiled into overlapping windows"] + direction TB + W1["Window 1
chars 0 → 504K"] + W2["Window 2
503K → 1008K
starts overlap_chars before W1 ends"] + W3["Window 3
1007K → end"] + end + W1 -. "overlap 1K" .-> W2 + W2 -. "overlap 1K" .-> W3 + W1 --> D["Dedupe by (value, label)
— the overlap duplicate collapses"] + W2 --> D + W3 --> D + D --> Out([Final entity set]):::io + classDef io fill:#76B900,stroke:#3d6b00,color:#ffffff,font-weight:bold; +``` + +Overlap is what makes independent windows safe. An entity that straddles a boundary — say `...the patient Maria Garcia was admitted...`, split by Window 1 — still appears **whole** inside Window 2, because Window 2 starts `overlap_chars` before Window 1 ends. The duplicate it produces is removed by the dedupe step. + +If a particular window is tag-dense and still renders over the cap, **only that window** shrinks (proportionally to the overage) and retries; the others stay full size. + +- **`detection_window_max_render_chars`** (default `512,000`) — the render cap. +- **`detection_window_safety_margin_chars`** (default `8,000`) — headroom subtracted when sizing windows. +- **`detection_window_overlap_chars`** (default `1,000`) — overlap between adjacent windows. + +All three are fields on [`Detect`](detection.md#detect-fields) and apply to both augmentation and latent detection. + +!!! note "Cross-window inference limit" + Because latent detection works one window at a time, a latent fact that is *only* deducible by combining details from distant parts of a very long document can be missed. Overlap mitigates boundary cases, not arbitrarily long-range inference. Implementation: `src/anonymizer/engine/detection/chunked_latent.py`. + +--- + +### 3. Substitute map: sequential windows + rolling state + +[Substitute](replace.md) must map each entity to a *consistent* synthetic value across the whole record — "Alice" → "Nadia" everywhere. Windows therefore can **not** be independent. The map path uses abutting [boundary windows](#the-shared-primitive-boundary-aware-windowing) and threads two pieces of state through the seams: the **accumulated map so far** and a **rolling summary**. + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart LR + K1["Chunk 1
propose for new entities"] --> M1["map: Alice→Nadia"] + M1 -. "map + summary so far" .-> K2["Chunk 2
propose for NEW entities only"] + K2 --> M2["map: + Bob→Tom"] + M2 -. "map + summary so far" .-> K3["Chunk 3
propose for NEW entities only"] + K3 --> M3["map: + Cory→Wes"] + M3 --> Out(["One merged map for the whole record
deduped by (original, label)"]):::io + classDef io fill:#76B900,stroke:#3d6b00,color:#ffffff,font-weight:bold; +``` + +Each chunk proposes replacements **only** for entities not already in the map, and the accumulated map plus a rolling summary travel forward through every seam — so `Alice` maps to `Nadia` everywhere. + +The rolling summary is capped at **`summary_max_chars`** (default **2,000**) and is meant to hold "only facts useful for keeping entity replacements consistent." + +Implementation: `generate_replacement_map_row()` in `src/anonymizer/engine/replace/chunked_replace.py`. + +--- + +### 4. Rewrite generation: sequential windows + continuity + +[Rewrite](rewrite.md) produces new prose, so it must keep the *narrative* coherent across chunk seams — not just entity names. Like Substitute, it uses abutting boundary windows, but instead of merging a map it concatenates rewritten text, and the carried state is a **continuity preamble** built from a rolling narrative summary. + +```mermaid +%%{init: {"flowchart": {"useMaxWidth": false}}}%% +flowchart LR + R1["Chunk 1
rewrite"] --> O1["rewritten 1"] + O1 -. "continuity summary" .-> R2["Chunk 2
preamble + rewrite"] + R2 --> O2["rewritten 2"] + O2 -. "continuity summary" .-> R3["Chunk 3
preamble + rewrite"] + R3 --> O3["rewritten 3"] + O1 --> Stitch["Stitch parts
(join with newlines)"] + O2 --> Stitch + O3 --> Stitch + Stitch --> Out([Final rewritten record]):::io + classDef io fill:#76B900,stroke:#3d6b00,color:#ffffff,font-weight:bold; +``` + +The continuity preamble — built from a rolling narrative summary — is prepended to every chunk after the first, so pseudonyms and narrative state (e.g. `Alice→Nadia`, "morning in NYC") stay consistent across the seams. The rewritten parts are then stitched in order. + +The same render cap (512,000), safety margin (8,000), and `summary_max_chars` (2,000) apply. The replacement map built in [step 3](#3-substitute-map-sequential-windows-rolling-state) is also passed into **every** chunk so "replace"-classified entities stay consistent across the whole rewrite — see [How the replacement map is used in rewrite mode](rewrite.md). + +Implementation: `generate_rewrite_row()` in `src/anonymizer/engine/rewrite/chunked_rewrite.py`. + +--- + +## What's tunable + +| Knob | Default | Affects | Where | +|------|---------|---------|-------| +| `validation_max_entities_per_call` | `100` | Validation | [`Detect`](detection.md#detect-fields) | +| `validation_excerpt_window_chars` | `500` | Validation | [`Detect`](detection.md#detect-fields) | +| `detection_window_max_render_chars` | `512,000` | Augmentation, Latent, Substitute map, Rewrite generation | [`Detect`](detection.md#detect-fields) | +| `detection_window_safety_margin_chars` | `8,000` | Augmentation, Latent, Substitute map, Rewrite generation | [`Detect`](detection.md#detect-fields) | +| `detection_window_overlap_chars` | `1,000` | Augmentation, Latent | [`Detect`](detection.md#detect-fields) | + +Setting `detection_window_max_render_chars` (and `detection_window_safety_margin_chars`) on your `Detect` config resizes **all four** windowed stages — including the Substitute-map and Rewrite-generation paths, which derive their per-call window size from these values. Lowering the cap is the main lever for entity-dense documents: it puts fewer entities (Substitute map) and less text (Rewrite) into each LLM call, which avoids per-request timeouts. `detection_window_overlap_chars` applies only to the overlapping stages (augmentation, latent); the Substitute-map and Rewrite paths use abutting windows. The 2,000-character rolling-summary cap and the 4,000-character window floor remain internal constants. + +--- + +## Source map + +| Concern | File | +|---------|------| +| Boundary-aware windowing primitive | `src/anonymizer/engine/windowing.py` | +| Validation chunking | `src/anonymizer/engine/detection/chunked_validation.py` | +| Augmentation windowing | `src/anonymizer/engine/detection/chunked_augmentation.py` | +| Latent windowing | `src/anonymizer/engine/detection/chunked_latent.py` | +| Substitute map chunking | `src/anonymizer/engine/replace/chunked_replace.py` | +| Rewrite generation chunking | `src/anonymizer/engine/rewrite/chunked_rewrite.py` | +| Window-config defaults | `src/anonymizer/config/anonymizer_config.py` (`Detect`) | diff --git a/docs/concepts/replace.md b/docs/concepts/replace.md index 69c30d8f..24d14c9b 100644 --- a/docs/concepts/replace.md +++ b/docs/concepts/replace.md @@ -22,6 +22,9 @@ Replace mode replaces each [detected entity](detection.md) with an alternative t Replaces entities with LLM-generated synthetic values that are contextually plausible. This is the only replacement strategy that requires an LLM call. +!!! info "Long records" + For records too large to fit in a single LLM call, Substitute builds its replacement map in sequential windows, threading the map and a rolling summary through each so the same entity maps to the same synthetic value throughout. See [Long-context handling](long-context.md#3-substitute-map-sequential-windows-rolling-state). + ```python from anonymizer import AnonymizerConfig, Substitute diff --git a/docs/concepts/rewrite.md b/docs/concepts/rewrite.md index 50a540ce..20ee8565 100644 --- a/docs/concepts/rewrite.md +++ b/docs/concepts/rewrite.md @@ -13,6 +13,9 @@ Instead of replacing individual entities, rewrite mode transforms the entire tex The text is then rewritten to reduce identifiability, applying targeted transformations that disrupt inference (e.g., weakening or removing linking details) rather than simply rewording content. The rewritten output is evaluated for both quality and privacy leakage using adversarial testing. If thresholds are exceeded, the system automatically refines the rewrite. A final judge provides a qualitative assessment of the rewritten record. Any records that failed to meet standards are tagged for human review. +!!! info "Long records" + A record too large to rewrite in a single LLM call is rewritten in sequential windows, with a continuity summary carried across each so the narrative and pseudonyms stay consistent. See [Long-context handling](long-context.md#4-rewrite-generation-sequential-windows-continuity). + --- ## Key concepts diff --git a/mkdocs.yml b/mkdocs.yml index 29d11c01..d2871911 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -158,6 +158,7 @@ nav: - Detect: concepts/detection.md - Replace: concepts/replace.md - Rewrite: concepts/rewrite.md + - Long-Context Handling: concepts/long-context.md - Choosing a Strategy: concepts/choosing-a-strategy.md - Troubleshooting: troubleshooting.md - Tutorials: diff --git a/src/anonymizer/config/anonymizer_config.py b/src/anonymizer/config/anonymizer_config.py index 6852bdbc..1ab4b0bf 100644 --- a/src/anonymizer/config/anonymizer_config.py +++ b/src/anonymizer/config/anonymizer_config.py @@ -20,6 +20,17 @@ logger = logging.getLogger(__name__) +try: + from data_designer.engine.processing.ginja.environment import MAX_RENDERED_LEN as _NDD_MAX_RENDERED_LEN +except Exception: # pragma: no cover - fall back if NDD internals move + _NDD_MAX_RENDERED_LEN = 512_000 + +# Default per-call render cap for the windowed long-context stages. Kept well below +# NDD's hard render cap so each window stays small enough to map/rewrite within a +# single LLM request — large windows on entity-dense documents otherwise time out. +# Clamped so it never exceeds NDD's cap if that is ever lowered. +_DEFAULT_WINDOW_MAX_RENDER_CHARS = min(128 * 1024, _NDD_MAX_RENDERED_LEN) + def is_remote_input_source(value: str) -> bool: """Return True when the input source is an HTTP(S) URL.""" @@ -100,6 +111,33 @@ class Detect(BaseModel): "validator sees per chunk; it is NOT the LLM's context window limit." ), ) + detection_window_max_render_chars: int = Field( + default=_DEFAULT_WINDOW_MAX_RENDER_CHARS, + gt=0, + description=( + "Upper bound on a single rendered prompt (characters) for the windowed " + "augmentation, latent, substitute-map, and rewrite stages. Documents whose " + "rendered prompt would exceed this are processed in windows. Defaults to 128 KiB " + "(131072), kept below NDD's MAX_RENDERED_LEN render cap so each window maps or " + "rewrites within a single LLM request without timing out on long documents." + ), + ) + detection_window_safety_margin_chars: int = Field( + default=8_000, + ge=0, + description=( + "Headroom subtracted from detection_window_max_render_chars to leave room for " + "prompt scaffolding and tags when sizing augmentation/latent windows." + ), + ) + detection_window_overlap_chars: int = Field( + default=1_000, + ge=0, + description=( + "Overlap between adjacent augmentation/latent windows so an entity straddling a " + "window boundary is fully visible in at least one window." + ), + ) @field_validator("entity_labels") @classmethod diff --git a/src/anonymizer/engine/detection/chunked_augmentation.py b/src/anonymizer/engine/detection/chunked_augmentation.py new file mode 100644 index 00000000..43bf61d5 --- /dev/null +++ b/src/anonymizer/engine/detection/chunked_augmentation.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Windowed LLM augmentation for the entity detection pipeline. + +Augmentation finds entities the detector missed by showing the LLM the full +(tagged) document. For very long documents that prompt can blow past model +context / render budgets, so this module tiles the document into overlapping +windows, runs the augmenter per window, and unions the proposed entities. + +Augmented entities carry **no offsets** -- positions are assigned later by +``apply_augmented_entities``, which locates each value in the full text. So the +per-window merge is a simple union/dedupe by ``(value, label)`` and produces the +same final result as a single-pass augmentation. + +Rendering the prompt here (instead of via an ``LLMStructuredColumnConfig``) also +sidesteps NDD's ginja per-render length cap, exactly like chunked validation. + +Public entry point: :func:`make_windowed_augmentation_generator`. The helpers +below are exposed for unit testing. +""" + +from __future__ import annotations + +import functools +import json +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_INITIAL_TAGGED_TEXT, + COL_SEED_ENTITIES_JSON, + COL_TAG_NOTATION, + COL_TEXT, + COL_VALIDATED_SEED_ENTITIES, +) +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation, build_tagged_text +from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema + +logger = logging.getLogger("anonymizer.detection.chunked_augmentation") + +# Floor on window size so a pathologically entity-dense slice still makes +# progress instead of shrinking toward zero. +_MIN_WINDOW_CHARS = 4000 + +# Jinja2 environment used to render the per-window augmentation prompt. Mirrors +# chunked_validation: same template, same placeholders, per-window values. +_PROMPT_ENV = Environment( + loader=BaseLoader(), + autoescape=False, + undefined=StrictUndefined, + keep_trailing_newline=True, +) + + +@functools.lru_cache(maxsize=4) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +class WindowedAugmentationParams(BaseModel): + """Parameters supplied to :func:`augment_row` via DD's ``generator_params``. + + Attributes: + alias: Augmenter model alias (must also be in the decorator's + ``model_aliases`` so DataDesigner materialises the facade). + prompt_template: Jinja2 source for the augmentation prompt (with + ``_initial_tagged_text``, ``_seed_entities_json``, ``_tag_notation`` + placeholders). Typically produced by ``_get_augment_prompt``. + max_render_chars: Upper bound on a single rendered prompt's length; + windows are sized so each render stays under + ``max_render_chars - safety_margin_chars``. + safety_margin_chars: Headroom subtracted from ``max_render_chars``. + overlap_chars: Overlap between adjacent windows so an entity straddling + a boundary is fully visible in at least one window. + system_prompt: Optional system prompt forwarded to each call. + + ``prompt_template``/``system_prompt`` are ``repr=False`` because DD logs this + model and the prompt is multi-kB. + """ + + alias: str = Field(min_length=1) + prompt_template: str = Field(repr=False) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + overlap_chars: int = Field(default=1000, ge=0) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers (no DataDesigner, no LLM). Tested directly. +# --------------------------------------------------------------------------- + + +def parse_validated_seed_spans(raw_payload: object) -> list[EntitySpan]: + """Parse ``COL_VALIDATED_SEED_ENTITIES`` (EntitiesSchema) into ``EntitySpan``s.""" + parsed = EntitiesSchema.from_raw(raw_payload) + return [ + EntitySpan( + entity_id=e.id, + value=e.value, + label=e.label, + start_position=e.start_position, + end_position=e.end_position, + score=e.score, + source=e.source, + ) + for e in parsed.entities + ] + + +def build_window_inputs( + *, + text: str, + all_spans: list[EntitySpan], + start: int, + end: int, + notation: TagNotation, +) -> tuple[str, str]: + """Build the per-window tagged text and seed-entities JSON. + + Only seed spans fully contained in ``[start, end)`` are tagged (with offsets + re-based to the window), so the window's tagged view matches the full-doc + view. The seed JSON lists those same in-window seeds so the augmenter does + not re-propose already-detected entities; positions are window-local but the + augmenter only reads value/label. + """ + window_raw = text[start:end] + in_window = [ + EntitySpan( + entity_id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position - start, + end_position=span.end_position - start, + score=span.score, + source=span.source, + ) + for span in all_spans + if span.start_position >= start and span.end_position <= end + ] + tagged = build_tagged_text(window_raw, in_window, notation=notation) + seed_json = json.dumps([span.as_dict() for span in in_window]) + return tagged, seed_json + + +def render_augment_prompt( + *, + template: str, + tagged_text: str, + seed_entities_json: str, + notation: TagNotation, +) -> str: + """Render the augmentation prompt for a single window via Jinja2.""" + compiled = _compile_template(template) + return compiled.render( + **{ + COL_INITIAL_TAGGED_TEXT: tagged_text, + COL_SEED_ENTITIES_JSON: seed_entities_json, + COL_TAG_NOTATION: notation.value, + } + ) + + +def merge_augmented(results: list[AugmentedEntitiesSchema]) -> AugmentedEntitiesSchema: + """Union per-window augmented entities, deduping by (normalized value, label).""" + seen: set[tuple[str, str]] = set() + merged: list[dict[str, Any]] = [] + for result in results: + for entity in result.entities: + value = entity.value.strip() + if not value: + continue + key = (value.casefold(), entity.label) + if key in seen: + continue + seen.add(key) + merged.append({"value": entity.value, "label": entity.label, "reason": entity.reason}) + return AugmentedEntitiesSchema.model_validate({"entities": merged}) + + +def iter_windows(text_len: int, window: int, overlap: int) -> list[tuple[int, int]]: + """Tile ``[0, text_len)`` into ``[start, end)`` windows of size ``window`` with ``overlap``.""" + if text_len <= 0: + return [] + window = max(window, _MIN_WINDOW_CHARS) + step = max(1, window - overlap) + bounds: list[tuple[int, int]] = [] + pos = 0 + while pos < text_len: + end = min(text_len, pos + window) + bounds.append((pos, end)) + if end >= text_len: + break + pos += step + return bounds + + +# --------------------------------------------------------------------------- +# Dispatch. Testable by passing a fake facade. +# --------------------------------------------------------------------------- + + +def _call_augmenter(*, facade: Any, prompt: str, system_prompt: str | None, purpose: str) -> AugmentedEntitiesSchema: + """Call the augmenter facade with structured output and return the parsed schema.""" + recipe = PydanticResponseRecipe(data_type=AugmentedEntitiesSchema) + final_prompt = recipe.apply_recipe_to_user_prompt(prompt) + final_system = recipe.apply_recipe_to_system_prompt(system_prompt) + output, _messages = facade.generate( + prompt=final_prompt, + parser=recipe.parse, + system_prompt=final_system, + purpose=purpose, + ) + return output + + +def augment_row( + row: dict[str, Any], + params: WindowedAugmentationParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Run (possibly windowed) augmentation for a single row, writing ``COL_AUGMENTED_ENTITIES``. + + Call directly in tests with a fake ``models`` dict; the DataDesigner-decorated + wrapper from :func:`make_windowed_augmentation_generator` just forwards here. + """ + if params.alias not in models: + raise KeyError( + f"Augmenter alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_augmentation_generator was invoked with the same alias " + "passed in WindowedAugmentationParams.alias." + ) + facade = models[params.alias] + + text = str(row.get(COL_TEXT, "")) + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + # ``cap`` is the hard ceiling a rendered prompt may not exceed. ``initial_window`` + # sizes the first raw window below it by ``safety_margin_chars`` so the per-window + # render overhead (scaffolding + tags + seed JSON) usually fits without shrinking. + cap = params.max_render_chars + initial_window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + + # Fast path: the full tagged document fits under the cap, so behave exactly like + # the pre-windowing single-call augmentation. + full_tagged = str(row.get(COL_INITIAL_TAGGED_TEXT, "")) + full_seed_json = str(row.get(COL_SEED_ENTITIES_JSON) or "[]") + full_rendered = render_augment_prompt( + template=params.prompt_template, + tagged_text=full_tagged, + seed_entities_json=full_seed_json, + notation=notation, + ) + if len(full_rendered) <= cap: + logger.debug("augmentation: single-call fast path (rendered=%d chars <= cap=%d)", len(full_rendered), cap) + output = _call_augmenter( + facade=facade, + prompt=full_rendered, + system_prompt=params.system_prompt, + purpose="entity-augmentation", + ) + row[COL_AUGMENTED_ENTITIES] = output.model_dump(mode="json") + return row + + # Windowed path: tile the document, shrinking a window only if its render + # exceeds the cap (e.g. an entity-dense slice with many tags). + all_spans = parse_validated_seed_spans(row.get(COL_VALIDATED_SEED_ENTITIES, {})) + results: list[AugmentedEntitiesSchema] = [] + window = initial_window + pos = 0 + text_len = len(text) + logger.info( + "augmentation: rendered prompt %d chars > cap %d; tiling %d-char document into " + "overlapping windows (initial_window=%d, overlap=%d, min_window=%d)", + len(full_rendered), cap, text_len, initial_window, params.overlap_chars, _MIN_WINDOW_CHARS, + ) + window_index = 0 + while pos < text_len: + end = min(text_len, pos + window) + tagged, seed_json = build_window_inputs( + text=text, all_spans=all_spans, start=pos, end=end, notation=notation + ) + rendered = render_augment_prompt( + template=params.prompt_template, + tagged_text=tagged, + seed_entities_json=seed_json, + notation=notation, + ) + if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: + # Shrink proportionally to the measured overage so the next try lands + # just under the cap (0.95 = small safety margin), instead of halving + # and overshooting. Strictly decreasing -> converges, then the floor stops it. + shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) + logger.debug( + "augmentation window %d @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", + window_index, pos, len(rendered), cap, window, shrunk, + ) + window = shrunk + continue + logger.debug( + "augmentation window %d: chars [%d, %d) size=%d, rendered=%d/%d chars", + window_index, pos, end, end - pos, len(rendered), cap, + ) + result = _call_augmenter( + facade=facade, + prompt=rendered, + system_prompt=params.system_prompt, + purpose=f"entity-augmentation-window-{pos}", + ) + logger.debug("augmentation window %d: augmenter proposed %d entities", window_index, len(result.entities)) + results.append(result) + if end >= text_len: + break + next_pos = max(pos + 1, end - params.overlap_chars) + logger.debug( + "augmentation window %d: advancing pos %d -> %d (overlap back %d chars)", + window_index, pos, next_pos, end - next_pos, + ) + pos = next_pos + window_index += 1 + + merged = merge_augmented(results) + logger.info( + "augmentation: %d window(s) over %d chars -> %d unique entities after dedupe (cap=%d, overlap=%d)", + len(results), text_len, len(merged.entities), cap, params.overlap_chars, + ) + row[COL_AUGMENTED_ENTITIES] = merged.model_dump(mode="json") + return row + + +# --------------------------------------------------------------------------- +# DataDesigner wiring factory. +# --------------------------------------------------------------------------- + + +def make_windowed_augmentation_generator(alias: str) -> Any: + """Build a ``@custom_column_generator``-decorated function bound to ``alias``. + + ``model_aliases`` must be declared statically so DataDesigner materialises the + augmenter facade. ``required_columns`` are exhaustive for DD's DAG ordering: + the generator reads the raw text, the full tagged text + seed JSON (fast + path), the validated seed spans (to rebuild per-window tagged text), and the + tag notation. + """ + if not alias: + raise ValueError("Cannot build windowed augmentation generator: alias is empty.") + + @custom_column_generator( + required_columns=[ + COL_TEXT, + COL_INITIAL_TAGGED_TEXT, + COL_SEED_ENTITIES_JSON, + COL_VALIDATED_SEED_ENTITIES, + COL_TAG_NOTATION, + ], + model_aliases=[alias], + ) + def windowed_augment( + row: dict[str, Any], + generator_params: WindowedAugmentationParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return augment_row(row, generator_params, models) + + return windowed_augment diff --git a/src/anonymizer/engine/detection/chunked_detection.py b/src/anonymizer/engine/detection/chunked_detection.py new file mode 100644 index 00000000..e9c5545c --- /dev/null +++ b/src/anonymizer/engine/detection/chunked_detection.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Windowed GLiNER (seed) detection for long documents. + +The seed detector normally runs on the whole document via a single DataDesigner +``LLMTextColumnConfig`` whose prompt *is* the raw text. DataDesigner renders that +prompt through its ginja engine, which enforces a hard ``MAX_RENDERED_LEN`` cap +(512000 chars) on every template — so a document larger than that fails outright +before any of the later (already windowed) stages run. + +This module tiles the raw text into overlapping windows, runs the detector per +window, rebases each window's character offsets back to the full document, and +merges the results — bypassing the ginja cap exactly like the chunked +augmentation/validation steps. The detector returns char offsets relative to the +submitted text, so rebasing is ``+window_start``. + +Boundary handling (an entity cut by a window edge) +-------------------------------------------------- +Windows overlap by ``overlap_chars`` so any entity straddling a cut is fully +*interior* to at least one adjacent window (this holds as long as the overlap is +larger than the longest entity — true for PII spans at the 1000-char default). +On top of that: + +* Spans that touch an **artificial** window edge (a left edge where + ``window_start > 0``, or a right edge where ``window_end < len(text)``) are + dropped: they are the truncated half of a straddling entity, and the full span + is recovered from the neighbouring window where it sits interior. +* After merging, :func:`resolve_overlaps` keeps the longest span among any that + overlap, which both de-duplicates the identical copies produced in the overlap + region and, as a backstop, discards any partial that still overlaps a full span. + +The output column (``COL_RAW_DETECTED``) is re-emitted in the detector's own JSON +shape (``{"entities": [{"text", "label", "start", "end", "score"}]}``) with global +offsets, so the downstream ``parse_detected_entities`` step is unchanged. + +Public entry point: :func:`make_windowed_detection_generator`. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import COL_RAW_DETECTED, COL_TEXT +from anonymizer.engine.detection.chunked_augmentation import iter_windows +from anonymizer.engine.detection.postprocess import EntitySpan, parse_raw_entities, resolve_overlaps + +logger = logging.getLogger("anonymizer.detection.chunked_detection") + +# Floor on window size so a pathologically small cap still makes progress. +_MIN_WINDOW_CHARS = 4000 + + +class WindowedDetectionParams(BaseModel): + """Parameters supplied to :func:`detect_row` via DD's ``generator_params``. + + Attributes: + alias: Detector model alias (must also be in the decorator's + ``model_aliases`` so DataDesigner materialises the facade). + max_render_chars: Upper bound on the text submitted in one detector call; + windows are sized to ``max_render_chars - safety_margin_chars``. + safety_margin_chars: Headroom subtracted from ``max_render_chars``. + overlap_chars: Overlap between adjacent windows. Must exceed the longest + expected entity so an entity straddling a window boundary is fully + visible (interior) in at least one window. + system_prompt: Optional system prompt; the hosted detector takes none. + """ + + alias: str = Field(min_length=1) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + overlap_chars: int = Field(default=1000, ge=0) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers (no DataDesigner, no LLM). Tested directly. +# --------------------------------------------------------------------------- + + +def rebase_spans(spans: list[EntitySpan], offset: int) -> list[EntitySpan]: + """Shift window-local spans to full-document coordinates by ``offset``.""" + if offset == 0: + return list(spans) + return [ + EntitySpan( + entity_id=s.entity_id, + value=s.value, + label=s.label, + start_position=s.start_position + offset, + end_position=s.end_position + offset, + score=s.score, + source=s.source, + ) + for s in spans + ] + + +def drop_boundary_spans( + spans: list[EntitySpan], *, window_start: int, window_end: int, text_len: int +) -> list[EntitySpan]: + """Drop spans touching an *artificial* window edge (a likely-truncated entity). + + ``spans`` are in full-document coordinates. A span touching the left edge is + dropped only when this window does not start at the document start + (``window_start > 0``); likewise the right edge only when it is not the + document end (``window_end < text_len``). The dropped span's full form is + recovered from the overlapping neighbour window, where it sits interior. + """ + left_artificial = window_start > 0 + right_artificial = window_end < text_len + kept: list[EntitySpan] = [] + for s in spans: + if left_artificial and s.start_position <= window_start: + continue + if right_artificial and s.end_position >= window_end: + continue + kept.append(s) + return kept + + +def spans_to_detector_payload(spans: list[EntitySpan]) -> str: + """Serialize merged spans back into the detector's raw JSON shape.""" + return json.dumps( + { + "entities": [ + { + "text": s.value, + "label": s.label, + "start": s.start_position, + "end": s.end_position, + "score": s.score, + } + for s in spans + ] + } + ) + + +# --------------------------------------------------------------------------- +# Dispatch. Testable by passing a fake facade. +# --------------------------------------------------------------------------- + + +def _call_detector(*, facade: Any, prompt: str, system_prompt: str | None, purpose: str) -> str: + """Call the detector facade with the text as a plain prompt; return raw text. + + Uses ``TextResponseRecipe`` (a pass-through that adds no task instructions), + so the submitted prompt is exactly the window text — identical to the + ``LLMTextColumnConfig(prompt=_jinja(COL_TEXT))`` this replaces. + """ + recipe = TextResponseRecipe() + final_prompt = recipe.apply_recipe_to_user_prompt(prompt) + final_system = recipe.apply_recipe_to_system_prompt(system_prompt) + output, _messages = facade.generate( + prompt=final_prompt, + parser=recipe.parse, + system_prompt=final_system, + purpose=purpose, + ) + return str(output) + + +def detect_row( + row: dict[str, Any], + params: WindowedDetectionParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Run (possibly windowed) seed detection for one row, writing ``COL_RAW_DETECTED``.""" + if params.alias not in models: + raise KeyError( + f"Detector alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_detection_generator was invoked with the same alias " + "passed in WindowedDetectionParams.alias." + ) + facade = models[params.alias] + + text = str(row.get(COL_TEXT, "")) + cap = params.max_render_chars + + # Fast path: the whole document fits in one call. Identical to the + # pre-windowing LLMTextColumnConfig behaviour (raw detector JSON passed through). + if len(text) <= cap: + logger.debug("detection: single-call fast path (text=%d chars <= cap=%d)", len(text), cap) + row[COL_RAW_DETECTED] = _call_detector( + facade=facade, prompt=text, system_prompt=params.system_prompt, purpose="entity-detection" + ) + return row + + # Windowed path: overlapping windows, rebase offsets, drop truncated edge + # spans, then resolve_overlaps to dedupe overlap-region copies. + text_len = len(text) + window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + windows = iter_windows(text_len, window, params.overlap_chars) + logger.info( + "detection: text %d chars > cap %d; tiling into %d overlapping window(s) " + "(window=%d, overlap=%d, min_window=%d)", + text_len, cap, len(windows), window, params.overlap_chars, _MIN_WINDOW_CHARS, + ) + + all_spans: list[EntitySpan] = [] + for window_index, (start, end) in enumerate(windows): + window_text = text[start:end] + raw = _call_detector( + facade=facade, + prompt=window_text, + system_prompt=params.system_prompt, + purpose=f"entity-detection-window-{start}", + ) + # parse_raw_entities validates + resolves overlaps within the window + # (offsets are window-local because the model only saw window_text). + local_spans = parse_raw_entities(raw_response=raw, text=window_text) + global_spans = rebase_spans(local_spans, start) + kept = drop_boundary_spans(global_spans, window_start=start, window_end=end, text_len=text_len) + logger.debug( + "detection window %d: chars [%d, %d) size=%d -> %d span(s), %d after edge-drop", + window_index, start, end, end - start, len(local_spans), len(kept), + ) + all_spans.extend(kept) + + merged = resolve_overlaps(all_spans) + logger.info( + "detection: %d window(s) over %d chars -> %d unique span(s) after merge", + len(windows), text_len, len(merged), + ) + row[COL_RAW_DETECTED] = spans_to_detector_payload(merged) + return row + + +# --------------------------------------------------------------------------- +# DataDesigner wiring factory. +# --------------------------------------------------------------------------- + + +def make_windowed_detection_generator(alias: str) -> Any: + """Build a ``@custom_column_generator``-decorated function bound to ``alias``. + + ``model_aliases`` must be declared statically so DataDesigner materialises the + detector facade. The only required column is the raw text. + """ + if not alias: + raise ValueError("Cannot build windowed detection generator: alias is empty.") + + @custom_column_generator( + required_columns=[COL_TEXT], + model_aliases=[alias], + ) + def windowed_detect( + row: dict[str, Any], + generator_params: WindowedDetectionParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return detect_row(row, generator_params, models) + + return windowed_detect diff --git a/src/anonymizer/engine/detection/chunked_latent.py b/src/anonymizer/engine/detection/chunked_latent.py new file mode 100644 index 00000000..571e01ca --- /dev/null +++ b/src/anonymizer/engine/detection/chunked_latent.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Windowed latent-entity detection (rewrite-mode only) for long documents. + +Latent entities are *inferred*, non-explicit sensitive attributes; they carry no +offsets, so per-window results merge by union (dedupe by ``(label, value)``). +Like windowed augmentation, rendering the prompt here (instead of via an +``LLMStructuredColumnConfig``) bypasses NDD's ginja per-render length cap. + +Caveat: windowing necessarily limits cross-window inference. A latent fact only +deducible by combining distant parts of a very long document may be missed. This +is a pragmatic trade so long documents do not fail outright; for documents that +fit in one window behaviour is unchanged (fast path). + +Public entry point: :func:`make_windowed_latent_generator`. +""" + +from __future__ import annotations + +import functools +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_DETECTED_ENTITIES, + COL_LATENT_ENTITIES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.chunked_augmentation import build_window_inputs, iter_windows +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation +from anonymizer.engine.schemas import EntitiesSchema, LatentEntitiesSchema + +logger = logging.getLogger("anonymizer.detection.chunked_latent") + +# Floor on window size so a pathologically entity-dense slice still progresses. +_MIN_WINDOW_CHARS = 4000 + +_PROMPT_ENV = Environment( + loader=BaseLoader(), + autoescape=False, + undefined=StrictUndefined, + keep_trailing_newline=True, +) + + +@functools.lru_cache(maxsize=4) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +class WindowedLatentParams(BaseModel): + """Parameters supplied to :func:`latent_row` via DD's ``generator_params``. + + Mirrors ``WindowedAugmentationParams`` but for the latent prompt, which reads + ``_tagged_text`` (the finalized tagged document) and ``_tag_notation``. + """ + + alias: str = Field(min_length=1) + prompt_template: str = Field(repr=False) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + overlap_chars: int = Field(default=1000, ge=0) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def parse_detected_spans(raw_payload: object) -> list[EntitySpan]: + """Parse ``COL_DETECTED_ENTITIES`` (EntitiesSchema) into ``EntitySpan``s.""" + parsed = EntitiesSchema.from_raw(raw_payload) + return [ + EntitySpan( + entity_id=e.id, + value=e.value, + label=e.label, + start_position=e.start_position, + end_position=e.end_position, + score=e.score, + source=e.source, + ) + for e in parsed.entities + ] + + +def render_latent_prompt(*, template: str, tagged_text: str, notation: TagNotation) -> str: + """Render the latent prompt for a single window via Jinja2.""" + compiled = _compile_template(template) + return compiled.render(**{COL_TAGGED_TEXT: tagged_text, COL_TAG_NOTATION: notation.value}) + + +def merge_latent(results: list[LatentEntitiesSchema]) -> LatentEntitiesSchema: + """Union per-window latent entities, deduping by ``(label, normalized value)``.""" + seen: set[tuple[str, str]] = set() + merged: list[dict[str, Any]] = [] + for result in results: + for entity in result.latent_entities: + value = entity.value.strip() + if not value: + continue + key = (entity.label.casefold(), value.casefold()) + if key in seen: + continue + seen.add(key) + merged.append(entity.model_dump(mode="json")) + return LatentEntitiesSchema.model_validate({"latent_entities": merged}) + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def _call_latent(*, facade: Any, prompt: str, system_prompt: str | None, purpose: str) -> LatentEntitiesSchema: + recipe = PydanticResponseRecipe(data_type=LatentEntitiesSchema) + final_prompt = recipe.apply_recipe_to_user_prompt(prompt) + final_system = recipe.apply_recipe_to_system_prompt(system_prompt) + output, _messages = facade.generate( + prompt=final_prompt, + parser=recipe.parse, + system_prompt=final_system, + purpose=purpose, + ) + return output + + +def latent_row( + row: dict[str, Any], + params: WindowedLatentParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Run (possibly windowed) latent detection for a single row, writing ``COL_LATENT_ENTITIES``.""" + if params.alias not in models: + raise KeyError( + f"Latent alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_latent_generator was invoked with the same alias passed " + "in WindowedLatentParams.alias." + ) + facade = models[params.alias] + + text = str(row.get(COL_TEXT, "")) + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + # ``cap`` is the hard ceiling a rendered prompt may not exceed. ``initial_window`` + # sizes the first raw window below it by ``safety_margin_chars`` so per-window + # render overhead (scaffolding + tags) usually fits without shrinking. + cap = params.max_render_chars + initial_window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + + # Fast path: the full finalized tagged document fits under the cap. + full_tagged = str(row.get(COL_TAGGED_TEXT, "")) + full_rendered = render_latent_prompt(template=params.prompt_template, tagged_text=full_tagged, notation=notation) + if len(full_rendered) <= cap: + logger.debug("latent: single-call fast path (rendered=%d chars <= cap=%d)", len(full_rendered), cap) + output = _call_latent( + facade=facade, + prompt=full_rendered, + system_prompt=params.system_prompt, + purpose="latent-detection", + ) + row[COL_LATENT_ENTITIES] = output.model_dump(mode="json") + return row + + # Windowed path. + spans = parse_detected_spans(row.get(COL_DETECTED_ENTITIES, {})) + results: list[LatentEntitiesSchema] = [] + window = initial_window + pos = 0 + text_len = len(text) + logger.info( + "latent: rendered prompt %d chars > cap %d; tiling %d-char document into " + "overlapping windows (initial_window=%d, overlap=%d, min_window=%d)", + len(full_rendered), cap, text_len, initial_window, params.overlap_chars, _MIN_WINDOW_CHARS, + ) + window_index = 0 + while pos < text_len: + end = min(text_len, pos + window) + # Reuse augmentation's span-rebasing + tagging; latent ignores the seed JSON. + tagged, _seed_json = build_window_inputs( + text=text, all_spans=spans, start=pos, end=end, notation=notation + ) + rendered = render_latent_prompt(template=params.prompt_template, tagged_text=tagged, notation=notation) + if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: + # Shrink proportionally to the measured overage so the next try lands + # just under the cap (0.95 = small safety margin), instead of halving + # and overshooting. Strictly decreasing -> converges, then the floor stops it. + shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) + logger.debug( + "latent window %d @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", + window_index, pos, len(rendered), cap, window, shrunk, + ) + window = shrunk + continue + logger.debug( + "latent window %d: chars [%d, %d) size=%d, rendered=%d/%d chars", + window_index, pos, end, end - pos, len(rendered), cap, + ) + result = _call_latent( + facade=facade, + prompt=rendered, + system_prompt=params.system_prompt, + purpose=f"latent-detection-window-{pos}", + ) + logger.debug("latent window %d: detector proposed %d latent entities", window_index, len(result.latent_entities)) + results.append(result) + if end >= text_len: + break + next_pos = max(pos + 1, end - params.overlap_chars) + logger.debug( + "latent window %d: advancing pos %d -> %d (overlap back %d chars)", + window_index, pos, next_pos, end - next_pos, + ) + pos = next_pos + window_index += 1 + + merged = merge_latent(results) + logger.info( + "latent: %d window(s) over %d chars -> %d unique latent entities after dedupe (cap=%d, overlap=%d)", + len(results), text_len, len(merged.latent_entities), cap, params.overlap_chars, + ) + row[COL_LATENT_ENTITIES] = merged.model_dump(mode="json") + return row + + +# --------------------------------------------------------------------------- +# DataDesigner wiring factory. +# --------------------------------------------------------------------------- + + +def make_windowed_latent_generator(alias: str) -> Any: + """Build a ``@custom_column_generator``-decorated function bound to ``alias``.""" + if not alias: + raise ValueError("Cannot build windowed latent generator: alias is empty.") + + @custom_column_generator( + required_columns=[ + COL_TEXT, + COL_TAGGED_TEXT, + COL_DETECTED_ENTITIES, + COL_TAG_NOTATION, + ], + model_aliases=[alias], + ) + def windowed_latent( + row: dict[str, Any], + generator_params: WindowedLatentParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return latent_row(row, generator_params, models) + + return windowed_latent + + +# Re-exported for symmetry/tests; latent windowing reuses the augmentation tiler. +__all__ = [ + "WindowedLatentParams", + "iter_windows", + "latent_row", + "make_windowed_latent_generator", + "merge_latent", + "parse_detected_spans", + "render_latent_prompt", +] diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index 87eb644b..7c9a041c 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -8,7 +8,7 @@ from dataclasses import dataclass import pandas as pd -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.models import ModelConfig from anonymizer.config.anonymizer_config import Detect as AnonymizerDetectConfig @@ -29,7 +29,6 @@ COL_SEED_VALIDATION_CANDIDATES, COL_TAG_NOTATION, COL_TAGGED_TEXT, - COL_TEXT, COL_VALIDATED_ENTITIES, COL_VALIDATION_DECISIONS, COL_VALIDATION_SKELETON, @@ -37,6 +36,18 @@ ENTITY_LABEL_EXAMPLES, _jinja, ) +from anonymizer.engine.detection.chunked_augmentation import ( + WindowedAugmentationParams, + make_windowed_augmentation_generator, +) +from anonymizer.engine.detection.chunked_detection import ( + WindowedDetectionParams, + make_windowed_detection_generator, +) +from anonymizer.engine.detection.chunked_latent import ( + WindowedLatentParams, + make_windowed_latent_generator, +) from anonymizer.engine.detection.chunked_validation import ( ChunkedValidationParams, make_chunked_validation_generator, @@ -54,23 +65,31 @@ from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases from anonymizer.engine.prompt_utils import substitute_placeholders from anonymizer.engine.schemas import ( - AugmentedEntitiesSchema, EntitiesByValueSchema, EntitiesSchema, - LatentEntitiesSchema, ) logger = logging.getLogger("anonymizer.detection") -# Defaults for the two chunked-validation knobs. Sourced from the Detect config -# so there is a single source of truth; the workflow method defaults exist so -# internal tests and ad-hoc callers do not have to wire plumbing by hand. +# Defaults for the chunked-validation and windowed augmentation/latent knobs. +# Sourced from the Detect config so there is a single source of truth; the +# workflow method defaults exist so internal tests and ad-hoc callers do not +# have to wire plumbing by hand. _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL: int = AnonymizerDetectConfig.model_fields[ "validation_max_entities_per_call" ].default _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS: int = AnonymizerDetectConfig.model_fields[ "validation_excerpt_window_chars" ].default +_DEFAULT_WINDOW_MAX_RENDER_CHARS: int = AnonymizerDetectConfig.model_fields[ + "detection_window_max_render_chars" +].default +_DEFAULT_WINDOW_SAFETY_MARGIN_CHARS: int = AnonymizerDetectConfig.model_fields[ + "detection_window_safety_margin_chars" +].default +_DEFAULT_WINDOW_OVERLAP_CHARS: int = AnonymizerDetectConfig.model_fields[ + "detection_window_overlap_chars" +].default @dataclass(frozen=True) @@ -94,6 +113,9 @@ def detect_and_validate_entities( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + detection_window_max_render_chars: int = _DEFAULT_WINDOW_MAX_RENDER_CHARS, + detection_window_safety_margin_chars: int = _DEFAULT_WINDOW_SAFETY_MARGIN_CHARS, + detection_window_overlap_chars: int = _DEFAULT_WINDOW_OVERLAP_CHARS, entity_labels: list[str] | None = None, data_summary: str | None = None, preview_num_records: int | None = None, @@ -150,10 +172,15 @@ def detect_and_validate_entities( dataframe, model_configs=workflow_model_configs, columns=[ - LLMTextColumnConfig( + CustomColumnConfig( name=COL_RAW_DETECTED, - prompt=_jinja(COL_TEXT), - model_alias=detection_alias, + generator_function=make_windowed_detection_generator(detection_alias), + generator_params=WindowedDetectionParams( + alias=detection_alias, + max_render_chars=detection_window_max_render_chars, + safety_margin_chars=detection_window_safety_margin_chars, + overlap_chars=detection_window_overlap_chars, + ), ), CustomColumnConfig( name=COL_SEED_ENTITIES, @@ -177,13 +204,18 @@ def detect_and_validate_entities( name=COL_SEED_ENTITIES_JSON, generator_function=apply_validation_to_seed_entities, ), - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_AUGMENTED_ENTITIES, - prompt=_get_augment_prompt( - data_summary=data_summary, labels=labels, strict_labels=entity_labels is not None + generator_function=make_windowed_augmentation_generator(augmenter_alias), + generator_params=WindowedAugmentationParams( + alias=augmenter_alias, + prompt_template=_get_augment_prompt( + data_summary=data_summary, labels=labels, strict_labels=entity_labels is not None + ), + max_render_chars=detection_window_max_render_chars, + safety_margin_chars=detection_window_safety_margin_chars, + overlap_chars=detection_window_overlap_chars, ), - model_alias=augmenter_alias, - output_format=AugmentedEntitiesSchema, ), CustomColumnConfig( name=COL_MERGED_ENTITIES, @@ -210,6 +242,9 @@ def identify_latent_entities( entity_labels: list[str] | None = None, privacy_goal: PrivacyGoal | None, data_summary: str | None = None, + detection_window_max_render_chars: int = _DEFAULT_WINDOW_MAX_RENDER_CHARS, + detection_window_safety_margin_chars: int = _DEFAULT_WINDOW_SAFETY_MARGIN_CHARS, + detection_window_overlap_chars: int = _DEFAULT_WINDOW_OVERLAP_CHARS, preview_num_records: int | None = None, ) -> EntityDetectionResult: """Detect latent/inferred entities that could enable re-identification. @@ -229,14 +264,19 @@ def identify_latent_entities( dataframe, model_configs=workflow_model_configs, columns=[ - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_LATENT_ENTITIES, - prompt=_get_latent_prompt( - data_summary=data_summary, - privacy_goal=privacy_goal, + generator_function=make_windowed_latent_generator(latent_alias), + generator_params=WindowedLatentParams( + alias=latent_alias, + prompt_template=_get_latent_prompt( + data_summary=data_summary, + privacy_goal=privacy_goal, + ), + max_render_chars=detection_window_max_render_chars, + safety_margin_chars=detection_window_safety_margin_chars, + overlap_chars=detection_window_overlap_chars, ), - model_alias=latent_alias, - output_format=LatentEntitiesSchema, ) ], workflow_name="latent-entity-detection", @@ -253,6 +293,9 @@ def run( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + detection_window_max_render_chars: int = _DEFAULT_WINDOW_MAX_RENDER_CHARS, + detection_window_safety_margin_chars: int = _DEFAULT_WINDOW_SAFETY_MARGIN_CHARS, + detection_window_overlap_chars: int = _DEFAULT_WINDOW_OVERLAP_CHARS, entity_labels: list[str] | None = None, privacy_goal: PrivacyGoal | None = None, data_summary: str | None = None, @@ -277,6 +320,9 @@ def run( gliner_detection_threshold=gliner_detection_threshold, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, + detection_window_max_render_chars=detection_window_max_render_chars, + detection_window_safety_margin_chars=detection_window_safety_margin_chars, + detection_window_overlap_chars=detection_window_overlap_chars, entity_labels=entity_labels, data_summary=data_summary, preview_num_records=preview_num_records, @@ -291,6 +337,9 @@ def run( entity_labels=entity_labels, privacy_goal=privacy_goal, data_summary=data_summary, + detection_window_max_render_chars=detection_window_max_render_chars, + detection_window_safety_margin_chars=detection_window_safety_margin_chars, + detection_window_overlap_chars=detection_window_overlap_chars, preview_num_records=preview_num_records, ) final_df = latent_result.dataframe.copy() diff --git a/src/anonymizer/engine/replace/chunked_replace.py b/src/anonymizer/engine/replace/chunked_replace.py new file mode 100644 index 00000000..0b81dfdc --- /dev/null +++ b/src/anonymizer/engine/replace/chunked_replace.py @@ -0,0 +1,377 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Long-context (chunked) replacement-map generation for Substitute. + +Substitute normally builds the per-record replacement map in a single LLM call +that embeds the whole tagged document. For documents whose rendered prompt would +exceed the render cap, this module instead processes the document in +boundary-aligned chunks, carrying context forward so replacements stay consistent +across chunks: + +For chunk *k* the model is given + (1) a rolling LLM **summary** of chunks 1..k-1, + (2) the **already-generated replacement map** (to reuse, never re-map), and + (3) the **new entities detected within chunk k**, +and asked to produce replacements only for the new entities. After the chunk, the +rolling summary is refreshed with chunk *k* via a second LLM call. + +Rendering the prompts here (rather than via an ``LLMStructuredColumnConfig``) +also sidesteps NDD's ginja per-render length cap, like chunked detection. +""" + +from __future__ import annotations + +import functools +import json +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe, TextResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_ENTITIES_FOR_REPLACE, + COL_ENTITY_EXAMPLES, + COL_FINAL_ENTITIES, + COL_REPLACEMENT_MAP, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation, build_tagged_text +from anonymizer.engine.schemas import EntitiesSchema, EntityReplacementMapSchema +from anonymizer.engine.windowing import DEFAULT_DELIMITER, iter_boundary_windows + +logger = logging.getLogger("anonymizer.replace.chunked") + +_MIN_WINDOW_CHARS = 4000 + +# Max characters of free-form text (e.g. a rolling summary) to emit in a single +# debug line, so logs stay readable even when the underlying value is large. +_LOG_CLIP_CHARS = 800 + + +def _clip(text: str, limit: int = _LOG_CLIP_CHARS) -> str: + """Single-line, length-bounded rendering of ``text`` for debug logs.""" + flat = " ".join(text.split()) + return flat if len(flat) <= limit else f"{flat[:limit]}… (+{len(flat) - limit} chars)" + +_PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) + + +@functools.lru_cache(maxsize=8) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +# --- prompts owned by the chunked path (the single-call prompt is passed in) --- + +_CHUNK_MAP_PROMPT = """Generate synthetic replacements for sensitive entities in ONE section of a longer document. +Output ONE replacement per NEW entity listed below. Replacements must: +- prevent re-identification, stay plausible in context, match grammatical role and class label +- NOT be a synonym/near-synonym of the original; shift to a distinct but plausible value +- keep related entities mutually consistent (geographic, personal name/email, org/domain, temporal, contact) +- preserve format/patterns and wildcards (* % ?); never return original unchanged + +Summary of earlier sections (for context and consistency): +{{ summary }} + +Already-generated replacements from earlier sections — REUSE these EXACTLY if the same entity +recurs here, and keep new replacements consistent with them (do NOT restate them in your output): +{{ existing_map }} + +Current section (entities tagged inline): +{{ chunk_tagged_text }} + +NEW entities to replace in this section: +{%- for entity in chunk_entities %} +- "{{ entity.value }}" ({{ entity.labels_str }}) +{%- endfor %} + +Examples: {{ examples }} + +Return replacements ONLY for the NEW entities listed above. +""" + +_SUMMARY_PROMPT = """You maintain a concise running summary of a long document that is being anonymized. +Update the summary to incorporate the NEW section, keeping only facts useful for keeping entity +replacements consistent across sections (who/what/where, relationships between entities, ongoing +context). Be terse; hard limit {{ summary_max_chars }} characters. Return only the updated summary. + +Previous summary: +{{ prev_summary }} + +New section: +{{ chunk_text }} +""" + + +class WindowedReplaceParams(BaseModel): + """Params for chunked Substitute map generation (via DD ``generator_params``).""" + + alias: str = Field(min_length=1) + single_call_prompt_template: str = Field(repr=False) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + summary_max_chars: int = Field(default=2000, gt=0) + delimiter: str = Field(default=DEFAULT_DELIMITER) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def _parse_spans(raw: object) -> list[EntitySpan]: + parsed = EntitiesSchema.from_raw(raw) + return [ + EntitySpan( + entity_id=e.id, + value=e.value, + label=e.label, + start_position=e.start_position, + end_position=e.end_position, + score=e.score, + source=e.source, + ) + for e in parsed.entities + ] + + +def chunk_tagged_text(text: str, spans: list[EntitySpan], start: int, end: int, notation: TagNotation) -> str: + """Build the inline-tagged text for the window ``[start, end)`` (spans re-based).""" + window_raw = text[start:end] + in_window = [ + EntitySpan( + entity_id=s.entity_id, + value=s.value, + label=s.label, + start_position=s.start_position - start, + end_position=s.end_position - start, + score=s.score, + source=s.source, + ) + for s in spans + if s.start_position >= start and s.end_position <= end + ] + return build_tagged_text(window_raw, in_window, notation=notation) + + +def new_chunk_entities( + spans: list[EntitySpan], + start: int, + end: int, + already_mapped: set[tuple[str, str]], +) -> list[dict[str, Any]]: + """Distinct (value, label) entities whose span starts in ``[start, end)`` and aren't mapped yet. + + Returned shape matches the prompt's ``chunk_entities`` loop (value, labels, labels_str). + """ + by_value: dict[str, set[str]] = {} + for span in spans: + if not (start <= span.start_position < end): + continue + if not span.value or not span.label: + continue + if (span.value, span.label) in already_mapped: + continue + by_value.setdefault(span.value, set()).add(span.label) + out: list[dict[str, Any]] = [] + for value in sorted(by_value): + labels = sorted(by_value[value]) + out.append({"value": value, "labels": labels, "labels_str": ", ".join(labels)}) + return out + + +def merge_replacements(existing: list[dict[str, str]], new: EntityReplacementMapSchema) -> list[dict[str, str]]: + """Append new replacements to ``existing``, deduping by (original, label); earlier wins.""" + seen = {(r["original"], r["label"]) for r in existing} + merged = list(existing) + for r in new.replacements: + key = (r.original, r.label) + if key in seen: + continue + seen.add(key) + merged.append({"original": r.original, "label": r.label, "synthetic": r.synthetic}) + return merged + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def _generate_chunk_map( + *, + facade: Any, + chunk_tagged: str, + chunk_entities: list[dict[str, Any]], + existing_map: list[dict[str, str]], + summary: str, + examples: str, + system_prompt: str | None, + purpose: str, +) -> EntityReplacementMapSchema: + recipe = PydanticResponseRecipe(data_type=EntityReplacementMapSchema) + rendered = _compile_template(_CHUNK_MAP_PROMPT).render( + summary=summary or "(none yet)", + existing_map=json.dumps(existing_map) if existing_map else "(none yet)", + chunk_tagged_text=chunk_tagged, + chunk_entities=chunk_entities, + examples=examples or "{}", + ) + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(rendered), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt), + purpose=purpose, + ) + return output + + +def _update_summary( + *, + facade: Any, + prev_summary: str, + chunk_text: str, + summary_max_chars: int, + system_prompt: str | None, + purpose: str, +) -> str: + recipe = TextResponseRecipe() + rendered = _compile_template(_SUMMARY_PROMPT).render( + prev_summary=prev_summary or "(none yet)", + chunk_text=chunk_text, + summary_max_chars=summary_max_chars, + ) + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(rendered), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt), + purpose=purpose, + ) + return str(output)[:summary_max_chars] + + +def generate_replacement_map_row( + row: dict[str, Any], + params: WindowedReplaceParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Build ``COL_REPLACEMENT_MAP`` for one row, chunking long documents with context carry-over.""" + if params.alias not in models: + raise KeyError( + f"Replacement alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_replace_generator was invoked with the same alias." + ) + facade = models[params.alias] + cap = params.max_render_chars + initial_window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + + # Fast path: the single-call prompt (full tagged doc + all entities) fits under the cap. + single_rendered = _compile_template(params.single_call_prompt_template).render(**row) + if len(single_rendered) <= cap: + logger.debug("replace-map: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) + recipe = PydanticResponseRecipe(data_type=EntityReplacementMapSchema) + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(single_rendered), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(params.system_prompt), + purpose="replace-map-generation", + ) + row[COL_REPLACEMENT_MAP] = output.model_dump(mode="json") + return row + + # Chunked path with rolling summary + carried map. + text = str(row.get(COL_TEXT, "")) + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + examples = str(row.get(COL_ENTITY_EXAMPLES) or "{}") + spans = _parse_spans(row.get(COL_FINAL_ENTITIES, {})) + + windows = iter_boundary_windows(text, initial_window, delimiter=params.delimiter) + logger.info( + "replace-map: rendered prompt %d chars > cap %d; chunking %d-char document into %d boundary " + "window(s) (initial_window=%d, delimiter=%r, summary_max=%d)", + len(single_rendered), cap, len(text), len(windows), initial_window, params.delimiter, params.summary_max_chars, + ) + accumulated: list[dict[str, str]] = [] + summary = "" + for i, (start, end) in enumerate(windows): + already = {(r["original"], r["label"]) for r in accumulated} + chunk_entities = new_chunk_entities(spans, start, end, already) + logger.debug( + "replace-map window %d/%d: chars [%d, %d) size=%d, %d new entit(y/ies), %d already mapped", + i + 1, len(windows), start, end, end - start, len(chunk_entities), len(already), + ) + if chunk_entities: + tagged = chunk_tagged_text(text, spans, start, end, notation) + chunk_map = _generate_chunk_map( + facade=facade, + chunk_tagged=tagged, + chunk_entities=chunk_entities, + existing_map=accumulated, + summary=summary, + examples=examples, + system_prompt=params.system_prompt, + purpose=f"replace-map-chunk-{start}", + ) + accumulated = merge_replacements(accumulated, chunk_map) + logger.debug( + "replace-map window %d/%d: chunk produced %d replacement(s); accumulated map now %d entries", + i + 1, len(windows), len(chunk_map.replacements), len(accumulated), + ) + else: + logger.debug("replace-map window %d/%d: no new entities, skipping map call", i + 1, len(windows)) + # Refresh the rolling summary for subsequent chunks (skip after the last chunk). + if i < len(windows) - 1: + summary = _update_summary( + facade=facade, + prev_summary=summary, + chunk_text=text[start:end], + summary_max_chars=params.summary_max_chars, + system_prompt=params.system_prompt, + purpose=f"replace-summary-{start}", + ) + logger.debug( + "replace-map window %d/%d: rolling summary updated -> %d chars: %s", + i + 1, len(windows), len(summary), _clip(summary), + ) + + logger.info( + "replace-map: %d window(s) over %d chars -> %d total replacement(s)", + len(windows), len(text), len(accumulated), + ) + row[COL_REPLACEMENT_MAP] = EntityReplacementMapSchema.model_validate( + {"replacements": accumulated} + ).model_dump(mode="json") + return row + + +def make_windowed_replace_generator(alias: str) -> Any: + """Build a ``@custom_column_generator`` for chunked replacement-map generation bound to ``alias``.""" + if not alias: + raise ValueError("Cannot build windowed replace generator: alias is empty.") + + @custom_column_generator( + required_columns=[ + COL_TEXT, + COL_TAGGED_TEXT, + COL_FINAL_ENTITIES, + COL_TAG_NOTATION, + COL_ENTITY_EXAMPLES, + COL_ENTITIES_FOR_REPLACE, + ], + model_aliases=[alias], + ) + def windowed_replace( + row: dict[str, Any], + generator_params: WindowedReplaceParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return generate_replacement_map_row(row, generator_params, models) + + return windowed_replace diff --git a/src/anonymizer/engine/replace/llm_replace_workflow.py b/src/anonymizer/engine/replace/llm_replace_workflow.py index ccd5cb1d..6f04b7a7 100644 --- a/src/anonymizer/engine/replace/llm_replace_workflow.py +++ b/src/anonymizer/engine/replace/llm_replace_workflow.py @@ -9,9 +9,10 @@ from dataclasses import dataclass import pandas as pd -from data_designer.config.column_configs import LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.models import ModelConfig +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import ReplaceModelSelection from anonymizer.engine.constants import ( COL_ENTITIES_BY_VALUE, @@ -24,11 +25,16 @@ from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.replace.chunked_replace import WindowedReplaceParams, make_windowed_replace_generator from anonymizer.engine.row_partitioning import merge_and_reorder, split_rows from anonymizer.engine.schemas import EntitiesByValueSchema, EntityReplacementMapSchema logger = logging.getLogger("anonymizer.replace.llm_workflow") +# Long-context window defaults shared with detection (single source of truth). +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default + # Workflow-internal scratch columns used only to build the replacement-generator # prompt. Created in `generate_map_only` and dropped before returning — nothing # downstream consumes them, and they carry pyarrow-backed pandas extension @@ -57,8 +63,17 @@ def generate_map_only( instructions: str | None = None, entities_column: str = COL_ENTITIES_BY_VALUE, preview_num_records: int | None = None, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> LlmReplaceResult: replace_alias = resolve_model_alias("replacement_generator", selected_models) + # Long-context window sizing: honor the caller's (Detect-config-derived) + # values, falling back to the shared defaults. Smaller windows map fewer + # entities per LLM call, which avoids timeouts on entity-dense documents. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) working_df = dataframe.copy() @@ -87,14 +102,18 @@ def generate_map_only( entity_rows, model_configs=model_configs, columns=[ - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_REPLACEMENT_MAP, - prompt=_get_replacement_mapping_prompt( - entities_column=COL_ENTITIES_FOR_REPLACE, - instructions=instructions, + generator_function=make_windowed_replace_generator(replace_alias), + generator_params=WindowedReplaceParams( + alias=replace_alias, + single_call_prompt_template=_get_replacement_mapping_prompt( + entities_column=COL_ENTITIES_FOR_REPLACE, + instructions=instructions, + ), + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, ), - model_alias=replace_alias, - output_format=EntityReplacementMapSchema, ) ], workflow_name="replace-map-generation", diff --git a/src/anonymizer/engine/replace/replace_runner.py b/src/anonymizer/engine/replace/replace_runner.py index d6501834..f60ba85a 100644 --- a/src/anonymizer/engine/replace/replace_runner.py +++ b/src/anonymizer/engine/replace/replace_runner.py @@ -67,6 +67,8 @@ def run( model_configs: list[ModelConfig], selected_models: ReplaceModelSelection, preview_num_records: int | None = None, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> ReplacementResult: """Apply the replacement strategy (no LLM judges). @@ -87,6 +89,8 @@ def run( selected_models=selected_models, instructions=replace_method.instructions, preview_num_records=preview_num_records, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, ) local_df = apply_replacement_map(map_result.dataframe) failed_records = list(map_result.failed_records) diff --git a/src/anonymizer/engine/rewrite/chunked_rewrite.py b/src/anonymizer/engine/rewrite/chunked_rewrite.py new file mode 100644 index 00000000..6c893e09 --- /dev/null +++ b/src/anonymizer/engine/rewrite/chunked_rewrite.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Long-context (chunked) rewrite generation. + +The rewrite-generation step normally rewrites the whole tagged document in a +single LLM call. For documents whose rendered prompt would exceed the render +cap, this module rewrites the document in boundary-aligned chunks and stitches +the results, carrying a rolling summary of what has already been rewritten so the +narrative stays coherent across chunks. The (already global, consistency-checked) +replacement map and the protection-disposition block are passed per chunk, +filtered to the entities that occur in that chunk. + +Rendering the prompt here (instead of via an ``LLMStructuredColumnConfig``) also +sidesteps NDD's ginja per-render length cap, like the chunked detection steps. +""" + +from __future__ import annotations + +import functools +import logging +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe, TextResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_FULL_REWRITE, + COL_REPLACEMENT_MAP_FOR_PROMPT, + COL_REWRITE_DISPOSITION_BLOCK, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.postprocess import TagNotation +from anonymizer.engine.schemas import RewriteOutputSchema +from anonymizer.engine.windowing import DEFAULT_DELIMITER, iter_boundary_windows + +logger = logging.getLogger("anonymizer.rewrite.chunked") + +_MIN_WINDOW_CHARS = 4000 + +# Max characters of free-form text (e.g. a rolling summary) to emit in a single +# debug line, so logs stay readable even when the underlying value is large. +_LOG_CLIP_CHARS = 800 + + +def _clip(text: str, limit: int = _LOG_CLIP_CHARS) -> str: + """Single-line, length-bounded rendering of ``text`` for debug logs.""" + flat = " ".join(text.split()) + return flat if len(flat) <= limit else f"{flat[:limit]}… (+{len(flat) - limit} chars)" + +_PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) + + +@functools.lru_cache(maxsize=8) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +# Preamble injected ahead of the per-chunk rewrite prompt to carry continuity. +_CONTINUITY_PREAMBLE = """ +You are rewriting ONE section of a longer document. Below is a summary of how the +EARLIER sections have already been rewritten. Keep this section consistent with it +(same pseudonyms, tone, and narrative); do NOT repeat earlier content. Output only +the rewritten text for THIS section. + +Summary so far: +{{ summary }} + + +""" + +_SUMMARY_PROMPT = """You maintain a concise running summary of a long document being rewritten for privacy. +Update the summary to incorporate the newly rewritten section, keeping only what is needed to keep +later sections consistent (narrative state, established pseudonyms/relationships). Be terse; hard +limit {{ summary_max_chars }} characters. Return only the updated summary. + +Previous summary: +{{ prev_summary }} + +Newly rewritten section: +{{ rewritten_chunk }} +""" + + +class WindowedRewriteParams(BaseModel): + """Params for chunked rewrite generation (via DD ``generator_params``).""" + + alias: str = Field(min_length=1) + single_call_prompt_template: str = Field(repr=False) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + summary_max_chars: int = Field(default=2000, gt=0) + delimiter: str = Field(default=DEFAULT_DELIMITER) + system_prompt: str | None = Field(default=None, repr=False) + + +def _filter_disposition_to_chunk(block: list[dict[str, Any]], chunk_raw: str) -> list[dict[str, Any]]: + """Keep disposition entries whose entity value appears in this chunk's raw text.""" + if not isinstance(block, list): + return [] + return [e for e in block if isinstance(e, dict) and str(e.get("entity_value", "")) and str(e["entity_value"]) in chunk_raw] + + +def _render_chunk_prompt(*, template: str, chunk_row: dict[str, Any], summary: str) -> str: + """Render the rewrite prompt for one chunk, prepended with the continuity preamble.""" + preamble = _compile_template(_CONTINUITY_PREAMBLE).render(summary=summary or "(this is the first section)") + body = _compile_template(template).render(**chunk_row) + return preamble + body + + +def _rewrite_chunk(*, facade: Any, prompt: str, system_prompt: str | None, purpose: str) -> str: + recipe = PydanticResponseRecipe(data_type=RewriteOutputSchema) + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(prompt), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt), + purpose=purpose, + ) + text = "" + if output is not None: + dumped = output.model_dump(mode="python") if hasattr(output, "model_dump") else output + text = str(dumped.get("rewritten_text", "")) if isinstance(dumped, dict) else "" + return text + + +def _update_summary( + *, facade: Any, prev_summary: str, rewritten_chunk: str, summary_max_chars: int, system_prompt: str | None, purpose: str +) -> str: + recipe = TextResponseRecipe() + rendered = _compile_template(_SUMMARY_PROMPT).render( + prev_summary=prev_summary or "(none yet)", rewritten_chunk=rewritten_chunk, summary_max_chars=summary_max_chars + ) + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(rendered), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt), + purpose=purpose, + ) + return str(output)[:summary_max_chars] + + +def generate_rewrite_row( + row: dict[str, Any], + params: WindowedRewriteParams, + models: dict[str, Any], +) -> dict[str, Any]: + """Produce ``COL_FULL_REWRITE`` for one row, chunking long documents with rolling-summary continuity.""" + if params.alias not in models: + raise KeyError( + f"Rewriter alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_rewrite_generator was invoked with the same alias." + ) + facade = models[params.alias] + cap = params.max_render_chars + initial_window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + + # Fast path: the full single-call rewrite prompt fits under the cap. + single_rendered = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=row, summary="") + if len(single_rendered) <= cap: + logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) + text = _rewrite_chunk( + facade=facade, prompt=_compile_template(params.single_call_prompt_template).render(**row), + system_prompt=params.system_prompt, purpose="rewrite-generation", + ) + row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=text).model_dump(mode="json") + return row + + # Chunked path: rewrite each boundary window with continuity carry-over, then stitch. + tagged = str(row.get(COL_TAGGED_TEXT, "")) + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + disposition_block = row.get(COL_REWRITE_DISPOSITION_BLOCK, []) + replacement_map = row.get(COL_REPLACEMENT_MAP_FOR_PROMPT, {"replacements": []}) + + # Chunk on the already-tagged text so tag boundaries stay intact; the delimiter + # default ("\n") keeps cuts on line breaks. + windows = iter_boundary_windows(tagged, initial_window, delimiter=params.delimiter) + logger.info( + "rewrite: rendered prompt %d chars > cap %d; chunking %d-char tagged document into %d boundary " + "window(s) (initial_window=%d, delimiter=%r, summary_max=%d)", + len(single_rendered), cap, len(tagged), len(windows), initial_window, params.delimiter, params.summary_max_chars, + ) + rewritten_parts: list[str] = [] + summary = "" + for i, (start, end) in enumerate(windows): + chunk_tagged = tagged[start:end] + chunk_disposition = _filter_disposition_to_chunk(disposition_block, chunk_tagged) + chunk_row = { + **row, + COL_TAGGED_TEXT: chunk_tagged, + COL_TAG_NOTATION: notation.value, + COL_REWRITE_DISPOSITION_BLOCK: chunk_disposition, + COL_REPLACEMENT_MAP_FOR_PROMPT: replacement_map, + } + prompt = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=chunk_row, summary=summary) + logger.debug( + "rewrite window %d/%d: chars [%d, %d) size=%d, %d in-chunk disposition entr(y/ies), prompt=%d chars", + i + 1, len(windows), start, end, end - start, len(chunk_disposition), len(prompt), + ) + rewritten_chunk = _rewrite_chunk( + facade=facade, prompt=prompt, system_prompt=params.system_prompt, purpose=f"rewrite-generation-chunk-{start}" + ) + logger.debug( + "rewrite window %d/%d: produced %d chars of rewritten text", i + 1, len(windows), len(rewritten_chunk), + ) + rewritten_parts.append(rewritten_chunk) + if i < len(windows) - 1: + summary = _update_summary( + facade=facade, prev_summary=summary, rewritten_chunk=rewritten_chunk, + summary_max_chars=params.summary_max_chars, system_prompt=params.system_prompt, + purpose=f"rewrite-summary-{start}", + ) + logger.debug( + "rewrite window %d/%d: continuity summary updated -> %d chars: %s", + i + 1, len(windows), len(summary), _clip(summary), + ) + + stitched = "\n".join(part for part in rewritten_parts if part) + logger.info( + "rewrite: %d window(s) over %d chars -> %d chars stitched output", len(windows), len(tagged), len(stitched), + ) + row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=stitched).model_dump(mode="json") + return row + + +def make_windowed_rewrite_generator(alias: str) -> Any: + """Build a ``@custom_column_generator`` for chunked rewrite generation bound to ``alias``.""" + if not alias: + raise ValueError("Cannot build windowed rewrite generator: alias is empty.") + + @custom_column_generator( + required_columns=[ + COL_TEXT, + COL_TAGGED_TEXT, + COL_TAG_NOTATION, + COL_REWRITE_DISPOSITION_BLOCK, + COL_REPLACEMENT_MAP_FOR_PROMPT, + ], + model_aliases=[alias], + ) + def windowed_rewrite( + row: dict[str, Any], + generator_params: WindowedRewriteParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return generate_rewrite_row(row, generator_params, models) + + return windowed_rewrite diff --git a/src/anonymizer/engine/rewrite/chunked_steps.py b/src/anonymizer/engine/rewrite/chunked_steps.py new file mode 100644 index 00000000..22d8f234 --- /dev/null +++ b/src/anonymizer/engine/rewrite/chunked_steps.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Generic boundary-windowed LLM step for rewrite-pipeline metadata columns. + +Several rewrite steps (domain classification, sensitivity disposition, QA / +meaning-unit extraction, evaluate, final judge) embed the full document into a +single LLM call. For long documents this exceeds the render cap. This module +runs such a step over boundary-aligned windows and merges the per-window outputs +with a step-specific ``merge_fn``: + +- domain -> classify the first window only (coarse, doc-level label) +- disposition -> union per-entity protection decisions +- meaning units -> concatenate per-window units +- evaluate -> aggregate metrics +- final judge -> OR across windows + +Rendering here (instead of via an ``LLMStructuredColumnConfig``) also sidesteps +NDD's ginja per-render length cap, like the chunked detection steps. +""" + +from __future__ import annotations + +import functools +import logging +from collections.abc import Callable +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.windowing import DEFAULT_DELIMITER, iter_boundary_windows + +logger = logging.getLogger("anonymizer.rewrite.chunked_steps") + +_MIN_WINDOW_CHARS = 4000 + +_PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) + + +@functools.lru_cache(maxsize=16) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +# A merge function takes the list of per-window parsed schema objects and returns +# a JSON-serializable value to store in the output column. +MergeFn = Callable[[list[Any]], Any] + + +class WindowedStepParams(BaseModel): + """Params for a generic windowed rewrite metadata step (via DD ``generator_params``).""" + + alias: str = Field(min_length=1) + prompt_template: str = Field(repr=False) + output_column: str = Field(min_length=1) + text_column: str = Field(min_length=1) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + delimiter: str = Field(default=DEFAULT_DELIMITER) + first_only: bool = Field(default=False) + system_prompt: str | None = Field(default=None, repr=False) + + +def run_windowed_step( + row: dict[str, Any], + params: WindowedStepParams, + models: dict[str, Any], + *, + schema: type[BaseModel], + merge_fn: MergeFn, + purpose_prefix: str, +) -> dict[str, Any]: + """Run ``params.prompt_template`` over boundary windows of ``params.text_column``; store ``merge_fn`` result.""" + if params.alias not in models: + raise KeyError(f"Alias {params.alias!r} not present in models dict for step {purpose_prefix!r}.") + facade = models[params.alias] + recipe = PydanticResponseRecipe(data_type=schema) + cap = params.max_render_chars + initial_window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) + + def _call(prompt: str, purpose: str) -> Any: + output, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(prompt), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(params.system_prompt), + purpose=purpose, + ) + return output + + full_rendered = _compile_template(params.prompt_template).render(**row) + if len(full_rendered) <= cap: + row[params.output_column] = merge_fn([_call(full_rendered, purpose_prefix)]) + return row + + text = str(row.get(params.text_column, "")) + windows = iter_boundary_windows(text, initial_window, delimiter=params.delimiter) + if params.first_only: + windows = windows[:1] + outputs = [] + for start, end in windows: + rendered = _compile_template(params.prompt_template).render(**{**row, params.text_column: text[start:end]}) + outputs.append(_call(rendered, f"{purpose_prefix}-{start}")) + logger.debug("windowed step %s: %d window(s) over %d chars", purpose_prefix, len(windows), len(text)) + row[params.output_column] = merge_fn(outputs) + return row + + +def make_windowed_metadata_generator( + *, + alias: str, + required_columns: list[str], + schema: type[BaseModel], + merge_fn: MergeFn, + purpose_prefix: str, +) -> Any: + """Build a ``@custom_column_generator`` running a windowed metadata step. + + ``schema``/``merge_fn`` are bound here (not in params) since they are not + serializable; window sizing + prompt come from ``WindowedStepParams``. + """ + if not alias: + raise ValueError(f"Cannot build windowed step generator for {purpose_prefix}: alias is empty.") + + @custom_column_generator(required_columns=list(required_columns), model_aliases=[alias]) + def windowed_step( + row: dict[str, Any], + generator_params: WindowedStepParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return run_windowed_step( + row, generator_params, models, schema=schema, merge_fn=merge_fn, purpose_prefix=purpose_prefix + ) + + return windowed_step diff --git a/src/anonymizer/engine/rewrite/domain_classification.py b/src/anonymizer/engine/rewrite/domain_classification.py index f943c41d..62173bc6 100644 --- a/src/anonymizer/engine/rewrite/domain_classification.py +++ b/src/anonymizer/engine/rewrite/domain_classification.py @@ -7,9 +7,10 @@ from typing import Any from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.engine.constants import ( COL_DOMAIN, @@ -20,8 +21,17 @@ ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.rewrite.chunked_steps import WindowedStepParams, make_windowed_metadata_generator from anonymizer.engine.schemas import Domain, DomainClassificationSchema +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default + + +def _first_output(outputs: list[Any]) -> dict[str, Any]: + """Domain is a single doc-level label: keep the first window's classification.""" + return outputs[0].model_dump(mode="json") + # --------------------------------------------------------------------------- # Single source of truth for rewrite-domain metadata. # @@ -270,11 +280,24 @@ def columns( ) -> list[ColumnConfigT]: domain_alias = resolve_model_alias("domain_classifier", selected_models) return [ - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_DOMAIN, - prompt=_get_domain_classification_prompt(data_summary), - model_alias=domain_alias, - output_format=DomainClassificationSchema, + generator_function=make_windowed_metadata_generator( + alias=domain_alias, + required_columns=[COL_TEXT], + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain-classification", + ), + generator_params=WindowedStepParams( + alias=domain_alias, + prompt_template=_get_domain_classification_prompt(data_summary), + output_column=COL_DOMAIN, + text_column=COL_TEXT, + max_render_chars=_DEFAULT_MAX_RENDER_CHARS, + safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + first_only=True, + ), ), CustomColumnConfig( name=COL_DOMAIN_SUPPLEMENT, diff --git a/src/anonymizer/engine/rewrite/qa_generation.py b/src/anonymizer/engine/rewrite/qa_generation.py index 03978657..f8ce5860 100644 --- a/src/anonymizer/engine/rewrite/qa_generation.py +++ b/src/anonymizer/engine/rewrite/qa_generation.py @@ -7,9 +7,11 @@ from typing import Any from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.engine.constants import ( COL_DOMAIN, @@ -25,6 +27,11 @@ ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.rewrite.chunked_steps import ( + WindowedStepParams, + _compile_template, + make_windowed_metadata_generator, +) from anonymizer.engine.rewrite.parsers import parse_sensitivity_disposition from anonymizer.engine.schemas import ( Domain, @@ -46,6 +53,18 @@ if _DOMAIN_KEY is None: raise RuntimeError("DomainClassificationSchema must define a field annotated with Domain") +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default + + +def _concat_meaning_units(outputs: list[Any]) -> dict[str, Any]: + """Concatenate per-window meaning units, re-sequencing IDs across windows.""" + units: list[dict[str, Any]] = [] + for out in outputs: + for unit in out.units: + units.append({**unit.model_dump(mode="json"), "id": len(units) + 1}) + return MeaningUnitsSchema.model_validate({"units": units}).model_dump(mode="json") + # --------------------------------------------------------------------------- # Stage 1 pre-step: format disposition → disposition block # --------------------------------------------------------------------------- @@ -293,6 +312,99 @@ def _get_quality_qa_prompt() -> str: ) +# --------------------------------------------------------------------------- +# Stage 2 (long-context): batch meaning units so each quality-QA prompt fits +# --------------------------------------------------------------------------- + + +def _batch_units_by_size(units: list[dict[str, Any]], base_len: int, budget: int) -> list[list[dict[str, Any]]]: + """Greedily pack units into batches whose rendered prompt stays within ``budget``. + + ``base_len`` is the rendered prompt length with an empty unit list, so + ``budget - base_len`` is the char allowance for the units' serialized payload. + ``current_len`` tracks that incremental payload (each unit ~ ``len(json)+1``). + When the allowance is non-positive a single unit cannot be split, so each unit + falls into its own batch. + """ + allowance = budget - base_len + batches: list[list[dict[str, Any]]] = [] + current: list[dict[str, Any]] = [] + current_len = 0 + for unit in units: + unit_len = len(json.dumps(unit, ensure_ascii=False)) + 1 + if current and current_len + unit_len > allowance: + batches.append(current) + current = [] + current_len = 0 + current.append(unit) + current_len += unit_len + if current: + batches.append(current) + return batches + + +def generate_quality_qa_row( + row: dict[str, Any], + models: dict[str, Any], + *, + alias: str, + prompt_template: str, + max_render_chars: int, + safety_margin_chars: int, +) -> dict[str, Any]: + """Build ``COL_QUALITY_QA`` for one row, batching meaning units if one prompt would exceed the cap.""" + facade = models[alias] + recipe = PydanticResponseRecipe(data_type=QualityQAPairsSchema) + compiled = _compile_template(prompt_template) + + def _generate(rendered: str, purpose: str) -> QualityQAPairsSchema: + out, _ = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(rendered), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(None), + purpose=purpose, + ) + return out + + full_rendered = compiled.render(**row) + if len(full_rendered) <= max_render_chars: + row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump() + return row + + units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") + base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) + batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) + items: list[dict[str, Any]] = [] + for batch_idx, batch in enumerate(batches): + rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) + out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") + for item in out.items: + items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) + row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump() + return row + + +def _make_quality_qa_column(qa_generator_alias: str, max_render_chars: int, safety_margin_chars: int) -> Any: + """Build the quality-QA generator, batching meaning units when one prompt would exceed the cap.""" + prompt_template = _get_quality_qa_prompt() + + @custom_column_generator( + required_columns=[COL_MEANING_UNITS_SERIALIZED], + model_aliases=[qa_generator_alias], + ) + def _quality_qa(row: dict[str, Any], generator_params: None, models: dict[str, Any]) -> dict[str, Any]: + return generate_quality_qa_row( + row, + models, + alias=qa_generator_alias, + prompt_template=prompt_template, + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, + ) + + return _quality_qa + + # --------------------------------------------------------------------------- # Stage 3: privacy QA generation (pure Python, no LLM) # --------------------------------------------------------------------------- @@ -348,21 +460,40 @@ def columns( name=COL_SENSITIVITY_DISPOSITION_BLOCK, generator_function=_format_disposition_block, ), - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_MEANING_UNITS, - prompt=_get_meaning_unit_extraction_prompt(), - model_alias=meaning_extractor_alias, - output_format=MeaningUnitsSchema, + generator_function=make_windowed_metadata_generator( + alias=meaning_extractor_alias, + required_columns=[ + COL_TEXT, + COL_SENSITIVITY_DISPOSITION_BLOCK, + COL_DOMAIN, + COL_DOMAIN_SUPPLEMENT, + ], + schema=MeaningUnitsSchema, + merge_fn=_concat_meaning_units, + purpose_prefix="meaning-unit-extraction", + ), + generator_params=WindowedStepParams( + alias=meaning_extractor_alias, + prompt_template=_get_meaning_unit_extraction_prompt(), + output_column=COL_MEANING_UNITS, + text_column=COL_TEXT, + max_render_chars=_DEFAULT_MAX_RENDER_CHARS, + safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + ), ), CustomColumnConfig( name=COL_MEANING_UNITS_SERIALIZED, generator_function=_serialize_meaning_units, ), - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_QUALITY_QA, - prompt=_get_quality_qa_prompt(), - model_alias=qa_generator_alias, - output_format=QualityQAPairsSchema, + generator_function=_make_quality_qa_column( + qa_generator_alias, + _DEFAULT_MAX_RENDER_CHARS, + _DEFAULT_SAFETY_MARGIN_CHARS, + ), ), CustomColumnConfig( name=COL_PRIVACY_QA, diff --git a/src/anonymizer/engine/rewrite/rewrite_generation.py b/src/anonymizer/engine/rewrite/rewrite_generation.py index 7892b1cf..ffb34c0b 100644 --- a/src/anonymizer/engine/rewrite/rewrite_generation.py +++ b/src/anonymizer/engine/rewrite/rewrite_generation.py @@ -7,9 +7,10 @@ from typing import Any from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal from anonymizer.engine.constants import ( @@ -25,14 +26,16 @@ ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.rewrite.chunked_rewrite import WindowedRewriteParams, make_windowed_rewrite_generator from anonymizer.engine.rewrite.parsers import normalize_payload, parse_sensitivity_disposition -from anonymizer.engine.schemas import ( - EntityReplacementMapSchema, - RewriteOutputSchema, -) +from anonymizer.engine.schemas import EntityReplacementMapSchema logger = logging.getLogger("anonymizer.rewrite.generation") +# Long-context window defaults shared with detection (single source of truth). +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default + # --------------------------------------------------------------------------- # Prompt @@ -218,8 +221,16 @@ def columns( selected_models: RewriteModelSelection, privacy_goal: PrivacyGoal, data_summary: str | None = None, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> list[ColumnConfigT]: rewriter_alias = resolve_model_alias("rewriter", selected_models) + # Honor caller-provided (Detect-config-derived) window sizing; fall back + # to the shared defaults. Smaller windows rewrite less text per LLM call. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) return [ CustomColumnConfig( name=COL_REWRITE_DISPOSITION_BLOCK, @@ -229,11 +240,15 @@ def columns( name=COL_REPLACEMENT_MAP_FOR_PROMPT, generator_function=_filter_replacement_map_for_prompt, ), - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_FULL_REWRITE, - prompt=_get_rewrite_prompt(privacy_goal, data_summary), - model_alias=rewriter_alias, - output_format=RewriteOutputSchema, + generator_function=make_windowed_rewrite_generator(rewriter_alias), + generator_params=WindowedRewriteParams( + alias=rewriter_alias, + single_call_prompt_template=_get_rewrite_prompt(privacy_goal, data_summary), + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, + ), ), CustomColumnConfig( name=COL_REWRITTEN_TEXT, diff --git a/src/anonymizer/engine/rewrite/rewrite_workflow.py b/src/anonymizer/engine/rewrite/rewrite_workflow.py index 88c2b9c3..94f95976 100644 --- a/src/anonymizer/engine/rewrite/rewrite_workflow.py +++ b/src/anonymizer/engine/rewrite/rewrite_workflow.py @@ -195,6 +195,8 @@ def run( data_summary: str | None = None, preview_num_records: int | None = None, strict_entity_protection: bool = False, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> RewriteResult: all_failed: list[FailedRecord] = [] @@ -212,6 +214,8 @@ def run( entity_rows, model_configs=model_configs, selected_models=replace_model_selection, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, ) entity_rows = _join_new_columns(entity_rows, replace_result.dataframe) all_failed.extend(replace_result.failed_records) @@ -227,6 +231,8 @@ def run( ), *self._qa_wf.columns(selected_models=selected_models), *self._rewrite_gen_wf.columns( + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, selected_models=selected_models, privacy_goal=privacy_goal, data_summary=data_summary, diff --git a/src/anonymizer/engine/rewrite/sensitivity_disposition.py b/src/anonymizer/engine/rewrite/sensitivity_disposition.py index fca24ba8..0879ff0f 100644 --- a/src/anonymizer/engine/rewrite/sensitivity_disposition.py +++ b/src/anonymizer/engine/rewrite/sensitivity_disposition.py @@ -3,9 +3,12 @@ from __future__ import annotations -from data_designer.config.column_configs import LLMStructuredColumnConfig +from typing import Any + +from data_designer.config.column_configs import CustomColumnConfig from data_designer.config.column_types import ColumnConfigT +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal from anonymizer.engine.constants import ( @@ -20,8 +23,40 @@ ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.rewrite.chunked_steps import WindowedStepParams, make_windowed_metadata_generator from anonymizer.engine.schemas import SensitivityDispositionSchema, StrictSensitivityDispositionSchema +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default + +_RISK_RANK = {"low": 0, "medium": 1, "high": 2} + + +def _make_disposition_merge(container_schema: type) -> Any: + """Union per-entity disposition entries across windows, keeping the highest combined-risk decision. + + Each window sees the full entity list (only the tagged-text *context* is + chunked), so windows largely agree; dedup by (source, label, value) and keep + the most protective entry. The container's validator re-sequences IDs. + """ + + def _merge(outputs: list[Any]) -> dict[str, Any]: + best: dict[tuple[str, str, str], tuple[int, Any]] = {} + order: list[tuple[str, str, str]] = [] + for out in outputs: + for e in out.sensitivity_disposition: + key = (str(e.source), e.entity_label, e.entity_value) + rank = _RISK_RANK.get(str(e.combined_risk_level), 0) + if key not in best: + order.append(key) + best[key] = (rank, e) + elif rank > best[key][0]: + best[key] = (rank, e) + entries = [best[k][1].model_dump(mode="json") for k in order] + return container_schema.model_validate({"sensitivity_disposition": entries}).model_dump(mode="json") + + return _merge + def _get_sensitivity_disposition_prompt( privacy_goal: PrivacyGoal, data_summary: str | None = None, strict_entity_protection: bool = False @@ -269,14 +304,33 @@ def columns( disposition_alias = resolve_model_alias("disposition_analyzer", selected_models) output_schema = StrictSensitivityDispositionSchema if strict_entity_protection else SensitivityDispositionSchema return [ - LLMStructuredColumnConfig( + CustomColumnConfig( name=COL_SENSITIVITY_DISPOSITION, - prompt=_get_sensitivity_disposition_prompt( - privacy_goal, - data_summary, - strict_entity_protection=strict_entity_protection, + generator_function=make_windowed_metadata_generator( + alias=disposition_alias, + required_columns=[ + COL_TAGGED_TEXT, + COL_ENTITIES_BY_VALUE, + COL_LATENT_ENTITIES, + COL_DOMAIN, + COL_DOMAIN_SUPPLEMENT_PRIVACY, + COL_TAG_NOTATION, + ], + schema=output_schema, + merge_fn=_make_disposition_merge(output_schema), + purpose_prefix="sensitivity-disposition", + ), + generator_params=WindowedStepParams( + alias=disposition_alias, + prompt_template=_get_sensitivity_disposition_prompt( + privacy_goal, + data_summary, + strict_entity_protection=strict_entity_protection, + ), + output_column=COL_SENSITIVITY_DISPOSITION, + text_column=COL_TAGGED_TEXT, + max_render_chars=_DEFAULT_MAX_RENDER_CHARS, + safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, ), - model_alias=disposition_alias, - output_format=output_schema, ), ] diff --git a/src/anonymizer/engine/windowing.py b/src/anonymizer/engine/windowing.py new file mode 100644 index 00000000..6524a6db --- /dev/null +++ b/src/anonymizer/engine/windowing.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Boundary-aware text windowing shared across long-context workflows. + +Splits a long document into sequential, non-overlapping windows of at most +``max_chars`` each. Rather than cutting at an arbitrary character offset, each +window is backed off to the last ``delimiter`` (default a newline) within the +window so a chunk boundary lands on a natural break instead of mid-line / +mid-token. If no delimiter occurs in the window, it falls back to a hard cut at +``max_chars`` so progress is always made. + +Used by the chunked Substitute (map generation) and Rewrite long-context paths. +""" + +from __future__ import annotations + +DEFAULT_DELIMITER = "\n" + + +def next_window_end(text: str, start: int, max_chars: int, *, delimiter: str = DEFAULT_DELIMITER) -> int: + """Return the end offset for a window starting at ``start``. + + The window is at most ``max_chars`` long; when it does not reach the end of + ``text`` it is backed off to just after the last ``delimiter`` inside the + window. If the window contains no delimiter (other than possibly at the very + start), a hard cut at ``start + max_chars`` is returned. + """ + if max_chars <= 0: + raise ValueError("max_chars must be positive") + hard_end = min(len(text), start + max_chars) + if hard_end >= len(text): + return len(text) + window = text[start:hard_end] + idx = window.rfind(delimiter) + # idx > 0 ensures we make progress (a delimiter at offset 0 would not advance). + if delimiter and idx > 0: + return start + idx + len(delimiter) + return hard_end + + +def iter_boundary_windows(text: str, max_chars: int, *, delimiter: str = DEFAULT_DELIMITER) -> list[tuple[int, int]]: + """Tile ``[0, len(text))`` into sequential boundary-aligned ``(start, end)`` windows.""" + n = len(text) + if n == 0: + return [] + bounds: list[tuple[int, int]] = [] + start = 0 + while start < n: + end = next_window_end(text, start, max_chars, delimiter=delimiter) + if end <= start: # defensive: always advance + end = min(n, start + max_chars) + bounds.append((start, end)) + start = end + return bounds diff --git a/src/anonymizer/interface/anonymizer.py b/src/anonymizer/interface/anonymizer.py index ec08164a..f4314b6c 100644 --- a/src/anonymizer/interface/anonymizer.py +++ b/src/anonymizer/interface/anonymizer.py @@ -380,6 +380,9 @@ def _run_internal( gliner_detection_threshold=config.detect.gliner_threshold, validation_max_entities_per_call=config.detect.validation_max_entities_per_call, validation_excerpt_window_chars=config.detect.validation_excerpt_window_chars, + detection_window_max_render_chars=config.detect.detection_window_max_render_chars, + detection_window_safety_margin_chars=config.detect.detection_window_safety_margin_chars, + detection_window_overlap_chars=config.detect.detection_window_overlap_chars, entity_labels=config.detect.entity_labels, privacy_goal=config.rewrite.privacy_goal if config.rewrite else None, data_summary=data.data_summary, @@ -412,6 +415,8 @@ def _run_internal( model_configs=self._model_configs, selected_models=self._selected_models.replace, preview_num_records=preview_num_records, + window_max_render_chars=config.detect.detection_window_max_render_chars, + window_safety_margin_chars=config.detect.detection_window_safety_margin_chars, ) replace_elapsed = time.perf_counter() - t0 final_df = replace_result.dataframe @@ -434,6 +439,8 @@ def _run_internal( data_summary=data.data_summary, preview_num_records=preview_num_records, strict_entity_protection=config.rewrite.strict_entity_protection, + window_max_render_chars=config.detect.detection_window_max_render_chars, + window_safety_margin_chars=config.detect.detection_window_safety_margin_chars, ) rewrite_elapsed = time.perf_counter() - t0 final_df = rewrite_result.dataframe diff --git a/tests/engine/test_chunked_augmentation.py b/tests/engine/test_chunked_augmentation.py new file mode 100644 index 00000000..cdef0790 --- /dev/null +++ b/tests/engine/test_chunked_augmentation.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for windowed LLM augmentation. + +Pure helpers (windowing, per-window inputs, merge) are tested directly; the +per-window dispatch is tested via a fake ``ModelFacade`` that records calls and +replays canned responses (mirrors the chunked-validation tests). +""" + +from __future__ import annotations + +import itertools +import json +from typing import Any, Callable + +import pytest + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_INITIAL_TAGGED_TEXT, + COL_SEED_ENTITIES_JSON, + COL_TAG_NOTATION, + COL_TEXT, + COL_VALIDATED_SEED_ENTITIES, +) +from anonymizer.engine.detection.chunked_augmentation import ( + WindowedAugmentationParams, + augment_row, + build_window_inputs, + iter_windows, + merge_augmented, + render_augment_prompt, +) +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation +from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema + +# Small stand-in for the real (multi-kB) augment prompt so window sizes are +# controllable in tests. References the same placeholders the real prompt uses. +_TEMPLATE = "AUG[{{ _tag_notation }}] {{ _initial_tagged_text }} || SEEDS: {{ _seed_entities_json }}" + + +class FakeFacade: + """Records invocations and replays canned responses through the recipe parser.""" + + def __init__(self, response: dict | str | Callable[[str], dict | str]) -> None: + self._response = response + self.calls: list[dict[str, Any]] = [] + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + self.calls.append({"prompt": prompt, "system_prompt": system_prompt, "purpose": purpose}) + response = self._response + if callable(response): + response = response(prompt) + raw = response if isinstance(response, str) else f"```json\n{json.dumps(response)}\n```" + return parser(raw), [] + + +def _span(value: str, label: str, start: int, end: int) -> EntitySpan: + return EntitySpan( + entity_id=f"{label}_{start}_{end}", + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="detector", + ) + + +def _aug(*pairs: tuple[str, str]) -> AugmentedEntitiesSchema: + return AugmentedEntitiesSchema.model_validate({"entities": [{"value": v, "label": lab} for v, lab in pairs]}) + + +def _params(max_render_chars: int, **kw: Any) -> WindowedAugmentationParams: + return WindowedAugmentationParams(alias="aug", prompt_template=_TEMPLATE, max_render_chars=max_render_chars, **kw) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +class TestIterWindows: + def test_tiles_with_overlap(self) -> None: + assert iter_windows(10000, window=4000, overlap=1000) == [(0, 4000), (3000, 7000), (6000, 10000)] + + def test_single_window_when_text_fits(self) -> None: + assert iter_windows(500, window=4000, overlap=1000) == [(0, 500)] + + def test_empty(self) -> None: + assert iter_windows(0, window=4000, overlap=1000) == [] + + +class TestMergeAugmented: + def test_dedupes_by_value_and_label_case_insensitively(self) -> None: + first = _aug(("Alice", "first_name"), ("bob", "first_name")) + second = AugmentedEntitiesSchema.model_validate( + { + "entities": [ + {"value": "alice", "label": "first_name"}, # case-dup of Alice -> dropped + {"value": "Alice", "label": "city"}, # same value, different label -> kept + {"value": " ", "label": "first_name"}, # blank after strip -> dropped + ] + } + ) + merged = merge_augmented([first, second]) + pairs = {(e.value, e.label) for e in merged.entities} + assert pairs == {("Alice", "first_name"), ("bob", "first_name"), ("Alice", "city")} + + +class TestBuildWindowInputs: + def test_tags_only_in_window_spans_rebased(self) -> None: + text = "Alice met Bob in Paris" # Alice 0-5, Bob 10-13, Paris 17-22 + spans = [ + _span("Alice", "first_name", 0, 5), + _span("Bob", "first_name", 10, 13), + _span("Paris", "city", 17, 22), + ] + tagged, seed_json = build_window_inputs(text=text, all_spans=spans, start=0, end=14, notation=TagNotation.xml) + assert "Alice" in tagged + assert "Bob" in tagged + assert "Paris" not in tagged # out-of-window span excluded + seeds = json.loads(seed_json) + assert {s["value"] for s in seeds} == {"Alice", "Bob"} + assert all(0 <= s["start_position"] < s["end_position"] <= 14 for s in seeds) # window-local offsets + + +def test_render_includes_notation_and_inputs() -> None: + rendered = render_augment_prompt( + template=_TEMPLATE, tagged_text="TAGGED", seed_entities_json="[]", notation=TagNotation.bracket + ) + assert "TAGGED" in rendered + assert "bracket" in rendered + + +# --------------------------------------------------------------------------- +# augment_row +# --------------------------------------------------------------------------- + + +class TestAugmentRowFastPath: + def test_single_call_when_under_budget(self) -> None: + facade = FakeFacade({"entities": [{"value": "Alice", "label": "first_name"}]}) + row = { + COL_TEXT: "Alice in Paris", + COL_INITIAL_TAGGED_TEXT: "Alice in Paris", + COL_SEED_ENTITIES_JSON: "[]", + COL_TAG_NOTATION: "xml", + } + out = augment_row(row, _params(max_render_chars=1_000_000), {"aug": facade}) + assert len(facade.calls) == 1 + result = AugmentedEntitiesSchema.model_validate(out[COL_AUGMENTED_ENTITIES]) + assert [(e.value, e.label) for e in result.entities] == [("Alice", "first_name")] + + +class TestAugmentRowWindowed: + def test_multiple_windows_unioned(self) -> None: + text = ("A" * 5000) + ("B" * 5000) # 10k chars -> forces several windows + counter = itertools.count() + # Each call returns a distinct entity so the union size == number of windows. + facade = FakeFacade(lambda _prompt: {"entities": [{"value": f"v{next(counter)}", "label": "name"}]}) + row = { + COL_TEXT: text, + COL_INITIAL_TAGGED_TEXT: text, # full render exceeds budget -> windowed path + COL_SEED_ENTITIES_JSON: "[]", + COL_VALIDATED_SEED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + out = augment_row( + row, _params(max_render_chars=4000, safety_margin_chars=0, overlap_chars=1000), {"aug": facade} + ) + assert len(facade.calls) > 1 + result = AugmentedEntitiesSchema.model_validate(out[COL_AUGMENTED_ENTITIES]) + assert len(result.entities) == len(facade.calls) + + def test_missing_alias_raises(self) -> None: + with pytest.raises(KeyError, match="not present in models"): + augment_row({COL_TEXT: "x"}, _params(max_render_chars=10), {}) diff --git a/tests/engine/test_chunked_detection.py b/tests/engine/test_chunked_detection.py new file mode 100644 index 00000000..b94ff4e8 --- /dev/null +++ b/tests/engine/test_chunked_detection.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for windowed (chunked) GLiNER seed detection. + +Pure helpers (rebasing, boundary-span dropping, payload serialization) are tested +directly. The per-window dispatch is tested via a fake detector facade that +simulates a NER endpoint over the *submitted window text*: it emits every full +occurrence of a target string plus a truncated prefix/suffix when an entity is +cut by the window edge — exercising the offset-rebasing and boundary handling. +""" + +from __future__ import annotations + +import json +from typing import Any + +from anonymizer.engine.constants import COL_RAW_DETECTED, COL_TEXT +from anonymizer.engine.detection.chunked_detection import ( + WindowedDetectionParams, + detect_row, + drop_boundary_spans, + rebase_spans, + spans_to_detector_payload, +) +from anonymizer.engine.detection.postprocess import EntitySpan + + +def _span(value: str, start: int, end: int, label: str = "person") -> EntitySpan: + return EntitySpan( + entity_id=f"{label}_{start}_{end}", + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="detector", + ) + + +# --------------------------------------------------------------------------- +# Fake detector facade +# --------------------------------------------------------------------------- + + +class FakeDetectorFacade: + """Simulates a NER endpoint over the submitted window text. + + Emits window-local offsets for every full occurrence of each target, plus a + truncated prefix at the window's right edge / suffix at the left edge to mimic + an entity straddling a window boundary. + """ + + def __init__(self, targets: list[str], label: str = "person") -> None: + self.targets = targets + self.label = label + self.calls: list[dict[str, Any]] = [] + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + self.calls.append({"prompt": prompt, "purpose": purpose}) + wt = prompt + ents: list[dict[str, Any]] = [] + for t in self.targets: + i = wt.find(t) + while i != -1: + ents.append({"text": t, "label": self.label, "start": i, "end": i + len(t), "score": 0.99}) + i = wt.find(t, i + 1) + for k in range(len(t) - 1, 2, -1): # truncated prefix at right edge + if wt.endswith(t[:k]): + ents.append({"text": t[:k], "label": self.label, "start": len(wt) - k, "end": len(wt), "score": 0.5}) + break + for k in range(len(t) - 1, 2, -1): # truncated suffix at left edge + if wt.startswith(t[-k:]): + ents.append({"text": t[-k:], "label": self.label, "start": 0, "end": k, "score": 0.5}) + break + raw = json.dumps({"entities": ents}) + return parser(raw), [] + + +def _place(buf: list[str], s: str, pos: int) -> None: + for j, ch in enumerate(s): + buf[pos + j] = ch + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def test_rebase_spans_shifts_offsets(): + out = rebase_spans([_span("X", 5, 7)], 100) + assert (out[0].start_position, out[0].end_position) == (105, 107) + assert out[0].value == "X" + + +def test_rebase_spans_zero_offset_is_noop(): + spans = [_span("X", 5, 7)] + assert rebase_spans(spans, 0) == spans + + +def test_drop_boundary_spans_drops_artificial_edges_only(): + text_len = 1000 + spans = [ + _span("left-edge", 200, 260), # touches left edge (window_start=200) + _span("right-edge", 540, 600), # touches right edge (window_end=600) + _span("interior", 300, 320), # safe + ] + kept = drop_boundary_spans(spans, window_start=200, window_end=600, text_len=text_len) + assert [s.value for s in kept] == ["interior"] + + +def test_drop_boundary_spans_keeps_true_document_edges(): + # window_start == 0 (true doc start) and window_end == text_len (true doc end): + # spans touching those are real, not truncated, so they are kept. + spans = [_span("doc-start", 0, 40), _span("doc-end", 460, 500)] + kept = drop_boundary_spans(spans, window_start=0, window_end=500, text_len=500) + assert [s.value for s in kept] == ["doc-start", "doc-end"] + + +def test_spans_to_detector_payload_roundtrips_detector_shape(): + payload = json.loads(spans_to_detector_payload([_span("Alice", 3, 8, label="first_name")])) + assert payload == {"entities": [{"text": "Alice", "label": "first_name", "start": 3, "end": 8, "score": 1.0}]} + + +# --------------------------------------------------------------------------- +# detect_row: fast path +# --------------------------------------------------------------------------- + + +def test_fast_path_single_call_passes_raw_through(): + facade = FakeDetectorFacade(["Alice"]) + text = "hello Alice world" + row = {COL_TEXT: text} + params = WindowedDetectionParams(alias="det", max_render_chars=10_000) + detect_row(row, params, {"det": facade}) + + assert len(facade.calls) == 1 # one call, whole document + assert facade.calls[0]["prompt"] == text + payload = json.loads(row[COL_RAW_DETECTED]) + assert payload["entities"][0]["text"] == "Alice" + assert text[payload["entities"][0]["start"] : payload["entities"][0]["end"]] == "Alice" + + +# --------------------------------------------------------------------------- +# detect_row: windowed path (the boundary-cut scenario) +# --------------------------------------------------------------------------- + + +def test_windowed_recovers_straddling_entity_and_dedupes_overlap(): + targets = ["Maria Garcia", "Bob Smith", "Acme Corporation"] + buf = ["a"] * 12_000 + _place(buf, "Maria Garcia", 4_995) # straddles the window-A boundary at 5000 + _place(buf, "Bob Smith", 4_200) # inside overlap region [4000, 5000] -> seen twice + _place(buf, "Acme Corporation", 9_100) # only inside window C + text = "".join(buf) + + facade = FakeDetectorFacade(targets) + row = {COL_TEXT: text} + # cap=5000, margin=0 -> window=5000; overlap=1000 -> windows [0,5000)[4000,9000)[8000,12000) + params = WindowedDetectionParams(alias="det", max_render_chars=5_000, safety_margin_chars=0, overlap_chars=1_000) + detect_row(row, params, {"det": facade}) + + ents = json.loads(row[COL_RAW_DETECTED])["entities"] + spans = {(e["start"], e["end"], e["text"]) for e in ents} + + # straddling entity recovered with correct GLOBAL offsets + assert (4_995, 5_007, "Maria Garcia") in spans + # overlap-region entity present exactly once (deduped across windows A and B) + assert (4_200, 4_209, "Bob Smith") in spans + assert sum(1 for e in ents if e["text"] == "Bob Smith") == 1 + # window-C-only entity present + assert (9_100, 9_116, "Acme Corporation") in spans + # every emitted span maps to a real, full target — no truncated partial leaked + assert all(text[e["start"] : e["end"]] in targets for e in ents) + assert len(ents) == 3 + + +def test_windowed_emits_empty_entities_when_nothing_detected(): + facade = FakeDetectorFacade(["Nonexistent Name"]) + text = "b" * 12_000 + row = {COL_TEXT: text} + params = WindowedDetectionParams(alias="det", max_render_chars=5_000, safety_margin_chars=0, overlap_chars=1_000) + detect_row(row, params, {"det": facade}) + assert json.loads(row[COL_RAW_DETECTED]) == {"entities": []} + assert len(facade.calls) >= 2 # actually windowed + + +def test_missing_alias_raises(): + params = WindowedDetectionParams(alias="det", max_render_chars=5_000) + try: + detect_row({COL_TEXT: "x"}, params, {}) + except KeyError as exc: + assert "det" in str(exc) + else: # pragma: no cover + raise AssertionError("expected KeyError for missing alias") diff --git a/tests/engine/test_chunked_latent.py b/tests/engine/test_chunked_latent.py new file mode 100644 index 00000000..d40e7233 --- /dev/null +++ b/tests/engine/test_chunked_latent.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for windowed latent-entity detection.""" + +from __future__ import annotations + +import itertools +import json +from typing import Any, Callable + +import pytest + +from anonymizer.engine.constants import ( + COL_DETECTED_ENTITIES, + COL_LATENT_ENTITIES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.chunked_latent import ( + WindowedLatentParams, + latent_row, + merge_latent, + render_latent_prompt, +) +from anonymizer.engine.detection.postprocess import TagNotation +from anonymizer.engine.schemas import EntitiesSchema, LatentEntitiesSchema + +# Build from the real column constants so the template can't drift from them. +_TEMPLATE = "LATENT[{{ " + COL_TAG_NOTATION + " }}] {{ " + COL_TAGGED_TEXT + " }}" + + +def _latent(label: str, value: str) -> dict[str, Any]: + return { + "category": "latent_identifier", + "label": label, + "value": value, + "confidence": "high", + "evidence": [f"context mentioning {value}"], + "rationale": "The surrounding text strongly implies this attribute about the subject.", + } + + +def _latent_schema(*pairs: tuple[str, str]) -> LatentEntitiesSchema: + return LatentEntitiesSchema.model_validate({"latent_entities": [_latent(lab, val) for lab, val in pairs]}) + + +class FakeFacade: + def __init__(self, response: dict | str | Callable[[str], dict | str]) -> None: + self._response = response + self.calls: list[str | None] = [] + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + self.calls.append(purpose) + response = self._response + if callable(response): + response = response(prompt) + raw = response if isinstance(response, str) else f"```json\n{json.dumps(response)}\n```" + return parser(raw), [] + + +def _params(max_render_chars: int, **kw: Any) -> WindowedLatentParams: + return WindowedLatentParams(alias="lat", prompt_template=_TEMPLATE, max_render_chars=max_render_chars, **kw) + + +class TestMergeLatent: + def test_dedupes_by_label_and_value(self) -> None: + first = _latent_schema(("employer", "Acme"), ("home_location", "Boston")) + second = _latent_schema(("employer", "acme"), ("employer", "Globex")) # acme dup, Globex new + merged = merge_latent([first, second]) + pairs = {(e.label, e.value) for e in merged.latent_entities} + assert pairs == {("employer", "Acme"), ("home_location", "Boston"), ("employer", "Globex")} + + +def test_render_includes_notation_and_text() -> None: + rendered = render_latent_prompt(template=_TEMPLATE, tagged_text="TAGGED", notation=TagNotation.xml) + assert "TAGGED" in rendered + assert "xml" in rendered + + +class TestLatentRowFastPath: + def test_single_call_when_under_budget(self) -> None: + facade = FakeFacade({"latent_entities": [_latent("employer", "Acme")]}) + row = { + COL_TEXT: "She works there", + COL_TAGGED_TEXT: "She works there", + COL_TAG_NOTATION: "xml", + } + out = latent_row(row, _params(max_render_chars=1_000_000), {"lat": facade}) + assert len(facade.calls) == 1 + result = LatentEntitiesSchema.model_validate(out[COL_LATENT_ENTITIES]) + assert [(e.label, e.value) for e in result.latent_entities] == [("employer", "Acme")] + + +class TestLatentRowWindowed: + def test_multiple_windows_unioned(self) -> None: + text = ("A" * 5000) + ("B" * 5000) + counter = itertools.count() + facade = FakeFacade(lambda _p: {"latent_entities": [_latent("employer", f"Org{next(counter)}")]}) + row = { + COL_TEXT: text, + COL_TAGGED_TEXT: text, # full render exceeds budget -> windowed + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + out = latent_row( + row, _params(max_render_chars=4000, safety_margin_chars=0, overlap_chars=1000), {"lat": facade} + ) + assert len(facade.calls) > 1 + result = LatentEntitiesSchema.model_validate(out[COL_LATENT_ENTITIES]) + assert len(result.latent_entities) == len(facade.calls) + + def test_missing_alias_raises(self) -> None: + with pytest.raises(KeyError, match="not present in models"): + latent_row({COL_TEXT: "x"}, _params(max_render_chars=10), {}) diff --git a/tests/engine/test_chunked_replace.py b/tests/engine/test_chunked_replace.py new file mode 100644 index 00000000..d1c8f36f --- /dev/null +++ b/tests/engine/test_chunked_replace.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for chunked (long-context) Substitute replacement-map generation.""" + +from __future__ import annotations + +import json +import re +from typing import Any + +import pytest + +from anonymizer.engine.constants import ( + COL_ENTITIES_FOR_REPLACE, + COL_ENTITY_EXAMPLES, + COL_FINAL_ENTITIES, + COL_REPLACEMENT_MAP, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.postprocess import EntitySpan, TagNotation +from anonymizer.engine.replace.chunked_replace import ( + WindowedReplaceParams, + chunk_tagged_text, + generate_replacement_map_row, + merge_replacements, + new_chunk_entities, +) +from anonymizer.engine.schemas import EntitiesSchema, EntityReplacementMapSchema + +# Single-call prompt stand-in (the real one is large); references the columns the +# fast path renders with the row. +_SINGLE_PROMPT = "MAP {{ tagged_text }} || {% for e in _entities_for_replace %}{{ e.value }};{% endfor %} || {{ _entity_examples }}" + + +def _span(value: str, label: str, start: int, end: int) -> EntitySpan: + return EntitySpan(entity_id=f"{label}_{start}", value=value, label=label, + start_position=start, end_position=end, score=1.0, source="d") + + +class TestNewChunkEntities: + def test_selects_in_window_excludes_mapped(self) -> None: + spans = [_span("Alice", "name", 0, 5), _span("Bob", "name", 50, 53), _span("Alice", "name", 60, 65)] + # window [0,55): Alice(0) + Bob(50); Alice already mapped -> excluded + got = new_chunk_entities(spans, 0, 55, already_mapped={("Alice", "name")}) + assert [(e["value"], e["labels_str"]) for e in got] == [("Bob", "name")] + + def test_groups_labels_per_value(self) -> None: + spans = [_span("Wash", "city", 0, 4), _span("Wash", "last_name", 10, 14)] + got = new_chunk_entities(spans, 0, 20, already_mapped=set()) + assert got == [{"value": "Wash", "labels": ["city", "last_name"], "labels_str": "city, last_name"}] + + +class TestMergeReplacements: + def test_dedupes_by_original_label_earlier_wins(self) -> None: + existing = [{"original": "Alice", "label": "name", "synthetic": "Jane"}] + new = EntityReplacementMapSchema.model_validate( + {"replacements": [ + {"original": "Alice", "label": "name", "synthetic": "DIFFERENT"}, # dup -> ignored + {"original": "Bob", "label": "name", "synthetic": "Mike"}, + ]} + ) + merged = merge_replacements(existing, new) + assert merged == [ + {"original": "Alice", "label": "name", "synthetic": "Jane"}, + {"original": "Bob", "label": "name", "synthetic": "Mike"}, + ] + + +def test_chunk_tagged_text_rebases_and_tags_in_window() -> None: + text = "Alice met Bob in Paris" # Alice 0-5, Bob 10-13 + spans = [_span("Alice", "first_name", 0, 5), _span("Bob", "first_name", 10, 13), _span("Paris", "city", 17, 22)] + tagged = chunk_tagged_text(text, spans, 0, 14, TagNotation.xml) + assert "Alice" in tagged and "Bob" in tagged + assert "Paris" not in tagged + + +class _Fake: + def __init__(self) -> None: + self.map_calls = 0 + self.summary_calls = 0 + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + if "summary" in (purpose or ""): + self.summary_calls += 1 + return parser("rolling summary"), [] + self.map_calls += 1 + pairs = re.findall(r'- "([^"]+)" \(([^)]+)\)', prompt) + reps = {"replacements": [{"original": v, "label": lab, "synthetic": v + "_S"} for v, lab in pairs]} + return parser("```json\n" + json.dumps(reps) + "\n```"), [] + + +def _line(s: str) -> str: + return s + "x" * (100 - len(s) - 1) + "\n" + + +class TestGenerateReplacementMapRow: + def test_fast_path_single_call(self) -> None: + facade = _Fake() + row = { + COL_TEXT: "Alice here", COL_TAGGED_TEXT: "Alice here", COL_TAG_NOTATION: "xml", + COL_ENTITY_EXAMPLES: "{}", + COL_ENTITIES_FOR_REPLACE: [{"value": "Alice", "labels": ["name"], "labels_str": "name"}], + COL_FINAL_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + # NB: fast path renders the single prompt; our stand-in lists entities so the fake maps them. + params = WindowedReplaceParams(alias="r", single_call_prompt_template='MAP - "Alice" (name)', max_render_chars=1_000_000) + out = generate_replacement_map_row(row, params, {"r": facade}) + assert facade.map_calls == 1 and facade.summary_calls == 0 + m = EntityReplacementMapSchema.model_validate(out[COL_REPLACEMENT_MAP]) + assert [(r.original, r.synthetic) for r in m.replacements] == [("Alice", "Alice_S")] + + def test_chunked_with_rolling_summary_and_dedupe(self) -> None: + lines = [_line("") for _ in range(120)] + lines[0] = _line("Alice") + lines[50] = _line("Alice") # recurs in window 2 -> must NOT be re-mapped + lines[60] = _line("Bob") + lines[119] = _line("Carol") + text = "".join(lines) + spans = [] + for val in ["Alice", "Bob", "Carol"]: + for mobj in re.finditer(re.escape(val), text): + spans.append({"id": f"name_{mobj.start()}", "value": val, "label": "name", + "start_position": mobj.start(), "end_position": mobj.end(), "score": 1.0, "source": "d"}) + final = EntitiesSchema.model_validate({"entities": spans}).model_dump(mode="json") + + facade = _Fake() + row = { + COL_TEXT: text, COL_TAGGED_TEXT: text, COL_TAG_NOTATION: "xml", COL_ENTITY_EXAMPLES: "{}", + COL_ENTITIES_FOR_REPLACE: [{"value": v, "labels": ["name"], "labels_str": "name"} for v in ["Alice", "Bob", "Carol"]], + COL_FINAL_ENTITIES: final, + } + params = WindowedReplaceParams(alias="r", single_call_prompt_template=_SINGLE_PROMPT, max_render_chars=4000, safety_margin_chars=0) + out = generate_replacement_map_row(row, params, {"r": facade}) + result = EntityReplacementMapSchema.model_validate(out[COL_REPLACEMENT_MAP]) + assert facade.map_calls == 3 # one per window with new entities + assert facade.summary_calls == 2 # after windows 1 and 2, not the last + assert sorted(r.original for r in result.replacements) == ["Alice", "Bob", "Carol"] + assert sum(r.original == "Alice" for r in result.replacements) == 1 # deduped across chunks + + def test_missing_alias_raises(self) -> None: + with pytest.raises(KeyError, match="not present in models"): + generate_replacement_map_row({COL_TEXT: "x"}, WindowedReplaceParams(alias="r", single_call_prompt_template="x", max_render_chars=10), {}) diff --git a/tests/engine/test_chunked_rewrite.py b/tests/engine/test_chunked_rewrite.py new file mode 100644 index 00000000..cee4a74c --- /dev/null +++ b/tests/engine/test_chunked_rewrite.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for chunked (long-context) rewrite generation.""" + +from __future__ import annotations + +import pytest + +from anonymizer.engine.constants import ( + COL_FULL_REWRITE, + COL_REPLACEMENT_MAP_FOR_PROMPT, + COL_REWRITE_DISPOSITION_BLOCK, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.rewrite.chunked_rewrite import ( + WindowedRewriteParams, + _filter_disposition_to_chunk, + generate_rewrite_row, +) +from anonymizer.engine.schemas import RewriteOutputSchema + +_TEMPLATE = "REWRITE[" + COL_TAG_NOTATION + "] {{ " + COL_TAGGED_TEXT + " }}" + + +class _Fake: + def __init__(self) -> None: + self.rewrite_calls = 0 + self.summary_calls = 0 + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + if "summary" in (purpose or ""): + self.summary_calls += 1 + return parser("running summary"), [] + self.rewrite_calls += 1 + return parser('```json\n{"rewritten_text":"OUT%d"}\n```' % self.rewrite_calls), [] + + +def _row(tagged: str) -> dict: + return { + COL_TEXT: tagged, + COL_TAGGED_TEXT: tagged, + COL_TAG_NOTATION: "xml", + COL_REWRITE_DISPOSITION_BLOCK: [], + COL_REPLACEMENT_MAP_FOR_PROMPT: {"replacements": []}, + } + + +def test_filter_disposition_to_chunk() -> None: + block = [{"entity_value": "Alice"}, {"entity_value": "Bob"}, {"entity_value": "Carol"}] + assert _filter_disposition_to_chunk(block, "... Alice and Carol ...") == [ + {"entity_value": "Alice"}, + {"entity_value": "Carol"}, + ] + + +class TestGenerateRewriteRow: + def test_fast_path_single_call(self) -> None: + facade = _Fake() + params = WindowedRewriteParams(alias="w", single_call_prompt_template=_TEMPLATE, max_render_chars=1_000_000) + out = generate_rewrite_row(_row("short tagged text"), params, {"w": facade}) + assert facade.rewrite_calls == 1 and facade.summary_calls == 0 + assert RewriteOutputSchema.model_validate(out[COL_FULL_REWRITE]).rewritten_text == "OUT1" + + def test_chunked_stitches_with_rolling_summary(self) -> None: + tagged = ("X" * 4000 + "\n") * 3 # ~12k chars -> several windows + facade = _Fake() + params = WindowedRewriteParams( + alias="w", single_call_prompt_template=_TEMPLATE, max_render_chars=4000, safety_margin_chars=0 + ) + out = generate_rewrite_row(_row(tagged), params, {"w": facade}) + text = RewriteOutputSchema.model_validate(out[COL_FULL_REWRITE]).rewritten_text + assert facade.rewrite_calls > 1 + assert facade.summary_calls == facade.rewrite_calls - 1 # summary after each chunk except the last + # stitched in order + assert text.split("\n") == [f"OUT{i}" for i in range(1, facade.rewrite_calls + 1)] + + def test_missing_alias_raises(self) -> None: + with pytest.raises(KeyError, match="not present in models"): + generate_rewrite_row(_row("x"), WindowedRewriteParams(alias="w", single_call_prompt_template="x", max_render_chars=10), {}) diff --git a/tests/engine/test_chunked_steps.py b/tests/engine/test_chunked_steps.py new file mode 100644 index 00000000..2531df62 --- /dev/null +++ b/tests/engine/test_chunked_steps.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from typing import Any + +import pytest + +from anonymizer.engine.constants import ( + COL_DOMAIN, + COL_MEANING_UNITS_SERIALIZED, + COL_QUALITY_QA, + COL_SENSITIVITY_DISPOSITION, + COL_TEXT, +) +from anonymizer.engine.rewrite.chunked_steps import ( + WindowedStepParams, + _compile_template, + run_windowed_step, +) +from anonymizer.engine.rewrite.domain_classification import _first_output, _get_domain_classification_prompt +from anonymizer.engine.rewrite.qa_generation import ( + _batch_units_by_size, + _concat_meaning_units, + _get_quality_qa_prompt, + generate_quality_qa_row, +) +from anonymizer.engine.rewrite.sensitivity_disposition import _make_disposition_merge +from anonymizer.engine.schemas import ( + DomainClassificationSchema, + MeaningUnitsSchema, + QualityQAPairsSchema, + SensitivityDispositionSchema, +) + + +class _FakeFacade: + """Facade stub that parses a fixed JSON response on every ``generate`` call.""" + + def __init__(self, response_obj: dict[str, Any]) -> None: + self._payload = "```json\n" + json.dumps(response_obj) + "\n```" + self.calls = 0 + + def generate(self, *, prompt: Any, parser: Any, system_prompt: Any = None, purpose: Any = None, **_: Any) -> Any: + self.calls += 1 + return parser(self._payload), [] + + +def _disposition_entry(value: str, risk: str, method: str) -> dict[str, Any]: + return { + "id": 1, + "source": "tagged", + "category": "direct_identifier", + "sensitivity": "high", + "entity_label": "first_name", + "entity_value": value, + "protection_reason": "name is identifying in this document", + "protection_method_suggestion": method, + "combined_risk_level": risk, + } + + +# --------------------------------------------------------------------------- +# run_windowed_step +# --------------------------------------------------------------------------- + + +def test_fast_path_single_call_when_under_cap() -> None: + facade = _FakeFacade({"domain": "OTHER", "domain_confidence": 0.9}) + row = run_windowed_step( + {COL_TEXT: "short"}, + WindowedStepParams( + alias="d", + prompt_template=_get_domain_classification_prompt(None), + output_column=COL_DOMAIN, + text_column=COL_TEXT, + max_render_chars=1_000_000, + first_only=True, + ), + {"d": facade}, + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain", + ) + assert facade.calls == 1 + assert DomainClassificationSchema.model_validate(row[COL_DOMAIN]).domain.value == "OTHER" + + +def test_first_only_classifies_just_one_window() -> None: + facade = _FakeFacade({"domain": "OTHER", "domain_confidence": 0.9}) + long_text = ("x" * 4000 + "\n") * 4 + run_windowed_step( + {COL_TEXT: long_text}, + WindowedStepParams( + alias="d", + prompt_template=_get_domain_classification_prompt(None), + output_column=COL_DOMAIN, + text_column=COL_TEXT, + max_render_chars=4000, + safety_margin_chars=0, + first_only=True, + ), + {"d": facade}, + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain", + ) + assert facade.calls == 1 + + +def test_missing_alias_raises() -> None: + with pytest.raises(KeyError): + run_windowed_step( + {COL_TEXT: "x"}, + WindowedStepParams( + alias="missing", + prompt_template="{{ _text }}", + output_column=COL_DOMAIN, + text_column=COL_TEXT, + max_render_chars=1000, + ), + {}, + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain", + ) + + +# --------------------------------------------------------------------------- +# Merges +# --------------------------------------------------------------------------- + + +def test_disposition_merge_keeps_highest_risk_and_reids() -> None: + low = SensitivityDispositionSchema.model_validate( + {"sensitivity_disposition": [_disposition_entry("Alice", "low", "leave_as_is")]} + ) + high = SensitivityDispositionSchema.model_validate( + {"sensitivity_disposition": [_disposition_entry("Alice", "high", "replace")]} + ) + merged = _make_disposition_merge(SensitivityDispositionSchema)([low, high]) + entries = SensitivityDispositionSchema.model_validate(merged).sensitivity_disposition + assert len(entries) == 1 # deduped by (source, label, value) + assert entries[0].combined_risk_level == "high" + assert entries[0].id == 1 + + +def test_meaning_units_concat_reids_sequentially() -> None: + out_a = MeaningUnitsSchema.model_validate( + {"units": [{"id": 1, "aspect": "role", "unit": "works in tech", "importance": "critical"}]} + ) + out_b = MeaningUnitsSchema.model_validate( + {"units": [{"id": 1, "aspect": "process", "unit": "follows a workflow", "importance": "important"}]} + ) + merged = _concat_meaning_units([out_a, out_b]) + units = MeaningUnitsSchema.model_validate(merged).units + assert [u.id for u in units] == [1, 2] + + +# --------------------------------------------------------------------------- +# Quality-QA batching +# --------------------------------------------------------------------------- + + +def test_batch_units_by_size_splits_when_over_allowance() -> None: + units = [{"id": i, "aspect": "role", "unit": f"unit number {i}", "importance": "important"} for i in range(10)] + unit_len = len(json.dumps(units[0], ensure_ascii=False)) + 1 + base = 1000 + budget = base + 3 * unit_len # ~3 units per batch + batches = _batch_units_by_size(units, base, budget) + assert len(batches) > 1 + assert sum(len(b) for b in batches) == len(units) + assert all(len(b) <= 3 for b in batches) + + +def test_quality_qa_batches_and_reids() -> None: + prompt = _get_quality_qa_prompt() + units = [ + {"id": i + 1, "aspect": "role", "unit": f"unit text number {i} about work", "importance": "important"} + for i in range(12) + ] + row = {COL_MEANING_UNITS_SERIALIZED: json.dumps(units, ensure_ascii=False)} + base = len(_compile_template(prompt).render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) + unit_len = len(json.dumps(units[0], ensure_ascii=False)) + 1 + cap = base + 3 * unit_len + facade = _FakeFacade( + { + "items": [ + {"id": 1, "aspect": "role", "importance": "critical", "question": "q1?", "reference_answer": "a1"}, + {"id": 2, "aspect": "role", "importance": "important", "question": "q2?", "reference_answer": "a2"}, + ] + } + ) + out = generate_quality_qa_row( + dict(row), + {"q": facade}, + alias="q", + prompt_template=prompt, + max_render_chars=cap, + safety_margin_chars=0, + ) + items = QualityQAPairsSchema.model_validate(out[COL_QUALITY_QA]).items + assert facade.calls > 1 # batched + assert [it.id for it in items] == list(range(1, len(items) + 1)) + assert len(items) == 2 * facade.calls + assert COL_SENSITIVITY_DISPOSITION not in out # sanity: only quality QA written diff --git a/tests/engine/test_detection_workflow.py b/tests/engine/test_detection_workflow.py index 3f45f1b1..6c247ceb 100644 --- a/tests/engine/test_detection_workflow.py +++ b/tests/engine/test_detection_workflow.py @@ -124,7 +124,10 @@ def test_run_with_latent_detection_calls_second_workflow( assert adapter.run_workflow.call_count == 2 second_columns = adapter.run_workflow.call_args_list[1].kwargs["columns"] assert len(second_columns) == 1 - assert isinstance(second_columns[0], LLMStructuredColumnConfig) + # Latent detection is windowed (chunks long docs) so it is a custom generator + # rather than an LLMStructuredColumnConfig. + assert isinstance(second_columns[0], CustomColumnConfig) + assert not isinstance(second_columns[0], LLMStructuredColumnConfig) assert second_columns[0].name == COL_LATENT_ENTITIES assert COL_LATENT_ENTITIES in result.dataframe.columns assert COL_FINAL_ENTITIES in result.dataframe.columns diff --git a/tests/engine/test_domain_classification.py b/tests/engine/test_domain_classification.py index dc6b4200..b814fb60 100644 --- a/tests/engine/test_domain_classification.py +++ b/tests/engine/test_domain_classification.py @@ -3,7 +3,7 @@ from __future__ import annotations -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.engine.constants import ( @@ -31,13 +31,16 @@ def test_columns_returns_exactly_three_in_order( ) -> None: cols = DomainClassificationWorkflow().columns(selected_models=stub_rewrite_model_selection) assert len(cols) == 3 - assert isinstance(cols[0], LLMStructuredColumnConfig) + # COL_DOMAIN is now a windowed custom generator (first-chunk classification) + # instead of an LLMStructuredColumnConfig, so it can bypass the render cap. + assert isinstance(cols[0], CustomColumnConfig) assert isinstance(cols[1], CustomColumnConfig) assert isinstance(cols[2], CustomColumnConfig) assert cols[0].name == COL_DOMAIN assert cols[1].name == COL_DOMAIN_SUPPLEMENT assert cols[2].name == COL_DOMAIN_SUPPLEMENT_PRIVACY - assert cols[0].model_alias == stub_rewrite_model_selection.domain_classifier + assert cols[0].generator_params.alias == stub_rewrite_model_selection.domain_classifier + assert cols[0].generator_params.first_only is True def test_enrich_domain_populates_supplement_for_known_domain() -> None: diff --git a/tests/engine/test_qa_generation.py b/tests/engine/test_qa_generation.py index 74831f9d..0a5fae19 100644 --- a/tests/engine/test_qa_generation.py +++ b/tests/engine/test_qa_generation.py @@ -5,7 +5,7 @@ import json -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.engine.constants import ( @@ -94,10 +94,12 @@ def test_columns_returns_exactly_five_in_order( ) -> None: cols = QAGenerationWorkflow().columns(selected_models=stub_rewrite_model_selection) assert len(cols) == 5 + # Meaning-unit extraction (cols[1]) and quality-QA (cols[3]) are now windowed + # custom generators (chunk text / batch units) so they bypass the render cap. assert isinstance(cols[0], CustomColumnConfig) - assert isinstance(cols[1], LLMStructuredColumnConfig) + assert isinstance(cols[1], CustomColumnConfig) assert isinstance(cols[2], CustomColumnConfig) - assert isinstance(cols[3], LLMStructuredColumnConfig) + assert isinstance(cols[3], CustomColumnConfig) assert isinstance(cols[4], CustomColumnConfig) assert cols[0].name == COL_SENSITIVITY_DISPOSITION_BLOCK assert cols[1].name == COL_MEANING_UNITS @@ -110,14 +112,15 @@ def test_meaning_extractor_alias_used( stub_rewrite_model_selection: RewriteModelSelection, ) -> None: cols = QAGenerationWorkflow().columns(selected_models=stub_rewrite_model_selection) - assert cols[1].model_alias == stub_rewrite_model_selection.meaning_extractor + assert cols[1].generator_params.alias == stub_rewrite_model_selection.meaning_extractor def test_qa_generator_alias_used( stub_rewrite_model_selection: RewriteModelSelection, ) -> None: cols = QAGenerationWorkflow().columns(selected_models=stub_rewrite_model_selection) - assert cols[3].model_alias == stub_rewrite_model_selection.qa_generator + metadata = cols[3].generator_function.custom_column_metadata + assert metadata["model_aliases"] == [stub_rewrite_model_selection.qa_generator] def test_format_disposition_block_produces_valid_json() -> None: diff --git a/tests/engine/test_rewrite_generation.py b/tests/engine/test_rewrite_generation.py index be32ad44..9a68f657 100644 --- a/tests/engine/test_rewrite_generation.py +++ b/tests/engine/test_rewrite_generation.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal @@ -291,25 +291,19 @@ def test_columns_returns_four_configs( assert len(cols) == 4 -def test_columns_has_llm_config_with_rewriter_alias( - stub_rewrite_model_selection: RewriteModelSelection, - privacy_goal: PrivacyGoal, -) -> None: - workflow = RewriteGenerationWorkflow() - cols = workflow.columns(selected_models=stub_rewrite_model_selection, privacy_goal=privacy_goal) - llm_cols = [c for c in cols if isinstance(c, LLMStructuredColumnConfig)] - assert len(llm_cols) == 1 - assert llm_cols[0].name == COL_FULL_REWRITE - - -def test_columns_full_rewrite_uses_rewrite_output_schema( +def test_columns_full_rewrite_is_windowed_generator_with_rewriter_alias( stub_rewrite_model_selection: RewriteModelSelection, privacy_goal: PrivacyGoal, ) -> None: + # COL_FULL_REWRITE is now a windowed custom generator (chunks long docs with + # rolling-summary continuity) instead of an LLMStructuredColumnConfig. workflow = RewriteGenerationWorkflow() cols = workflow.columns(selected_models=stub_rewrite_model_selection, privacy_goal=privacy_goal) full_rewrite_col = next(c for c in cols if c.name == COL_FULL_REWRITE) - assert full_rewrite_col.output_format == RewriteOutputSchema.model_json_schema() + assert isinstance(full_rewrite_col, CustomColumnConfig) + assert full_rewrite_col.generator_params.alias == stub_rewrite_model_selection.rewriter + metadata = full_rewrite_col.generator_function.custom_column_metadata + assert metadata["model_aliases"] == [stub_rewrite_model_selection.rewriter] def test_columns_includes_custom_configs_for_disposition_and_text_extraction( diff --git a/tests/engine/test_sensitivity_disposition.py b/tests/engine/test_sensitivity_disposition.py index 8b8a77ff..6685180e 100644 --- a/tests/engine/test_sensitivity_disposition.py +++ b/tests/engine/test_sensitivity_disposition.py @@ -3,7 +3,7 @@ from __future__ import annotations -from data_designer.config.column_configs import LLMStructuredColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import PrivacyGoal @@ -34,8 +34,10 @@ def test_columns_uses_disposition_analyzer_alias( privacy_goal=_STUB_PRIVACY_GOAL, ) assert len(cols) == 1 - assert isinstance(cols[0], LLMStructuredColumnConfig) - assert cols[0].model_alias == stub_rewrite_model_selection.disposition_analyzer + # Disposition is now a windowed custom generator (chunks the tagged text, + # unions per-entity decisions) so it can bypass the render cap. + assert isinstance(cols[0], CustomColumnConfig) + assert cols[0].generator_params.alias == stub_rewrite_model_selection.disposition_analyzer assert cols[0].name == COL_SENSITIVITY_DISPOSITION diff --git a/tests/engine/test_windowing.py b/tests/engine/test_windowing.py new file mode 100644 index 00000000..42507693 --- /dev/null +++ b/tests/engine/test_windowing.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for boundary-aware windowing.""" + +from __future__ import annotations + +import pytest + +from anonymizer.engine.windowing import iter_boundary_windows, next_window_end + + +class TestNextWindowEnd: + def test_backs_off_to_last_delimiter(self) -> None: + # "aaaa\nbbbb\ncccc" — cap at 12 lands inside "cccc"; back off to after second "\n" (offset 10). + text = "aaaa\nbbbb\ncccc" + assert next_window_end(text, 0, 12) == 10 # after the "\n" at index 9 + assert text[:10] == "aaaa\nbbbb\n" + + def test_hard_cut_when_no_delimiter(self) -> None: + text = "abcdefghij" # no newline + assert next_window_end(text, 0, 4) == 4 + + def test_returns_end_when_remainder_fits(self) -> None: + text = "short\ntext" + assert next_window_end(text, 0, 1000) == len(text) + + def test_custom_delimiter(self) -> None: + text = "a;b;c;d;e" + assert next_window_end(text, 0, 4, delimiter=";") == 4 # "a;b;" (last ';' within [0,4)) + + +class TestIterBoundaryWindows: + def test_tiles_on_newlines(self) -> None: + text = "line1\nline2\nline3\nline4" + windows = iter_boundary_windows(text, 12) + # backs off to the LAST newline within the cap, e.g. first window "line1\nline2\n" + assert windows[0] == (0, 12) + assert all((e - s) <= 12 for s, e in windows) + assert all(text[s:e].endswith("\n") for s, e in windows[:-1]) # non-last windows end on a newline + assert windows[-1][1] == len(text) + assert "".join(text[s:e] for s, e in windows) == text # exact reconstruction + + def test_empty(self) -> None: + assert iter_boundary_windows("", 10) == [] + + def test_no_delimiter_hard_cuts(self) -> None: + text = "x" * 25 + windows = iter_boundary_windows(text, 10) + assert windows == [(0, 10), (10, 20), (20, 25)] From 968ec3722e178904259f0bf28583a2c3a819a299 Mon Sep 17 00:00:00 2001 From: eurekayuan Date: Wed, 3 Jun 2026 11:00:19 -0700 Subject: [PATCH 2/4] chunking final judge --- src/anonymizer/config/anonymizer_config.py | 2 +- src/anonymizer/engine/constants.py | 6 + .../engine/detection/chunked_augmentation.py | 163 ++++++++---- .../engine/detection/chunked_detection.py | 76 ++++-- .../engine/detection/chunked_latent.py | 151 +++++++---- .../engine/rewrite/chunked_final_judge.py | 234 ++++++++++++++++++ src/anonymizer/engine/rewrite/final_judge.py | 51 +++- tests/engine/test_chunked_augmentation.py | 30 +++ tests/engine/test_chunked_final_judge.py | 159 ++++++++++++ tests/engine/test_chunked_latent.py | 28 +++ tests/engine/test_final_judge.py | 42 ++-- 11 files changed, 789 insertions(+), 153 deletions(-) create mode 100644 src/anonymizer/engine/rewrite/chunked_final_judge.py create mode 100644 tests/engine/test_chunked_final_judge.py diff --git a/src/anonymizer/config/anonymizer_config.py b/src/anonymizer/config/anonymizer_config.py index 1ab4b0bf..9fe3f09d 100644 --- a/src/anonymizer/config/anonymizer_config.py +++ b/src/anonymizer/config/anonymizer_config.py @@ -29,7 +29,7 @@ # NDD's hard render cap so each window stays small enough to map/rewrite within a # single LLM request — large windows on entity-dense documents otherwise time out. # Clamped so it never exceeds NDD's cap if that is ever lowered. -_DEFAULT_WINDOW_MAX_RENDER_CHARS = min(128 * 1024, _NDD_MAX_RENDERED_LEN) +_DEFAULT_WINDOW_MAX_RENDER_CHARS = min(128_000, _NDD_MAX_RENDERED_LEN) def is_remote_input_source(value: str) -> bool: diff --git a/src/anonymizer/engine/constants.py b/src/anonymizer/engine/constants.py index fdfecadf..0e4e8da6 100644 --- a/src/anonymizer/engine/constants.py +++ b/src/anonymizer/engine/constants.py @@ -24,6 +24,10 @@ # Step 3: LLM augmentation COL_AUGMENTED_ENTITIES = "_augmented_entities" +# Count of augmentation windows skipped due to a per-window failure (graceful +# degradation on long docs); 0 on the single-call fast path. Surfaced in +# trace_dataframe so callers can flag partially-degraded records. +COL_AUGMENTATION_FAILED_WINDOWS = "_augmentation_failed_windows" # Step 3b: prepare_validation_inputs (seed-only, pre-augmentation) COL_SEED_TAGGED_TEXT = "_seed_tagged_text" @@ -55,6 +59,8 @@ # Latent detection (optional second workflow) COL_LATENT_ENTITIES = "_latent_entities" +# Count of latent windows skipped due to a per-window failure; 0 on the fast path. +COL_LATENT_FAILED_WINDOWS = "_latent_failed_windows" # Final output COL_FINAL_ENTITIES = "final_entities" diff --git a/src/anonymizer/engine/detection/chunked_augmentation.py b/src/anonymizer/engine/detection/chunked_augmentation.py index 43bf61d5..3fe41b59 100644 --- a/src/anonymizer/engine/detection/chunked_augmentation.py +++ b/src/anonymizer/engine/detection/chunked_augmentation.py @@ -25,6 +25,7 @@ import functools import json import logging +from concurrent.futures import ThreadPoolExecutor from typing import Any from data_designer.config import custom_column_generator @@ -33,6 +34,7 @@ from pydantic import BaseModel, Field from anonymizer.engine.constants import ( + COL_AUGMENTATION_FAILED_WINDOWS, COL_AUGMENTED_ENTITIES, COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, @@ -49,6 +51,11 @@ # progress instead of shrinking toward zero. _MIN_WINDOW_CHARS = 4000 +# Upper bound on augmentation windows dispatched concurrently for one record. The +# per-alias rate limit (``max_parallel_requests`` on the facade) still caps the +# real in-flight count; this just bounds thread creation on very long inputs. +_MAX_PARALLEL_WINDOWS = 8 + # Jinja2 environment used to render the per-window augmentation prompt. Mirrors # chunked_validation: same template, same placeholders, per-window values. _PROMPT_ENV = Environment( @@ -221,6 +228,75 @@ def _call_augmenter(*, facade: Any, prompt: str, system_prompt: str | None, purp return output +def plan_augmentation_windows( + *, + text: str, + all_spans: list[EntitySpan], + notation: TagNotation, + params: WindowedAugmentationParams, + cap: int, + initial_window: int, +) -> list[tuple[int, int, str]]: + """Walk the document applying the shrink rule, returning ``(start, end, rendered_prompt)``. + + No LLM calls — only Jinja renders + length checks — so this is cheap and runs + serially to fix the (data-dependent) window boundaries before the parallel LLM + pass. The window shrinks when a render exceeds ``cap`` and the shrunk size + carries forward, exactly as the previous in-line loop did. + """ + text_len = len(text) + windows: list[tuple[int, int, str]] = [] + window = initial_window + pos = 0 + while pos < text_len: + end = min(text_len, pos + window) + tagged, seed_json = build_window_inputs(text=text, all_spans=all_spans, start=pos, end=end, notation=notation) + rendered = render_augment_prompt( + template=params.prompt_template, tagged_text=tagged, seed_entities_json=seed_json, notation=notation + ) + if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: + # Shrink proportionally to the measured overage so the next try lands just + # under the cap (0.95 = small safety margin). Strictly decreasing -> converges. + shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) + logger.debug( + "augmentation @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", + pos, len(rendered), cap, window, shrunk, + ) + window = shrunk + continue + logger.debug( + "augmentation window @[%d, %d) size=%d, rendered=%d/%d chars", pos, end, end - pos, len(rendered), cap + ) + windows.append((pos, end, rendered)) + if end >= text_len: + break + pos = max(pos + 1, end - params.overlap_chars) + return windows + + +def _augment_window( + *, facade: Any, rendered: str, system_prompt: str | None, start: int +) -> AugmentedEntitiesSchema | None: + """Run one augmentation window; return its result, or ``None`` if the call fails. + + A single window's failure (e.g. an unparseable model response, or a timeout) + must not drop the whole record: augmentation only *adds* entities GLiNER missed, + so a skipped window degrades gracefully. The error is logged at WARNING with its + type and message so it stays fully visible. Safe to run per window in a pool. + """ + try: + result = _call_augmenter( + facade=facade, prompt=rendered, system_prompt=system_prompt, purpose=f"entity-augmentation-window-{start}" + ) + except Exception as exc: # noqa: BLE001 — one window must not sink the record + logger.warning( + "augmentation window @%d failed (%s: %s); skipping this window's entities", start, type(exc).__name__, exc + ) + return None + logger.debug("augmentation window @%d: augmenter proposed %d entities", start, len(result.entities)) + return result + + def augment_row( row: dict[str, Any], params: WindowedAugmentationParams, @@ -266,71 +342,49 @@ def augment_row( purpose="entity-augmentation", ) row[COL_AUGMENTED_ENTITIES] = output.model_dump(mode="json") + row[COL_AUGMENTATION_FAILED_WINDOWS] = 0 return row - # Windowed path: tile the document, shrinking a window only if its render - # exceeds the cap (e.g. an entity-dense slice with many tags). + # Windowed path. First plan the windows serially (cheap renders + the + # data-dependent shrink), then run the LLM calls in parallel. A single + # window's failure is logged and skipped rather than dropping the record. all_spans = parse_validated_seed_spans(row.get(COL_VALIDATED_SEED_ENTITIES, {})) - results: list[AugmentedEntitiesSchema] = [] - window = initial_window - pos = 0 text_len = len(text) + windows = plan_augmentation_windows( + text=text, all_spans=all_spans, notation=notation, params=params, cap=cap, initial_window=initial_window + ) + max_workers = min(len(windows), _MAX_PARALLEL_WINDOWS) logger.info( - "augmentation: rendered prompt %d chars > cap %d; tiling %d-char document into " - "overlapping windows (initial_window=%d, overlap=%d, min_window=%d)", - len(full_rendered), cap, text_len, initial_window, params.overlap_chars, _MIN_WINDOW_CHARS, + "augmentation: rendered prompt %d chars > cap %d; tiling %d-char document into %d overlapping " + "window(s) (initial_window=%d, overlap=%d, min_window=%d, max_workers=%d)", + len(full_rendered), cap, text_len, len(windows), initial_window, params.overlap_chars, + _MIN_WINDOW_CHARS, max_workers, ) - window_index = 0 - while pos < text_len: - end = min(text_len, pos + window) - tagged, seed_json = build_window_inputs( - text=text, all_spans=all_spans, start=pos, end=end, notation=notation - ) - rendered = render_augment_prompt( - template=params.prompt_template, - tagged_text=tagged, - seed_entities_json=seed_json, - notation=notation, - ) - if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: - # Shrink proportionally to the measured overage so the next try lands - # just under the cap (0.95 = small safety margin), instead of halving - # and overshooting. Strictly decreasing -> converges, then the floor stops it. - shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) - logger.debug( - "augmentation window %d @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", - window_index, pos, len(rendered), cap, window, shrunk, - ) - window = shrunk - continue - logger.debug( - "augmentation window %d: chars [%d, %d) size=%d, rendered=%d/%d chars", - window_index, pos, end, end - pos, len(rendered), cap, - ) - result = _call_augmenter( - facade=facade, - prompt=rendered, - system_prompt=params.system_prompt, - purpose=f"entity-augmentation-window-{pos}", - ) - logger.debug("augmentation window %d: augmenter proposed %d entities", window_index, len(result.entities)) - results.append(result) - if end >= text_len: - break - next_pos = max(pos + 1, end - params.overlap_chars) - logger.debug( - "augmentation window %d: advancing pos %d -> %d (overlap back %d chars)", - window_index, pos, next_pos, end - next_pos, - ) - pos = next_pos - window_index += 1 + results: list[AugmentedEntitiesSchema] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _augment_window, facade=facade, rendered=rendered, system_prompt=params.system_prompt, start=start + ) + for start, _end, rendered in windows + ] + for future in futures: + result = future.result() # _augment_window swallows errors -> never raises + if result is not None: + results.append(result) + + failed = len(windows) - len(results) + if failed: + logger.warning("augmentation: %d of %d window(s) failed and were skipped", failed, len(windows)) merged = merge_augmented(results) logger.info( - "augmentation: %d window(s) over %d chars -> %d unique entities after dedupe (cap=%d, overlap=%d)", - len(results), text_len, len(merged.entities), cap, params.overlap_chars, + "augmentation: %d window(s) over %d chars -> %d unique entities after dedupe " + "(cap=%d, overlap=%d, %d failed)", + len(windows), text_len, len(merged.entities), cap, params.overlap_chars, failed, ) row[COL_AUGMENTED_ENTITIES] = merged.model_dump(mode="json") + row[COL_AUGMENTATION_FAILED_WINDOWS] = failed return row @@ -359,6 +413,7 @@ def make_windowed_augmentation_generator(alias: str) -> Any: COL_VALIDATED_SEED_ENTITIES, COL_TAG_NOTATION, ], + side_effect_columns=[COL_AUGMENTATION_FAILED_WINDOWS], model_aliases=[alias], ) def windowed_augment( diff --git a/src/anonymizer/engine/detection/chunked_detection.py b/src/anonymizer/engine/detection/chunked_detection.py index e9c5545c..3a9fa372 100644 --- a/src/anonymizer/engine/detection/chunked_detection.py +++ b/src/anonymizer/engine/detection/chunked_detection.py @@ -41,6 +41,7 @@ import json import logging +from concurrent.futures import ThreadPoolExecutor from typing import Any from data_designer.config import custom_column_generator @@ -56,6 +57,12 @@ # Floor on window size so a pathologically small cap still makes progress. _MIN_WINDOW_CHARS = 4000 +# Upper bound on detector windows dispatched concurrently for one record. Windows +# are independent LLM calls, so they run in parallel; the per-alias rate limit +# (``max_parallel_requests`` on the facade's ThrottledModelClient) still caps the +# real in-flight count, so this just bounds thread creation on very long inputs. +_MAX_PARALLEL_WINDOWS = 16 + class WindowedDetectionParams(BaseModel): """Parameters supplied to :func:`detect_row` via DD's ``generator_params``. @@ -167,6 +174,34 @@ def _call_detector(*, facade: Any, prompt: str, system_prompt: str | None, purpo return str(output) +def _detect_window( + *, facade: Any, text: str, start: int, end: int, text_len: int, system_prompt: str | None +) -> list[EntitySpan]: + """Detect entities in one window and return them in full-document coordinates. + + Runs the detector on ``text[start:end]`` (window-local offsets), rebases to + global offsets, then drops spans touching an artificial window edge (truncated + halves of straddling entities — recovered interior to the overlapping + neighbour). Pure w.r.t. shared state, so it is safe to run per window in a + thread pool. + """ + window_text = text[start:end] + raw = _call_detector( + facade=facade, + prompt=window_text, + system_prompt=system_prompt, + purpose=f"entity-detection-window-{start}", + ) + local_spans = parse_raw_entities(raw_response=raw, text=window_text) + global_spans = rebase_spans(local_spans, start) + kept = drop_boundary_spans(global_spans, window_start=start, window_end=end, text_len=text_len) + logger.debug( + "detection window [%d, %d) size=%d -> %d span(s), %d after edge-drop", + start, end, end - start, len(local_spans), len(kept), + ) + return kept + + def detect_row( row: dict[str, Any], params: WindowedDetectionParams, @@ -198,31 +233,32 @@ def detect_row( text_len = len(text) window = max(_MIN_WINDOW_CHARS, cap - params.safety_margin_chars) windows = iter_windows(text_len, window, params.overlap_chars) + max_workers = min(len(windows), _MAX_PARALLEL_WINDOWS) logger.info( "detection: text %d chars > cap %d; tiling into %d overlapping window(s) " - "(window=%d, overlap=%d, min_window=%d)", - text_len, cap, len(windows), window, params.overlap_chars, _MIN_WINDOW_CHARS, + "(window=%d, overlap=%d, min_window=%d, max_workers=%d)", + text_len, cap, len(windows), window, params.overlap_chars, _MIN_WINDOW_CHARS, max_workers, ) + # Windows are independent LLM calls -> dispatch concurrently. ``future.result()`` + # re-raises the first failing window (same fail-the-row semantics as a serial + # loop); the merge is order-independent because resolve_overlaps sorts. all_spans: list[EntitySpan] = [] - for window_index, (start, end) in enumerate(windows): - window_text = text[start:end] - raw = _call_detector( - facade=facade, - prompt=window_text, - system_prompt=params.system_prompt, - purpose=f"entity-detection-window-{start}", - ) - # parse_raw_entities validates + resolves overlaps within the window - # (offsets are window-local because the model only saw window_text). - local_spans = parse_raw_entities(raw_response=raw, text=window_text) - global_spans = rebase_spans(local_spans, start) - kept = drop_boundary_spans(global_spans, window_start=start, window_end=end, text_len=text_len) - logger.debug( - "detection window %d: chars [%d, %d) size=%d -> %d span(s), %d after edge-drop", - window_index, start, end, end - start, len(local_spans), len(kept), - ) - all_spans.extend(kept) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _detect_window, + facade=facade, + text=text, + start=start, + end=end, + text_len=text_len, + system_prompt=params.system_prompt, + ) + for start, end in windows + ] + for future in futures: + all_spans.extend(future.result()) merged = resolve_overlaps(all_spans) logger.info( diff --git a/src/anonymizer/engine/detection/chunked_latent.py b/src/anonymizer/engine/detection/chunked_latent.py index 571e01ca..fd41674d 100644 --- a/src/anonymizer/engine/detection/chunked_latent.py +++ b/src/anonymizer/engine/detection/chunked_latent.py @@ -20,6 +20,7 @@ import functools import logging +from concurrent.futures import ThreadPoolExecutor from typing import Any from data_designer.config import custom_column_generator @@ -30,6 +31,7 @@ from anonymizer.engine.constants import ( COL_DETECTED_ENTITIES, COL_LATENT_ENTITIES, + COL_LATENT_FAILED_WINDOWS, COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, @@ -43,6 +45,11 @@ # Floor on window size so a pathologically entity-dense slice still progresses. _MIN_WINDOW_CHARS = 4000 +# Upper bound on latent windows dispatched concurrently for one record. The +# per-alias rate limit (``max_parallel_requests`` on the facade) still caps the +# real in-flight count; this just bounds thread creation on very long inputs. +_MAX_PARALLEL_WINDOWS = 8 + _PROMPT_ENV = Environment( loader=BaseLoader(), autoescape=False, @@ -134,6 +141,68 @@ def _call_latent(*, facade: Any, prompt: str, system_prompt: str | None, purpose return output +def plan_latent_windows( + *, + text: str, + spans: list[EntitySpan], + notation: TagNotation, + params: WindowedLatentParams, + cap: int, + initial_window: int, +) -> list[tuple[int, int, str]]: + """Walk the document applying the shrink rule, returning ``(start, end, rendered_prompt)``. + + No LLM calls — only Jinja renders + length checks — so this runs serially to fix + the (data-dependent) window boundaries before the parallel LLM pass. The window + shrinks when a render exceeds ``cap`` and the shrunk size carries forward. + """ + text_len = len(text) + windows: list[tuple[int, int, str]] = [] + window = initial_window + pos = 0 + while pos < text_len: + end = min(text_len, pos + window) + # Reuse augmentation's span-rebasing + tagging; latent ignores the seed JSON. + tagged, _seed_json = build_window_inputs(text=text, all_spans=spans, start=pos, end=end, notation=notation) + rendered = render_latent_prompt(template=params.prompt_template, tagged_text=tagged, notation=notation) + if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: + shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) + logger.debug( + "latent @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", + pos, len(rendered), cap, window, shrunk, + ) + window = shrunk + continue + logger.debug("latent window @[%d, %d) size=%d, rendered=%d/%d chars", pos, end, end - pos, len(rendered), cap) + windows.append((pos, end, rendered)) + if end >= text_len: + break + pos = max(pos + 1, end - params.overlap_chars) + return windows + + +def _latent_window( + *, facade: Any, rendered: str, system_prompt: str | None, start: int +) -> LatentEntitiesSchema | None: + """Run one latent window; return its result, or ``None`` if the call fails. + + A single window's failure (unparseable response, timeout, ...) must not drop the + record: latent detection only contributes inferred entities, so a skipped window + degrades gracefully. The error is logged at WARNING so it stays visible. + """ + try: + result = _call_latent( + facade=facade, prompt=rendered, system_prompt=system_prompt, purpose=f"latent-detection-window-{start}" + ) + except Exception as exc: # noqa: BLE001 — one window must not sink the record + logger.warning( + "latent window @%d failed (%s: %s); skipping this window's entities", start, type(exc).__name__, exc + ) + return None + logger.debug("latent window @%d: detector proposed %d latent entities", start, len(result.latent_entities)) + return result + + def latent_row( row: dict[str, Any], params: WindowedLatentParams, @@ -168,66 +237,49 @@ def latent_row( purpose="latent-detection", ) row[COL_LATENT_ENTITIES] = output.model_dump(mode="json") + row[COL_LATENT_FAILED_WINDOWS] = 0 return row - # Windowed path. + # Windowed path. Plan the windows serially (cheap renders + the data-dependent + # shrink), then run the LLM calls in parallel; a single window's failure is + # logged and skipped rather than dropping the record. spans = parse_detected_spans(row.get(COL_DETECTED_ENTITIES, {})) - results: list[LatentEntitiesSchema] = [] - window = initial_window - pos = 0 text_len = len(text) + windows = plan_latent_windows( + text=text, spans=spans, notation=notation, params=params, cap=cap, initial_window=initial_window + ) + max_workers = min(len(windows), _MAX_PARALLEL_WINDOWS) logger.info( - "latent: rendered prompt %d chars > cap %d; tiling %d-char document into " - "overlapping windows (initial_window=%d, overlap=%d, min_window=%d)", - len(full_rendered), cap, text_len, initial_window, params.overlap_chars, _MIN_WINDOW_CHARS, + "latent: rendered prompt %d chars > cap %d; tiling %d-char document into %d overlapping " + "window(s) (initial_window=%d, overlap=%d, min_window=%d, max_workers=%d)", + len(full_rendered), cap, text_len, len(windows), initial_window, params.overlap_chars, + _MIN_WINDOW_CHARS, max_workers, ) - window_index = 0 - while pos < text_len: - end = min(text_len, pos + window) - # Reuse augmentation's span-rebasing + tagging; latent ignores the seed JSON. - tagged, _seed_json = build_window_inputs( - text=text, all_spans=spans, start=pos, end=end, notation=notation - ) - rendered = render_latent_prompt(template=params.prompt_template, tagged_text=tagged, notation=notation) - if len(rendered) > cap and (end - pos) > _MIN_WINDOW_CHARS: - # Shrink proportionally to the measured overage so the next try lands - # just under the cap (0.95 = small safety margin), instead of halving - # and overshooting. Strictly decreasing -> converges, then the floor stops it. - shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) - logger.debug( - "latent window %d @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", - window_index, pos, len(rendered), cap, window, shrunk, - ) - window = shrunk - continue - logger.debug( - "latent window %d: chars [%d, %d) size=%d, rendered=%d/%d chars", - window_index, pos, end, end - pos, len(rendered), cap, - ) - result = _call_latent( - facade=facade, - prompt=rendered, - system_prompt=params.system_prompt, - purpose=f"latent-detection-window-{pos}", - ) - logger.debug("latent window %d: detector proposed %d latent entities", window_index, len(result.latent_entities)) - results.append(result) - if end >= text_len: - break - next_pos = max(pos + 1, end - params.overlap_chars) - logger.debug( - "latent window %d: advancing pos %d -> %d (overlap back %d chars)", - window_index, pos, next_pos, end - next_pos, - ) - pos = next_pos - window_index += 1 + results: list[LatentEntitiesSchema] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _latent_window, facade=facade, rendered=rendered, system_prompt=params.system_prompt, start=start + ) + for start, _end, rendered in windows + ] + for future in futures: + result = future.result() # _latent_window swallows errors -> never raises + if result is not None: + results.append(result) + + failed = len(windows) - len(results) + if failed: + logger.warning("latent: %d of %d window(s) failed and were skipped", failed, len(windows)) merged = merge_latent(results) logger.info( - "latent: %d window(s) over %d chars -> %d unique latent entities after dedupe (cap=%d, overlap=%d)", - len(results), text_len, len(merged.latent_entities), cap, params.overlap_chars, + "latent: %d window(s) over %d chars -> %d unique latent entities after dedupe " + "(cap=%d, overlap=%d, %d failed)", + len(windows), text_len, len(merged.latent_entities), cap, params.overlap_chars, failed, ) row[COL_LATENT_ENTITIES] = merged.model_dump(mode="json") + row[COL_LATENT_FAILED_WINDOWS] = failed return row @@ -248,6 +300,7 @@ def make_windowed_latent_generator(alias: str) -> Any: COL_DETECTED_ENTITIES, COL_TAG_NOTATION, ], + side_effect_columns=[COL_LATENT_FAILED_WINDOWS], model_aliases=[alias], ) def windowed_latent( diff --git a/src/anonymizer/engine/rewrite/chunked_final_judge.py b/src/anonymizer/engine/rewrite/chunked_final_judge.py new file mode 100644 index 00000000..4070dce7 --- /dev/null +++ b/src/anonymizer/engine/rewrite/chunked_final_judge.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Windowed final judge for long-context rewrites. + +The final judge scores a rewrite on privacy / quality / naturalness. Normally it +runs as a single DataDesigner ``LLMJudgeColumnConfig`` whose prompt embeds the full +original + rewritten text, which blows past NDD's ginja render cap (512000 chars) on +long documents. + +This module instead tiles the original and rewritten text into matching, *independent* +windows, judges each window in parallel (like detection — no cross-window state), and +aggregates the per-window scores by taking the **minimum** per dimension (the worst +section drives the score; conservative for privacy). Windows are paired positionally: +both texts are split into the same number of near-equal slices. + +Output column (``COL_JUDGE_EVALUATION``) keeps the judge's dict shape so the existing +display/extract code is unchanged: ``{rubric_name: {"score": int, "reasoning": str}}``. + +Public entry point: :func:`make_windowed_judge_generator`. +""" + +from __future__ import annotations + +import functools +import logging +import math +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from data_designer.config import custom_column_generator +from data_designer.engine.models.recipes.response_recipes import PydanticResponseRecipe +from jinja2 import BaseLoader, Environment, StrictUndefined +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import COL_JUDGE_EVALUATION, COL_REWRITTEN_TEXT, COL_TEXT + +logger = logging.getLogger("anonymizer.rewrite.chunked_judge") + +# Floor on the per-window text budget so a tiny cap still makes progress. +_MIN_BUDGET_CHARS = 4000 + +# Upper bound on judge windows dispatched concurrently for one record. The per-alias +# rate limit on the facade still caps real in-flight calls; this bounds thread creation. +_MAX_PARALLEL_WINDOWS = 8 + +_DIMENSIONS = ("privacy", "quality", "naturalness") + +_PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) + + +@functools.lru_cache(maxsize=4) +def _compile_template(template: str) -> Any: + return _PROMPT_ENV.from_string(template) + + +# --------------------------------------------------------------------------- +# Structured output schema (one window's scores) +# --------------------------------------------------------------------------- + + +class _DimensionScore(BaseModel): + score: int = Field(ge=1, le=10) + reasoning: str = "" + + +class JudgeScoresSchema(BaseModel): + """Per-window judge scores; mirrors the three final-judge rubrics.""" + + privacy: _DimensionScore + quality: _DimensionScore + naturalness: _DimensionScore + + +class WindowedJudgeParams(BaseModel): + """Params for the windowed final judge (via DD ``generator_params``).""" + + alias: str = Field(min_length=1) + prompt_template: str = Field(repr=False) + max_render_chars: int = Field(gt=0) + safety_margin_chars: int = Field(default=8000, ge=0) + system_prompt: str | None = Field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def slice_evenly(text: str, n: int) -> list[str]: + """Split ``text`` into ``n`` contiguous, near-equal character slices.""" + if n <= 1: + return [text] + size = max(1, math.ceil(len(text) / n)) + slices = [text[i * size : (i + 1) * size] for i in range(n)] + # Guarantee exactly n entries (pad with "" if text shorter than n chars). + while len(slices) < n: + slices.append("") + return slices[:n] + + +def render_judge_prompt(*, template: str, original: str, rewritten: str) -> str: + """Render the judge prompt for one window.""" + return _compile_template(template).render(original_text=original, rewritten_text=rewritten) + + +def plan_judge_windows( + *, original: str, rewritten: str, template: str, cap: int, safety_margin_chars: int +) -> list[str]: + """Return one rendered prompt per window (positionally paired slices). + + The window count is chosen so each window's combined text fits the render + budget; both texts are split into that many near-equal slices. No LLM calls. + """ + overhead = len(render_judge_prompt(template=template, original="", rewritten="")) + budget = max(_MIN_BUDGET_CHARS, cap - safety_margin_chars - overhead) + total = len(original) + len(rewritten) + n = max(1, math.ceil(total / budget)) + o_slices = slice_evenly(original, n) + r_slices = slice_evenly(rewritten, n) + return [render_judge_prompt(template=template, original=o, rewritten=r) for o, r in zip(o_slices, r_slices)] + + +def aggregate_judge(results: list[JudgeScoresSchema]) -> dict[str, dict[str, Any]]: + """Aggregate per-window scores into the judge-column dict shape. + + Per dimension, takes the minimum score across windows (worst section) and the + reasoning from that worst window. Shape matches LLMJudgeColumnConfig output: + ``{dim: {"score": int, "reasoning": str}}``. + """ + out: dict[str, dict[str, Any]] = {} + for dim in _DIMENSIONS: + dims = [getattr(r, dim) for r in results] + worst = min(dims, key=lambda d: d.score) + note = worst.reasoning + if len(results) > 1: + note = f"{note} [min over {len(results)} window(s)]".strip() + out[dim] = {"score": worst.score, "reasoning": note} + return out + + +# --------------------------------------------------------------------------- +# Dispatch +# --------------------------------------------------------------------------- + + +def _judge_window(*, facade: Any, prompt: str, system_prompt: str | None, idx: int) -> JudgeScoresSchema | None: + """Judge one window; return its scores, or ``None`` if the call fails (logged, skipped).""" + recipe = PydanticResponseRecipe(data_type=JudgeScoresSchema) + try: + output, _messages = facade.generate( + prompt=recipe.apply_recipe_to_user_prompt(prompt), + parser=recipe.parse, + system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt), + purpose=f"final-judge-window-{idx}", + ) + except Exception as exc: # noqa: BLE001 — one window must not sink the (non-critical) judge + logger.warning("final-judge window %d failed (%s: %s); skipping it", idx, type(exc).__name__, exc) + return None + return output + + +def judge_row(row: dict[str, Any], params: WindowedJudgeParams, models: dict[str, Any]) -> dict[str, Any]: + """Run the (possibly windowed) final judge for one row, writing ``COL_JUDGE_EVALUATION``.""" + if params.alias not in models: + raise KeyError( + f"Judge alias {params.alias!r} not present in models dict. Ensure " + "make_windowed_judge_generator was invoked with the same alias." + ) + facade = models[params.alias] + + original = str(row.get(COL_TEXT, "")) + rewritten = row.get(COL_REWRITTEN_TEXT) + if not rewritten: + # No rewrite to judge (dropped/empty). Leave the evaluation unset. + row[COL_JUDGE_EVALUATION] = None + return row + rewritten = str(rewritten) + + prompts = plan_judge_windows( + original=original, + rewritten=rewritten, + template=params.prompt_template, + cap=params.max_render_chars, + safety_margin_chars=params.safety_margin_chars, + ) + max_workers = min(len(prompts), _MAX_PARALLEL_WINDOWS) + if len(prompts) > 1: + logger.info("final-judge: judging %d window(s) in parallel (max_workers=%d)", len(prompts), max_workers) + + results: list[JudgeScoresSchema] = [] + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _judge_window, facade=facade, prompt=prompt, system_prompt=params.system_prompt, idx=idx + ) + for idx, prompt in enumerate(prompts) + ] + for future in futures: + res = future.result() # _judge_window swallows errors -> never raises + if res is not None: + results.append(res) + + if not results: + logger.warning("final-judge: all %d window(s) failed; leaving evaluation unset", len(prompts)) + row[COL_JUDGE_EVALUATION] = None + return row + + row[COL_JUDGE_EVALUATION] = aggregate_judge(results) + return row + + +# --------------------------------------------------------------------------- +# DataDesigner wiring factory. +# --------------------------------------------------------------------------- + + +def make_windowed_judge_generator(alias: str) -> Any: + """Build a ``@custom_column_generator``-decorated final judge bound to ``alias``.""" + if not alias: + raise ValueError("Cannot build windowed judge generator: alias is empty.") + + @custom_column_generator( + required_columns=[COL_TEXT, COL_REWRITTEN_TEXT], + model_aliases=[alias], + ) + def windowed_judge( + row: dict[str, Any], + generator_params: WindowedJudgeParams, + models: dict[str, Any], + ) -> dict[str, Any]: + return judge_row(row, generator_params, models) + + return windowed_judge diff --git a/src/anonymizer/engine/rewrite/final_judge.py b/src/anonymizer/engine/rewrite/final_judge.py index 6f039999..b6f620ca 100644 --- a/src/anonymizer/engine/rewrite/final_judge.py +++ b/src/anonymizer/engine/rewrite/final_judge.py @@ -6,10 +6,11 @@ from typing import Any from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMJudgeColumnConfig, Score +from data_designer.config.column_configs import CustomColumnConfig, Score from data_designer.config.column_types import ColumnConfigT from pydantic import BaseModel +from anonymizer.config.anonymizer_config import Detect as _DetectConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import EvaluationCriteria, PrivacyGoal from anonymizer.engine.constants import ( @@ -18,12 +19,15 @@ COL_LEAKAGE_MASS, COL_NEEDS_HUMAN_REVIEW, COL_REWRITTEN_TEXT, - COL_TEXT, COL_UTILITY_SCORE, - _jinja, ) from anonymizer.engine.ndd.model_loader import resolve_model_alias from anonymizer.engine.prompt_utils import substitute_placeholders +from anonymizer.engine.rewrite.chunked_final_judge import WindowedJudgeParams, make_windowed_judge_generator + +# Long-context window defaults, shared with detection (single source of truth). +_DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default +_DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default # --------------------------------------------------------------------------- # Generator params @@ -127,17 +131,39 @@ def _judge_prompt(privacy_goal: PrivacyGoal) -> str: - Judge naturalness independently from privacy and quality. A rewrite can be natural even if it changes content, and it can preserve content while still sounding awkward. + + +<> + + + +Score each dimension from 1 to 10 using the scales above, then return your evaluation as JSON with +keys "privacy", "quality", and "naturalness" — each an object with an integer "score" (1-10) and a +brief "reasoning". + """ return substitute_placeholders( prompt, { "<>": privacy_goal.to_prompt_string(), - "<>": _jinja(COL_TEXT), - "<>": _jinja(COL_REWRITTEN_TEXT), + "<>": "{{ original_text }}", + "<>": "{{ rewritten_text }}", + "<>": _rubric_scale_block(), }, ) +def _rubric_scale_block() -> str: + """Render the three rubrics' 1-10 scales into prompt text (the windowed judge + calls the model directly, so the scales must live in the prompt rather than being + passed as DataDesigner ``Score`` objects).""" + blocks = [] + for rubric in (PRIVACY_RUBRIC, QUALITY_RUBRIC, NATURALNESS_RUBRIC): + lines = "\n".join(f" {score}: {desc}" for score, desc in sorted(rubric.options.items(), reverse=True)) + blocks.append(f"{rubric.name} — {rubric.description}\n{lines}") + return "\n\n".join(blocks) + + # --------------------------------------------------------------------------- # Custom column generators # --------------------------------------------------------------------------- @@ -249,11 +275,18 @@ def columns( judge_alias = resolve_model_alias("judge", selected_models) return [ - LLMJudgeColumnConfig( + # Windowed final judge: long documents are split into parallel, independent + # windows (the rubric scales live in the prompt and scores are parsed via + # structured output), bypassing NDD's single-render cap. See chunked_final_judge. + CustomColumnConfig( name=COL_JUDGE_EVALUATION, - prompt=_judge_prompt(privacy_goal), - model_alias=judge_alias, - scores=[PRIVACY_RUBRIC, QUALITY_RUBRIC, NATURALNESS_RUBRIC], + generator_function=make_windowed_judge_generator(judge_alias), + generator_params=WindowedJudgeParams( + alias=judge_alias, + prompt_template=_judge_prompt(privacy_goal), + max_render_chars=_DEFAULT_MAX_RENDER_CHARS, + safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + ), ), CustomColumnConfig( name=COL_NEEDS_HUMAN_REVIEW, diff --git a/tests/engine/test_chunked_augmentation.py b/tests/engine/test_chunked_augmentation.py index cdef0790..57a78d22 100644 --- a/tests/engine/test_chunked_augmentation.py +++ b/tests/engine/test_chunked_augmentation.py @@ -17,6 +17,7 @@ import pytest from anonymizer.engine.constants import ( + COL_AUGMENTATION_FAILED_WINDOWS, COL_AUGMENTED_ENTITIES, COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, @@ -174,6 +175,35 @@ def test_multiple_windows_unioned(self) -> None: result = AugmentedEntitiesSchema.model_validate(out[COL_AUGMENTED_ENTITIES]) assert len(result.entities) == len(facade.calls) + def test_one_failing_window_is_skipped_not_fatal(self) -> None: + # A single window's LLM failure must not drop the whole record: the row + # still completes and the other windows' entities survive. + text = ("A" * 5000) + ("B" * 5000) # forces multiple windows + counter = itertools.count() + + def resp(_prompt): + i = next(counter) + if i == 1: # exactly one window raises (order is nondeterministic under the pool) + raise ValueError("simulated unparseable model output") + return {"entities": [{"value": f"v{i}", "label": "name"}]} + + facade = FakeFacade(resp) + row = { + COL_TEXT: text, + COL_INITIAL_TAGGED_TEXT: text, + COL_SEED_ENTITIES_JSON: "[]", + COL_VALIDATED_SEED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + # Should NOT raise despite one window failing. + out = augment_row( + row, _params(max_render_chars=4000, safety_margin_chars=0, overlap_chars=1000), {"aug": facade} + ) + result = AugmentedEntitiesSchema.model_validate(out[COL_AUGMENTED_ENTITIES]) + assert len(facade.calls) > 1 # multiple windows attempted + assert len(result.entities) == len(facade.calls) - 1 # exactly one window dropped + assert out[COL_AUGMENTATION_FAILED_WINDOWS] == 1 # the skip is recorded for degraded-flagging + def test_missing_alias_raises(self) -> None: with pytest.raises(KeyError, match="not present in models"): augment_row({COL_TEXT: "x"}, _params(max_render_chars=10), {}) diff --git a/tests/engine/test_chunked_final_judge.py b/tests/engine/test_chunked_final_judge.py new file mode 100644 index 00000000..c953500e --- /dev/null +++ b/tests/engine/test_chunked_final_judge.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the windowed (parallel) final judge. + +Pure helpers (slicing, window planning, score aggregation) are tested directly; the +per-window dispatch is tested via a fake facade that replays canned judge scores +through the recipe parser, including a failing window (which must be skipped). +""" + +from __future__ import annotations + +import json +from typing import Any, Callable + +from anonymizer.engine.constants import COL_JUDGE_EVALUATION, COL_REWRITTEN_TEXT, COL_TEXT +from anonymizer.engine.rewrite.chunked_final_judge import ( + JudgeScoresSchema, + WindowedJudgeParams, + aggregate_judge, + judge_row, + plan_judge_windows, + slice_evenly, +) + +_TEMPLATE = "ORIG:{{ original_text }} REW:{{ rewritten_text }}" + + +def _scores(privacy: int, quality: int, naturalness: int, note: str = "r") -> JudgeScoresSchema: + return JudgeScoresSchema.model_validate( + { + "privacy": {"score": privacy, "reasoning": note}, + "quality": {"score": quality, "reasoning": note}, + "naturalness": {"score": naturalness, "reasoning": note}, + } + ) + + +class FakeJudgeFacade: + """Replays canned JudgeScoresSchema responses through the recipe parser.""" + + def __init__(self, response: dict | Callable[[str], dict]) -> None: + self._response = response + self.calls: list[str] = [] + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + self.calls.append(purpose) + resp = self._response(prompt) if callable(self._response) else self._response + return parser(f"```json\n{json.dumps(resp)}\n```"), [] + + +def _params(max_render_chars: int, **kw: Any) -> WindowedJudgeParams: + return WindowedJudgeParams(alias="judge", prompt_template=_TEMPLATE, max_render_chars=max_render_chars, **kw) + + +# --------------------------------------------------------------------------- +# Pure helpers +# --------------------------------------------------------------------------- + + +def test_slice_evenly_single_window(): + assert slice_evenly("abcdef", 1) == ["abcdef"] + + +def test_slice_evenly_splits_and_pads(): + assert slice_evenly("abcdef", 3) == ["ab", "cd", "ef"] + assert slice_evenly("ab", 3) == ["a", "b", ""] # padded to exactly n + + +def test_plan_judge_windows_single_when_small(): + prompts = plan_judge_windows(original="o", rewritten="r", template=_TEMPLATE, cap=10_000, safety_margin_chars=0) + assert len(prompts) == 1 + assert "ORIG:o REW:r" in prompts[0] + + +def test_plan_judge_windows_splits_when_large(): + original, rewritten = "o" * 9000, "r" * 9000 + # tiny budget -> multiple windows; both texts sliced into the same count + prompts = plan_judge_windows(original=original, rewritten=rewritten, template=_TEMPLATE, cap=4200, safety_margin_chars=0) + assert len(prompts) > 1 + + +def test_aggregate_judge_takes_min_per_dimension(): + agg = aggregate_judge([_scores(8, 9, 7, "good"), _scores(3, 6, 9, "leak")]) + assert agg["privacy"]["score"] == 3 + assert agg["quality"]["score"] == 6 + assert agg["naturalness"]["score"] == 7 + assert "min over 2 window(s)" in agg["privacy"]["reasoning"] + assert "leak" in agg["privacy"]["reasoning"] # reasoning from the worst window + + +# --------------------------------------------------------------------------- +# judge_row +# --------------------------------------------------------------------------- + + +def test_judge_row_fast_path_single_call(): + facade = FakeJudgeFacade(_scores(8, 8, 8).model_dump()) + row = {COL_TEXT: "original", COL_REWRITTEN_TEXT: "rewritten"} + judge_row(row, _params(max_render_chars=10_000), {"judge": facade}) + assert len(facade.calls) == 1 + ev = row[COL_JUDGE_EVALUATION] + assert set(ev) == {"privacy", "quality", "naturalness"} + assert ev["privacy"]["score"] == 8 + + +def test_judge_row_windowed_aggregates_min(): + # distinct score per window so we can verify min aggregation + import itertools + + counter = itertools.count() + + def resp(_prompt): + i = next(counter) + return _scores(5 + i, 9, 9).model_dump() + + facade = FakeJudgeFacade(resp) + row = {COL_TEXT: "o" * 9000, COL_REWRITTEN_TEXT: "r" * 9000} + judge_row(row, _params(max_render_chars=4200, safety_margin_chars=0), {"judge": facade}) + assert len(facade.calls) > 1 + # min privacy across windows is the first window's (5 + 0) + assert row[COL_JUDGE_EVALUATION]["privacy"]["score"] == 5 + + +def test_judge_row_skips_failing_window_not_fatal(): + import itertools + + counter = itertools.count() + + def resp(_prompt): + i = next(counter) + if i == 0: + raise ValueError("simulated judge parse failure") + return _scores(7, 7, 7).model_dump() + + facade = FakeJudgeFacade(resp) + row = {COL_TEXT: "o" * 9000, COL_REWRITTEN_TEXT: "r" * 9000} + judge_row(row, _params(max_render_chars=4200, safety_margin_chars=0), {"judge": facade}) + # one window failed but the rest produced a valid evaluation + assert row[COL_JUDGE_EVALUATION] is not None + assert row[COL_JUDGE_EVALUATION]["privacy"]["score"] == 7 + + +def test_judge_row_all_windows_fail_yields_none(): + def resp(_prompt): + raise RuntimeError("judge down") + + facade = FakeJudgeFacade(resp) + row = {COL_TEXT: "original", COL_REWRITTEN_TEXT: "rewritten"} + judge_row(row, _params(max_render_chars=10_000), {"judge": facade}) + assert row[COL_JUDGE_EVALUATION] is None + + +def test_judge_row_no_rewrite_yields_none(): + facade = FakeJudgeFacade(_scores(8, 8, 8).model_dump()) + row = {COL_TEXT: "original", COL_REWRITTEN_TEXT: None} + judge_row(row, _params(max_render_chars=10_000), {"judge": facade}) + assert row[COL_JUDGE_EVALUATION] is None + assert facade.calls == [] # no judge call when there's nothing to judge diff --git a/tests/engine/test_chunked_latent.py b/tests/engine/test_chunked_latent.py index d40e7233..efed585b 100644 --- a/tests/engine/test_chunked_latent.py +++ b/tests/engine/test_chunked_latent.py @@ -14,6 +14,7 @@ from anonymizer.engine.constants import ( COL_DETECTED_ENTITIES, COL_LATENT_ENTITIES, + COL_LATENT_FAILED_WINDOWS, COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, @@ -111,6 +112,33 @@ def test_multiple_windows_unioned(self) -> None: result = LatentEntitiesSchema.model_validate(out[COL_LATENT_ENTITIES]) assert len(result.latent_entities) == len(facade.calls) + def test_one_failing_window_is_skipped_not_fatal(self) -> None: + # A single window's LLM failure must not drop the record; the other + # windows' latent entities survive and no exception propagates. + text = ("A" * 5000) + ("B" * 5000) + counter = itertools.count() + + def resp(_p): + i = next(counter) + if i == 1: # exactly one window raises (order nondeterministic under the pool) + raise ValueError("simulated unparseable model output") + return {"latent_entities": [_latent("employer", f"Org{i}")]} + + facade = FakeFacade(resp) + row = { + COL_TEXT: text, + COL_TAGGED_TEXT: text, + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + out = latent_row( + row, _params(max_render_chars=4000, safety_margin_chars=0, overlap_chars=1000), {"lat": facade} + ) + result = LatentEntitiesSchema.model_validate(out[COL_LATENT_ENTITIES]) + assert len(facade.calls) > 1 + assert len(result.latent_entities) == len(facade.calls) - 1 # one window dropped + assert out[COL_LATENT_FAILED_WINDOWS] == 1 # the skip is recorded for degraded-flagging + def test_missing_alias_raises(self) -> None: with pytest.raises(KeyError, match="not present in models"): latent_row({COL_TEXT: "x"}, _params(max_render_chars=10), {}) diff --git a/tests/engine/test_final_judge.py b/tests/engine/test_final_judge.py index a33cb7c5..87b9c1fd 100644 --- a/tests/engine/test_final_judge.py +++ b/tests/engine/test_final_judge.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from data_designer.config.column_configs import CustomColumnConfig, LLMJudgeColumnConfig +from data_designer.config.column_configs import CustomColumnConfig from anonymizer.config.models import RewriteModelSelection from anonymizer.config.rewrite import EvaluationCriteria, PrivacyGoal @@ -14,7 +14,6 @@ COL_LEAKAGE_MASS, COL_NEEDS_HUMAN_REVIEW, COL_REWRITTEN_TEXT, - COL_TEXT, COL_UTILITY_SCORE, ) from anonymizer.engine.rewrite.final_judge import ( @@ -55,10 +54,15 @@ def test_judge_prompt_uses_xml_sections() -> None: assert "" in prompt -def test_judge_prompt_references_required_columns() -> None: +def test_judge_prompt_has_window_placeholders_and_scales() -> None: prompt = _judge_prompt(_STUB_PRIVACY_GOAL) - assert COL_TEXT in prompt - assert COL_REWRITTEN_TEXT in prompt + # Per-window slices are injected via these Jinja placeholders (not DD column refs). + assert "{{ original_text }}" in prompt + assert "{{ rewritten_text }}" in prompt + # The 1-10 rubric scales must be embedded in the prompt for the direct model call. + for name in ("privacy", "quality", "naturalness"): + assert name in prompt + assert "" in prompt # --------------------------------------------------------------------------- @@ -87,12 +91,11 @@ def test_judge_column_uses_judge_alias( privacy_goal=_STUB_PRIVACY_GOAL, evaluation=_STUB_EVALUATION, ) - judge_cols = [c for c in cols if isinstance(c, LLMJudgeColumnConfig)] - assert len(judge_cols) == 1 - assert judge_cols[0].model_alias == stub_rewrite_model_selection.judge + judge_col = next(c for c in cols if c.name == COL_JUDGE_EVALUATION) + assert judge_col.generator_params.alias == stub_rewrite_model_selection.judge -def test_judge_column_has_three_rubrics( +def test_judge_column_is_windowed_generator_with_three_rubrics( stub_rewrite_model_selection: RewriteModelSelection, ) -> None: wf = FinalJudgeWorkflow() @@ -101,13 +104,12 @@ def test_judge_column_has_three_rubrics( privacy_goal=_STUB_PRIVACY_GOAL, evaluation=_STUB_EVALUATION, ) - judge_col = next(c for c in cols if isinstance(c, LLMJudgeColumnConfig)) - assert judge_col.name == COL_JUDGE_EVALUATION - score_names = {s.name for s in judge_col.scores} - assert score_names == {"privacy", "quality", "naturalness"} - for score in judge_col.scores: - assert 1 in score.options - assert 10 in score.options + judge_col = next(c for c in cols if c.name == COL_JUDGE_EVALUATION) + assert isinstance(judge_col, CustomColumnConfig) + template = judge_col.generator_params.prompt_template + for name in ("privacy", "quality", "naturalness"): + assert name in template + assert judge_col.generator_params.max_render_chars > 0 def test_needs_human_review_column_present( @@ -119,9 +121,9 @@ def test_needs_human_review_column_present( privacy_goal=_STUB_PRIVACY_GOAL, evaluation=_STUB_EVALUATION, ) - custom_cols = [c for c in cols if isinstance(c, CustomColumnConfig)] - assert len(custom_cols) == 1 - assert custom_cols[0].name == COL_NEEDS_HUMAN_REVIEW + review_cols = [c for c in cols if c.name == COL_NEEDS_HUMAN_REVIEW] + assert len(review_cols) == 1 + assert isinstance(review_cols[0], CustomColumnConfig) def test_needs_human_review_column_uses_evaluation_thresholds( @@ -134,7 +136,7 @@ def test_needs_human_review_column_uses_evaluation_thresholds( privacy_goal=_STUB_PRIVACY_GOAL, evaluation=evaluation, ) - custom_col = next(c for c in cols if isinstance(c, CustomColumnConfig)) + custom_col = next(c for c in cols if c.name == COL_NEEDS_HUMAN_REVIEW) params = HumanReviewParams.model_validate(custom_col.generator_params) assert params.flag_utility_below == 0.6 assert params.flag_leakage_above == 1.0 From 76aa62c4e0e9cc3661a74dba1ec259f43b765fc0 Mon Sep 17 00:00:00 2001 From: eurekayuan Date: Wed, 3 Jun 2026 11:34:17 -0700 Subject: [PATCH 3/4] format scripts --- .../engine/detection/chunked_augmentation.py | 26 ++++++-- .../engine/detection/chunked_detection.py | 18 ++++- .../engine/detection/chunked_latent.py | 30 ++++++--- .../engine/detection/detection_workflow.py | 8 +-- .../engine/replace/chunked_replace.py | 37 ++++++++--- .../engine/rewrite/chunked_final_judge.py | 4 +- .../engine/rewrite/chunked_rewrite.py | 64 ++++++++++++++---- .../engine/rewrite/domain_classification.py | 1 + .../engine/rewrite/qa_generation.py | 1 + tests/engine/test_chunked_detection.py | 4 +- tests/engine/test_chunked_final_judge.py | 4 +- tests/engine/test_chunked_replace.py | 66 ++++++++++++++----- tests/engine/test_chunked_rewrite.py | 4 +- tests/engine/test_windowing.py | 2 - 14 files changed, 201 insertions(+), 68 deletions(-) diff --git a/src/anonymizer/engine/detection/chunked_augmentation.py b/src/anonymizer/engine/detection/chunked_augmentation.py index 3fe41b59..f76a0d6b 100644 --- a/src/anonymizer/engine/detection/chunked_augmentation.py +++ b/src/anonymizer/engine/detection/chunked_augmentation.py @@ -260,7 +260,11 @@ def plan_augmentation_windows( shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) logger.debug( "augmentation @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", - pos, len(rendered), cap, window, shrunk, + pos, + len(rendered), + cap, + window, + shrunk, ) window = shrunk continue @@ -357,8 +361,14 @@ def augment_row( logger.info( "augmentation: rendered prompt %d chars > cap %d; tiling %d-char document into %d overlapping " "window(s) (initial_window=%d, overlap=%d, min_window=%d, max_workers=%d)", - len(full_rendered), cap, text_len, len(windows), initial_window, params.overlap_chars, - _MIN_WINDOW_CHARS, max_workers, + len(full_rendered), + cap, + text_len, + len(windows), + initial_window, + params.overlap_chars, + _MIN_WINDOW_CHARS, + max_workers, ) results: list[AugmentedEntitiesSchema] = [] @@ -379,9 +389,13 @@ def augment_row( logger.warning("augmentation: %d of %d window(s) failed and were skipped", failed, len(windows)) merged = merge_augmented(results) logger.info( - "augmentation: %d window(s) over %d chars -> %d unique entities after dedupe " - "(cap=%d, overlap=%d, %d failed)", - len(windows), text_len, len(merged.entities), cap, params.overlap_chars, failed, + "augmentation: %d window(s) over %d chars -> %d unique entities after dedupe (cap=%d, overlap=%d, %d failed)", + len(windows), + text_len, + len(merged.entities), + cap, + params.overlap_chars, + failed, ) row[COL_AUGMENTED_ENTITIES] = merged.model_dump(mode="json") row[COL_AUGMENTATION_FAILED_WINDOWS] = failed diff --git a/src/anonymizer/engine/detection/chunked_detection.py b/src/anonymizer/engine/detection/chunked_detection.py index 3a9fa372..ae6f4a04 100644 --- a/src/anonymizer/engine/detection/chunked_detection.py +++ b/src/anonymizer/engine/detection/chunked_detection.py @@ -197,7 +197,11 @@ def _detect_window( kept = drop_boundary_spans(global_spans, window_start=start, window_end=end, text_len=text_len) logger.debug( "detection window [%d, %d) size=%d -> %d span(s), %d after edge-drop", - start, end, end - start, len(local_spans), len(kept), + start, + end, + end - start, + len(local_spans), + len(kept), ) return kept @@ -237,7 +241,13 @@ def detect_row( logger.info( "detection: text %d chars > cap %d; tiling into %d overlapping window(s) " "(window=%d, overlap=%d, min_window=%d, max_workers=%d)", - text_len, cap, len(windows), window, params.overlap_chars, _MIN_WINDOW_CHARS, max_workers, + text_len, + cap, + len(windows), + window, + params.overlap_chars, + _MIN_WINDOW_CHARS, + max_workers, ) # Windows are independent LLM calls -> dispatch concurrently. ``future.result()`` @@ -263,7 +273,9 @@ def detect_row( merged = resolve_overlaps(all_spans) logger.info( "detection: %d window(s) over %d chars -> %d unique span(s) after merge", - len(windows), text_len, len(merged), + len(windows), + text_len, + len(merged), ) row[COL_RAW_DETECTED] = spans_to_detector_payload(merged) return row diff --git a/src/anonymizer/engine/detection/chunked_latent.py b/src/anonymizer/engine/detection/chunked_latent.py index fd41674d..0f3c8d72 100644 --- a/src/anonymizer/engine/detection/chunked_latent.py +++ b/src/anonymizer/engine/detection/chunked_latent.py @@ -169,7 +169,11 @@ def plan_latent_windows( shrunk = max(_MIN_WINDOW_CHARS, int(window * (cap / len(rendered)) * 0.95)) logger.debug( "latent @pos=%d: render %d > cap %d; shrinking window %d -> %d chars and retrying", - pos, len(rendered), cap, window, shrunk, + pos, + len(rendered), + cap, + window, + shrunk, ) window = shrunk continue @@ -181,9 +185,7 @@ def plan_latent_windows( return windows -def _latent_window( - *, facade: Any, rendered: str, system_prompt: str | None, start: int -) -> LatentEntitiesSchema | None: +def _latent_window(*, facade: Any, rendered: str, system_prompt: str | None, start: int) -> LatentEntitiesSchema | None: """Run one latent window; return its result, or ``None`` if the call fails. A single window's failure (unparseable response, timeout, ...) must not drop the @@ -252,8 +254,14 @@ def latent_row( logger.info( "latent: rendered prompt %d chars > cap %d; tiling %d-char document into %d overlapping " "window(s) (initial_window=%d, overlap=%d, min_window=%d, max_workers=%d)", - len(full_rendered), cap, text_len, len(windows), initial_window, params.overlap_chars, - _MIN_WINDOW_CHARS, max_workers, + len(full_rendered), + cap, + text_len, + len(windows), + initial_window, + params.overlap_chars, + _MIN_WINDOW_CHARS, + max_workers, ) results: list[LatentEntitiesSchema] = [] @@ -274,9 +282,13 @@ def latent_row( logger.warning("latent: %d of %d window(s) failed and were skipped", failed, len(windows)) merged = merge_latent(results) logger.info( - "latent: %d window(s) over %d chars -> %d unique latent entities after dedupe " - "(cap=%d, overlap=%d, %d failed)", - len(windows), text_len, len(merged.latent_entities), cap, params.overlap_chars, failed, + "latent: %d window(s) over %d chars -> %d unique latent entities after dedupe (cap=%d, overlap=%d, %d failed)", + len(windows), + text_len, + len(merged.latent_entities), + cap, + params.overlap_chars, + failed, ) row[COL_LATENT_ENTITIES] = merged.model_dump(mode="json") row[COL_LATENT_FAILED_WINDOWS] = failed diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index 7c9a041c..57fc4cec 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -81,15 +81,11 @@ _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS: int = AnonymizerDetectConfig.model_fields[ "validation_excerpt_window_chars" ].default -_DEFAULT_WINDOW_MAX_RENDER_CHARS: int = AnonymizerDetectConfig.model_fields[ - "detection_window_max_render_chars" -].default +_DEFAULT_WINDOW_MAX_RENDER_CHARS: int = AnonymizerDetectConfig.model_fields["detection_window_max_render_chars"].default _DEFAULT_WINDOW_SAFETY_MARGIN_CHARS: int = AnonymizerDetectConfig.model_fields[ "detection_window_safety_margin_chars" ].default -_DEFAULT_WINDOW_OVERLAP_CHARS: int = AnonymizerDetectConfig.model_fields[ - "detection_window_overlap_chars" -].default +_DEFAULT_WINDOW_OVERLAP_CHARS: int = AnonymizerDetectConfig.model_fields["detection_window_overlap_chars"].default @dataclass(frozen=True) diff --git a/src/anonymizer/engine/replace/chunked_replace.py b/src/anonymizer/engine/replace/chunked_replace.py index 0b81dfdc..7228237a 100644 --- a/src/anonymizer/engine/replace/chunked_replace.py +++ b/src/anonymizer/engine/replace/chunked_replace.py @@ -59,6 +59,7 @@ def _clip(text: str, limit: int = _LOG_CLIP_CHARS) -> str: flat = " ".join(text.split()) return flat if len(flat) <= limit else f"{flat[:limit]}… (+{len(flat) - limit} chars)" + _PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) @@ -296,7 +297,13 @@ def generate_replacement_map_row( logger.info( "replace-map: rendered prompt %d chars > cap %d; chunking %d-char document into %d boundary " "window(s) (initial_window=%d, delimiter=%r, summary_max=%d)", - len(single_rendered), cap, len(text), len(windows), initial_window, params.delimiter, params.summary_max_chars, + len(single_rendered), + cap, + len(text), + len(windows), + initial_window, + params.delimiter, + params.summary_max_chars, ) accumulated: list[dict[str, str]] = [] summary = "" @@ -305,7 +312,13 @@ def generate_replacement_map_row( chunk_entities = new_chunk_entities(spans, start, end, already) logger.debug( "replace-map window %d/%d: chars [%d, %d) size=%d, %d new entit(y/ies), %d already mapped", - i + 1, len(windows), start, end, end - start, len(chunk_entities), len(already), + i + 1, + len(windows), + start, + end, + end - start, + len(chunk_entities), + len(already), ) if chunk_entities: tagged = chunk_tagged_text(text, spans, start, end, notation) @@ -322,7 +335,10 @@ def generate_replacement_map_row( accumulated = merge_replacements(accumulated, chunk_map) logger.debug( "replace-map window %d/%d: chunk produced %d replacement(s); accumulated map now %d entries", - i + 1, len(windows), len(chunk_map.replacements), len(accumulated), + i + 1, + len(windows), + len(chunk_map.replacements), + len(accumulated), ) else: logger.debug("replace-map window %d/%d: no new entities, skipping map call", i + 1, len(windows)) @@ -338,16 +354,21 @@ def generate_replacement_map_row( ) logger.debug( "replace-map window %d/%d: rolling summary updated -> %d chars: %s", - i + 1, len(windows), len(summary), _clip(summary), + i + 1, + len(windows), + len(summary), + _clip(summary), ) logger.info( "replace-map: %d window(s) over %d chars -> %d total replacement(s)", - len(windows), len(text), len(accumulated), + len(windows), + len(text), + len(accumulated), + ) + row[COL_REPLACEMENT_MAP] = EntityReplacementMapSchema.model_validate({"replacements": accumulated}).model_dump( + mode="json" ) - row[COL_REPLACEMENT_MAP] = EntityReplacementMapSchema.model_validate( - {"replacements": accumulated} - ).model_dump(mode="json") return row diff --git a/src/anonymizer/engine/rewrite/chunked_final_judge.py b/src/anonymizer/engine/rewrite/chunked_final_judge.py index 4070dce7..7a2c2124 100644 --- a/src/anonymizer/engine/rewrite/chunked_final_judge.py +++ b/src/anonymizer/engine/rewrite/chunked_final_judge.py @@ -191,9 +191,7 @@ def judge_row(row: dict[str, Any], params: WindowedJudgeParams, models: dict[str results: list[JudgeScoresSchema] = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ - executor.submit( - _judge_window, facade=facade, prompt=prompt, system_prompt=params.system_prompt, idx=idx - ) + executor.submit(_judge_window, facade=facade, prompt=prompt, system_prompt=params.system_prompt, idx=idx) for idx, prompt in enumerate(prompts) ] for future in futures: diff --git a/src/anonymizer/engine/rewrite/chunked_rewrite.py b/src/anonymizer/engine/rewrite/chunked_rewrite.py index 6c893e09..1c4e0324 100644 --- a/src/anonymizer/engine/rewrite/chunked_rewrite.py +++ b/src/anonymizer/engine/rewrite/chunked_rewrite.py @@ -52,6 +52,7 @@ def _clip(text: str, limit: int = _LOG_CLIP_CHARS) -> str: flat = " ".join(text.split()) return flat if len(flat) <= limit else f"{flat[:limit]}… (+{len(flat) - limit} chars)" + _PROMPT_ENV = Environment(loader=BaseLoader(), autoescape=False, undefined=StrictUndefined, keep_trailing_newline=True) @@ -102,7 +103,11 @@ def _filter_disposition_to_chunk(block: list[dict[str, Any]], chunk_raw: str) -> """Keep disposition entries whose entity value appears in this chunk's raw text.""" if not isinstance(block, list): return [] - return [e for e in block if isinstance(e, dict) and str(e.get("entity_value", "")) and str(e["entity_value"]) in chunk_raw] + return [ + e + for e in block + if isinstance(e, dict) and str(e.get("entity_value", "")) and str(e["entity_value"]) in chunk_raw + ] def _render_chunk_prompt(*, template: str, chunk_row: dict[str, Any], summary: str) -> str: @@ -128,7 +133,13 @@ def _rewrite_chunk(*, facade: Any, prompt: str, system_prompt: str | None, purpo def _update_summary( - *, facade: Any, prev_summary: str, rewritten_chunk: str, summary_max_chars: int, system_prompt: str | None, purpose: str + *, + facade: Any, + prev_summary: str, + rewritten_chunk: str, + summary_max_chars: int, + system_prompt: str | None, + purpose: str, ) -> str: recipe = TextResponseRecipe() rendered = _compile_template(_SUMMARY_PROMPT).render( @@ -163,8 +174,10 @@ def generate_rewrite_row( if len(single_rendered) <= cap: logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) text = _rewrite_chunk( - facade=facade, prompt=_compile_template(params.single_call_prompt_template).render(**row), - system_prompt=params.system_prompt, purpose="rewrite-generation", + facade=facade, + prompt=_compile_template(params.single_call_prompt_template).render(**row), + system_prompt=params.system_prompt, + purpose="rewrite-generation", ) row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=text).model_dump(mode="json") return row @@ -181,7 +194,13 @@ def generate_rewrite_row( logger.info( "rewrite: rendered prompt %d chars > cap %d; chunking %d-char tagged document into %d boundary " "window(s) (initial_window=%d, delimiter=%r, summary_max=%d)", - len(single_rendered), cap, len(tagged), len(windows), initial_window, params.delimiter, params.summary_max_chars, + len(single_rendered), + cap, + len(tagged), + len(windows), + initial_window, + params.delimiter, + params.summary_max_chars, ) rewritten_parts: list[str] = [] summary = "" @@ -198,29 +217,50 @@ def generate_rewrite_row( prompt = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=chunk_row, summary=summary) logger.debug( "rewrite window %d/%d: chars [%d, %d) size=%d, %d in-chunk disposition entr(y/ies), prompt=%d chars", - i + 1, len(windows), start, end, end - start, len(chunk_disposition), len(prompt), + i + 1, + len(windows), + start, + end, + end - start, + len(chunk_disposition), + len(prompt), ) rewritten_chunk = _rewrite_chunk( - facade=facade, prompt=prompt, system_prompt=params.system_prompt, purpose=f"rewrite-generation-chunk-{start}" + facade=facade, + prompt=prompt, + system_prompt=params.system_prompt, + purpose=f"rewrite-generation-chunk-{start}", ) logger.debug( - "rewrite window %d/%d: produced %d chars of rewritten text", i + 1, len(windows), len(rewritten_chunk), + "rewrite window %d/%d: produced %d chars of rewritten text", + i + 1, + len(windows), + len(rewritten_chunk), ) rewritten_parts.append(rewritten_chunk) if i < len(windows) - 1: summary = _update_summary( - facade=facade, prev_summary=summary, rewritten_chunk=rewritten_chunk, - summary_max_chars=params.summary_max_chars, system_prompt=params.system_prompt, + facade=facade, + prev_summary=summary, + rewritten_chunk=rewritten_chunk, + summary_max_chars=params.summary_max_chars, + system_prompt=params.system_prompt, purpose=f"rewrite-summary-{start}", ) logger.debug( "rewrite window %d/%d: continuity summary updated -> %d chars: %s", - i + 1, len(windows), len(summary), _clip(summary), + i + 1, + len(windows), + len(summary), + _clip(summary), ) stitched = "\n".join(part for part in rewritten_parts if part) logger.info( - "rewrite: %d window(s) over %d chars -> %d chars stitched output", len(windows), len(tagged), len(stitched), + "rewrite: %d window(s) over %d chars -> %d chars stitched output", + len(windows), + len(tagged), + len(stitched), ) row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=stitched).model_dump(mode="json") return row diff --git a/src/anonymizer/engine/rewrite/domain_classification.py b/src/anonymizer/engine/rewrite/domain_classification.py index 62173bc6..945ebed6 100644 --- a/src/anonymizer/engine/rewrite/domain_classification.py +++ b/src/anonymizer/engine/rewrite/domain_classification.py @@ -32,6 +32,7 @@ def _first_output(outputs: list[Any]) -> dict[str, Any]: """Domain is a single doc-level label: keep the first window's classification.""" return outputs[0].model_dump(mode="json") + # --------------------------------------------------------------------------- # Single source of truth for rewrite-domain metadata. # diff --git a/src/anonymizer/engine/rewrite/qa_generation.py b/src/anonymizer/engine/rewrite/qa_generation.py index f8ce5860..6427accb 100644 --- a/src/anonymizer/engine/rewrite/qa_generation.py +++ b/src/anonymizer/engine/rewrite/qa_generation.py @@ -65,6 +65,7 @@ def _concat_meaning_units(outputs: list[Any]) -> dict[str, Any]: units.append({**unit.model_dump(mode="json"), "id": len(units) + 1}) return MeaningUnitsSchema.model_validate({"units": units}).model_dump(mode="json") + # --------------------------------------------------------------------------- # Stage 1 pre-step: format disposition → disposition block # --------------------------------------------------------------------------- diff --git a/tests/engine/test_chunked_detection.py b/tests/engine/test_chunked_detection.py index b94ff4e8..ee1108af 100644 --- a/tests/engine/test_chunked_detection.py +++ b/tests/engine/test_chunked_detection.py @@ -67,7 +67,9 @@ def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs i = wt.find(t, i + 1) for k in range(len(t) - 1, 2, -1): # truncated prefix at right edge if wt.endswith(t[:k]): - ents.append({"text": t[:k], "label": self.label, "start": len(wt) - k, "end": len(wt), "score": 0.5}) + ents.append( + {"text": t[:k], "label": self.label, "start": len(wt) - k, "end": len(wt), "score": 0.5} + ) break for k in range(len(t) - 1, 2, -1): # truncated suffix at left edge if wt.startswith(t[-k:]): diff --git a/tests/engine/test_chunked_final_judge.py b/tests/engine/test_chunked_final_judge.py index c953500e..858af698 100644 --- a/tests/engine/test_chunked_final_judge.py +++ b/tests/engine/test_chunked_final_judge.py @@ -76,7 +76,9 @@ def test_plan_judge_windows_single_when_small(): def test_plan_judge_windows_splits_when_large(): original, rewritten = "o" * 9000, "r" * 9000 # tiny budget -> multiple windows; both texts sliced into the same count - prompts = plan_judge_windows(original=original, rewritten=rewritten, template=_TEMPLATE, cap=4200, safety_margin_chars=0) + prompts = plan_judge_windows( + original=original, rewritten=rewritten, template=_TEMPLATE, cap=4200, safety_margin_chars=0 + ) assert len(prompts) > 1 diff --git a/tests/engine/test_chunked_replace.py b/tests/engine/test_chunked_replace.py index d1c8f36f..9125896a 100644 --- a/tests/engine/test_chunked_replace.py +++ b/tests/engine/test_chunked_replace.py @@ -7,7 +7,6 @@ import json import re -from typing import Any import pytest @@ -32,12 +31,21 @@ # Single-call prompt stand-in (the real one is large); references the columns the # fast path renders with the row. -_SINGLE_PROMPT = "MAP {{ tagged_text }} || {% for e in _entities_for_replace %}{{ e.value }};{% endfor %} || {{ _entity_examples }}" +_SINGLE_PROMPT = ( + "MAP {{ tagged_text }} || {% for e in _entities_for_replace %}{{ e.value }};{% endfor %} || {{ _entity_examples }}" +) def _span(value: str, label: str, start: int, end: int) -> EntitySpan: - return EntitySpan(entity_id=f"{label}_{start}", value=value, label=label, - start_position=start, end_position=end, score=1.0, source="d") + return EntitySpan( + entity_id=f"{label}_{start}", + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="d", + ) class TestNewChunkEntities: @@ -57,10 +65,12 @@ class TestMergeReplacements: def test_dedupes_by_original_label_earlier_wins(self) -> None: existing = [{"original": "Alice", "label": "name", "synthetic": "Jane"}] new = EntityReplacementMapSchema.model_validate( - {"replacements": [ - {"original": "Alice", "label": "name", "synthetic": "DIFFERENT"}, # dup -> ignored - {"original": "Bob", "label": "name", "synthetic": "Mike"}, - ]} + { + "replacements": [ + {"original": "Alice", "label": "name", "synthetic": "DIFFERENT"}, # dup -> ignored + {"original": "Bob", "label": "name", "synthetic": "Mike"}, + ] + } ) merged = merge_replacements(existing, new) assert merged == [ @@ -100,13 +110,17 @@ class TestGenerateReplacementMapRow: def test_fast_path_single_call(self) -> None: facade = _Fake() row = { - COL_TEXT: "Alice here", COL_TAGGED_TEXT: "Alice here", COL_TAG_NOTATION: "xml", + COL_TEXT: "Alice here", + COL_TAGGED_TEXT: "Alice here", + COL_TAG_NOTATION: "xml", COL_ENTITY_EXAMPLES: "{}", COL_ENTITIES_FOR_REPLACE: [{"value": "Alice", "labels": ["name"], "labels_str": "name"}], COL_FINAL_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), } # NB: fast path renders the single prompt; our stand-in lists entities so the fake maps them. - params = WindowedReplaceParams(alias="r", single_call_prompt_template='MAP - "Alice" (name)', max_render_chars=1_000_000) + params = WindowedReplaceParams( + alias="r", single_call_prompt_template='MAP - "Alice" (name)', max_render_chars=1_000_000 + ) out = generate_replacement_map_row(row, params, {"r": facade}) assert facade.map_calls == 1 and facade.summary_calls == 0 m = EntityReplacementMapSchema.model_validate(out[COL_REPLACEMENT_MAP]) @@ -122,17 +136,33 @@ def test_chunked_with_rolling_summary_and_dedupe(self) -> None: spans = [] for val in ["Alice", "Bob", "Carol"]: for mobj in re.finditer(re.escape(val), text): - spans.append({"id": f"name_{mobj.start()}", "value": val, "label": "name", - "start_position": mobj.start(), "end_position": mobj.end(), "score": 1.0, "source": "d"}) + spans.append( + { + "id": f"name_{mobj.start()}", + "value": val, + "label": "name", + "start_position": mobj.start(), + "end_position": mobj.end(), + "score": 1.0, + "source": "d", + } + ) final = EntitiesSchema.model_validate({"entities": spans}).model_dump(mode="json") facade = _Fake() row = { - COL_TEXT: text, COL_TAGGED_TEXT: text, COL_TAG_NOTATION: "xml", COL_ENTITY_EXAMPLES: "{}", - COL_ENTITIES_FOR_REPLACE: [{"value": v, "labels": ["name"], "labels_str": "name"} for v in ["Alice", "Bob", "Carol"]], + COL_TEXT: text, + COL_TAGGED_TEXT: text, + COL_TAG_NOTATION: "xml", + COL_ENTITY_EXAMPLES: "{}", + COL_ENTITIES_FOR_REPLACE: [ + {"value": v, "labels": ["name"], "labels_str": "name"} for v in ["Alice", "Bob", "Carol"] + ], COL_FINAL_ENTITIES: final, } - params = WindowedReplaceParams(alias="r", single_call_prompt_template=_SINGLE_PROMPT, max_render_chars=4000, safety_margin_chars=0) + params = WindowedReplaceParams( + alias="r", single_call_prompt_template=_SINGLE_PROMPT, max_render_chars=4000, safety_margin_chars=0 + ) out = generate_replacement_map_row(row, params, {"r": facade}) result = EntityReplacementMapSchema.model_validate(out[COL_REPLACEMENT_MAP]) assert facade.map_calls == 3 # one per window with new entities @@ -142,4 +172,8 @@ def test_chunked_with_rolling_summary_and_dedupe(self) -> None: def test_missing_alias_raises(self) -> None: with pytest.raises(KeyError, match="not present in models"): - generate_replacement_map_row({COL_TEXT: "x"}, WindowedReplaceParams(alias="r", single_call_prompt_template="x", max_render_chars=10), {}) + generate_replacement_map_row( + {COL_TEXT: "x"}, + WindowedReplaceParams(alias="r", single_call_prompt_template="x", max_render_chars=10), + {}, + ) diff --git a/tests/engine/test_chunked_rewrite.py b/tests/engine/test_chunked_rewrite.py index cee4a74c..12cb3095 100644 --- a/tests/engine/test_chunked_rewrite.py +++ b/tests/engine/test_chunked_rewrite.py @@ -79,4 +79,6 @@ def test_chunked_stitches_with_rolling_summary(self) -> None: def test_missing_alias_raises(self) -> None: with pytest.raises(KeyError, match="not present in models"): - generate_rewrite_row(_row("x"), WindowedRewriteParams(alias="w", single_call_prompt_template="x", max_render_chars=10), {}) + generate_rewrite_row( + _row("x"), WindowedRewriteParams(alias="w", single_call_prompt_template="x", max_render_chars=10), {} + ) diff --git a/tests/engine/test_windowing.py b/tests/engine/test_windowing.py index 42507693..23050c2b 100644 --- a/tests/engine/test_windowing.py +++ b/tests/engine/test_windowing.py @@ -5,8 +5,6 @@ from __future__ import annotations -import pytest - from anonymizer.engine.windowing import iter_boundary_windows, next_window_end From 6639ecbb7dbcec00a08bcd7d737ee1cc8a4a8443 Mon Sep 17 00:00:00 2001 From: eurekayuan Date: Thu, 11 Jun 2026 13:42:26 -0700 Subject: [PATCH 4/4] Address review: thread window sizing, localize window failures, validate overlap, flag empty rewrite chunks Signed-off-by: eurekayuan --- src/anonymizer/config/anonymizer_config.py | 35 ++++++++++- src/anonymizer/engine/constants.py | 4 ++ .../engine/rewrite/chunked_rewrite.py | 27 ++++++++- .../engine/rewrite/chunked_steps.py | 26 +++++++- .../engine/rewrite/domain_classification.py | 12 +++- src/anonymizer/engine/rewrite/final_judge.py | 12 +++- .../engine/rewrite/qa_generation.py | 17 ++++-- .../engine/rewrite/rewrite_workflow.py | 24 +++++++- .../engine/rewrite/sensitivity_disposition.py | 12 +++- tests/config/test_anonymizer_config.py | 49 +++++++++++++++ tests/engine/test_chunked_rewrite.py | 30 ++++++++++ tests/engine/test_chunked_steps.py | 59 +++++++++++++++++++ tests/engine/test_domain_classification.py | 12 ++++ tests/engine/test_final_judge.py | 16 +++++ tests/engine/test_qa_generation.py | 14 +++++ tests/engine/test_sensitivity_disposition.py | 13 ++++ 16 files changed, 347 insertions(+), 15 deletions(-) diff --git a/src/anonymizer/config/anonymizer_config.py b/src/anonymizer/config/anonymizer_config.py index 9fe3f09d..efebee23 100644 --- a/src/anonymizer/config/anonymizer_config.py +++ b/src/anonymizer/config/anonymizer_config.py @@ -31,6 +31,12 @@ # Clamped so it never exceeds NDD's cap if that is ever lowered. _DEFAULT_WINDOW_MAX_RENDER_CHARS = min(128_000, _NDD_MAX_RENDERED_LEN) +# Floor on the per-window character budget; mirrors ``_MIN_WINDOW_CHARS`` in the +# engine's chunked-detection planners. Defined here (rather than imported) to keep +# the user-facing config free of an engine import cycle. Used only to compute the +# effective window size for overlap validation. +_MIN_DETECTION_WINDOW_CHARS = 4_000 + def is_remote_input_source(value: str) -> bool: """Return True when the input source is an HTTP(S) URL.""" @@ -135,7 +141,9 @@ class Detect(BaseModel): ge=0, description=( "Overlap between adjacent augmentation/latent windows so an entity straddling a " - "window boundary is fully visible in at least one window." + "window boundary is fully visible in at least one window. Must be smaller than the " + "effective window size (max_render_chars - safety_margin_chars, floored at 4000); a " + "larger overlap stalls the planners to one character per step and is rejected at config time." ), ) @@ -152,6 +160,31 @@ def validate_entity_labels(cls, value: list[str] | None) -> list[str] | None: logger.warning("entity_labels contained duplicates, removed automatically.") return deduped + @model_validator(mode="after") + def validate_window_overlap(self) -> Detect: + """Reject overlaps that would stall the windowed detection planners. + + The augmentation/latent planners advance by ``window - overlap`` characters + per step. The effective window mirrors the engine's sizing: + ``max(_MIN_DETECTION_WINDOW_CHARS, max_render - safety_margin)``. When the + overlap meets or exceeds that, the stride collapses to a single character + and one long row explodes into tens of thousands of model calls (a 20k-char + row became 16,001 windows in testing), so require it to be strictly smaller. + """ + effective_window = max( + _MIN_DETECTION_WINDOW_CHARS, + self.detection_window_max_render_chars - self.detection_window_safety_margin_chars, + ) + if self.detection_window_overlap_chars >= effective_window: + raise ValueError( + f"detection_window_overlap_chars ({self.detection_window_overlap_chars}) must be smaller than " + f"the effective window size ({effective_window} chars = max({_MIN_DETECTION_WINDOW_CHARS}, " + "detection_window_max_render_chars - detection_window_safety_margin_chars)). A larger overlap " + "makes the windowed planners advance one character at a time, exploding a single row into " + "thousands of model calls." + ) + return self + class Rewrite(BaseModel): """Configuration for rewrite-mode execution.""" diff --git a/src/anonymizer/engine/constants.py b/src/anonymizer/engine/constants.py index 0e4e8da6..713043a1 100644 --- a/src/anonymizer/engine/constants.py +++ b/src/anonymizer/engine/constants.py @@ -101,6 +101,10 @@ COL_REWRITE_DISPOSITION_BLOCK = "_rewrite_disposition_block" COL_REPLACEMENT_MAP_FOR_PROMPT = "_replacement_map_for_prompt" COL_FULL_REWRITE = "_full_rewrite" +# Number of chunked-rewrite windows that returned empty text and were dropped +# from the stitched output. 0 on the single-call fast path. Surfaced (like the +# detection failed-window counts) so empty sections are not mistaken for success. +COL_REWRITE_EMPTY_WINDOWS = "_rewrite_empty_windows" COL_MEANING_UNITS = "_meaning_units" COL_MEANING_UNITS_SERIALIZED = "_meaning_units_serialized" COL_QUALITY_QA = "_quality_qa" diff --git a/src/anonymizer/engine/rewrite/chunked_rewrite.py b/src/anonymizer/engine/rewrite/chunked_rewrite.py index 1c4e0324..437da4a1 100644 --- a/src/anonymizer/engine/rewrite/chunked_rewrite.py +++ b/src/anonymizer/engine/rewrite/chunked_rewrite.py @@ -30,6 +30,7 @@ COL_FULL_REWRITE, COL_REPLACEMENT_MAP_FOR_PROMPT, COL_REWRITE_DISPOSITION_BLOCK, + COL_REWRITE_EMPTY_WINDOWS, COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, @@ -180,6 +181,7 @@ def generate_rewrite_row( purpose="rewrite-generation", ) row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=text).model_dump(mode="json") + row[COL_REWRITE_EMPTY_WINDOWS] = 0 return row # Chunked path: rewrite each boundary window with continuity carry-over, then stitch. @@ -203,6 +205,7 @@ def generate_rewrite_row( params.summary_max_chars, ) rewritten_parts: list[str] = [] + empty_windows = 0 summary = "" for i, (start, end) in enumerate(windows): chunk_tagged = tagged[start:end] @@ -237,6 +240,16 @@ def generate_rewrite_row( len(windows), len(rewritten_chunk), ) + if not rewritten_chunk.strip(): + empty_windows += 1 + logger.warning( + "rewrite window %d/%d (chars [%d, %d)): model returned an empty rewrite; this section will be " + "dropped from the stitched output", + i + 1, + len(windows), + start, + end, + ) rewritten_parts.append(rewritten_chunk) if i < len(windows) - 1: summary = _update_summary( @@ -255,14 +268,25 @@ def generate_rewrite_row( _clip(summary), ) + # ``if part`` drops empty chunks from the join; ``empty_windows`` records how + # many so a section silently lost to an empty model response is visible as a + # review signal rather than passing as successful output. stitched = "\n".join(part for part in rewritten_parts if part) + if empty_windows: + logger.warning( + "rewrite: %d of %d window(s) returned empty text and were omitted from the stitched output", + empty_windows, + len(windows), + ) logger.info( - "rewrite: %d window(s) over %d chars -> %d chars stitched output", + "rewrite: %d window(s) over %d chars -> %d chars stitched output (%d empty window(s))", len(windows), len(tagged), len(stitched), + empty_windows, ) row[COL_FULL_REWRITE] = RewriteOutputSchema(rewritten_text=stitched).model_dump(mode="json") + row[COL_REWRITE_EMPTY_WINDOWS] = empty_windows return row @@ -279,6 +303,7 @@ def make_windowed_rewrite_generator(alias: str) -> Any: COL_REWRITE_DISPOSITION_BLOCK, COL_REPLACEMENT_MAP_FOR_PROMPT, ], + side_effect_columns=[COL_REWRITE_EMPTY_WINDOWS], model_aliases=[alias], ) def windowed_rewrite( diff --git a/src/anonymizer/engine/rewrite/chunked_steps.py b/src/anonymizer/engine/rewrite/chunked_steps.py index 22d8f234..92afdf95 100644 --- a/src/anonymizer/engine/rewrite/chunked_steps.py +++ b/src/anonymizer/engine/rewrite/chunked_steps.py @@ -99,10 +99,34 @@ def _call(prompt: str, purpose: str) -> Any: windows = iter_boundary_windows(text, initial_window, delimiter=params.delimiter) if params.first_only: windows = windows[:1] + # Run each window independently: a single transient model error (or a chunk + # the model cannot parse into the schema) should drop only that window, not + # the whole record. Failures are logged and skipped; the all-failed case is + # handled explicitly below so we never merge an empty output set silently. outputs = [] + failed = 0 for start, end in windows: rendered = _compile_template(params.prompt_template).render(**{**row, params.text_column: text[start:end]}) - outputs.append(_call(rendered, f"{purpose_prefix}-{start}")) + try: + outputs.append(_call(rendered, f"{purpose_prefix}-{start}")) + except Exception: + failed += 1 + logger.warning( + "windowed step %s: window [%d, %d) failed and was skipped", + purpose_prefix, + start, + end, + exc_info=True, + ) + if not outputs: + raise RuntimeError( + f"windowed step {purpose_prefix!r}: all {len(windows)} window(s) failed; no output produced " + f"for column {params.output_column!r}." + ) + if failed: + logger.warning( + "windowed step %s: %d of %d window(s) failed and were skipped", purpose_prefix, failed, len(windows) + ) logger.debug("windowed step %s: %d window(s) over %d chars", purpose_prefix, len(windows), len(text)) row[params.output_column] = merge_fn(outputs) return row diff --git a/src/anonymizer/engine/rewrite/domain_classification.py b/src/anonymizer/engine/rewrite/domain_classification.py index 945ebed6..a8912a1f 100644 --- a/src/anonymizer/engine/rewrite/domain_classification.py +++ b/src/anonymizer/engine/rewrite/domain_classification.py @@ -278,8 +278,16 @@ def columns( *, selected_models: RewriteModelSelection, data_summary: str | None = None, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> list[ColumnConfigT]: domain_alias = resolve_model_alias("domain_classifier", selected_models) + # Honor caller-provided (Detect-config-derived) window sizing; fall back + # to the shared defaults so a lowered detection cap also bounds this stage. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) return [ CustomColumnConfig( name=COL_DOMAIN, @@ -295,8 +303,8 @@ def columns( prompt_template=_get_domain_classification_prompt(data_summary), output_column=COL_DOMAIN, text_column=COL_TEXT, - max_render_chars=_DEFAULT_MAX_RENDER_CHARS, - safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, first_only=True, ), ), diff --git a/src/anonymizer/engine/rewrite/final_judge.py b/src/anonymizer/engine/rewrite/final_judge.py index b6f620ca..06b6f0ab 100644 --- a/src/anonymizer/engine/rewrite/final_judge.py +++ b/src/anonymizer/engine/rewrite/final_judge.py @@ -271,8 +271,16 @@ def columns( selected_models: RewriteModelSelection, privacy_goal: PrivacyGoal, evaluation: EvaluationCriteria, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> list[ColumnConfigT]: judge_alias = resolve_model_alias("judge", selected_models) + # Honor caller-provided (Detect-config-derived) window sizing; fall back + # to the shared defaults so a lowered detection cap also bounds this stage. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) return [ # Windowed final judge: long documents are split into parallel, independent @@ -284,8 +292,8 @@ def columns( generator_params=WindowedJudgeParams( alias=judge_alias, prompt_template=_judge_prompt(privacy_goal), - max_render_chars=_DEFAULT_MAX_RENDER_CHARS, - safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, ), ), CustomColumnConfig( diff --git a/src/anonymizer/engine/rewrite/qa_generation.py b/src/anonymizer/engine/rewrite/qa_generation.py index 6427accb..27eaa199 100644 --- a/src/anonymizer/engine/rewrite/qa_generation.py +++ b/src/anonymizer/engine/rewrite/qa_generation.py @@ -453,9 +453,18 @@ def columns( self, *, selected_models: RewriteModelSelection, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> list[ColumnConfigT]: meaning_extractor_alias = resolve_model_alias("meaning_extractor", selected_models) qa_generator_alias = resolve_model_alias("qa_generator", selected_models) + # Honor caller-provided (Detect-config-derived) window sizing; fall back + # to the shared defaults so a lowered detection cap also bounds meaning-unit + # extraction and the quality-QA batching below. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) return [ CustomColumnConfig( name=COL_SENSITIVITY_DISPOSITION_BLOCK, @@ -480,8 +489,8 @@ def columns( prompt_template=_get_meaning_unit_extraction_prompt(), output_column=COL_MEANING_UNITS, text_column=COL_TEXT, - max_render_chars=_DEFAULT_MAX_RENDER_CHARS, - safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, ), ), CustomColumnConfig( @@ -492,8 +501,8 @@ def columns( name=COL_QUALITY_QA, generator_function=_make_quality_qa_column( qa_generator_alias, - _DEFAULT_MAX_RENDER_CHARS, - _DEFAULT_SAFETY_MARGIN_CHARS, + max_render_chars, + safety_margin_chars, ), ), CustomColumnConfig( diff --git a/src/anonymizer/engine/rewrite/rewrite_workflow.py b/src/anonymizer/engine/rewrite/rewrite_workflow.py index 94f95976..d03a03a9 100644 --- a/src/anonymizer/engine/rewrite/rewrite_workflow.py +++ b/src/anonymizer/engine/rewrite/rewrite_workflow.py @@ -221,15 +221,29 @@ def run( all_failed.extend(replace_result.failed_records) # --- Step 2: domain, disposition, QA, rewrite (single adapter call) --- + # Thread the user-supplied (Detect-config-derived) window sizing through + # every windowed stage so lowering the detection cap bounds prompt size in + # all of them, not just rewrite generation. pipeline_columns = [ - *self._domain_wf.columns(selected_models=selected_models, data_summary=data_summary), + *self._domain_wf.columns( + selected_models=selected_models, + data_summary=data_summary, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, + ), *self._disposition_wf.columns( selected_models=selected_models, privacy_goal=privacy_goal, data_summary=data_summary, strict_entity_protection=strict_entity_protection, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, + ), + *self._qa_wf.columns( + selected_models=selected_models, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, ), - *self._qa_wf.columns(selected_models=selected_models), *self._rewrite_gen_wf.columns( window_max_render_chars=window_max_render_chars, window_safety_margin_chars=window_safety_margin_chars, @@ -269,6 +283,8 @@ def run( privacy_goal=privacy_goal, evaluation=evaluation, preview_num_records=preview_num_records, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, ) all_failed.extend(judge_failed) @@ -386,12 +402,16 @@ def _run_final_judge( privacy_goal: PrivacyGoal, evaluation: EvaluationCriteria, preview_num_records: int | None, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> tuple[pd.DataFrame, list[FailedRecord]]: try: judge_columns = self._judge_wf.columns( selected_models=selected_models, privacy_goal=privacy_goal, evaluation=evaluation, + window_max_render_chars=window_max_render_chars, + window_safety_margin_chars=window_safety_margin_chars, ) judge_seed = select_seed_cols(df, derive_seed_columns(judge_columns, df)) judge_result = self._adapter.run_workflow( diff --git a/src/anonymizer/engine/rewrite/sensitivity_disposition.py b/src/anonymizer/engine/rewrite/sensitivity_disposition.py index 0879ff0f..79d71b79 100644 --- a/src/anonymizer/engine/rewrite/sensitivity_disposition.py +++ b/src/anonymizer/engine/rewrite/sensitivity_disposition.py @@ -300,9 +300,17 @@ def columns( privacy_goal: PrivacyGoal, data_summary: str | None = None, strict_entity_protection: bool = False, + window_max_render_chars: int | None = None, + window_safety_margin_chars: int | None = None, ) -> list[ColumnConfigT]: disposition_alias = resolve_model_alias("disposition_analyzer", selected_models) output_schema = StrictSensitivityDispositionSchema if strict_entity_protection else SensitivityDispositionSchema + # Honor caller-provided (Detect-config-derived) window sizing; fall back + # to the shared defaults so a lowered detection cap also bounds this stage. + max_render_chars = window_max_render_chars if window_max_render_chars is not None else _DEFAULT_MAX_RENDER_CHARS + safety_margin_chars = ( + window_safety_margin_chars if window_safety_margin_chars is not None else _DEFAULT_SAFETY_MARGIN_CHARS + ) return [ CustomColumnConfig( name=COL_SENSITIVITY_DISPOSITION, @@ -329,8 +337,8 @@ def columns( ), output_column=COL_SENSITIVITY_DISPOSITION, text_column=COL_TAGGED_TEXT, - max_render_chars=_DEFAULT_MAX_RENDER_CHARS, - safety_margin_chars=_DEFAULT_SAFETY_MARGIN_CHARS, + max_render_chars=max_render_chars, + safety_margin_chars=safety_margin_chars, ), ), ] diff --git a/tests/config/test_anonymizer_config.py b/tests/config/test_anonymizer_config.py index ee8c8594..56ffa2d1 100644 --- a/tests/config/test_anonymizer_config.py +++ b/tests/config/test_anonymizer_config.py @@ -145,3 +145,52 @@ def test_detect_validation_max_entities_per_call_must_be_positive() -> None: def test_detect_validation_excerpt_window_chars_must_be_positive() -> None: with pytest.raises(ValidationError): AnonymizerConfig(detect={"validation_excerpt_window_chars": 0}, replace=Redact()) + + +def test_detect_window_overlap_at_or_above_effective_window_raises() -> None: + # overlap == effective window (max_render - safety_margin) collapses the + # planner stride to one character; a single row explodes into many windows. + with pytest.raises(ValidationError): + AnonymizerConfig( + detect={ + "detection_window_max_render_chars": 20_000, + "detection_window_safety_margin_chars": 0, + "detection_window_overlap_chars": 20_000, + }, + replace=Redact(), + ) + + +def test_detect_window_overlap_below_effective_window_ok() -> None: + config = AnonymizerConfig( + detect={ + "detection_window_max_render_chars": 20_000, + "detection_window_safety_margin_chars": 0, + "detection_window_overlap_chars": 19_999, + }, + replace=Redact(), + ) + assert config.detect.detection_window_overlap_chars == 19_999 + + +def test_detect_window_overlap_uses_min_window_floor() -> None: + # When max_render - safety_margin drops below the 4000-char floor, the floor + # is the effective window: overlaps below it pass, at/above it are rejected. + ok = AnonymizerConfig( + detect={ + "detection_window_max_render_chars": 5_000, + "detection_window_safety_margin_chars": 4_000, # 1000 < 4000 floor + "detection_window_overlap_chars": 3_999, + }, + replace=Redact(), + ) + assert ok.detect.detection_window_overlap_chars == 3_999 + with pytest.raises(ValidationError): + AnonymizerConfig( + detect={ + "detection_window_max_render_chars": 5_000, + "detection_window_safety_margin_chars": 4_000, + "detection_window_overlap_chars": 4_000, + }, + replace=Redact(), + ) diff --git a/tests/engine/test_chunked_rewrite.py b/tests/engine/test_chunked_rewrite.py index 12cb3095..a8fe45da 100644 --- a/tests/engine/test_chunked_rewrite.py +++ b/tests/engine/test_chunked_rewrite.py @@ -11,6 +11,7 @@ COL_FULL_REWRITE, COL_REPLACEMENT_MAP_FOR_PROMPT, COL_REWRITE_DISPOSITION_BLOCK, + COL_REWRITE_EMPTY_WINDOWS, COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, @@ -63,6 +64,35 @@ def test_fast_path_single_call(self) -> None: out = generate_rewrite_row(_row("short tagged text"), params, {"w": facade}) assert facade.rewrite_calls == 1 and facade.summary_calls == 0 assert RewriteOutputSchema.model_validate(out[COL_FULL_REWRITE]).rewritten_text == "OUT1" + assert out[COL_REWRITE_EMPTY_WINDOWS] == 0 + + def test_chunked_counts_and_drops_empty_window(self) -> None: + tagged = ("X" * 4000 + "\n") * 3 # ~12k chars -> several windows + + class _EmptyMiddle: + """Returns an empty rewrite for the second window, normal text otherwise.""" + + def __init__(self) -> None: + self.rewrite_calls = 0 + + def generate(self, *, prompt, parser, system_prompt=None, purpose=None, **kwargs): + if "summary" in (purpose or ""): + return parser("running summary"), [] + self.rewrite_calls += 1 + body = "" if self.rewrite_calls == 2 else "OUT%d" % self.rewrite_calls + return parser('```json\n{"rewritten_text":"%s"}\n```' % body), [] + + facade = _EmptyMiddle() + params = WindowedRewriteParams( + alias="w", single_call_prompt_template=_TEMPLATE, max_render_chars=4000, safety_margin_chars=0 + ) + out = generate_rewrite_row(_row(tagged), params, {"w": facade}) + # The empty window is counted, not silently treated as successful output. + assert out[COL_REWRITE_EMPTY_WINDOWS] == 1 + text = RewriteOutputSchema.model_validate(out[COL_FULL_REWRITE]).rewritten_text + # ...and dropped from the stitch rather than emitted as a blank section. + assert "OUT2" not in text + assert "" not in text.split("\n") def test_chunked_stitches_with_rolling_summary(self) -> None: tagged = ("X" * 4000 + "\n") * 3 # ~12k chars -> several windows diff --git a/tests/engine/test_chunked_steps.py b/tests/engine/test_chunked_steps.py index 2531df62..46f4cca6 100644 --- a/tests/engine/test_chunked_steps.py +++ b/tests/engine/test_chunked_steps.py @@ -128,6 +128,65 @@ def test_missing_alias_raises() -> None: ) +class _FlakyFacade: + """Facade stub that raises on selected (1-indexed) calls, else returns fixed JSON.""" + + def __init__( + self, response_obj: dict[str, Any], *, fail_calls: tuple[int, ...] = (), fail_all: bool = False + ) -> None: + self._payload = "```json\n" + json.dumps(response_obj) + "\n```" + self._fail_calls = set(fail_calls) + self._fail_all = fail_all + self.calls = 0 + + def generate(self, *, prompt: Any, parser: Any, system_prompt: Any = None, purpose: Any = None, **_: Any) -> Any: + self.calls += 1 + if self._fail_all or self.calls in self._fail_calls: + raise RuntimeError("transient model error") + return parser(self._payload), [] + + +def _multi_window_params() -> WindowedStepParams: + return WindowedStepParams( + alias="d", + prompt_template=_get_domain_classification_prompt(None), + output_column=COL_DOMAIN, + text_column=COL_TEXT, + max_render_chars=4000, + safety_margin_chars=0, + ) + + +def test_windowed_step_skips_failed_window_and_merges_survivors() -> None: + # First window's model call fails; remaining windows still run and merge. + facade = _FlakyFacade({"domain": "OTHER", "domain_confidence": 0.9}, fail_calls=(1,)) + long_text = ("x" * 4000 + "\n") * 3 + row = run_windowed_step( + {COL_TEXT: long_text}, + _multi_window_params(), + {"d": facade}, + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain", + ) + assert facade.calls >= 2 # the window after the failure was still attempted + assert DomainClassificationSchema.model_validate(row[COL_DOMAIN]).domain.value == "OTHER" + + +def test_windowed_step_raises_when_all_windows_fail() -> None: + facade = _FlakyFacade({"domain": "OTHER", "domain_confidence": 0.9}, fail_all=True) + long_text = ("x" * 4000 + "\n") * 3 + with pytest.raises(RuntimeError, match="all .* window"): + run_windowed_step( + {COL_TEXT: long_text}, + _multi_window_params(), + {"d": facade}, + schema=DomainClassificationSchema, + merge_fn=_first_output, + purpose_prefix="domain", + ) + + # --------------------------------------------------------------------------- # Merges # --------------------------------------------------------------------------- diff --git a/tests/engine/test_domain_classification.py b/tests/engine/test_domain_classification.py index b814fb60..2d28c1d7 100644 --- a/tests/engine/test_domain_classification.py +++ b/tests/engine/test_domain_classification.py @@ -43,6 +43,18 @@ def test_columns_returns_exactly_three_in_order( assert cols[0].generator_params.first_only is True +def test_columns_threads_window_sizing( + stub_rewrite_model_selection: RewriteModelSelection, +) -> None: + cols = DomainClassificationWorkflow().columns( + selected_models=stub_rewrite_model_selection, + window_max_render_chars=12_345, + window_safety_margin_chars=678, + ) + assert cols[0].generator_params.max_render_chars == 12_345 + assert cols[0].generator_params.safety_margin_chars == 678 + + def test_enrich_domain_populates_supplement_for_known_domain() -> None: result = _enrich_domain({COL_DOMAIN: {"domain": Domain.BIOGRAPHY_PROFILE, "domain_confidence": 0.9}}) assert result[COL_DOMAIN_SUPPLEMENT] == _DOMAIN_BY_ENUM[Domain.BIOGRAPHY_PROFILE].quality_supplement diff --git a/tests/engine/test_final_judge.py b/tests/engine/test_final_judge.py index 87b9c1fd..41a82958 100644 --- a/tests/engine/test_final_judge.py +++ b/tests/engine/test_final_judge.py @@ -112,6 +112,22 @@ def test_judge_column_is_windowed_generator_with_three_rubrics( assert judge_col.generator_params.max_render_chars > 0 +def test_columns_threads_window_sizing( + stub_rewrite_model_selection: RewriteModelSelection, +) -> None: + wf = FinalJudgeWorkflow() + cols = wf.columns( + selected_models=stub_rewrite_model_selection, + privacy_goal=_STUB_PRIVACY_GOAL, + evaluation=_STUB_EVALUATION, + window_max_render_chars=12_345, + window_safety_margin_chars=678, + ) + judge_col = next(c for c in cols if c.name == COL_JUDGE_EVALUATION) + assert judge_col.generator_params.max_render_chars == 12_345 + assert judge_col.generator_params.safety_margin_chars == 678 + + def test_needs_human_review_column_present( stub_rewrite_model_selection: RewriteModelSelection, ) -> None: diff --git a/tests/engine/test_qa_generation.py b/tests/engine/test_qa_generation.py index 0a5fae19..b2e86857 100644 --- a/tests/engine/test_qa_generation.py +++ b/tests/engine/test_qa_generation.py @@ -123,6 +123,20 @@ def test_qa_generator_alias_used( assert metadata["model_aliases"] == [stub_rewrite_model_selection.qa_generator] +def test_columns_threads_window_sizing( + stub_rewrite_model_selection: RewriteModelSelection, +) -> None: + cols = QAGenerationWorkflow().columns( + selected_models=stub_rewrite_model_selection, + window_max_render_chars=12_345, + window_safety_margin_chars=678, + ) + # Meaning-unit extraction (cols[1]) is the windowed metadata generator; its + # params carry the same sizing also passed to the quality-QA batcher. + assert cols[1].generator_params.max_render_chars == 12_345 + assert cols[1].generator_params.safety_margin_chars == 678 + + def test_format_disposition_block_produces_valid_json() -> None: row = {COL_SENSITIVITY_DISPOSITION: _STUB_DISPOSITION} result = _format_disposition_block(row) diff --git a/tests/engine/test_sensitivity_disposition.py b/tests/engine/test_sensitivity_disposition.py index 6685180e..835a15ac 100644 --- a/tests/engine/test_sensitivity_disposition.py +++ b/tests/engine/test_sensitivity_disposition.py @@ -41,6 +41,19 @@ def test_columns_uses_disposition_analyzer_alias( assert cols[0].name == COL_SENSITIVITY_DISPOSITION +def test_columns_threads_window_sizing( + stub_rewrite_model_selection: RewriteModelSelection, +) -> None: + cols = SensitivityDispositionWorkflow().columns( + selected_models=stub_rewrite_model_selection, + privacy_goal=_STUB_PRIVACY_GOAL, + window_max_render_chars=12_345, + window_safety_margin_chars=678, + ) + assert cols[0].generator_params.max_render_chars == 12_345 + assert cols[0].generator_params.safety_margin_chars == 678 + + def test_privacy_goal_interpolated_into_prompt() -> None: prompt = _get_sensitivity_disposition_prompt(_STUB_PRIVACY_GOAL) assert "Protect direct identifiers and quasi-identifier combinations" in prompt