feat: add support for long-context documents#179
Conversation
|
All contributors have signed the DCO ✍️ ✅ |
|
I have read the DCO document and I hereby sign the DCO. |
Greptile SummaryThis PR adds windowed long-context support to every LLM-calling stage of the anonymizer pipeline (detection, augmentation, latent, validation, substitute-map, rewrite, judge) so documents exceeding DataDesigner's 512K
Confidence Score: 4/5The change is broadly safe to merge; the fast path preserves existing behaviour for documents that fit the cap, and the new windowed paths are well-structured with per-window fault tolerance. The core windowing mechanics are correct and well-tested. Two quality gaps exist in the substitute-map chunked path: entities straddling newline boundaries are queued for replacement in a window whose tagged text does not highlight them, and the full replacement map is forwarded to every chunk unfiltered. Neither gap causes outright failure on typical PII inputs, but they are real edge-case correctness concerns. src/anonymizer/engine/replace/chunked_replace.py — both the boundary-entity tagging gap and the unfiltered replacement-map overhead deserve a second look. Important Files Changed
Reviews (2): Last reviewed commit: "Address review: thread window sizing, lo..." | Re-trigger Greptile |
| _clip(summary), | ||
| ) | ||
|
|
||
| stitched = "\n".join(part for part in rewritten_parts if part) |
There was a problem hiding this comment.
Chunk boundaries are aligned to newlines by
iter_boundary_windows, so each tagged[start:end] slice already ends with " ". When the LLM mirrors that structure in its output (natural for paragraph-aware models), every rewritten_chunk also ends with " ", and " ".join(...) then inserts a second newline — producing a blank line between every chunk boundary in the final anonymized document. Joining with "" is sufficient because the delimiter is already part of each chunk.
| stitched = "\n".join(part for part in rewritten_parts if part) | |
| stitched = "".join(part for part in rewritten_parts if part) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| # Fast path: the full single-call rewrite prompt fits under the cap. | ||
| single_rendered = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=row, summary="") | ||
| if len(single_rendered) <= cap: | ||
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | ||
| text = _rewrite_chunk( | ||
| facade=facade, | ||
| prompt=_compile_template(params.single_call_prompt_template).render(**row), | ||
| system_prompt=params.system_prompt, | ||
| purpose="rewrite-generation", | ||
| ) |
There was a problem hiding this comment.
The fast path measures
single_rendered as _render_chunk_prompt(..., summary=""), which prepends the ~270-char continuity preamble, but then the actual LLM call omits that preamble. This means a document whose body-only prompt falls in (cap - 270, cap] chars will be routed into the chunked path unnecessarily. Measure with just the body to match what is actually sent.
| # Fast path: the full single-call rewrite prompt fits under the cap. | |
| single_rendered = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=row, summary="") | |
| if len(single_rendered) <= cap: | |
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | |
| text = _rewrite_chunk( | |
| facade=facade, | |
| prompt=_compile_template(params.single_call_prompt_template).render(**row), | |
| system_prompt=params.system_prompt, | |
| purpose="rewrite-generation", | |
| ) | |
| # Fast path: measure body-only prompt (no continuity preamble) since that is what is sent. | |
| single_rendered = _compile_template(params.single_call_prompt_template).render(**row) | |
| if len(single_rendered) <= cap: | |
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | |
| text = _rewrite_chunk( | |
| facade=facade, | |
| prompt=single_rendered, | |
| system_prompt=params.system_prompt, | |
| purpose="rewrite-generation", | |
| ) |
| ) | ||
| from anonymizer.engine.ndd.model_loader import resolve_model_alias | ||
| from anonymizer.engine.prompt_utils import substitute_placeholders |
There was a problem hiding this comment.
Private symbol imported across module boundary.
_compile_template is module-private (underscore-prefixed) in chunked_steps.py. Importing it here creates a hidden coupling: if the function is renamed or inlined, qa_generation.py breaks without any clear contract. Consider exposing it as a public helper in chunked_steps.py or defining a local copy with its own lru_cache in this module.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump() | ||
| return row | ||
|
|
||
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | ||
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | ||
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | ||
| items: list[dict[str, Any]] = [] | ||
| for batch_idx, batch in enumerate(batches): | ||
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | ||
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | ||
| for item in out.items: | ||
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | ||
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump() |
There was a problem hiding this comment.
The fast path stores the result via
.model_dump() (no mode="json"), while every other windowed generator in this PR consistently uses .model_dump(mode="json"). Without mode="json", Pydantic returns native Python objects rather than JSON-serializable equivalents, which can cause downstream serialization failures. The batched path has the same inconsistency.
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump() | |
| return row | |
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | |
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | |
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | |
| items: list[dict[str, Any]] = [] | |
| for batch_idx, batch in enumerate(batches): | |
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | |
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | |
| for item in out.items: | |
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | |
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump() | |
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump(mode="json") | |
| return row | |
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | |
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | |
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | |
| items: list[dict[str, Any]] = [] | |
| for batch_idx, batch in enumerate(batches): | |
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | |
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | |
| for item in out.items: | |
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | |
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump(mode="json") |
| _DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default | ||
| _DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default | ||
|
|
||
|
|
There was a problem hiding this comment.
Unguarded index on potentially empty list.
_first_output calls outputs[0] without checking length. In run_windowed_step with first_only=True, if iter_boundary_windows returns an empty list, outputs is empty and this raises IndexError. The fast path makes this unreachable today, but a defensive guard would make the failure mode explicit.
andreatgretel
left a comment
There was a problem hiding this comment.
Thanks for taking this on. This is a substantial first PR, and the overall direction makes sense: split long records into bounded windows, carry forward the state needed for consistency, and keep the replacement map explicit.
I left a few comments on edge cases I think are worth tightening before merge. The main themes are:
- thread the user-supplied window sizing through every windowed stage
- make per-window failures local where possible, instead of dropping the whole record
- validate overlap settings early so a bad config cannot explode into thousands of model calls
- avoid silently accepting empty rewrite chunks as successful output
The tests and docs coverage are in good shape, and I think the feature is close. These changes should make it more reliable on real long documents.
| ), | ||
| *self._qa_wf.columns(selected_models=selected_models), | ||
| *self._rewrite_gen_wf.columns( | ||
| window_max_render_chars=window_max_render_chars, |
There was a problem hiding this comment.
this only threads the user-supplied window cap into rewrite generation. domain classification, sensitivity disposition, QA generation, and final judge still build their window params from module defaults, so a user who lowers Detect.detection_window_max_render_chars still gets ~128k prompts in those stages. Could pass the same kwargs through those columns() calls and _run_final_judge too?
| if params.first_only: | ||
| windows = windows[:1] | ||
| outputs = [] | ||
| for start, end in windows: |
There was a problem hiding this comment.
Claude Code caught this one: once this takes the windowed path, a single transient model error or a chunk that legitimately has no meaning units can drop the whole record. Could wrap each window call, skip/log failed windows, and handle the all-failed case explicitly?
| "prompt scaffolding and tags when sizing augmentation/latent windows." | ||
| ), | ||
| ) | ||
| detection_window_overlap_chars: int = Field( |
There was a problem hiding this comment.
suggestion: can we validate that detection_window_overlap_chars is smaller than the effective window size? Right now overlap == window is accepted and the planners advance one character at a time. My smoke test turned a 20k-char row into 16,001 windows.
| _clip(summary), | ||
| ) | ||
|
|
||
| stitched = "\n".join(part for part in rewritten_parts if part) |
There was a problem hiding this comment.
separate from the newline-stitching comment already here: filtering with if part also hides an empty rewrite chunk. If one window returns {"rewritten_text": ""}, that section disappears with no failed-window count or review signal. Maybe count/flag empty chunks instead of treating them as successful output?
…ate overlap, flag empty rewrite chunks Signed-off-by: eurekayuan <zhuoweny@nvidia.com>
Summary
Several stages embedded the whole document in a single prompt and hit DataDesigner's 512K (
MAX_RENDERED_LEN) render cap, failing outright on long inputs. Every such stage is now windowed: each chunked generator renders its own per-window prompt and calls the model directly, bypassing the cap. Stages keep a single-call fast path when the rendered prompt already fits, so short-document behavior is unchanged.Per-stage windowing
chunked_detection.py, new): Overlapping fixed-size character windows; each window is a raw text slice sent to the detector. Per-window offsets are rebased to global, boundary-touching spans are dropped, and overlaps are resolved (resolve_overlaps).chunked_validation.py): Not a text window — batches candidate entities (≤100 per call), each with a ±500-character excerpt. Calls run in parallel across the validator pool with round-robin + failover. Decisions are merged per row; the row is dropped only if every pool member fails.chunked_augmentation.py): Overlapping character windows over tagged text plus seed JSON. A window dynamically shrinks if its rendered prompt exceeds the cap. Outputs are unioned and deduped by(value, label).chunked_latent.py): Same mechanism as augmentation (rewrite mode only); deduped by(label, value).chunked_replace.py): Abutting newline-aligned windows, no overlap. Each chunk carries the accumulated replacement map and a rolling summary, proposing replacements only for new entities so mappings stay consistent across chunks.chunked_rewrite.py): Abutting newline-aligned windows, no overlap. Runs sequentially, passing a continuity preamble and rolling summary between chunks; rewritten parts are stitched.chunked_final_judge.py, new): Splits original and rewritten text into N positionally-paired slices, scores each, and aggregates per-dimension by minimum. Rubric scales are embedded in the prompt with structured output. Replaces the non-windowedLLMJudgeColumnConfig.Parallel processing
ThreadPoolExecutor; the per-alias rate limit still governs real in-flight calls) and merge afterward.Window sizing
detection_window_max_render_chars(default 128 KiB, clamped ≤ NDD's render cap) is the single knob; it is threaded into detection, augmentation, latent, substitute-map, rewrite, and judge.detection_window_safety_margin_chars(8K) leaves headroom for prompt scaffolding;detection_window_overlap_chars(1K) sets the overlap for the overlapping stages; a 4K floor prevents pathological shrinking.Fault tolerance & failure tracking
trace_dataframe(COL_AUGMENTATION_FAILED_WINDOWS/COL_LATENT_FAILED_WINDOWS); the judge degrades to defaults if all windows fail.Observability
Per-window debug logging across all chunked stages: window ranges/sizes, rendered length vs cap, shrink events, rolling-summary contents, and per-stage entity/replacement/window counts.
Type of Change
Testing
make testpasses locallymake checkpasses locally (format + lint + typecheck + lock-check)Documentation
make docs-buildpasses locally