Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ license = "Apache-2.0"
dependencies = [
"data-designer==0.6.0",
"pydantic>=2.9,<3",
"pydantic-settings>=2.12,<3",
"cyclopts>=3",
"pygments>=2.20.0",
"cryptography>=46.0.6",
Expand Down
1 change: 1 addition & 0 deletions src/anonymizer/engine/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
COL_ENTITIES_BY_VALUE = "_entities_by_value"
COL_REPLACED_TEXT = "__nemo_anonymizer_text_output__"
COL_REPLACEMENT_MAP = "_replacement_map"
COL_REPLACEMENT_MAP_SOURCE = "_replacement_map_source"

# LlmReplaceWorkflow internal prompt-construction columns. Created by
# `LlmReplaceWorkflow.generate_map_only` for the replacement-generator prompt
Expand Down
12 changes: 11 additions & 1 deletion src/anonymizer/engine/detection/chunked_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class ChunkedValidationParams(BaseModel):
max_entities_per_call: Upper bound on candidates per chunk.
excerpt_window_chars: Chars of surrounding raw text included in each
chunk's excerpt on either side of the chunk span.
single_chunk_full_text: If True, a row with one validation chunk sees
the full tagged document. If False, even a single chunk uses the
excerpt window. The default preserves production parity with the
pre-chunking validation path; benchmarks may disable it to probe
compact validation prompts.
prompt_template: Jinja2 source for the validation prompt (with
``_seed_tagged_text``, ``_validation_skeleton``, ``_tag_notation``
placeholders). Typically produced by ``_get_validation_prompt``.
Expand All @@ -119,6 +124,7 @@ class ChunkedValidationParams(BaseModel):
pool: list[str] = Field(min_length=1)
max_entities_per_call: int = Field(gt=0)
excerpt_window_chars: int = Field(gt=0)
single_chunk_full_text: bool = True
prompt_template: str = Field(repr=False)
system_prompt: str | None = Field(default=None, repr=False)

Expand Down Expand Up @@ -449,7 +455,11 @@ def chunked_validate_row(
# only making one call there's no cost reason to clip, and clipping
# would silently narrow the context the validator sees. Computed once
# here because ``len(chunks) == 1`` is loop-invariant.
single_chunk_tagged_text = build_tagged_text(text, all_spans, notation=notation) if len(chunks) == 1 else None
single_chunk_tagged_text = (
build_tagged_text(text, all_spans, notation=notation)
if len(chunks) == 1 and params.single_chunk_full_text
else None
)

dispatch_kwargs_per_chunk: list[dict[str, Any]] = []
for chunk_index, chunk in enumerate(chunks):
Expand Down
93 changes: 53 additions & 40 deletions src/anonymizer/engine/detection/detection_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
EntitiesSchema,
LatentEntitiesSchema,
)
from anonymizer.measurement import stage_timer

logger = logging.getLogger("anonymizer.detection")

Expand Down Expand Up @@ -94,6 +95,7 @@ def detect_and_validate_entities(
gliner_detection_threshold: float,
validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL,
validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS,
validation_single_chunk_full_text: bool = True,
entity_labels: list[str] | None = None,
data_summary: str | None = None,
preview_num_records: int | None = None,
Expand Down Expand Up @@ -143,6 +145,7 @@ def detect_and_validate_entities(
pool=list(validator_aliases),
max_entities_per_call=validation_max_entities_per_call,
excerpt_window_chars=validation_excerpt_window_chars,
single_chunk_full_text=validation_single_chunk_full_text,
prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels),
)

Expand Down Expand Up @@ -266,54 +269,64 @@ def run(
``identify_latent_entities`` if ``tag_latent_entities`` is True
(rewrite mode). Merges failures from both stages.
"""
if tag_latent_entities and privacy_goal is None:
raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)")

compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities
detected_result = self.detect_and_validate_entities(
dataframe,
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
validation_max_entities_per_call=validation_max_entities_per_call,
validation_excerpt_window_chars=validation_excerpt_window_chars,
entity_labels=entity_labels,
data_summary=data_summary,
preview_num_records=preview_num_records,
)

if tag_latent_entities:
latent_result = self.identify_latent_entities(
detected_result.dataframe,
with stage_timer(
"EntityDetectionWorkflow.run",
input_row_count=len(dataframe),
tag_latent_entities=tag_latent_entities,
) as measurement:
if tag_latent_entities and privacy_goal is None:
raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)")

compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities
detected_result = self.detect_and_validate_entities(
dataframe,
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
validation_max_entities_per_call=validation_max_entities_per_call,
validation_excerpt_window_chars=validation_excerpt_window_chars,
entity_labels=entity_labels,
privacy_goal=privacy_goal,
data_summary=data_summary,
preview_num_records=preview_num_records,
)
final_df = latent_result.dataframe.copy()
final_failures = [*detected_result.failed_records, *latent_result.failed_records]
else:
final_df = detected_result.dataframe.copy()
final_failures = detected_result.failed_records

# When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS),
# the augmenter is strict and out-of-scope labels are filtered.
# entity_labels=None is the only way to get permissive augmentation.
# TODO(docs): document this None-vs-explicit contract in user-facing docs.
if COL_DETECTED_ENTITIES in final_df.columns:
allowed = set(entity_labels) if entity_labels is not None else None
final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply(
lambda raw: _materialize_final_entities(raw, allowed_labels=allowed)

if tag_latent_entities:
latent_result = self.identify_latent_entities(
detected_result.dataframe,
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
entity_labels=entity_labels,
privacy_goal=privacy_goal,
data_summary=data_summary,
preview_num_records=preview_num_records,
)
final_df = latent_result.dataframe.copy()
final_failures = [*detected_result.failed_records, *latent_result.failed_records]
else:
final_df = detected_result.dataframe.copy()
final_failures = detected_result.failed_records

# When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS),
# the augmenter is strict and out-of-scope labels are filtered.
# entity_labels=None is the only way to get permissive augmentation.
# TODO(docs): document this None-vs-explicit contract in user-facing docs.
if COL_DETECTED_ENTITIES in final_df.columns:
allowed = set(entity_labels) if entity_labels is not None else None
final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply(
lambda raw: _materialize_final_entities(raw, allowed_labels=allowed)
)
if compute_grouped:
final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value)
result = EntityDetectionResult(
dataframe=final_df,
failed_records=final_failures,
)
if compute_grouped:
final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value)
return EntityDetectionResult(
dataframe=final_df,
failed_records=final_failures,
)
measurement.update(
output_row_count=len(result.dataframe),
failed_record_count=len(result.failed_records),
)
return result

def _inject_detector_params(
self,
Expand Down
Loading
Loading