diff --git a/pyproject.toml b/pyproject.toml index 29865798..972fbb12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" dependencies = [ "data-designer==0.6.0", "pydantic>=2.9,<3", + "pydantic-settings>=2.12,<3", "cyclopts>=3", "pygments>=2.20.0", "cryptography>=46.0.6", diff --git a/src/anonymizer/engine/constants.py b/src/anonymizer/engine/constants.py index fdfecadf..d9e1f079 100644 --- a/src/anonymizer/engine/constants.py +++ b/src/anonymizer/engine/constants.py @@ -45,6 +45,7 @@ COL_ENTITIES_BY_VALUE = "_entities_by_value" COL_REPLACED_TEXT = "__nemo_anonymizer_text_output__" COL_REPLACEMENT_MAP = "_replacement_map" +COL_REPLACEMENT_MAP_SOURCE = "_replacement_map_source" # LlmReplaceWorkflow internal prompt-construction columns. Created by # `LlmReplaceWorkflow.generate_map_only` for the replacement-generator prompt diff --git a/src/anonymizer/engine/detection/chunked_validation.py b/src/anonymizer/engine/detection/chunked_validation.py index 50601cad..b870fab9 100644 --- a/src/anonymizer/engine/detection/chunked_validation.py +++ b/src/anonymizer/engine/detection/chunked_validation.py @@ -102,6 +102,11 @@ class ChunkedValidationParams(BaseModel): max_entities_per_call: Upper bound on candidates per chunk. excerpt_window_chars: Chars of surrounding raw text included in each chunk's excerpt on either side of the chunk span. + single_chunk_full_text: If True, a row with one validation chunk sees + the full tagged document. If False, even a single chunk uses the + excerpt window. The default preserves production parity with the + pre-chunking validation path; benchmarks may disable it to probe + compact validation prompts. prompt_template: Jinja2 source for the validation prompt (with ``_seed_tagged_text``, ``_validation_skeleton``, ``_tag_notation`` placeholders). Typically produced by ``_get_validation_prompt``. @@ -119,6 +124,7 @@ class ChunkedValidationParams(BaseModel): pool: list[str] = Field(min_length=1) max_entities_per_call: int = Field(gt=0) excerpt_window_chars: int = Field(gt=0) + single_chunk_full_text: bool = True prompt_template: str = Field(repr=False) system_prompt: str | None = Field(default=None, repr=False) @@ -449,7 +455,11 @@ def chunked_validate_row( # only making one call there's no cost reason to clip, and clipping # would silently narrow the context the validator sees. Computed once # here because ``len(chunks) == 1`` is loop-invariant. - single_chunk_tagged_text = build_tagged_text(text, all_spans, notation=notation) if len(chunks) == 1 else None + single_chunk_tagged_text = ( + build_tagged_text(text, all_spans, notation=notation) + if len(chunks) == 1 and params.single_chunk_full_text + else None + ) dispatch_kwargs_per_chunk: list[dict[str, Any]] = [] for chunk_index, chunk in enumerate(chunks): diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index 87eb644b..c0a34f83 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -59,6 +59,7 @@ EntitiesSchema, LatentEntitiesSchema, ) +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.detection") @@ -94,6 +95,7 @@ def detect_and_validate_entities( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, data_summary: str | None = None, preview_num_records: int | None = None, @@ -143,6 +145,7 @@ def detect_and_validate_entities( pool=list(validator_aliases), max_entities_per_call=validation_max_entities_per_call, excerpt_window_chars=validation_excerpt_window_chars, + single_chunk_full_text=validation_single_chunk_full_text, prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels), ) @@ -266,54 +269,64 @@ def run( ``identify_latent_entities`` if ``tag_latent_entities`` is True (rewrite mode). Merges failures from both stages. """ - if tag_latent_entities and privacy_goal is None: - raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)") - - compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities - detected_result = self.detect_and_validate_entities( - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - - if tag_latent_entities: - latent_result = self.identify_latent_entities( - detected_result.dataframe, + with stage_timer( + "EntityDetectionWorkflow.run", + input_row_count=len(dataframe), + tag_latent_entities=tag_latent_entities, + ) as measurement: + if tag_latent_entities and privacy_goal is None: + raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)") + + compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities + detected_result = self.detect_and_validate_entities( + dataframe, model_configs=model_configs, selected_models=selected_models, gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, entity_labels=entity_labels, - privacy_goal=privacy_goal, data_summary=data_summary, preview_num_records=preview_num_records, ) - final_df = latent_result.dataframe.copy() - final_failures = [*detected_result.failed_records, *latent_result.failed_records] - else: - final_df = detected_result.dataframe.copy() - final_failures = detected_result.failed_records - - # When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS), - # the augmenter is strict and out-of-scope labels are filtered. - # entity_labels=None is the only way to get permissive augmentation. - # TODO(docs): document this None-vs-explicit contract in user-facing docs. - if COL_DETECTED_ENTITIES in final_df.columns: - allowed = set(entity_labels) if entity_labels is not None else None - final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply( - lambda raw: _materialize_final_entities(raw, allowed_labels=allowed) + + if tag_latent_entities: + latent_result = self.identify_latent_entities( + detected_result.dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + entity_labels=entity_labels, + privacy_goal=privacy_goal, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + final_df = latent_result.dataframe.copy() + final_failures = [*detected_result.failed_records, *latent_result.failed_records] + else: + final_df = detected_result.dataframe.copy() + final_failures = detected_result.failed_records + + # When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS), + # the augmenter is strict and out-of-scope labels are filtered. + # entity_labels=None is the only way to get permissive augmentation. + # TODO(docs): document this None-vs-explicit contract in user-facing docs. + if COL_DETECTED_ENTITIES in final_df.columns: + allowed = set(entity_labels) if entity_labels is not None else None + final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply( + lambda raw: _materialize_final_entities(raw, allowed_labels=allowed) + ) + if compute_grouped: + final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value) + result = EntityDetectionResult( + dataframe=final_df, + failed_records=final_failures, ) - if compute_grouped: - final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value) - return EntityDetectionResult( - dataframe=final_df, - failed_records=final_failures, - ) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result def _inject_detector_params( self, diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 8aa9b920..f48dee4a 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -6,10 +6,14 @@ import json import logging import tempfile +import time import uuid +from collections.abc import Iterator, Mapping +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from threading import RLock +from typing import TYPE_CHECKING, Any, cast from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder @@ -18,6 +22,7 @@ from data_designer.config.seed_source import LocalFileSeedSource from anonymizer.interface.errors import AnonymizerWorkflowError +from anonymizer.measurement import current_collector, record_ndd_workflow if TYPE_CHECKING: import pandas as pd @@ -26,6 +31,7 @@ logger = logging.getLogger("anonymizer.ndd") RECORD_ID_COLUMN = "_anonymizer_record_id" +_DD_MESSAGE_TRACE_PATCH_LOCK = RLock() @dataclass(frozen=True) @@ -86,7 +92,15 @@ def run_workflow( logger.debug("NDD workflow '%s' starting with %d records", workflow_name, len(workflow_input_df)) col_names = [c.name for c in columns] logger.debug("NDD workflow '%s': %d columns %s", workflow_name, len(col_names), col_names) - model_aliases = [m.alias for m in model_configs] + available_model_aliases = [m.alias for m in model_configs] + model_aliases = _extract_workflow_model_aliases(columns) or available_model_aliases + record_count = ( + min(preview_num_records, len(workflow_input_df)) + if preview_num_records is not None + else len(workflow_input_df) + ) + started = time.perf_counter() + usage_probe = _DataDesignerUsageProbe(self._data_designer, enabled=current_collector() is not None) with tempfile.TemporaryDirectory(prefix=f"anonymizer_{workflow_name}_") as tmp_dir: seed_path = str(Path(tmp_dir) / "seed.parquet") @@ -97,33 +111,29 @@ def run_workflow( for column in columns: config_builder.add_column(column) - record_count = ( - min(preview_num_records, len(workflow_input_df)) - if preview_num_records is not None - else len(workflow_input_df) - ) try: - if preview_num_records is None: - run_results = self._data_designer.create( - config_builder, - num_records=len(workflow_input_df), - dataset_name=workflow_name, - ) - output_df = run_results.load_dataset() - else: - preview_results = self._data_designer.preview( - config_builder, - num_records=record_count, - ) - if preview_results.dataset is None: - output_df = workflow_input_df.iloc[0:0].copy() + with usage_probe, _dd_message_trace(workflow_name=workflow_name): + if preview_num_records is None: + run_results = self._data_designer.create( + config_builder, + num_records=len(workflow_input_df), + dataset_name=workflow_name, + ) + output_df = run_results.load_dataset() else: - output_df = preview_results.dataset + preview_results = self._data_designer.preview( + config_builder, + num_records=record_count, + ) + if preview_results.dataset is None: + output_df = workflow_input_df.iloc[0:0].copy() + else: + output_df = preview_results.dataset except Exception as exc: logger.warning( "Workflow failed for %d input record(s) on model(s) %s: %s", record_count, - model_aliases, + available_model_aliases, exc, ) logger.debug( @@ -131,6 +141,20 @@ def run_workflow( workflow_name, col_names, ) + record_ndd_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=record_count, + seed_row_count=len(workflow_input_df), + output_row_count=None, + failed_record_count=None, + elapsed_sec=time.perf_counter() - started, + status="error", + preview_num_records=preview_num_records, + column_count=len(col_names), + column_names=col_names, + model_usage=usage_probe.model_usage(), + ) raise AnonymizerWorkflowError(f"Workflow failed: {exc}") from exc logger.debug("NDD workflow '%s' returned %d records", workflow_name, len(output_df)) @@ -143,6 +167,19 @@ def run_workflow( ), output_df=output_df, ) + record_ndd_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=record_count, + seed_row_count=len(workflow_input_df), + output_row_count=len(output_df), + failed_record_count=len(failed_records), + elapsed_sec=time.perf_counter() - started, + preview_num_records=preview_num_records, + column_count=len(col_names), + column_names=col_names, + model_usage=usage_probe.model_usage(), + ) return WorkflowRunResult(dataframe=output_df, failed_records=failed_records) def _attach_record_ids(self, df: pd.DataFrame) -> pd.DataFrame: @@ -225,3 +262,326 @@ def _detect_missing_records( ) for record_id in missing_ids ] + + +def _extract_workflow_model_aliases(columns: list[ColumnConfigT]) -> list[str]: + aliases: list[str] = [] + for column in columns: + aliases.extend(_as_alias_list(getattr(column, "model_alias", None))) + generator = getattr(column, "generator_function", None) + metadata = getattr(generator, "custom_column_metadata", None) + if isinstance(metadata, dict): + aliases.extend(_as_alias_list(metadata.get("model_aliases"))) + return list(dict.fromkeys(alias for alias in aliases if alias)) + + +def _as_alias_list(raw: Any) -> list[str]: + if raw is None: + return [] + if isinstance(raw, str): + return [raw] + if isinstance(raw, (list, tuple, set)): + return [str(item) for item in raw if item is not None and str(item)] + return [str(raw)] + + +class _DataDesignerUsageProbe: + """Capture DataDesigner model usage from the per-run private ResourceProvider.""" + + def __init__(self, data_designer: DataDesigner, *, enabled: bool) -> None: + self._data_designer = data_designer + self._enabled = enabled + self._original_create_resource_provider: Any | None = None + self._resource_providers: list[Any] = [] + + def __enter__(self) -> _DataDesignerUsageProbe: + if not self._enabled: + return self + + original = getattr(self._data_designer, "_create_resource_provider", None) + if not callable(original): + return self + + self._original_create_resource_provider = original + + def wrapper(*args: Any, **kwargs: Any) -> Any: + resource_provider = original(*args, **kwargs) + self._resource_providers.append(resource_provider) + return resource_provider + + setattr(self._data_designer, "_create_resource_provider", wrapper) + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + if self._original_create_resource_provider is not None: + setattr(self._data_designer, "_create_resource_provider", self._original_create_resource_provider) + + def model_usage(self) -> dict[str, Any] | None: + usage: dict[str, Any] = {} + for resource_provider in self._resource_providers: + model_registry = getattr(resource_provider, "model_registry", None) + snapshot = _get_model_usage_snapshot(model_registry) + if not snapshot: + continue + for model_name, stats in snapshot.items(): + usage[str(model_name)] = _model_usage_as_json(stats) + return usage or None + + +def _get_model_usage_snapshot(model_registry: object) -> Mapping[str, object] | None: + alias_snapshot = _get_model_usage_snapshot_by_alias(model_registry) + if alias_snapshot: + return alias_snapshot + + get_snapshot = getattr(model_registry, "get_model_usage_snapshot", None) + if not callable(get_snapshot): + return None + snapshot = get_snapshot() + if isinstance(snapshot, Mapping): + return snapshot + return None + + +def _get_model_usage_snapshot_by_alias(model_registry: object) -> Mapping[str, object] | None: + models = getattr(model_registry, "_models", None) + if not isinstance(models, Mapping): + return None + + snapshot: dict[str, object] = {} + for model_alias, model_facade in models.items(): + stats = getattr(model_facade, "usage_stats", None) + if stats is None or not getattr(stats, "has_usage", False): + continue + payload = _model_usage_as_json(stats) + if isinstance(payload, Mapping): + payload = { + **payload, + "model_alias": getattr(model_facade, "model_alias", str(model_alias)), + "model_name": getattr(model_facade, "model_name", None), + "model_provider_name": getattr(model_facade, "model_provider_name", None), + } + snapshot[str(model_alias)] = payload + return snapshot or None + + +def _model_usage_as_json(stats: object) -> Any: + model_dump = getattr(stats, "model_dump", None) + if callable(model_dump): + return model_dump(mode="json") + return stats + + +@contextmanager +def _dd_message_trace(*, workflow_name: str) -> Iterator[None]: + """Trace DataDesigner model messages for the active measurement context. + + DataDesigner constructs ``ModelFacade`` instances internally, so this hook + wraps the facade class while a traced workflow runs. The wrappers re-check + the active collector before writing a trace record so unrelated concurrent + workflows pass through without contaminating the traced collector. + """ + collector = current_collector() + if collector is None or not collector.dd_trace_enabled: + yield + return + + from data_designer.engine.models.facade import ModelFacade + + with _DD_MESSAGE_TRACE_PATCH_LOCK: + original_completion = ModelFacade.completion + original_acompletion = ModelFacade.acompletion + ModelFacade.completion = _traced_completion( + original_completion, collector=collector, workflow_name=workflow_name + ) + ModelFacade.acompletion = _traced_acompletion( + original_acompletion, + collector=collector, + workflow_name=workflow_name, + ) + try: + yield + finally: + ModelFacade.completion = original_completion + ModelFacade.acompletion = original_acompletion + + +def _traced_completion(original_completion: Any, *, collector: Any, workflow_name: str) -> Any: + def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: + return _run_traced_completion( + original_completion, + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + args=args, + kwargs=kwargs, + ) + + return traced + + +def _run_traced_completion( + completion: Any, + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + started = time.perf_counter() + response: Any | None = None + status = "completed" + error_type: str | None = None + try: + response = completion(model_facade, messages, *args, **kwargs) + return response + except BaseException as exc: + status = "error" + error_type = type(exc).__name__ + raise + finally: + active_collector = _active_trace_collector(collector) + if active_collector is not None: + _record_dd_message_trace( + collector=active_collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=False, + ) + + +def _traced_acompletion(original_acompletion: Any, *, collector: Any, workflow_name: str) -> Any: + async def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: + return await _run_traced_acompletion( + original_acompletion, + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + args=args, + kwargs=kwargs, + ) + + return traced + + +async def _run_traced_acompletion( + acompletion: Any, + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + started = time.perf_counter() + response: Any | None = None + status = "completed" + error_type: str | None = None + try: + response = await acompletion(model_facade, messages, *args, **kwargs) + return response + except BaseException as exc: + status = "error" + error_type = type(exc).__name__ + raise + finally: + active_collector = _active_trace_collector(collector) + if active_collector is not None: + _record_dd_message_trace( + collector=active_collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=True, + ) + + +def _active_trace_collector(expected_collector: Any) -> Any | None: + active_collector = current_collector() + if active_collector is expected_collector and active_collector.dd_trace_enabled: + return active_collector + return None + + +def _record_dd_message_trace( + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + response: Any | None, + elapsed_sec: float, + status: str, + error_type: str | None, + is_async: bool, +) -> None: + collector.record_dd_message_trace( + workflow_name=workflow_name, + model_alias=getattr(model_facade, "model_alias", None), + model_name=getattr(model_facade, "model_name", None), + model_provider_name=getattr(model_facade, "model_provider_name", None), + modality="chat", + is_async=is_async, + status=status, + error_type=error_type, + elapsed_sec=elapsed_sec, + messages=_trace_messages(messages, mode=collector.dd_trace_mode), + response=_trace_response(response), + usage=_trace_usage(getattr(response, "usage", None) if response is not None else None), + ) + + +def _trace_messages(messages: list[Any], *, mode: str) -> list[dict[str, Any]]: + selected = messages if mode == "all_messages" else messages[-1:] + return [_trace_message(message) for message in selected] + + +def _trace_message(message: Any) -> dict[str, Any]: + to_dict = getattr(message, "to_dict", None) + if callable(to_dict): + return cast(dict[str, Any], to_dict()) + if isinstance(message, Mapping): + return dict(message) + return {"role": getattr(message, "role", None), "content": getattr(message, "content", None)} + + +def _trace_response(response: Any | None) -> dict[str, Any] | None: + if response is None: + return None + message = getattr(response, "message", None) + if message is None: + return None + return { + "content": getattr(message, "content", None), + "reasoning_content": getattr(message, "reasoning_content", None), + "tool_calls": _trace_tool_calls(getattr(message, "tool_calls", [])), + } + + +def _trace_tool_calls(tool_calls: Any) -> list[Any]: + if isinstance(tool_calls, list): + return [getattr(tool_call, "__dict__", tool_call) for tool_call in tool_calls] + return [] + + +def _trace_usage(usage: Any | None) -> dict[str, Any] | None: + if usage is None: + return None + return { + "input_tokens": getattr(usage, "input_tokens", None), + "output_tokens": getattr(usage, "output_tokens", None), + "total_tokens": getattr(usage, "total_tokens", None), + } diff --git a/src/anonymizer/engine/replace/llm_replace_workflow.py b/src/anonymizer/engine/replace/llm_replace_workflow.py index ccd5cb1d..531b6827 100644 --- a/src/anonymizer/engine/replace/llm_replace_workflow.py +++ b/src/anonymizer/engine/replace/llm_replace_workflow.py @@ -19,6 +19,7 @@ COL_ENTITIES_FOR_REPLACE_JSON, COL_ENTITY_EXAMPLES, COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, ENTITY_LABEL_EXAMPLES, ) from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter @@ -28,6 +29,7 @@ from anonymizer.engine.schemas import EntitiesByValueSchema, EntityReplacementMapSchema logger = logging.getLogger("anonymizer.replace.llm_workflow") +REPLACEMENT_MAP_SOURCE_LLM = "llm" # Workflow-internal scratch columns used only to build the replacement-generator # prompt. Created in `generate_map_only` and dropped before returning — nothing @@ -71,6 +73,7 @@ def generate_map_only( # Partition: rows with an empty entity list bypass replacement-map generation. entity_rows, passthrough_rows = split_rows(working_df, column=COL_ENTITIES_FOR_REPLACE, predicate=bool) passthrough_rows[COL_REPLACEMENT_MAP] = [{"replacements": []} for _ in range(len(passthrough_rows))] + passthrough_rows[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LLM if entity_rows.empty: passthrough_only = merge_and_reorder(passthrough_rows) @@ -110,6 +113,7 @@ def generate_map_only( ), axis=1, ) + output_df[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LLM combined = merge_and_reorder(output_df, passthrough_rows) return LlmReplaceResult( @@ -160,28 +164,56 @@ def _filter_replacement_map_to_input_entities( for label in entity.labels if entity.value and label } + protected_original_values = {value for value, _ in allowed_pairs} filtered: list[dict[str, str]] = [] seen: set[tuple[str, str]] = set() + synthetic_collision_labels: Counter[str] = Counter() for replacement in parsed_map.replacements: key = (replacement.original, replacement.label) if key not in allowed_pairs or key in seen: continue + if replacement.synthetic in protected_original_values: + synthetic_collision_labels[replacement.label] += 1 + seen.add(key) + filtered.append( + { + "original": replacement.original, + "label": replacement.label, + "synthetic": _collision_safe_synthetic( + replacement.label, + index=synthetic_collision_labels[replacement.label], + protected_original_values=protected_original_values, + ), + } + ) + continue seen.add(key) filtered.append(replacement.model_dump()) + if synthetic_collision_labels: + logger.warning( + "Replacement map repaired synthetic-original collision entries for record %s; repaired=%d " + "(repaired_by_label=%s)", + record_id or "", + sum(synthetic_collision_labels.values()), + dict(synthetic_collision_labels), + ) if logger.isEnabledFor(logging.DEBUG): raw_pairs = {(r.original, r.label) for r in parsed_map.replacements} filtered_pairs = {(f["original"], f["label"]) for f in filtered} unrequested_labels = Counter(label for _, label in (raw_pairs - allowed_pairs)) unfilled_labels = Counter(label for _, label in (allowed_pairs - filtered_pairs)) logger.debug( - "Replacement map record %s: requested=%d raw=%d filtered=%d%s%s", + "Replacement map record %s: requested=%d raw=%d filtered=%d%s%s%s", record_id or "", len(allowed_pairs), len(parsed_map.replacements), len(filtered), f" unrequested_by_label={dict(unrequested_labels)}" if unrequested_labels else "", f" unfilled_by_label={dict(unfilled_labels)}" if unfilled_labels else "", + f" synthetic_original_collision_by_label={dict(synthetic_collision_labels)}" + if synthetic_collision_labels + else "", ) if not filtered and allowed_pairs: requested_labels = Counter(label for _, label in allowed_pairs) @@ -195,6 +227,15 @@ def _filter_replacement_map_to_input_entities( return {"replacements": filtered} +def _collision_safe_synthetic(label: str, *, index: int, protected_original_values: set[str]) -> str: + label_token = "".join(char.upper() if char.isalnum() else "_" for char in label).strip("_") or "VALUE" + while True: + candidate = f"[SUBSTITUTE_{label_token}_{index}]" + if candidate not in protected_original_values: + return candidate + index += 1 + + def _get_replacement_mapping_prompt(*, entities_column: str, instructions: str | None = None) -> str: instruction_block = f"\nAdditional instructions: {instructions}\n" if instructions else "" prompt = """Generate synthetic replacements for sensitive entities. ONE value per entity, used consistently. diff --git a/src/anonymizer/engine/replace/replace_runner.py b/src/anonymizer/engine/replace/replace_runner.py index d6501834..f3a95adc 100644 --- a/src/anonymizer/engine/replace/replace_runner.py +++ b/src/anonymizer/engine/replace/replace_runner.py @@ -28,6 +28,7 @@ from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, FailedRecord, NddAdapter from anonymizer.engine.replace.llm_replace_workflow import LlmReplaceWorkflow from anonymizer.engine.replace.strategies import apply_local_replace_strategy, apply_replacement_map +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.replace") @@ -73,27 +74,38 @@ def run( Evaluation is a separate concern — call ``evaluate()`` on the resulting dataframe when you want the LLM alignment scores. """ - logger.debug("replacement strategy: %s on %d records", type(replace_method).__name__, len(dataframe)) - - if isinstance(replace_method, (Annotate, Redact, Hash)): - local_df = apply_local_replace_strategy(dataframe, strategy=replace_method) - failed_records: list[FailedRecord] = [] - elif isinstance(replace_method, Substitute): - if self._llm_workflow is None: - raise ValueError("Substitute requires an llm_workflow, but none was provided.") - map_result = self._llm_workflow.generate_map_only( - dataframe, - model_configs=model_configs, - selected_models=selected_models, - instructions=replace_method.instructions, - preview_num_records=preview_num_records, - ) - local_df = apply_replacement_map(map_result.dataframe) - failed_records = list(map_result.failed_records) - else: - raise ValueError(f"Unsupported replace method: {type(replace_method).__name__}") + strategy = type(replace_method).__name__ + with stage_timer( + "ReplacementWorkflow.run", + strategy=strategy, + input_row_count=len(dataframe), + ) as measurement: + logger.debug("replacement strategy: %s on %d records", strategy, len(dataframe)) + + if isinstance(replace_method, (Annotate, Redact, Hash)): + local_df = apply_local_replace_strategy(dataframe, strategy=replace_method) + failed_records: list[FailedRecord] = [] + elif isinstance(replace_method, Substitute): + if self._llm_workflow is None: + raise ValueError("Substitute requires an llm_workflow, but none was provided.") + map_result = self._llm_workflow.generate_map_only( + dataframe, + model_configs=model_configs, + selected_models=selected_models, + instructions=replace_method.instructions, + preview_num_records=preview_num_records, + ) + local_df = apply_replacement_map(map_result.dataframe) + failed_records = list(map_result.failed_records) + else: + raise ValueError(f"Unsupported replace method: {type(replace_method).__name__}") - return ReplacementResult(dataframe=local_df, failed_records=failed_records) + result = ReplacementResult(dataframe=local_df, failed_records=failed_records) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result def evaluate( self, diff --git a/src/anonymizer/engine/replace/structured_substitute.py b/src/anonymizer/engine/replace/structured_substitute.py new file mode 100644 index 00000000..6205c005 --- /dev/null +++ b/src/anonymizer/engine/replace/structured_substitute.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import re +from collections.abc import Callable, Iterable + +import pandas as pd + +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE +from anonymizer.engine.schemas import EntitiesByValueSchema + +REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED = "local_structured" +SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS = frozenset( + { + "api_key", + "date_of_birth", + "email", + "http_cookie", + "organization_name", + "password", + "pin", + "religious_belief", + "street_address", + "unique_id", + "url", + "user_name", + } +) + +_RELIGIOUS_BELIEF_SUBSTITUTES = ( + "agnostic", + "atheist", + "buddhist", + "catholic", + "christian", + "hindu", + "jewish", + "muslim", + "secular", +) + + +def apply_structured_substitution_maps( + dataframe: pd.DataFrame, + *, + entities_column: str = COL_ENTITIES_BY_VALUE, +) -> pd.DataFrame: + """Attach deterministic substitute maps for supported structured labels. + + This helper intentionally builds only replacement maps. Text rewriting still + uses the normal replacement-map application path, so span handling remains + identical to LLM-backed ``Substitute``. + """ + output_df = dataframe.copy() + output_df[COL_REPLACEMENT_MAP] = output_df[entities_column].apply(build_structured_substitution_map) + output_df[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + return output_df + + +def build_structured_substitution_map(raw_entities: object) -> dict[str, list[dict[str, str]]]: + """Build a substitute map without model calls for narrow structured labels.""" + parsed = EntitiesByValueSchema.from_raw(raw_entities) + unsupported = _unsupported_labels(parsed) + if unsupported: + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + raise ValueError( + f"local structured substitute supports only deterministic structured labels; " + f"unsupported labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + original_values = {entity.value for entity in parsed.entities_by_value if entity.value} + synthetic_values: set[str] = set() + replacements: list[dict[str, str]] = [] + seen: set[tuple[str, str]] = set() + for entity in parsed.entities_by_value: + if not entity.value: + continue + for label in entity.labels: + if not label: + continue + key = (entity.value, label) + if key in seen: + continue + seen.add(key) + synthetic = structured_substitute_value( + entity.value, + label, + forbidden_values=original_values | synthetic_values, + ) + synthetic_values.add(synthetic) + replacements.append( + { + "original": entity.value, + "label": label, + "synthetic": synthetic, + } + ) + return {"replacements": replacements} + + +def structured_substitute_value( + value: str, + label: str, + *, + forbidden_values: Iterable[str] | None = None, +) -> str: + """Return a deterministic synthetic value for one supported structured label.""" + generator = _GENERATORS.get(label) + if generator is None: + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + raise ValueError(f"unsupported local structured substitute label: {label}; supported labels: {supported}") + forbidden = {str(item) for item in forbidden_values or () if item is not None} + forbidden.add(value) + + for salt in ("", "alternate"): + synthetic = generator(value, _digest(value=value, label=label, salt=salt)) + if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): + return synthetic + + fallback_index = 0 + while True: + synthetic = f"synthetic-{label}-{_digest(value=value, label=label, salt=f'fallback-{fallback_index}')[:12]}" + if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): + return synthetic + fallback_index += 1 + + +def _synthetic_is_allowed(synthetic: str, *, original: str, forbidden_values: set[str]) -> bool: + """Reject self-preservation and exact collisions with other protected originals.""" + return synthetic != original and original not in synthetic and synthetic not in forbidden_values + + +def _unsupported_labels(parsed: EntitiesByValueSchema) -> list[str]: + labels = {label for entity in parsed.entities_by_value for label in entity.labels if label} + return sorted(labels - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) + + +def _digest(*, value: str, label: str, salt: str = "") -> str: + return hashlib.sha256(f"{label}\0{salt}\0{value}".encode("utf-8")).hexdigest() + + +def _api_key(value: str, digest: str) -> str: + if value.startswith("ghp_"): + return "ghp_" + digest[:36] + if value.startswith("hf_"): + return "hf_" + digest[:40] + if value.startswith("pat-"): + return "pat-" + digest[:40] + if value.startswith("xoxb-"): + return f"xoxb-{digest[:12]}-{digest[12:24]}-{digest[24:36]}" + if value.startswith("ya29."): + return "ya29." + digest[:44] + if value.startswith("AKIA"): + return "AKIA" + digest[:16].upper() + sk_match = re.match(r"^(sk-(?:test|ant-api03|proj|prod)-)", value) + if sk_match: + return sk_match.group(1) + digest[:48] + return "tok_" + digest[:32] + + +def _password(_value: str, digest: str) -> str: + return f"Synthetic!{digest[:10]}A7" + + +def _email(_value: str, digest: str) -> str: + return f"user-{digest[:12]}@example.invalid" + + +def _http_cookie(value: str, digest: str) -> str: + parts = [part.strip() for part in value.split(";") if part.strip()] + rendered: list[str] = [] + for index, part in enumerate(parts): + if "=" not in part: + continue + name = part.split("=", 1)[0].strip() or f"cookie_{index}" + rendered.append(f"{name}={_cookie_value(name=name, digest=digest, index=index)}") + if rendered: + return "; ".join(rendered) + return f"session_id={digest[:32]}; auth_token={digest[32:56]}" + + +def _cookie_value(*, name: str, digest: str, index: int) -> str: + offset = (index * 8) % max(len(digest) - 8, 1) + chunk = digest[offset : offset + 24] + normalized = name.lower() + if "session" in normalized: + return chunk[:32] + if normalized.endswith("id") or normalized == "user_id": + return str(10000 + (int(chunk[:8], 16) % 90000)) + if "token" in normalized or "auth" in normalized or "jwt" in normalized: + return f"tok_{chunk}" + return chunk + + +def _pin(value: str, digest: str) -> str: + length = max(4, min(len(value), 8)) + number = int(digest[:12], 16) % (10**length) + return f"{number:0{length}d}" + + +def _unique_id(value: str, digest: str) -> str: + if re.fullmatch(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}", value): + return f"{digest[:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" + prefix_match = re.match(r"^([A-Za-z]+[-_])", value) + if prefix_match: + prefix = prefix_match.group(1) + return f"{prefix}{digest[:20]}" + if value.isdigit(): + return str(100000 + (int(digest[:12], 16) % 900000)) + return f"id_{digest[:24]}" + + +def _user_name(value: str, digest: str) -> str: + separator = "." if "." in value else "_" if "_" in value else "" + if separator: + return f"user{separator}{digest[:10]}" + return f"user{digest[:12]}" + + +def _url(value: str, digest: str) -> str: + scheme_match = re.match(r"^([A-Za-z][A-Za-z0-9+.-]*://)", value) + scheme = scheme_match.group(1) if scheme_match else "https://" + scheme_name = scheme[:-3].lower() + if scheme_name in {"postgres", "postgresql", "mysql", "mariadb", "mongodb", "mongodb+srv", "redis", "rediss"}: + return f"{scheme}user_{digest[:8]}:Synthetic!{digest[8:16]}@db-{digest[16:24]}.example.invalid:5432/app" + return f"{scheme}synthetic-{digest[:16]}.example.invalid/resource/{digest[16:24]}" + + +def _date_of_birth(value: str, digest: str) -> str: + year = 1950 + (int(digest[:4], 16) % 50) + month = 1 + (int(digest[4:6], 16) % 12) + day = 1 + (int(digest[6:8], 16) % 28) + if re.fullmatch(r"\d{4}", value): + return str(year) + ymd = re.fullmatch(r"\d{4}([/-])\d{1,2}\1\d{1,2}", value) + if ymd: + sep = ymd.group(1) + return f"{year:04d}{sep}{month:02d}{sep}{day:02d}" + mdy = re.fullmatch(r"\d{1,2}([/-])\d{1,2}\1(\d{2}|\d{4})", value) + if mdy: + sep = mdy.group(1) + rendered_year = f"{year % 100:02d}" if len(value.rsplit(sep, 1)[-1]) == 2 else f"{year:04d}" + return f"{month:02d}{sep}{day:02d}{sep}{rendered_year}" + return f"{year:04d}-{month:02d}-{day:02d}" + + +def _street_address(value: str, digest: str) -> str: + suffix_match = re.search( + r"\b(Street|St\.?|Avenue|Ave\.?|Road|Rd\.?|Drive|Dr\.?|Trail|Boulevard|Blvd\.?|Lane|Ln\.?|Court|Ct\.?)$", + value, + ) + suffix = suffix_match.group(1) if suffix_match else "Street" + number = 100 + (int(digest[:4], 16) % 8900) + return f"{number} Cedar Ridge {suffix}" + + +def _organization_name(value: str, digest: str) -> str: + suffixes = ( + "Center", + "Hospital", + "Clinic", + "University", + "College", + "Institute", + "Bank", + "Builders", + "Construction", + "Woodworks", + "Health", + ) + suffix = next((candidate for candidate in suffixes if value.endswith(candidate)), "Group") + prefixes = ("Northbridge", "Helios", "Mariner", "Summit", "Cedar") + prefix = prefixes[int(digest[:2], 16) % len(prefixes)] + return f"{prefix} {suffix}" + + +def _religious_belief(value: str, digest: str) -> str: + normalized = value.lower() + candidates = [candidate for candidate in _RELIGIOUS_BELIEF_SUBSTITUTES if candidate != normalized] + return candidates[int(digest[:2], 16) % len(candidates)] + + +_GENERATORS: dict[str, Callable[[str, str], str]] = { + "api_key": _api_key, + "date_of_birth": _date_of_birth, + "email": _email, + "http_cookie": _http_cookie, + "organization_name": _organization_name, + "password": _password, + "pin": _pin, + "religious_belief": _religious_belief, + "street_address": _street_address, + "unique_id": _unique_id, + "url": _url, + "user_name": _user_name, +} diff --git a/src/anonymizer/engine/rewrite/rewrite_workflow.py b/src/anonymizer/engine/rewrite/rewrite_workflow.py index 88c2b9c3..88fe07f0 100644 --- a/src/anonymizer/engine/rewrite/rewrite_workflow.py +++ b/src/anonymizer/engine/rewrite/rewrite_workflow.py @@ -37,6 +37,7 @@ from anonymizer.engine.rewrite.sensitivity_disposition import SensitivityDispositionWorkflow from anonymizer.engine.rewrite.workflow_utils import derive_seed_columns, select_seed_cols from anonymizer.engine.row_partitioning import merge_and_reorder, split_rows +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.rewrite.workflow") @@ -196,80 +197,95 @@ def run( preview_num_records: int | None = None, strict_entity_protection: bool = False, ) -> RewriteResult: - all_failed: list[FailedRecord] = [] - - entity_rows, passthrough_rows = split_rows(dataframe, column=COL_ENTITIES_BY_VALUE, predicate=_has_entities) + with stage_timer("RewriteWorkflow.run", input_row_count=len(dataframe)) as measurement: + all_failed: list[FailedRecord] = [] - # Fast path: no entities anywhere - if entity_rows.empty: - _apply_passthrough_defaults(passthrough_rows) - result_df = merge_and_reorder(passthrough_rows) - return RewriteResult(dataframe=result_df, failed_records=all_failed) + entity_rows, passthrough_rows = split_rows(dataframe, column=COL_ENTITIES_BY_VALUE, predicate=_has_entities) + measurement.update( + entity_row_count=len(entity_rows), + passthrough_row_count=len(passthrough_rows), + ) - # --- Step 1: replacement map (needs only detection output) --- - replace_workflow = LlmReplaceWorkflow(adapter=self._adapter) - replace_result = replace_workflow.generate_map_only( - entity_rows, - model_configs=model_configs, - selected_models=replace_model_selection, - ) - entity_rows = _join_new_columns(entity_rows, replace_result.dataframe) - all_failed.extend(replace_result.failed_records) + # Fast path: no entities anywhere + if entity_rows.empty: + _apply_passthrough_defaults(passthrough_rows) + result_df = merge_and_reorder(passthrough_rows) + result = RewriteResult(dataframe=result_df, failed_records=all_failed) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result + + # --- Step 1: replacement map (needs only detection output) --- + replace_workflow = LlmReplaceWorkflow(adapter=self._adapter) + replace_result = replace_workflow.generate_map_only( + entity_rows, + model_configs=model_configs, + selected_models=replace_model_selection, + ) + entity_rows = _join_new_columns(entity_rows, replace_result.dataframe) + all_failed.extend(replace_result.failed_records) + + # --- Step 2: domain, disposition, QA, rewrite (single adapter call) --- + pipeline_columns = [ + *self._domain_wf.columns(selected_models=selected_models, data_summary=data_summary), + *self._disposition_wf.columns( + selected_models=selected_models, + privacy_goal=privacy_goal, + data_summary=data_summary, + strict_entity_protection=strict_entity_protection, + ), + *self._qa_wf.columns(selected_models=selected_models), + *self._rewrite_gen_wf.columns( + selected_models=selected_models, + privacy_goal=privacy_goal, + data_summary=data_summary, + ), + ] + + pipeline_seed = select_seed_cols(entity_rows, derive_seed_columns(pipeline_columns, entity_rows)) + pipeline_result = self._adapter.run_workflow( + pipeline_seed, + model_configs=model_configs, + columns=pipeline_columns, + workflow_name="rewrite-pipeline", + preview_num_records=preview_num_records, + ) + entity_rows = _join_new_columns(entity_rows, pipeline_result.dataframe) + all_failed.extend(pipeline_result.failed_records) - # --- Step 2: domain, disposition, QA, rewrite (single adapter call) --- - pipeline_columns = [ - *self._domain_wf.columns(selected_models=selected_models, data_summary=data_summary), - *self._disposition_wf.columns( + # --- Step 5: evaluate-repair loop --- + entity_rows, eval_repair_failed = self._run_evaluate_repair_loop( + entity_rows, + model_configs=model_configs, selected_models=selected_models, privacy_goal=privacy_goal, - data_summary=data_summary, - strict_entity_protection=strict_entity_protection, - ), - *self._qa_wf.columns(selected_models=selected_models), - *self._rewrite_gen_wf.columns( + evaluation=evaluation, + preview_num_records=preview_num_records, + ) + all_failed.extend(eval_repair_failed) + + # --- Step 6: final judge (non-critical) --- + entity_rows, judge_failed = self._run_final_judge( + entity_rows, + model_configs=model_configs, selected_models=selected_models, privacy_goal=privacy_goal, - data_summary=data_summary, - ), - ] - - pipeline_seed = select_seed_cols(entity_rows, derive_seed_columns(pipeline_columns, entity_rows)) - pipeline_result = self._adapter.run_workflow( - pipeline_seed, - model_configs=model_configs, - columns=pipeline_columns, - workflow_name="rewrite-pipeline", - preview_num_records=preview_num_records, - ) - entity_rows = _join_new_columns(entity_rows, pipeline_result.dataframe) - all_failed.extend(pipeline_result.failed_records) - - # --- Step 5: evaluate-repair loop --- - entity_rows, eval_repair_failed = self._run_evaluate_repair_loop( - entity_rows, - model_configs=model_configs, - selected_models=selected_models, - privacy_goal=privacy_goal, - evaluation=evaluation, - preview_num_records=preview_num_records, - ) - all_failed.extend(eval_repair_failed) - - # --- Step 6: final judge (non-critical) --- - entity_rows, judge_failed = self._run_final_judge( - entity_rows, - model_configs=model_configs, - selected_models=selected_models, - privacy_goal=privacy_goal, - evaluation=evaluation, - preview_num_records=preview_num_records, - ) - all_failed.extend(judge_failed) + evaluation=evaluation, + preview_num_records=preview_num_records, + ) + all_failed.extend(judge_failed) - # --- Merge and return --- - _apply_passthrough_defaults(passthrough_rows) - combined = merge_and_reorder(entity_rows, passthrough_rows) - return RewriteResult(dataframe=combined, failed_records=all_failed) + # --- Merge and return --- + _apply_passthrough_defaults(passthrough_rows) + combined = merge_and_reorder(entity_rows, passthrough_rows) + result = RewriteResult(dataframe=combined, failed_records=all_failed) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result # --------------------------------------------------------------------------- # Evaluate-repair loop diff --git a/src/anonymizer/interface/anonymizer.py b/src/anonymizer/interface/anonymizer.py index ec08164a..b19762d3 100644 --- a/src/anonymizer/interface/anonymizer.py +++ b/src/anonymizer/interface/anonymizer.py @@ -59,6 +59,11 @@ from anonymizer.interface.errors import InvalidConfigError from anonymizer.interface.results import AnonymizerResult, PreviewResult from anonymizer.logging import LOG_INDENT, configure_logging, reapply_log_levels +from anonymizer.measurement import ( + record_record_metrics, + record_run_metadata, + stage_timer, +) from anonymizer.telemetry import ( NOT_APPLICABLE, AnonymizerEvent, @@ -331,6 +336,45 @@ def _run_internal( data: AnonymizerInput, context: ResolvedInput, preview_num_records: int | None, + ) -> AnonymizerResult: + input_df = context.dataframe + mode = "replace" if config.replace is not None else "rewrite" + strategy = type(config.replace).__name__ if config.replace is not None else "Rewrite" + with stage_timer( + "Anonymizer._run_internal", + mode=mode, + strategy=strategy, + input_row_count=len(input_df), + preview_num_records=preview_num_records, + ) as measurement: + record_run_metadata( + config=config, + data=data, + mode=mode, + strategy=strategy, + input_row_count=len(input_df), + preview_num_records=preview_num_records, + model_configs=self._model_configs, + ) + result = self._run_internal_impl( + config=config, + data=data, + context=context, + preview_num_records=preview_num_records, + ) + measurement.update( + output_row_count=len(result.trace_dataframe), + failed_record_count=len(result.failed_records), + ) + return result + + def _run_internal_impl( + self, + *, + config: AnonymizerConfig, + data: AnonymizerInput, + context: ResolvedInput, + preview_num_records: int | None, ) -> AnonymizerResult: input_df = context.dataframe num_records = len(input_df) @@ -455,6 +499,13 @@ def _run_internal( text_col = context.resolved_text_column renamed_trace = _rename_output_columns(final_df, resolved_text_column=text_col) logger.info("🎉 Pipeline complete — %d records processed, %d total failures", num_records, len(all_failures)) + record_record_metrics( + final_df, + mode="replace" if config.replace is not None else "rewrite", + strategy=type(config.replace).__name__ if config.replace is not None else "Rewrite", + text_column=COL_TEXT, + validation_max_entities_per_call=config.detect.validation_max_entities_per_call, + ) return AnonymizerResult( dataframe=_build_user_dataframe(renamed_trace, resolved_text_column=text_col), trace_dataframe=renamed_trace, diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py new file mode 100644 index 00000000..b6d3dbfa --- /dev/null +++ b/src/anonymizer/measurement.py @@ -0,0 +1,1509 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import hmac +import json +import logging +import math +import platform +import secrets +import time +import uuid +from collections import Counter +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +from numbers import Integral +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast +from urllib.parse import urlparse + +from pydantic import Field, ValidationError +from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError + +from anonymizer.engine.constants import COL_FINAL_ENTITIES + +if TYPE_CHECKING: + import pandas as pd + + +_ACTIVE_COLLECTOR: ContextVar[MeasurementCollector | None] = ContextVar( + "anonymizer_measurement_collector", + default=None, +) +_GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") +_ENTITY_LABEL_EQUIVALENCE_CLASSES = ( + frozenset( + { + "access_token", + "api_key", + "auth_token", + "bearer_token", + "password", + "secret_key", + "session_id", + "unique_id", + "user_id", + } + ), + frozenset({"full_name", "person_name", "user", "user_name", "username"}), + frozenset({"phone", "phone_number", "telephone"}), + frozenset({"email", "email_address"}), + frozenset({"cookie", "http_cookie", "session_cookie"}), +) +_ENTITY_LABEL_EQUIVALENCE: dict[str, str] = { + label: sorted(labels)[0] for labels in _ENTITY_LABEL_EQUIVALENCE_CLASSES for label in labels +} +MEASUREMENT_SCHEMA_VERSION = 1 +DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" +DD_TRACE_MODES = {"none", "last_message", "all_messages"} +DDTraceMode = Literal["none", "last_message", "all_messages"] + +logger = logging.getLogger("anonymizer.measurement") + + +class _MeasurementWriter(Protocol): + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: ... + + +class _MeasurementSink(Protocol): + def write_record(self, record: dict[str, Any]) -> None: ... + + def close(self) -> None: ... + + +class _JsonlMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + +class _JsonMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(records, f, ensure_ascii=True, indent=2, sort_keys=True) + + +def _writer_for_format(output_format: Literal["jsonl", "json"]) -> _MeasurementWriter: + if output_format == "json": + return _JsonMeasurementWriter() + return _JsonlMeasurementWriter() + + +class _JsonlMeasurementSink: + def __init__(self, path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + self._file = output_path.open("w", encoding="utf-8", buffering=1) + + def write_record(self, record: dict[str, Any]) -> None: + self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + def close(self) -> None: + self._file.close() + + +class _MeasurementEnvSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix=DEFAULT_MEASUREMENT_ENV_PREFIX, + env_ignore_empty=True, + extra="ignore", + ) + + output_path: str | None = None + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | None = None + fail_on_write_error: bool = False + run_id: str | None = None + run_tags: dict[str, Any] = Field(default_factory=dict) + + +class MeasurementCollector: + """In-memory collector for local benchmark and throughput records. + + Records contain counts, labels, lengths, aliases, timings, and run-scoped + HMACs. They must not contain raw text, entity values, prompts, generated + outputs, replacement maps, provider secrets, or API keys. + """ + + def __init__( + self, + *, + run_id: str | None = None, + record_hash_key: bytes | str | None = None, + record_level: bool = True, + run_tags: Mapping[str, Any] | None = None, + record_sink: _MeasurementSink | None = None, + keep_records: bool = True, + dd_trace_mode: DDTraceMode = "none", + dd_trace_sink: _MeasurementSink | None = None, + fail_on_write_error: bool = False, + ) -> None: + self.run_id = run_id or uuid.uuid4().hex + self.record_level = record_level + self.run_tags = cast(dict[str, Any], _json_safe(dict(run_tags or {}))) + self._record_sink = record_sink + self._keep_records = keep_records + self._dd_trace_mode = dd_trace_mode + self._dd_trace_sink = dd_trace_sink + self._fail_on_write_error = fail_on_write_error + self._sink_failed = False + self._dd_trace_failed = False + if record_hash_key is None: + self._record_hash_key = secrets.token_bytes(32) + elif isinstance(record_hash_key, str): + self._record_hash_key = record_hash_key.encode("utf-8") + else: + self._record_hash_key = bytes(record_hash_key) + self._records: list[dict[str, Any]] = [] + + @property + def records(self) -> list[dict[str, Any]]: + """Return a shallow copy of collected measurement records.""" + return list(self._records) + + def record(self, record_type: str, **fields: Any) -> None: + """Append one machine-readable measurement record.""" + record = { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": record_type, + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + safe_record = _json_safe(record) + if self._keep_records: + self._records.append(safe_record) + if self._record_sink is not None: + self._write_record_to_sink(safe_record) + + def close(self) -> None: + """Close any streaming measurement sink attached to this collector.""" + if self._record_sink is not None: + self._record_sink.close() + if self._dd_trace_sink is not None: + self._dd_trace_sink.close() + + @property + def dd_trace_mode(self) -> DDTraceMode: + return self._dd_trace_mode + + @property + def dd_trace_enabled(self) -> bool: + return self._dd_trace_mode != "none" and self._dd_trace_sink is not None + + def record_dd_message_trace(self, **fields: Any) -> None: + """Write an explicitly opt-in DataDesigner message trace record. + + These records may contain raw prompts, input text, model outputs, and + PII. They are intentionally written to a separate trace sink and are + never appended to the safe measurement record list. + """ + if not self.dd_trace_enabled or self._dd_trace_failed: + return + + record = _json_safe( + { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": "dd_message_trace", + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + ) + try: + cast(_MeasurementSink, self._dd_trace_sink).write_record(record) + except Exception: + self._dd_trace_failed = True + logger.warning("Failed to write DataDesigner message trace records") + if self._fail_on_write_error: + raise + + def _write_record_to_sink(self, record: dict[str, Any]) -> None: + if self._sink_failed: + return + try: + cast(_MeasurementSink, self._record_sink).write_record(record) + except Exception: + self._sink_failed = True + logger.warning("Failed to stream Anonymizer measurement records") + if self._fail_on_write_error: + raise + + def record_hash(self, *, row_index: object, text: str) -> str: + """Return a run-scoped HMAC for joining records without storing text.""" + serialized = json.dumps( + {"row_index": str(row_index), "text": text}, + default=str, + sort_keys=True, + separators=(",", ":"), + ) + return hmac.new(self._record_hash_key, serialized.encode("utf-8"), hashlib.sha256).hexdigest() + + def write_jsonl(self, path: str | Path) -> None: + """Write records as newline-delimited JSON.""" + _JsonlMeasurementWriter().write(self._records, path) + + def write_json(self, path: str | Path) -> None: + """Write records as a JSON array.""" + _JsonMeasurementWriter().write(self._records, path) + + def to_dataframe(self) -> pd.DataFrame: + """Return records as a pandas DataFrame for benchmark tooling.""" + import pandas as pd + + return pd.DataFrame(self._records) + + +@dataclass(frozen=True) +class MeasurementConfig: + """Configuration for writing structured measurement records around a run.""" + + output_path: str | Path + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | Path | None = None + run_id: str | None = None + record_hash_key: bytes | str | None = None + run_tags: Mapping[str, Any] | None = None + fail_on_write_error: bool = False + + def __post_init__(self) -> None: + if self.output_format not in {"jsonl", "json"}: + raise ValueError("output_format must be 'jsonl' or 'json'") + if self.streaming and self.output_format != "jsonl": + raise ValueError("streaming measurement output only supports jsonl") + if self.dd_trace not in DD_TRACE_MODES: + raise ValueError("dd_trace must be 'none', 'last_message', or 'all_messages'") + if self.dd_trace != "none" and self.dd_trace_path is None: + raise ValueError("dd_trace_path is required when dd_trace is enabled") + + @classmethod + def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> MeasurementConfig | None: + """Build measurement config from environment variables, or None if output is unset. + + This is intentionally opt-in. Anonymizer API and CLI calls do not read + measurement environment variables unless benchmark/tooling code calls this + helper explicitly. + """ + try: + settings = _load_measurement_env_settings(prefix=prefix) + except (SettingsError, ValidationError) as exc: + raise ValueError(_measurement_env_error_message(exc, prefix=prefix)) from None + + if settings.output_path is None: + return None + return cls( + output_path=settings.output_path, + output_format=settings.output_format, + record_level=settings.record_level, + streaming=settings.streaming, + keep_records=settings.keep_records, + dd_trace=settings.dd_trace, + dd_trace_path=settings.dd_trace_path, + run_id=settings.run_id, + run_tags=settings.run_tags, + fail_on_write_error=settings.fail_on_write_error, + ) + + @classmethod + def from_sources( + cls, + explicit: MeasurementConfig | None = None, + *, + env: bool = False, + prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX, + ) -> MeasurementConfig | None: + """Resolve measurement config from explicit config first, then optional env.""" + if explicit is not None: + return explicit + if env: + return cls.from_env(prefix=prefix) + return None + + def write_collector(self, collector: MeasurementCollector) -> None: + """Write a collector using this config's output format.""" + _writer_for_format(self.output_format).write(collector.records, self.output_path) + + +@contextmanager +def measurement_session(collector: MeasurementCollector | None = None) -> Iterator[MeasurementCollector]: + """Activate a collector for code running in this context.""" + active = collector or MeasurementCollector() + token = _ACTIVE_COLLECTOR.set(active) + try: + yield active + finally: + _ACTIVE_COLLECTOR.reset(token) + + +@contextmanager +def configured_measurement_session(config: MeasurementConfig | None) -> Iterator[MeasurementCollector | None]: + """Activate and persist a collector when a measurement config is provided.""" + if config is None: + yield None + return + + sink = _JsonlMeasurementSink(config.output_path) if config.streaming else None + dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) if config.dd_trace != "none" else None + collector = MeasurementCollector( + run_id=config.run_id, + record_hash_key=config.record_hash_key, + record_level=config.record_level, + run_tags=config.run_tags, + record_sink=sink, + keep_records=config.keep_records, + dd_trace_mode=config.dd_trace, + dd_trace_sink=dd_trace_sink, + fail_on_write_error=config.fail_on_write_error, + ) + with measurement_session(collector): + body_error: BaseException | None = None + try: + yield collector + except BaseException as exc: + body_error = exc + raise + finally: + if config.streaming: + _close_collector_safely(config=config, collector=collector, body_error=body_error) + else: + _write_collector_safely(config=config, collector=collector, body_error=body_error) + _close_collector_safely(config=config, collector=collector, body_error=body_error) + + +def current_collector() -> MeasurementCollector | None: + """Return the active collector, if measurement is enabled.""" + return _ACTIVE_COLLECTOR.get() + + +@contextmanager +def stage_timer(stage: str, **fields: Any) -> Iterator[dict[str, Any]]: + """Record wall time for a stage when collection is active.""" + collector = current_collector() + if collector is None: + yield fields + return + + started = time.perf_counter() + status = "completed" + try: + yield fields + except BaseException: + status = "error" + raise + finally: + elapsed_sec = time.perf_counter() - started + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_stage(stage: str, *, elapsed_sec: float, status: str = "completed", **fields: Any) -> None: + """Record a pre-timed stage measurement if collection is active.""" + collector = current_collector() + if collector is None: + return + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_ndd_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, +) -> None: + """Record one DataDesigner workflow execution through the adapter boundary.""" + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="ndd_workflow", + extra_fields=None, + ) + + +def record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, + extra_fields: Mapping[str, Any] | None = None, +) -> None: + """Record one sanitized model-backed workflow execution. + + Use this for non-DataDesigner model calls that still need benchmark + accounting. Raw prompts, text, responses, and replacement values do not + belong in ``model_usage``. + """ + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="model_workflow", + extra_fields=extra_fields, + ) + + +def _record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str, + seed_row_count: int | None, + preview_num_records: int | None, + column_count: int | None, + column_names: list[str] | None, + model_usage: Mapping[str, Any] | None, + record_type: str, + extra_fields: Mapping[str, Any] | None, +) -> None: + collector = current_collector() + if collector is None: + return + observed_usage = _summarize_model_usage(model_usage) + workflow_fields = { + "workflow_name": workflow_name, + "status": status, + "model_aliases": sorted(set(model_aliases)), + "input_row_count": input_row_count, + "seed_row_count": seed_row_count, + "output_row_count": output_row_count, + "failed_record_count": failed_record_count, + "elapsed_sec": elapsed_sec, + "preview_num_records": preview_num_records, + "column_count": column_count, + "column_names": column_names or [], + "model_usage": dict(model_usage or {}), + **dict(extra_fields or {}), + } + collector.record(record_type, **_model_workflow_fields(workflow_fields, observed_usage)) + + +def _model_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: + return { + **fields, + **observed_usage, + "observed_failed_request_rate": _safe_ratio( + observed_usage["observed_failed_requests"], + observed_usage["observed_total_requests"], + ), + **_throughput_fields( + elapsed_sec=cast(float, fields["elapsed_sec"]), + input_row_count=cast(int, fields["input_row_count"]), + output_row_count=cast(int | None, fields["output_row_count"]), + total_tokens=observed_usage["observed_total_tokens"], + total_requests=observed_usage["observed_total_requests"], + successful_requests=observed_usage["observed_successful_requests"], + ), + } + + +def record_run_metadata( + *, + config: Any, + data: Any, + mode: str, + strategy: str, + input_row_count: int, + preview_num_records: int | None, + model_configs: list[Any], +) -> None: + """Record sanitized run/config metadata once per anonymizer run.""" + collector = current_collector() + if collector is None: + return + + detect = getattr(config, "detect", None) + source = str(getattr(data, "source", "")) + collector.record( + "run", + mode=mode, + strategy=strategy, + input_row_count=input_row_count, + preview_num_records=preview_num_records, + source_hash=collector.record_hash(row_index="source", text=source), + input_source=_source_metadata(source), + input_text_column=str(getattr(data, "text_column", "")), + input_has_id_column=bool(getattr(data, "id_column", None)), + input_has_data_summary=bool(getattr(data, "data_summary", None)), + detect=_detect_config_metadata(detect), + replace=_replace_config_metadata(getattr(config, "replace", None)), + rewrite=_rewrite_config_metadata(getattr(config, "rewrite", None)), + models=[_model_config_metadata(model_config) for model_config in model_configs], + runtime=_runtime_metadata(), + ) + + +def record_record_metrics( + dataframe: pd.DataFrame, + *, + mode: str, + strategy: str, + text_column: str, + validation_max_entities_per_call: int, +) -> None: + """Record per-row count, length, and nominal-call metrics from a trace DataFrame.""" + collector = current_collector() + if collector is None or not collector.record_level: + return + + ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) + columns = set(dataframe.columns) + for row_index, row in dataframe.iterrows(): + final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) + collector.record( + "record", + **_base_record_fields( + collector=collector, + row_index=row_index, + row=row, + text_column=text_column, + mode=mode, + strategy=strategy, + ), + **_entity_record_fields(row, final_entities=final_entities, ground_truth_column=ground_truth_column), + **_replacement_record_fields(row, columns=columns, final_entities=final_entities), + **_rewrite_record_fields(row, columns=columns), + **_original_value_leak_record_fields(row, columns=columns, final_entities=final_entities), + **_llm_record_fields( + row, + columns=columns, + mode=mode, + strategy=strategy, + final_entity_count=len(final_entities), + validation_max_entities_per_call=validation_max_entities_per_call, + ), + ) + + +def _detect_config_metadata(detect: Any | None) -> dict[str, Any]: + entity_labels = getattr(detect, "entity_labels", None) + if entity_labels is None: + from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS + + entity_label_count = len(DEFAULT_ENTITY_LABELS) + else: + entity_label_count = len(entity_labels) + return { + "gliner_threshold": getattr(detect, "gliner_threshold", None), + "entity_label_source": "custom" if entity_labels is not None else "default", + "entity_label_count": entity_label_count, + "entity_labels": list(entity_labels) if entity_labels is not None else None, + "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), + "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), + } + + +def _base_record_fields( + *, + collector: MeasurementCollector, + row_index: object, + row: Any, + text_column: str, + mode: str, + strategy: str, +) -> dict[str, Any]: + text = str(row.get(text_column, "")) + text_length_tokens = _count_text_tokens(text) + return { + "mode": mode, + "strategy": strategy, + "row_index": _safe_row_index(row_index), + "record_hash": collector.record_hash(row_index=row_index, text=text), + "text_length_chars": len(text), + "text_length_chars_bucket": _size_bucket(len(text)), + "text_length_tokens": text_length_tokens, + "text_length_tokens_bucket": _size_bucket(text_length_tokens), + } + + +def _entity_record_fields( + row: Any, + *, + final_entities: list[dict[str, Any]], + ground_truth_column: str | None, +) -> dict[str, Any]: + ground_truth_entities = ( + _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None + ) + return { + "final_entity_count": len(final_entities), + "final_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) + ), + **_entity_ground_truth_metrics(final_entities, ground_truth_entities), + } + + +def _replacement_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPLACEMENT_MAP + + if COL_REPLACEMENT_MAP not in columns: + return {} + raw_map = row.get(COL_REPLACEMENT_MAP) + return { + **_replacement_map_metrics(raw_map), + **_replacement_coverage_metrics(raw_map, final_entities), + **_replacement_collision_metrics(raw_map, final_entities), + } + + +def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: + from anonymizer.engine.constants import ( + COL_ANY_HIGH_LEAKED, + COL_LEAKAGE_MASS, + COL_NEEDS_HUMAN_REVIEW, + COL_NEEDS_REPAIR, + COL_UTILITY_SCORE, + COL_WEIGHTED_LEAKAGE_RATE, + ) + + return { + "utility_score": _coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in columns else None, + "leakage_mass": _coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in columns else None, + "weighted_leakage_rate": ( + _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) if COL_WEIGHTED_LEAKAGE_RATE in columns else None + ), + "any_high_leaked": _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in columns else None, + "needs_human_review": ( + _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in columns else None + ), + "needs_repair": _coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in columns else None, + } + + +def _original_value_leak_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: + output_column = _output_text_column(columns) + if output_column is None: + return {"original_value_leak_count": None, "original_value_leak_label_counts": {}} + output_text = str(row.get(output_column, "")) + leaked = [ + entity + for entity in final_entities + if entity.get("value") and _output_contains_original_value(output_text, str(entity.get("value"))) + ] + return { + "original_value_leak_count": len(leaked), + "original_value_leak_label_counts": dict( + sorted(Counter(str(entity.get("label") or "") for entity in leaked if entity.get("label")).items()) + ), + } + + +def _output_contains_original_value(output_text: str, value: str) -> bool: + if _needs_boundary_sensitive_leak_match(value): + return _contains_with_alnum_boundaries(output_text, value) + return value in output_text + + +def _needs_boundary_sensitive_leak_match(value: str) -> bool: + return len(value) <= 4 or value.isdigit() + + +def _contains_with_alnum_boundaries(output_text: str, value: str) -> bool: + start = 0 + while True: + match_start = output_text.find(value, start) + if match_start < 0: + return False + match_end = match_start + len(value) + if _has_alnum_boundaries(output_text, match_start, match_end): + return True + start = match_start + 1 + + +def _has_alnum_boundaries(text: str, start: int, end: int) -> bool: + before_is_alnum = start > 0 and text[start - 1].isalnum() + after_is_alnum = end < len(text) and text[end].isalnum() + return not before_is_alnum and not after_is_alnum + + +def _output_text_column(columns: set[str]) -> str | None: + from anonymizer.engine.constants import COL_REPLACED_TEXT, COL_REWRITTEN_TEXT + + if COL_REPLACED_TEXT in columns: + return COL_REPLACED_TEXT + if COL_REWRITTEN_TEXT in columns: + return COL_REWRITTEN_TEXT + return None + + +def _llm_record_fields( + row: Any, + *, + columns: set[str], + mode: str, + strategy: str, + final_entity_count: int, + validation_max_entities_per_call: int, +) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPAIR_ITERATIONS + + detected_candidate_count = _detected_candidate_count(row, columns=columns) + validation_chunk_count = _validation_chunk_count( + detected_candidate_count, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + grouped_entity_count = _grouped_entity_count(row, columns=columns, final_entity_count=final_entity_count) + repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) + replace_map_generation_uses_llm = _replace_map_generation_uses_llm(row, columns=columns) + calls_by_stage = estimate_llm_calls_by_stage( + mode=mode, + strategy=strategy, + has_grouped_entities=grouped_entity_count > 0, + validation_chunk_count=validation_chunk_count, + repair_iterations=repair_iterations, + replace_map_generation_uses_llm=replace_map_generation_uses_llm, + ) + total_estimated = ( + sum(calls_by_stage.values()) if all(value is not None for value in calls_by_stage.values()) else None + ) + return { + "detected_candidate_count": detected_candidate_count, + "validation_chunk_count": validation_chunk_count, + "repair_iterations": repair_iterations if mode == "rewrite" else 0, + "llm_calls_estimated_by_stage": calls_by_stage, + "llm_calls_estimated_total": total_estimated, + } + + +def _replace_map_generation_uses_llm(row: Any, *, columns: set[str]) -> bool: + from anonymizer.engine.constants import COL_REPLACEMENT_MAP_SOURCE + from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + + if COL_REPLACEMENT_MAP_SOURCE not in columns: + return True + return row.get(COL_REPLACEMENT_MAP_SOURCE) != REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + + +def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: + from anonymizer.engine.constants import COL_SEED_VALIDATION_CANDIDATES + + if COL_SEED_VALIDATION_CANDIDATES not in columns: + return None + return _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) + + +def _grouped_entity_count(row: Any, *, columns: set[str], final_entity_count: int) -> int: + from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE + + if COL_ENTITIES_BY_VALUE not in columns: + return final_entity_count + return _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) + + +def _source_metadata(source: str) -> dict[str, Any]: + parsed = urlparse(source) + if parsed.scheme in {"http", "https"}: + return { + "kind": "remote_file", + "scheme": parsed.scheme, + "suffix": Path(parsed.path).suffix.lower() or None, + } + if parsed.scheme == "file": + return { + "kind": "local_file", + "scheme": "file", + "suffix": Path(parsed.path).suffix.lower() or None, + } + return { + "kind": "local_file" if source else "unknown", + "scheme": None, + "suffix": Path(source).suffix.lower() or None, + } + + +def _replace_config_metadata(replace_config: Any | None) -> dict[str, Any] | None: + if replace_config is None: + return None + + metadata: dict[str, Any] = { + "strategy": type(replace_config).__name__, + "has_instructions": bool(getattr(replace_config, "instructions", None)), + } + for attr in ("normalize_label", "algorithm", "digest_length"): + if hasattr(replace_config, attr): + metadata[attr] = getattr(replace_config, attr) + if hasattr(replace_config, "format_template"): + metadata["has_format_template"] = True + return metadata + + +def _rewrite_config_metadata(rewrite_config: Any | None) -> dict[str, Any] | None: + if rewrite_config is None: + return None + return { + "risk_tolerance": _enum_value(getattr(rewrite_config, "risk_tolerance", None)), + "max_repair_iterations": getattr(rewrite_config, "max_repair_iterations", None), + "strict_entity_protection": getattr(rewrite_config, "strict_entity_protection", None), + "has_privacy_goal": bool(getattr(rewrite_config, "privacy_goal", None)), + "has_instructions": bool(getattr(rewrite_config, "instructions", None)), + } + + +def _model_config_metadata(model_config: Any) -> dict[str, Any]: + inference_parameters = getattr(model_config, "inference_parameters", None) + return { + "alias": getattr(model_config, "alias", None), + "model": getattr(model_config, "model", None), + "provider": _enum_value(getattr(model_config, "provider", None)), + "base_url": bool(getattr(model_config, "base_url", None)), + "max_parallel_requests": getattr(inference_parameters, "max_parallel_requests", None), + } + + +def _runtime_metadata() -> dict[str, Any]: + try: + anonymizer_version = version("nemo-anonymizer") + except PackageNotFoundError: + anonymizer_version = None + return { + "anonymizer_version": anonymizer_version, + "measurement_schema_version": MEASUREMENT_SCHEMA_VERSION, + "platform_machine": platform.machine(), + "platform_system": platform.system(), + "python_version": platform.python_version(), + } + + +def _enum_value(value: Any) -> Any: + return getattr(value, "value", value) + + +def _throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, + total_tokens: int | None, + total_requests: int | None, + successful_requests: int | None, +) -> dict[str, float | None]: + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + "observed_tokens_per_sec": _safe_rate(total_tokens, elapsed_sec), + "observed_requests_per_sec": _safe_rate(total_requests, elapsed_sec), + "observed_tokens_per_successful_request": _safe_ratio(total_tokens, successful_requests), + } + + +def _row_throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, +) -> dict[str, float | None]: + if input_row_count is not None and input_row_count < 0: + input_row_count = None + if output_row_count is not None and output_row_count < 0: + output_row_count = None + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + } + + +def estimate_llm_calls_by_stage( + *, + mode: str, + strategy: str, + has_grouped_entities: bool, + validation_chunk_count: int | None, + repair_iterations: int = 0, + replace_map_generation_uses_llm: bool = True, +) -> dict[str, int | None]: + """Estimate nominal model calls for one record, split by workflow stage.""" + detection_calls = None if validation_chunk_count is None else 2 + validation_chunk_count + replace_map_generation = 0 + if replace_map_generation_uses_llm and has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): + replace_map_generation = 1 + + if mode != "rewrite": + return { + "entity_detection": detection_calls, + "replace_map_generation": replace_map_generation, + } + + rewrite_body_calls = has_grouped_entities + return { + "entity_detection": detection_calls, + "latent_entity_detection": 1 if rewrite_body_calls else 0, + "replace_map_generation": replace_map_generation, + "rewrite_pipeline": 5 if rewrite_body_calls else 0, + "rewrite_evaluate": 3 * (1 + repair_iterations) if rewrite_body_calls else 0, + "rewrite_repair": repair_iterations if rewrite_body_calls else 0, + "rewrite_final_judge": 1 if rewrite_body_calls else 0, + } + + +def _write_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + config.write_collector(collector) + except Exception as exc: + logger.warning("Failed to write Anonymizer measurement records (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise + + +def _close_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + collector.close() + except Exception as exc: + logger.warning("Failed to close Anonymizer measurement stream (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise + + +def _measurement_env_error_message(exc: SettingsError | ValidationError, *, prefix: str) -> str: + fields: set[str] = set() + if isinstance(exc, ValidationError): + for error in exc.errors(include_input=False): + loc = error.get("loc", ()) + if loc: + fields.add(str(loc[0]).upper()) + else: + error_text = str(exc).lower() + for field_name in _MeasurementEnvSettings.model_fields: + if field_name in error_text: + fields.add(field_name.upper()) + + if fields: + env_fields = ", ".join(f"{prefix}{field}" for field in sorted(fields)) + return f"Invalid Anonymizer measurement environment configuration for: {env_fields}" + return "Invalid Anonymizer measurement environment configuration" + + +def _load_measurement_env_settings(*, prefix: str) -> _MeasurementEnvSettings: + settings_factory = cast(Any, _MeasurementEnvSettings) + return cast(_MeasurementEnvSettings, settings_factory(_env_prefix=prefix)) + + +def _validation_chunk_count( + detected_candidate_count: int | None, + *, + validation_max_entities_per_call: int, +) -> int | None: + if detected_candidate_count is None: + return None + if detected_candidate_count <= 0: + return 0 + return int(math.ceil(detected_candidate_count / validation_max_entities_per_call)) + + +def _safe_row_index(row_index: object) -> int | None: + if isinstance(row_index, bool): + return None + if isinstance(row_index, Integral): + return int(row_index) + return None + + +def _entities_from_raw(raw: object) -> list[dict[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + items = cast(Mapping[str, Any], payload).get("entities", []) + elif isinstance(payload, list): + items = payload + else: + items = [] + return [dict(cast(Mapping[str, Any], item)) for item in items if isinstance(item, Mapping)] + + +def _entity_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]] | None, +) -> dict[str, Any]: + if ground_truth_entities is None: + return { + "ground_truth_entity_count": None, + "ground_truth_entity_label_counts": None, + "entity_true_positive_count": None, + "entity_false_positive_count": None, + "entity_false_negative_count": None, + "entity_precision": None, + "entity_recall": None, + "entity_f1": None, + "entity_relaxed_gt_found_count": None, + "entity_relaxed_detected_tp_count": None, + "entity_relaxed_label_compatible_gt_found_count": None, + "entity_relaxed_label_compatible_detected_tp_count": None, + "entity_relaxed_precision": None, + "entity_relaxed_recall": None, + "entity_relaxed_f1": None, + "entity_relaxed_label_compatible_precision": None, + "entity_relaxed_label_compatible_recall": None, + "entity_relaxed_label_compatible_f1": None, + } + + predicted = _entity_identity_set(final_entities) + expected = _entity_identity_set(ground_truth_entities) + true_positive = len(predicted & expected) + false_positive = len(predicted - expected) + false_negative = len(expected - predicted) + precision = _safe_ratio(true_positive, true_positive + false_positive) + recall = _safe_ratio(true_positive, true_positive + false_negative) + return { + "ground_truth_entity_count": len(ground_truth_entities), + "ground_truth_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in ground_truth_entities if e.get("label")).items()) + ), + "entity_true_positive_count": true_positive, + "entity_false_positive_count": false_positive, + "entity_false_negative_count": false_negative, + "entity_precision": precision, + "entity_recall": recall, + "entity_f1": _f1(precision, recall), + **_entity_relaxed_ground_truth_metrics(final_entities, ground_truth_entities), + } + + +def _entity_identity_set(entities: list[dict[str, Any]]) -> set[tuple[str, str]]: + identities: set[tuple[str, str]] = set() + for entity in entities: + label = entity.get("label") + value = entity.get("value") + if label is None or value is None: + continue + identities.add((str(value), str(label))) + return identities + + +def _entity_relaxed_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]], +) -> dict[str, Any]: + gt_found = sum( + 1 + for ground_truth_entity in ground_truth_entities + if _has_relaxed_entity_match(final_entities, ground_truth_entity) + ) + detected_tp = sum( + 1 for final_entity in final_entities if _has_relaxed_entity_match(ground_truth_entities, final_entity) + ) + label_compatible_gt_found = sum( + 1 + for ground_truth_entity in ground_truth_entities + if _has_relaxed_entity_match(final_entities, ground_truth_entity, require_label_compatible=True) + ) + label_compatible_detected_tp = sum( + 1 + for final_entity in final_entities + if _has_relaxed_entity_match(ground_truth_entities, final_entity, require_label_compatible=True) + ) + precision = _safe_ratio(detected_tp, len(final_entities)) + recall = _safe_ratio(gt_found, len(ground_truth_entities)) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, len(final_entities)) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, len(ground_truth_entities)) + return { + "entity_relaxed_gt_found_count": gt_found, + "entity_relaxed_detected_tp_count": detected_tp, + "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, + "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, + "entity_relaxed_precision": precision, + "entity_relaxed_recall": recall, + "entity_relaxed_f1": _f1(precision, recall), + "entity_relaxed_label_compatible_precision": label_compatible_precision, + "entity_relaxed_label_compatible_recall": label_compatible_recall, + "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), + } + + +def _has_relaxed_entity_match( + candidates: list[dict[str, Any]], + target: dict[str, Any], + *, + require_label_compatible: bool = False, +) -> bool: + return any( + _entities_match_relaxed(candidate, target, require_label_compatible=require_label_compatible) + for candidate in candidates + ) + + +def _entities_match_relaxed( + left: dict[str, Any], + right: dict[str, Any], + *, + require_label_compatible: bool, +) -> bool: + if require_label_compatible and not _entity_labels_compatible(left.get("label"), right.get("label")): + return False + left_span = _entity_span(left) + right_span = _entity_span(right) + if left_span is not None and right_span is not None: + return left_span[0] < right_span[1] and right_span[0] < left_span[1] + left_value = left.get("value") + right_value = right.get("value") + return left_value is not None and right_value is not None and str(left_value) == str(right_value) + + +def _entity_span(entity: dict[str, Any]) -> tuple[int, int] | None: + start = _coerce_float(entity.get("start_position", entity.get("start"))) + end = _coerce_float(entity.get("end_position", entity.get("end"))) + if start is None or end is None: + return None + start_int = int(start) + end_int = int(end) + if start_int < 0 or end_int <= start_int: + return None + return start_int, end_int + + +def _entity_labels_compatible(left: object, right: object) -> bool: + left_key = _entity_label_key(left) + right_key = _entity_label_key(right) + return left_key is not None and right_key is not None and left_key == right_key + + +def _entity_label_key(label: object) -> str | None: + if label is None: + return None + normalized = str(label).strip().lower() + if not normalized: + return None + return _ENTITY_LABEL_EQUIVALENCE.get(normalized, normalized) + + +def _replacement_map_metrics(raw: object) -> dict[str, Any]: + replacement_maps = _replacement_maps_from_raw(raw) + synthetic_values = [] + for item in replacement_maps: + synthetic = item.get("replacement", item.get("synthetic")) + if synthetic is not None: + synthetic_values.append(str(synthetic)) + return { + "replacement_count": len(replacement_maps), + "replacement_label_counts": dict( + sorted(Counter(item.get("label", "") for item in replacement_maps if item.get("label")).items()) + ), + "replacement_duplicate_value_count": max(0, len(synthetic_values) - len(set(synthetic_values))), + } + + +def _replacement_coverage_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + replacement_original_values = { + str(original) + for item in _replacement_maps_from_raw(raw) + if (original := item.get("original")) is not None and str(original) + } + missing_entities = [ + entity + for entity in final_entities + if entity.get("value") and str(entity.get("value")) not in replacement_original_values + ] + missing_values = {str(entity.get("value")) for entity in missing_entities if entity.get("value")} + return { + "replacement_missing_final_entity_count": len(missing_entities), + "replacement_missing_final_entity_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in missing_entities if entity.get("label")).items() + ) + ), + "replacement_missing_final_value_count": len(missing_values), + } + + +def _replacement_collision_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + synthetic_values = { + str(synthetic) + for item in _replacement_maps_from_raw(raw) + if (synthetic := item.get("replacement", item.get("synthetic"))) is not None and str(synthetic) + } + collided_entities = [ + entity for entity in final_entities if entity.get("value") and str(entity.get("value")) in synthetic_values + ] + collided_values = {str(entity.get("value")) for entity in collided_entities if entity.get("value")} + return { + "replacement_synthetic_original_collision_count": len(collided_entities), + "replacement_synthetic_original_collision_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in collided_entities if entity.get("label")).items() + ) + ), + "replacement_synthetic_original_collision_value_count": len(collided_values), + } + + +def _replacement_maps_from_raw(raw: object) -> list[Mapping[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + replacements_raw = cast(Mapping[str, Any], payload).get("replacements") + tolist = getattr(replacements_raw, "tolist", None) + if callable(tolist): + replacements_raw = tolist() + replacements = replacements_raw if isinstance(replacements_raw, list) else [] + elif isinstance(payload, list): + replacements = payload + else: + replacements = [] + return [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] + + +def _count_items(raw: object, *, primary_key: str, fallback_keys: tuple[str, ...] = ()) -> int: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + payload_map = cast(Mapping[str, Any], payload) + for key in (primary_key, *fallback_keys): + items = payload_map.get(key) + if isinstance(items, list): + return len(items) + return 0 + if isinstance(payload, list): + return len(payload) + return 0 + + +def _coerce_payload(raw: object) -> object: + model_dump = getattr(raw, "model_dump", None) + if callable(model_dump): + return model_dump(mode="python") + if isinstance(raw, str): + try: + return json.loads(raw) + except json.JSONDecodeError: + return {} + if raw is None: + return {} + return raw + + +def _coerce_int(raw: object, *, default: int) -> int: + try: + return int(cast(Any, raw)) + except (TypeError, ValueError): + return default + + +def _coerce_float(raw: object) -> float | None: + try: + value = float(cast(Any, raw)) + except (TypeError, ValueError): + return None + return None if math.isnan(value) else value + + +def _coerce_bool(raw: object) -> bool | None: + if raw is None: + return None + if isinstance(raw, float) and math.isnan(raw): + return None + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered in {"true", "1", "yes"}: + return True + if lowered in {"false", "0", "no"}: + return False + return None + try: + return bool(cast(Any, raw)) + except (TypeError, ValueError): + return None + + +def _safe_rate(numerator: int | float | None, elapsed_sec: float) -> float | None: + if numerator is None or elapsed_sec <= 0: + return None + return float(numerator) / elapsed_sec + + +def _safe_ratio(numerator: int | float | None, denominator: int | float | None) -> float | None: + if numerator is None or denominator is None or denominator == 0: + return None + return float(numerator) / float(denominator) + + +def _f1(precision: float | None, recall: float | None) -> float | None: + if precision is None or recall is None or precision + recall == 0: + return None + return 2 * precision * recall / (precision + recall) + + +def _size_bucket(value: int) -> str: + if value == 0: + return "0" + for upper in (128, 512, 2048, 8192): + if value < upper: + return f"1-{upper - 1}" if upper == 128 else f"{upper // 4}-{upper - 1}" + return "8192+" + + +def _count_text_tokens(text: str) -> int: + try: + import tiktoken + + tokenizer = tiktoken.get_encoding("cl100k_base") + return len(tokenizer.encode(text, disallowed_special=())) + except Exception: + return len(text.split()) + + +def _summarize_model_usage(model_usage: Mapping[str, Any] | None) -> dict[str, int | None]: + totals = _empty_model_usage_totals() + for usage in (model_usage or {}).values(): + if not isinstance(usage, Mapping): + continue + _add_model_usage_totals(totals, usage) + + if totals["total_tokens"] == 0: + totals["total_tokens"] = totals["input_tokens"] + totals["output_tokens"] + if totals["total_requests"] == 0: + totals["total_requests"] = totals["successful_requests"] + totals["failed_requests"] + + return { + "observed_input_tokens": totals["input_tokens"], + "observed_output_tokens": totals["output_tokens"], + "observed_total_tokens": totals["total_tokens"], + "observed_reasoning_tokens": totals["reasoning_tokens"] if totals["has_reasoning_tokens"] else None, + "observed_successful_requests": totals["successful_requests"], + "observed_failed_requests": totals["failed_requests"], + "observed_total_requests": totals["total_requests"], + } + + +def _empty_model_usage_totals() -> dict[str, int | bool]: + return { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + "has_reasoning_tokens": False, + "successful_requests": 0, + "failed_requests": 0, + "total_requests": 0, + } + + +def _add_model_usage_totals(totals: dict[str, int | bool], usage: Mapping[str, Any]) -> None: + token_usage = usage.get("token_usage") + if isinstance(token_usage, Mapping): + totals["input_tokens"] += _coerce_int(token_usage.get("input_tokens"), default=0) + totals["output_tokens"] += _coerce_int(token_usage.get("output_tokens"), default=0) + totals["total_tokens"] += _coerce_int(token_usage.get("total_tokens"), default=0) + if token_usage.get("reasoning_tokens") is not None: + totals["has_reasoning_tokens"] = True + totals["reasoning_tokens"] += _coerce_int(token_usage.get("reasoning_tokens"), default=0) + + request_usage = usage.get("request_usage") + if isinstance(request_usage, Mapping): + totals["successful_requests"] += _coerce_int(request_usage.get("successful_requests"), default=0) + totals["failed_requests"] += _coerce_int(request_usage.get("failed_requests"), default=0) + totals["total_requests"] += _coerce_int(request_usage.get("total_requests"), default=0) + + +def _json_safe(value: object) -> Any: + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, list): + return [_json_safe(v) for v in value] + if isinstance(value, tuple): + return [_json_safe(v) for v in value] + if isinstance(value, set): + return sorted((_json_safe(v) for v in value), key=str) + if isinstance(value, float) and not math.isfinite(value): + return None + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return str(value) diff --git a/tests/engine/test_chunked_validation.py b/tests/engine/test_chunked_validation.py index f9b402a3..6eb3ed9e 100644 --- a/tests/engine/test_chunked_validation.py +++ b/tests/engine/test_chunked_validation.py @@ -447,6 +447,41 @@ def test_single_chunk_sends_single_chunk_tagged_text_not_windowed_excerpt(self) "around Alice/Bob would clip the suffix entirely." ) + def test_single_chunk_can_use_compact_excerpt_when_configured(self) -> None: + prefix = "START_ONLY_MARKER " + ("prefix filler " * 80) + suffix = (" suffix filler" * 80) + " END_ONLY_MARKER" + middle = "Alice met Bob." + text = prefix + middle + suffix + alice_start = len(prefix) + bob_start = alice_start + 10 + + spans = [ + _entity_span("a", "Alice", "first_name", alice_start, alice_start + 5), + _entity_span("b", "Bob", "first_name", bob_start, bob_start + 3), + ] + candidates = _candidates_schema( + ("a", "Alice", "first_name"), + ("b", "Bob", "first_name"), + ) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade("v0", response={"decisions": [{"id": "a", "decision": "keep"}]}) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=20, + single_chunk_full_text=False, + prompt_template=_MINIMAL_TEMPLATE, + ) + + chunked_validate_row(row, params, {"v0": facade}) + + prompt = facade.calls[0]["prompt"] + assert "Alice" in prompt + assert "Bob" in prompt + assert "START_ONLY_MARKER" not in prompt + assert "END_ONLY_MARKER" not in prompt + def test_empty_candidates_short_circuits_without_calls(self) -> None: row = _build_row(text="hello", seed_entities=[], candidates=_candidates_schema()) facade = FakeFacade("v0", response={"decisions": []}) diff --git a/tests/engine/test_llm_replace_workflow.py b/tests/engine/test_llm_replace_workflow.py index f7abbc44..f6ae372b 100644 --- a/tests/engine/test_llm_replace_workflow.py +++ b/tests/engine/test_llm_replace_workflow.py @@ -330,6 +330,41 @@ def test_filter_replacement_map_anomaly_summaries_do_not_leak_pii( _assert_no_pii_in_logs(caplog, extra_secrets=("Acme Corp", "NovaCorp")) +def test_filter_replacement_map_repairs_synthetic_original_collisions_without_pii( + caplog: pytest.LogCaptureFixture, +) -> None: + """Synthetic values must not reuse another protected original from the same row.""" + parsed_entities = EntitiesByValueSchema.model_validate( + { + "entities_by_value": [ + {"value": "1979-01-01", "labels": ["date"]}, + {"value": "1980-02-02", "labels": ["date"]}, + ] + } + ) + raw_map = { + "replacements": [ + {"original": "1979-01-01", "label": "date", "synthetic": "1980-02-02"}, + {"original": "1980-02-02", "label": "date", "synthetic": "1991-03-04"}, + ] + } + + with caplog.at_level(logging.WARNING, logger="anonymizer"): + result = _filter_replacement_map_to_input_entities( + raw_map=raw_map, parsed_entities=parsed_entities, record_id="row-collision" + ) + + assert result == { + "replacements": [ + {"original": "1979-01-01", "label": "date", "synthetic": "[SUBSTITUTE_DATE_1]"}, + {"original": "1980-02-02", "label": "date", "synthetic": "1991-03-04"}, + ] + } + assert "synthetic-original collision" in caplog.text + assert "date" in caplog.text + _assert_no_pii_in_logs(caplog, extra_secrets=("1979-01-01", "1980-02-02", "1991-03-04")) + + def test_filter_replacement_map_empty_warning_does_not_leak_pii( caplog: pytest.LogCaptureFixture, ) -> None: diff --git a/tests/engine/test_ndd_adapter.py b/tests/engine/test_ndd_adapter.py index ea0a01a8..7563f371 100644 --- a/tests/engine/test_ndd_adapter.py +++ b/tests/engine/test_ndd_adapter.py @@ -13,6 +13,7 @@ from data_designer.config.models import ModelConfig from data_designer.interface.data_designer import DataDesigner +from anonymizer.engine.ndd import adapter as ndd_adapter from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter from anonymizer.interface.errors import AnonymizerWorkflowError @@ -60,6 +61,10 @@ def _make_columns() -> list[ColumnConfigT]: ] +def test_as_alias_list_drops_none_items_before_stringifying() -> None: + assert ndd_adapter._as_alias_list(["validator", None, "", 0]) == ["validator", "0"] + + def test_attach_record_ids_adds_deterministic_ids() -> None: adapter = NddAdapter(data_designer=Mock(spec=DataDesigner)) input_df = pd.DataFrame({"text": ["a", "b"]}) diff --git a/tests/engine/test_structured_substitute.py b/tests/engine/test_structured_substitute.py new file mode 100644 index 00000000..83014cba --- /dev/null +++ b/tests/engine/test_structured_substitute.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pandas as pd +import pytest + +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP +from anonymizer.engine.replace import structured_substitute as structured_substitute_module +from anonymizer.engine.replace.structured_substitute import ( + apply_structured_substitution_maps, + build_structured_substitution_map, + structured_substitute_value, +) + + +@pytest.mark.parametrize( + ("label", "value"), + [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("date_of_birth", "1978-02-03"), + ("email", "alice@example.com"), + ("http_cookie", "session_id=abc123xyz; user_id=12345; auth_token=secret-token"), + ("organization_name", "Acme Research Center"), + ("password", "CorrectHorse!123"), + ("pin", "97294"), + ("religious_belief", "secular"), + ("street_address", "123 Maple Street"), + ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), + ("url", "https://staging.example.com/admin"), + ("user_name", "sloanenguy217"), + ], +) +def test_structured_substitute_value_does_not_preserve_original(label: str, value: str) -> None: + synthetic = structured_substitute_value(value, label) + + assert synthetic != value + assert value not in synthetic + + +def test_build_structured_substitution_map_for_entities_by_value() -> None: + raw_entities = { + "entities_by_value": [ + {"value": "alice@example.com", "labels": ["email"]}, + {"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "labels": ["api_key"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + replacements = replacement_map["replacements"] + assert [(item["original"], item["label"]) for item in replacements] == [ + ("alice@example.com", "email"), + ("sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "api_key"), + ] + serialized = str(replacement_map) + assert "alice@example.com" in serialized + assert "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" in serialized + assert all(item["original"] != item["synthetic"] for item in replacements) + + +def test_build_structured_substitution_map_avoids_other_original_values(monkeypatch: pytest.MonkeyPatch) -> None: + first_original = "alice@example.com" + second_original = "bob@example.com" + first_default_digest = structured_substitute_module._digest(value=first_original, label="email") + + def synthetic_email(value: str, digest: str) -> str: + if value == first_original and digest == first_default_digest: + return second_original + return f"user-{digest[:12]}@example.invalid" + + monkeypatch.setitem(structured_substitute_module._GENERATORS, "email", synthetic_email) + raw_entities = { + "entities_by_value": [ + {"value": first_original, "labels": ["email"]}, + {"value": second_original, "labels": ["email"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + replacements = {(item["original"], item["label"]): item["synthetic"] for item in replacement_map["replacements"]} + assert replacements[(first_original, "email")] != second_original + assert set(replacements.values()).isdisjoint({first_original, second_original}) + + +def test_build_structured_substitution_map_avoids_duplicate_synthetic_values( + monkeypatch: pytest.MonkeyPatch, +) -> None: + first_original = "Acme Research Center" + second_original = "Globex Research Center" + default_digests = { + structured_substitute_module._digest(value=first_original, label="organization_name"), + structured_substitute_module._digest(value=second_original, label="organization_name"), + } + + def synthetic_organization(_value: str, digest: str) -> str: + if digest in default_digests: + return "Northbridge Center" + return f"Helios {digest[:8]} Center" + + monkeypatch.setitem(structured_substitute_module._GENERATORS, "organization_name", synthetic_organization) + raw_entities = { + "entities_by_value": [ + {"value": first_original, "labels": ["organization_name"]}, + {"value": second_original, "labels": ["organization_name"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + synthetics = [item["synthetic"] for item in replacement_map["replacements"]] + assert len(synthetics) == 2 + assert len(set(synthetics)) == 2 + assert "Northbridge Center" in synthetics + + +def test_structured_cookie_substitute_preserves_cookie_shape() -> None: + synthetic = structured_substitute_value( + "session_id=abc123xyz; user_id=12345; auth_token=secret-token", + "http_cookie", + ) + + assert "session_id=" in synthetic + assert "user_id=" in synthetic + assert "auth_token=" in synthetic + session_value = synthetic.split("session_id=", 1)[1].split(";", 1)[0] + assert len(session_value) >= 16 + assert not session_value.isdigit() + assert "abc123xyz" not in synthetic + assert "secret-token" not in synthetic + + +def test_structured_unique_id_substitute_preserves_uuid_shape() -> None: + synthetic = structured_substitute_value("1ce21179-998b-447b-2dee-3e8adb6afa35", "unique_id") + + assert synthetic.count("-") == 4 + assert synthetic != "1ce21179-998b-447b-2dee-3e8adb6afa35" + + +def test_build_structured_substitution_map_rejects_unsupported_labels() -> None: + raw_entities = {"entities_by_value": [{"value": "Alice", "labels": ["person"]}]} + + with pytest.raises(ValueError, match="unsupported labels: person"): + build_structured_substitution_map(raw_entities) + + +def test_apply_structured_substitution_maps_adds_replacement_map_column() -> None: + dataframe = pd.DataFrame( + {COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": "alice@example.com", "labels": ["email"]}]}]} + ) + + output = apply_structured_substitution_maps(dataframe) + + assert COL_REPLACEMENT_MAP in output.columns + replacement = output[COL_REPLACEMENT_MAP].iloc[0]["replacements"][0] + assert replacement["original"] == "alice@example.com" + assert replacement["label"] == "email" + assert replacement["synthetic"].endswith("@example.invalid") diff --git a/tests/test_measurement.py b/tests/test_measurement.py new file mode 100644 index 00000000..38dc0881 --- /dev/null +++ b/tests/test_measurement.py @@ -0,0 +1,1257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import logging +import threading +from pathlib import Path +from types import SimpleNamespace +from typing import cast +from unittest.mock import Mock + +import numpy as np +import pandas as pd +import pytest +from data_designer.config.column_configs import LLMTextColumnConfig +from data_designer.config.models import ModelConfig +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, Usage +from data_designer.engine.models.facade import ModelFacade +from data_designer.engine.models.utils import ChatMessage +from data_designer.interface.data_designer import DataDesigner + +import anonymizer.measurement as measurement +from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect +from anonymizer.config.models import DetectionModelSelection +from anonymizer.config.replace_strategies import Redact +from anonymizer.engine.constants import ( + COL_ANY_HIGH_LEAKED, + COL_FINAL_ENTITIES, + COL_LEAKAGE_MASS, + COL_NEEDS_HUMAN_REVIEW, + COL_NEEDS_REPAIR, + COL_REPAIR_ITERATIONS, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, + COL_SEED_VALIDATION_CANDIDATES, + COL_TEXT, + COL_UTILITY_SCORE, + COL_WEIGHTED_LEAKAGE_RATE, +) +from anonymizer.engine.detection.detection_workflow import EntityDetectionResult, EntityDetectionWorkflow +from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter +from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow +from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED +from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow +from anonymizer.interface.anonymizer import Anonymizer +from anonymizer.measurement import ( + DEFAULT_MEASUREMENT_ENV_PREFIX, + MEASUREMENT_SCHEMA_VERSION, + MeasurementCollector, + MeasurementConfig, + configured_measurement_session, + estimate_llm_calls_by_stage, + measurement_session, + record_record_metrics, + stage_timer, +) + + +def test_ndd_adapter_records_workflow_measurement_without_raw_text() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme", "Bob works at Beta"], + RECORD_ID_COLUMN: ["record-a", "record-b"], + } + ) + mock_dd = Mock(spec=DataDesigner) + mock_dd.preview.return_value = SimpleNamespace(dataset=input_df.iloc[[0]].copy()) + adapter = NddAdapter(data_designer=mock_dd) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + result = adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="detector", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="detector", + ) + ], + workflow_name="entity-detection", + preview_num_records=2, + ) + + assert len(result.failed_records) == 1 + records = [record for record in collector.records if record["record_type"] == "ndd_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection" + assert record["model_aliases"] == ["detector"] + assert record["input_row_count"] == 2 + assert record["seed_row_count"] == 2 + assert record["output_row_count"] == 1 + assert record["failed_record_count"] == 1 + assert record["elapsed_sec"] >= 0 + + serialized = json.dumps(record) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert "Bob" not in serialized + + +def test_ndd_adapter_records_datadesigner_model_usage() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme"], + RECORD_ID_COLUMN: ["record-a"], + } + ) + + class UsageStats: + def model_dump(self, *, mode: str) -> dict[str, object]: + assert mode == "json" + return { + "token_usage": { + "input_tokens": 12, + "output_tokens": 4, + "total_tokens": 16, + }, + "request_usage": { + "successful_requests": 2, + "failed_requests": 1, + "total_requests": 3, + }, + } + + class ModelRegistry: + def get_model_usage_snapshot(self) -> dict[str, UsageStats]: + return {"dummy-model": UsageStats()} + + class UsageDataDesigner: + def _create_resource_provider(self, *_args: object, **_kwargs: object) -> SimpleNamespace: + return SimpleNamespace(model_registry=ModelRegistry()) + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + self._create_resource_provider("preview-dataset", _config_builder) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, UsageDataDesigner())) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="detector", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="detector", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + record = next(record for record in collector.records if record["record_type"] == "ndd_workflow") + assert record["model_usage"]["dummy-model"]["token_usage"]["input_tokens"] == 12 + assert record["observed_input_tokens"] == 12 + assert record["observed_output_tokens"] == 4 + assert record["observed_total_tokens"] == 16 + assert record["observed_successful_requests"] == 2 + assert record["observed_failed_requests"] == 1 + assert record["observed_total_requests"] == 3 + assert record["input_rows_per_sec"] >= 0 + assert record["output_rows_per_sec"] >= 0 + assert record["observed_tokens_per_sec"] >= 0 + assert record["observed_requests_per_sec"] >= 0 + assert record["observed_tokens_per_successful_request"] == 8 + + +def test_ndd_adapter_records_datadesigner_model_usage_by_alias_for_shared_model_names() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme"], + RECORD_ID_COLUMN: ["record-a"], + } + ) + + class UsageStats: + has_usage = True + + def __init__(self, *, input_tokens: int, output_tokens: int, successful: int, failed: int) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.successful = successful + self.failed = failed + + def model_dump(self, *, mode: str) -> dict[str, object]: + assert mode == "json" + return { + "token_usage": { + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "total_tokens": self.input_tokens + self.output_tokens, + }, + "request_usage": { + "successful_requests": self.successful, + "failed_requests": self.failed, + "total_requests": self.successful + self.failed, + }, + } + + class ModelRegistry: + def __init__(self) -> None: + self._models = { + "validator": SimpleNamespace( + model_alias="validator", + model_name="shared-model", + model_provider_name="local-vllm", + usage_stats=UsageStats(input_tokens=12, output_tokens=4, successful=2, failed=1), + ), + "augmenter": SimpleNamespace( + model_alias="augmenter", + model_name="shared-model", + model_provider_name="local-vllm", + usage_stats=UsageStats(input_tokens=20, output_tokens=8, successful=1, failed=0), + ), + } + + def get_model_usage_snapshot(self) -> dict[str, UsageStats]: + return { + "shared-model": UsageStats(input_tokens=999, output_tokens=999, successful=99, failed=99), + } + + class UsageDataDesigner: + def _create_resource_provider(self, *_args: object, **_kwargs: object) -> SimpleNamespace: + return SimpleNamespace(model_registry=ModelRegistry()) + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + self._create_resource_provider("preview-dataset", _config_builder) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, UsageDataDesigner())) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="validator", model="shared-model")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="validator", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + record = next(record for record in collector.records if record["record_type"] == "ndd_workflow") + assert sorted(record["model_usage"]) == ["augmenter", "validator"] + assert record["model_usage"]["validator"]["model_alias"] == "validator" + assert record["model_usage"]["validator"]["model_name"] == "shared-model" + assert record["model_usage"]["validator"]["model_provider_name"] == "local-vllm" + assert record["model_usage"]["validator"]["token_usage"]["input_tokens"] == 12 + assert record["model_usage"]["augmenter"]["token_usage"]["input_tokens"] == 20 + assert record["observed_input_tokens"] == 32 + assert record["observed_output_tokens"] == 12 + assert record["observed_total_tokens"] == 44 + assert record["observed_successful_requests"] == 3 + assert record["observed_failed_requests"] == 1 + assert record["observed_total_requests"] == 4 + + +def test_records_generic_model_workflow_usage_without_raw_text() -> None: + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + assert hasattr(measurement, "record_model_workflow") + measurement.record_model_workflow( + workflow_name="entity-detection-native-rules-router", + model_aliases=["native-direct"], + input_row_count=1, + output_row_count=1, + failed_record_count=0, + elapsed_sec=0.25, + model_usage={ + "native-direct": { + "model_alias": "native-direct", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 3, + "failed_requests": 0, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + }, + }, + }, + ) + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-rules-router" + assert record["model_aliases"] == ["native-direct"] + assert record["observed_total_requests"] == 3 + assert record["observed_input_tokens"] == 30 + assert record["observed_output_tokens"] == 12 + assert record["observed_total_tokens"] == 42 + assert record["observed_failed_request_rate"] == 0 + assert record["observed_tokens_per_successful_request"] == 14 + + serialized = json.dumps(record) + assert "Alice" not in serialized + assert "sk-test" not in serialized + + +def test_anonymizer_records_per_record_measurement_without_raw_pii(tmp_path: Path) -> None: + input_csv = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_csv, index=False) + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Acme", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + validation_candidates = { + "candidates": [ + {"value": "Alice", "label": "first_name"}, + {"value": "Acme", "label": "company_name"}, + ] + } + detection_workflow = Mock(spec=EntityDetectionWorkflow) + detection_workflow.run.return_value = EntityDetectionResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [validation_candidates], + } + ), + failed_records=[], + ) + replace_runner = Mock(spec=ReplacementWorkflow) + replace_runner.run.return_value = ReplacementResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_REPLACED_TEXT: ["[REDACTED] works at [REDACTED]"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [validation_candidates], + } + ), + failed_records=[], + ) + rewrite_runner = Mock(spec=RewriteWorkflow) + rewrite_runner.run.return_value = RewriteResult(dataframe=pd.DataFrame(), failed_records=[]) + anonymizer = Anonymizer( + detection_workflow=detection_workflow, + replace_runner=replace_runner, + rewrite_runner=rewrite_runner, + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + anonymizer.run( + config=AnonymizerConfig(replace=Redact(), detect=Detect(validation_max_entities_per_call=2)), + data=AnonymizerInput(source=str(input_csv)), + ) + + record_metrics = [record for record in collector.records if record["record_type"] == "record"] + assert len(record_metrics) == 1 + record = record_metrics[0] + assert record["mode"] == "replace" + assert record["strategy"] == "Redact" + assert record["text_length_chars"] == len("Alice works at Acme") + assert record["text_length_chars_bucket"] == "1-127" + assert record["text_length_tokens"] > 0 + assert record["text_length_tokens_bucket"] == "1-127" + assert record["final_entity_count"] == 2 + assert record["final_entity_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["detected_candidate_count"] == 2 + assert record["validation_chunk_count"] == 1 + assert record["original_value_leak_count"] == 0 + assert record["original_value_leak_label_counts"] == {} + assert record["llm_calls_estimated_by_stage"] == { + "entity_detection": 3, + "replace_map_generation": 0, + } + assert record["llm_calls_estimated_total"] == 3 + assert len(record["record_hash"]) == 64 + + stage_records = [record for record in collector.records if record["record_type"] == "stage"] + assert any(record["stage"] == "Anonymizer._run_internal" for record in stage_records) + assert any(record.get("input_rows_per_sec") is not None for record in stage_records) + + run_records = [record for record in collector.records if record["record_type"] == "run"] + assert len(run_records) == 1 + run_record = run_records[0] + assert run_record["mode"] == "replace" + assert run_record["strategy"] == "Redact" + assert run_record["input_row_count"] == 1 + assert run_record["input_source"] == {"kind": "local_file", "scheme": None, "suffix": ".csv"} + assert run_record["input_text_column"] == "text" + assert run_record["input_has_id_column"] is False + assert run_record["input_has_data_summary"] is False + assert run_record["detect"]["entity_label_source"] == "default" + assert run_record["detect"]["entity_label_count"] > 0 + assert run_record["replace"]["strategy"] == "Redact" + assert run_record["replace"]["normalize_label"] is True + assert len(run_record["source_hash"]) == 64 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert str(input_csv) not in serialized + + +def test_anonymizer_measurement_config_writes_jsonl(tmp_path: Path) -> None: + input_csv = tmp_path / "input.csv" + output_jsonl = tmp_path / "measurements.jsonl" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_csv, index=False) + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + ] + } + detection_workflow = Mock(spec=EntityDetectionWorkflow) + detection_workflow.run.return_value = EntityDetectionResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + } + ), + failed_records=[], + ) + replace_runner = Mock(spec=ReplacementWorkflow) + replace_runner.run.return_value = ReplacementResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_REPLACED_TEXT: ["[REDACTED] works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + } + ), + failed_records=[], + ) + anonymizer = Anonymizer( + detection_workflow=detection_workflow, + replace_runner=replace_runner, + rewrite_runner=Mock(spec=RewriteWorkflow), + ) + + with configured_measurement_session( + MeasurementConfig( + output_path=output_jsonl, + run_id="measurement-run", + record_hash_key="test-key", + run_tags={"config_id": "redact-default", "workload_id": "unit-small"}, + ) + ): + anonymizer.run( + config=AnonymizerConfig(replace=Redact()), + data=AnonymizerInput(source=str(input_csv)), + ) + + records = [json.loads(line) for line in output_jsonl.read_text(encoding="utf-8").splitlines()] + assert {record["record_type"] for record in records} >= {"record", "run", "stage"} + assert {record["run_id"] for record in records} == {"measurement-run"} + assert {record["schema_version"] for record in records} == {MEASUREMENT_SCHEMA_VERSION} + assert {record["run_tags"]["workload_id"] for record in records} == {"unit-small"} + assert all(isinstance(record["timestamp_unix_sec"], float) for record in records) + + serialized = json.dumps(records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert str(input_csv) not in serialized + + +def test_measurement_records_write_strict_json_safe_values(tmp_path: Path) -> None: + output_jsonl = tmp_path / "measurements.jsonl" + collector = MeasurementCollector(record_hash_key="test-key") + collector.record("run", non_finite=float("nan"), mixed_set={1, "two"}) + + collector.write_jsonl(output_jsonl) + + payload = json.loads(output_jsonl.read_text(encoding="utf-8")) + assert payload["non_finite"] is None + assert payload["mixed_set"] == [1, "two"] + + +def test_measurement_config_record_level_false_skips_record_rows(tmp_path: Path) -> None: + output_json = tmp_path / "measurements.json" + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ) + + with configured_measurement_session( + MeasurementConfig( + output_path=output_json, + output_format="json", + record_level=False, + run_id="stage-only", + record_hash_key="test-key", + ) + ): + with stage_timer("example", input_row_count=1): + pass + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + records = json.loads(output_json.read_text(encoding="utf-8")) + assert [record["record_type"] for record in records] == ["stage"] + assert records[0]["run_id"] == "stage-only" + assert records[0]["input_rows_per_sec"] >= 0 + + +def test_measurement_config_from_env_returns_none_without_output_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + prefix = "ANON_TEST_EMPTY_MEASUREMENT_" + monkeypatch.setenv(f"{prefix}RUN_ID", "env-run") + + assert MeasurementConfig.from_env(prefix=prefix) is None + + +def test_measurement_config_from_env_parses_supported_values( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + output_path = tmp_path / "measurements.json" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(output_path)) + monkeypatch.setenv(f"{prefix}OUTPUT_FORMAT", "json") + monkeypatch.setenv(f"{prefix}RECORD_LEVEL", "false") + monkeypatch.setenv(f"{prefix}FAIL_ON_WRITE_ERROR", "true") + monkeypatch.setenv(f"{prefix}RUN_ID", "env-run") + monkeypatch.setenv(f"{prefix}RUN_TAGS", '{"config_id": "redact-default", "attempt": 2}') + + config = MeasurementConfig.from_env(prefix=prefix) + + assert config is not None + assert config.output_path == str(output_path) + assert config.output_format == "json" + assert config.record_level is False + assert config.fail_on_write_error is True + assert config.run_id == "env-run" + assert config.run_tags == {"config_id": "redact-default", "attempt": 2} + + +def test_measurement_config_from_sources_keeps_env_opt_in( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(tmp_path / "env.jsonl")) + explicit = MeasurementConfig(output_path=tmp_path / "explicit.jsonl") + + assert MeasurementConfig.from_sources(env=False, prefix=prefix) is None + assert MeasurementConfig.from_sources(explicit=explicit, env=True, prefix=prefix) is explicit + + +def test_measurement_config_from_env_reports_sanitized_invalid_values( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + secret_payload = "sk-secret-token-value" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(tmp_path / "measurements.jsonl")) + monkeypatch.setenv(f"{prefix}RUN_TAGS", secret_payload) + + with pytest.raises(ValueError) as exc_info: + MeasurementConfig.from_env(prefix=prefix) + + message = str(exc_info.value) + assert f"{prefix}RUN_TAGS" in message + assert secret_payload not in message + assert str(tmp_path) not in message + + +def test_default_measurement_env_prefix_is_anonymizer_scoped() -> None: + assert DEFAULT_MEASUREMENT_ENV_PREFIX == "ANONYMIZER_MEASUREMENT_" + + +def test_measurement_config_write_errors_are_best_effort( + caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError(f"cannot write {_self.output_path}") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + caplog.set_level(logging.WARNING, logger="anonymizer.measurement") + output_path = tmp_path / "secret-output-sk-live-value.jsonl" + + with configured_measurement_session(MeasurementConfig(output_path=output_path)) as collector: + assert collector is not None + collector.record("example") + + assert "Failed to write Anonymizer measurement records" in caplog.text + assert str(output_path) not in caplog.text + assert "sk-live-value" not in caplog.text + + +def test_measurement_config_strict_write_errors_can_fail_clean_body( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError("cannot write") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + + with pytest.raises(OSError, match="cannot write"): + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", fail_on_write_error=True) + ) as collector: + assert collector is not None + collector.record("example") + + +def test_measurement_config_write_errors_do_not_mask_body_errors( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError("cannot write") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + + with pytest.raises(RuntimeError, match="body failed"): + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", fail_on_write_error=True) + ) as collector: + assert collector is not None + collector.record("example") + raise RuntimeError("body failed") + + +def test_streaming_measurement_session_writes_jsonl_without_retaining_records(tmp_path: Path) -> None: + output_path = tmp_path / "measurements.jsonl" + + with configured_measurement_session( + MeasurementConfig(output_path=output_path, streaming=True, keep_records=False) + ) as collector: + assert collector is not None + collector.record("example", value=1) + + assert collector.records == [] + assert output_path.read_text(encoding="utf-8").count("\n") == 1 + + collector.record("example", value=2) + + lines = output_path.read_text(encoding="utf-8").splitlines() + assert len(lines) == 2 + assert [json.loads(line)["value"] for line in lines] == [1, 2] + + +def test_streaming_measurement_requires_jsonl_output(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="streaming measurement output only supports jsonl"): + MeasurementConfig(output_path=tmp_path / "measurements.json", output_format="json", streaming=True) + + +def test_dd_message_trace_requires_trace_path(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="dd_trace_path is required"): + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_trace="last_message") + + +def test_ndd_adapter_writes_opt_in_dd_message_trace( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + def fake_completion( + _self: object, + _messages: list[ChatMessage], + skip_usage_tracking: bool = False, + **_kwargs: object, + ) -> ChatCompletionResponse: + assert skip_usage_tracking is False + return ChatCompletionResponse( + message=AssistantMessage(content="secret response"), + usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), + ) + + monkeypatch.setattr(ModelFacade, "completion", fake_completion) + + class TraceDataDesigner: + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + ModelFacade.completion( + SimpleNamespace(model_alias="alias", model_name="dummy-model", model_provider_name="provider"), + [ + ChatMessage.as_system("system secret"), + ChatMessage.as_user("prompt secret"), + ], + ) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + trace_path = tmp_path / "trace.jsonl" + + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", dd_trace="last_message", dd_trace_path=trace_path + ) + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="alias", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="alias", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + trace = json.loads(trace_path.read_text(encoding="utf-8").strip()) + assert trace["record_type"] == "dd_message_trace" + assert trace["workflow_name"] == "entity-detection" + assert trace["model_alias"] == "alias" + assert trace["status"] == "completed" + assert trace["messages"] == [{"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}] + assert trace["response"]["content"] == "secret response" + assert trace["usage"] == {"input_tokens": 3, "output_tokens": 2, "total_tokens": 5} + + serialized_measurements = (tmp_path / "measurements.jsonl").read_text(encoding="utf-8") + assert "prompt secret" not in serialized_measurements + assert "secret response" not in serialized_measurements + + +def test_dd_message_trace_does_not_capture_concurrent_unmeasured_calls( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + def fake_completion( + _self: object, + _messages: list[ChatMessage], + **_kwargs: object, + ) -> ChatCompletionResponse: + return ChatCompletionResponse( + message=AssistantMessage(content="response"), + usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), + ) + + monkeypatch.setattr(ModelFacade, "completion", fake_completion) + + class TraceDataDesigner: + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + errors: list[BaseException] = [] + + def concurrent_call_without_measurement_context() -> None: + try: + ModelFacade.completion( + SimpleNamespace( + model_alias="outside", + model_name="outside-model", + model_provider_name="provider", + ), + [ChatMessage.as_user("outside prompt")], + ) + except BaseException as exc: + errors.append(exc) + + thread = threading.Thread(target=concurrent_call_without_measurement_context) + thread.start() + thread.join() + assert errors == [] + + ModelFacade.completion( + SimpleNamespace(model_alias="inside", model_name="inside-model", model_provider_name="provider"), + [ChatMessage.as_user("inside prompt")], + ) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + trace_path = tmp_path / "trace.jsonl" + + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", dd_trace="all_messages", dd_trace_path=trace_path + ) + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="inside", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="inside", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] + assert [trace["model_alias"] for trace in traces] == ["inside"] + serialized_traces = json.dumps(traces) + assert "inside prompt" in serialized_traces + assert "outside prompt" not in serialized_traces + + +def test_record_metrics_capture_generic_counts_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Acme", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + ground_truth_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Beta", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + {"original": "Acme", "label": "company_name", "synthetic": "Maya"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + "ground_truth_entities": [ground_truth_entities], + COL_REPLACEMENT_MAP: [replacement_map], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + COL_REPAIR_ITERATIONS: [2], + COL_UTILITY_SCORE: [0.82], + COL_LEAKAGE_MASS: [0.2], + COL_WEIGHTED_LEAKAGE_RATE: [0.1], + COL_ANY_HIGH_LEAKED: [False], + COL_NEEDS_HUMAN_REVIEW: [True], + COL_NEEDS_REPAIR: [False], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="rewrite", + strategy="Rewrite", + text_column=COL_TEXT, + validation_max_entities_per_call=2, + ) + + record = collector.records[0] + assert record["ground_truth_entity_count"] == 2 + assert record["ground_truth_entity_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["entity_true_positive_count"] == 1 + assert record["entity_false_positive_count"] == 1 + assert record["entity_false_negative_count"] == 1 + assert record["entity_precision"] == 0.5 + assert record["entity_recall"] == 0.5 + assert record["entity_f1"] == 0.5 + assert record["entity_relaxed_gt_found_count"] == 2 + assert record["entity_relaxed_detected_tp_count"] == 2 + assert record["entity_relaxed_label_compatible_gt_found_count"] == 2 + assert record["entity_relaxed_label_compatible_detected_tp_count"] == 2 + assert record["entity_relaxed_precision"] == 1.0 + assert record["entity_relaxed_recall"] == 1.0 + assert record["entity_relaxed_f1"] == 1.0 + assert record["entity_relaxed_label_compatible_precision"] == 1.0 + assert record["entity_relaxed_label_compatible_recall"] == 1.0 + assert record["entity_relaxed_label_compatible_f1"] == 1.0 + assert record["replacement_count"] == 2 + assert record["replacement_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["replacement_duplicate_value_count"] == 1 + assert record["replacement_missing_final_entity_count"] == 0 + assert record["replacement_missing_final_entity_label_counts"] == {} + assert record["replacement_missing_final_value_count"] == 0 + assert record["replacement_synthetic_original_collision_count"] == 0 + assert record["replacement_synthetic_original_collision_label_counts"] == {} + assert record["replacement_synthetic_original_collision_value_count"] == 0 + assert record["repair_iterations"] == 2 + assert record["utility_score"] == 0.82 + assert record["leakage_mass"] == 0.2 + assert record["weighted_leakage_rate"] == 0.1 + assert record["any_high_leaked"] is False + assert record["needs_human_review"] is True + assert record["needs_repair"] is False + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert "Beta" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_capture_relaxed_gt_label_equivalence_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "builduser42", "label": "user_name", "start_position": 4, "end_position": 15}, + ] + } + ground_truth_entities = { + "entities": [ + {"value": "legacy-user", "label": "username", "start_position": 6, "end_position": 14}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["ssh builduser42@host"], + COL_FINAL_ENTITIES: [final_entities], + "ground_truth_entities": [ground_truth_entities], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=2, + ) + + record = collector.records[0] + assert record["entity_true_positive_count"] == 0 + assert record["entity_false_positive_count"] == 1 + assert record["entity_false_negative_count"] == 1 + assert record["entity_relaxed_gt_found_count"] == 1 + assert record["entity_relaxed_detected_tp_count"] == 1 + assert record["entity_relaxed_label_compatible_gt_found_count"] == 1 + assert record["entity_relaxed_label_compatible_detected_tp_count"] == 1 + assert record["entity_relaxed_precision"] == 1.0 + assert record["entity_relaxed_recall"] == 1.0 + assert record["entity_relaxed_f1"] == 1.0 + assert record["entity_relaxed_label_compatible_precision"] == 1.0 + assert record["entity_relaxed_label_compatible_recall"] == 1.0 + assert record["entity_relaxed_label_compatible_f1"] == 1.0 + + serialized = json.dumps(collector.records) + assert "builduser42" not in serialized + assert "legacy-user" not in serialized + + +def test_record_metrics_counts_missing_replacement_map_entries_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "2030-01-01", "label": "date", "start_position": 13, "end_position": 23}, + {"value": "2030-01-01", "label": "date", "start_position": 27, "end_position": 37}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice filed 2030-01-01 and 2030-01-01"], + COL_FINAL_ENTITIES: [final_entities], + COL_REPLACEMENT_MAP: [replacement_map], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["replacement_missing_final_entity_count"] == 2 + assert record["replacement_missing_final_entity_label_counts"] == {"date": 2} + assert record["replacement_missing_final_value_count"] == 1 + assert record["replacement_synthetic_original_collision_count"] == 0 + assert record["replacement_synthetic_original_collision_label_counts"] == {} + assert record["replacement_synthetic_original_collision_value_count"] == 0 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "2030-01-01" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_counts_synthetic_original_collisions_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "2030-01-01", "label": "date", "start_position": 13, "end_position": 23}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + {"original": "2029-12-01", "label": "date", "synthetic": "2030-01-01"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice filed 2030-01-01"], + COL_FINAL_ENTITIES: [final_entities], + COL_REPLACEMENT_MAP: [replacement_map], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["replacement_synthetic_original_collision_count"] == 1 + assert record["replacement_synthetic_original_collision_label_counts"] == {"date": 1} + assert record["replacement_synthetic_original_collision_value_count"] == 1 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "2030-01-01" not in serialized + assert "2029-12-01" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_counts_original_value_replacement_leaks_without_raw_values() -> None: + leaked_key = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + dataframe = pd.DataFrame( + { + COL_TEXT: [f"token={leaked_key}"], + COL_REPLACED_TEXT: [f"still token={leaked_key}"], + COL_FINAL_ENTITIES: [{"entities": [{"value": leaked_key, "label": "api_key"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 1 + assert record["original_value_leak_label_counts"] == {"api_key": 1} + assert leaked_key not in json.dumps(collector.records) + + +def test_record_metrics_ignores_short_value_inside_hash_replacement_token() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice is 34 years old."], + COL_REPLACED_TEXT: ["Alice is years old."], + COL_FINAL_ENTITIES: [{"entities": [{"value": "34", "label": "age"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 0 + assert record["original_value_leak_label_counts"] == {} + + +def test_record_metrics_counts_standalone_short_value_replacement_leaks() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice is 34 years old."], + COL_REPLACED_TEXT: ["Alice is 34 years old."], + COL_FINAL_ENTITIES: [{"entities": [{"value": "34", "label": "age"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 1 + assert record["original_value_leak_label_counts"] == {"age": 1} + + +def test_record_metrics_do_not_estimate_llm_replace_map_call_for_local_structured_source() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "label": "api_key"}]}], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACEMENT_MAP_SOURCE: [REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["llm_calls_estimated_by_stage"] == { + "entity_detection": None, + "replace_map_generation": 0, + } + assert record["llm_calls_estimated_total"] is None + + +def test_record_metrics_normalizes_integral_row_index_types() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + }, + index=pd.Index([np.int64(7)]), + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + assert collector.records[0]["row_index"] == 7 + + +def test_record_hash_uses_run_scoped_secret_by_default() -> None: + first = MeasurementCollector() + second = MeasurementCollector() + + assert first.record_hash(row_index=0, text="Alice works at Acme") != second.record_hash( + row_index=0, + text="Alice works at Acme", + ) + + +def test_stage_timer_records_errors() -> None: + workflow = EntityDetectionWorkflow(adapter=Mock(spec=NddAdapter)) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector), pytest.raises(ValueError, match="privacy_goal is required"): + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=[], + selected_models=DetectionModelSelection( + entity_detector="detector", + entity_validator=["validator"], + entity_augmenter="augmenter", + latent_detector="latent", + ), + gliner_detection_threshold=0.3, + tag_latent_entities=True, + privacy_goal=None, + ) + + stage_records = [record for record in collector.records if record["record_type"] == "stage"] + assert len(stage_records) == 1 + record = stage_records[0] + assert record["schema_version"] == MEASUREMENT_SCHEMA_VERSION + assert record["record_type"] == "stage" + assert record["run_id"] == collector.run_id + assert record["run_tags"] == {} + assert isinstance(record["timestamp_unix_sec"], float) + assert record["stage"] == "EntityDetectionWorkflow.run" + assert record["status"] == "error" + assert record["elapsed_sec"] >= 0 + assert record["input_row_count"] == 1 + assert record["input_rows_per_sec"] >= 0 + assert record["output_rows_per_sec"] is None + assert record["tag_latent_entities"] is True + + +def test_rewrite_llm_call_estimate_splits_by_stage() -> None: + calls = estimate_llm_calls_by_stage( + mode="rewrite", + strategy="Rewrite", + has_grouped_entities=True, + validation_chunk_count=2, + repair_iterations=2, + ) + + assert calls == { + "entity_detection": 4, + "latent_entity_detection": 1, + "replace_map_generation": 1, + "rewrite_pipeline": 5, + "rewrite_evaluate": 9, + "rewrite_repair": 2, + "rewrite_final_judge": 1, + } + + +def test_rewrite_llm_call_estimate_skips_rewrite_body_without_entities() -> None: + calls = estimate_llm_calls_by_stage( + mode="rewrite", + strategy="Rewrite", + has_grouped_entities=False, + validation_chunk_count=0, + repair_iterations=2, + ) + + assert calls == { + "entity_detection": 2, + "latent_entity_detection": 0, + "replace_map_generation": 0, + "rewrite_pipeline": 0, + "rewrite_evaluate": 0, + "rewrite_repair": 0, + "rewrite_final_judge": 0, + } diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py new file mode 100644 index 00000000..6d620dbe --- /dev/null +++ b/tests/tools/test_benchmark_output_analysis.py @@ -0,0 +1,1000 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "ndd_workflow", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "elapsed_sec": 8.5, + "observed_total_requests": 4, + "observed_successful_requests": 3, + "observed_input_tokens": 5000, + "observed_output_tokens": 1000, + "observed_total_tokens": 6000, + "observed_failed_requests": 1, + "model_usage": { + "nvidia/gliner-pii": { + "request_usage": { + "successful_requests": 1, + "failed_requests": 0, + "total_requests": 1, + }, + "token_usage": { + "input_tokens": 1000, + "output_tokens": 100, + "total_tokens": 1100, + }, + }, + "local-nemotron-json": { + "model_alias": "local-nemotron-json", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 2, + "failed_requests": 1, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 4000, + "output_tokens": 900, + "total_tokens": 4900, + }, + }, + }, + "detect": {"validation_max_entities_per_call": 10}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "topology_endpoint_count": 2, + "topology_gpu_count": 4, + "topology_tensor_parallelism": 2, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "stage", + "run_id": "bio__default__r000", + "stage": "Anonymizer._run_internal", + "elapsed_sec": 10.0, + "status": "completed", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "topology_endpoint_count": 2, + "topology_gpu_count": 4, + "topology_tensor_parallelism": 2, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r000", + "text_length_tokens": 1200, + "final_entity_count": 14, + "ground_truth_entity_count": 20, + "entity_true_positive_count": 10, + "entity_false_positive_count": 4, + "entity_false_negative_count": 10, + "entity_relaxed_gt_found_count": 15, + "entity_relaxed_detected_tp_count": 14, + "entity_relaxed_label_compatible_gt_found_count": 13, + "entity_relaxed_label_compatible_detected_tp_count": 12, + "replacement_count": 12, + "replacement_missing_final_entity_count": 2, + "replacement_missing_final_entity_label_counts": {"date": 2}, + "replacement_missing_final_value_count": 1, + "replacement_synthetic_original_collision_count": 1, + "replacement_synthetic_original_collision_label_counts": {"date": 1}, + "replacement_synthetic_original_collision_value_count": 1, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r000", + "text_length_tokens": 300, + "final_entity_count": 0, + "ground_truth_entity_count": 2, + "entity_true_positive_count": 0, + "entity_false_positive_count": 0, + "entity_false_negative_count": 2, + "entity_relaxed_gt_found_count": 0, + "entity_relaxed_detected_tp_count": 0, + "entity_relaxed_label_compatible_gt_found_count": 0, + "entity_relaxed_label_compatible_detected_tp_count": 0, + "replacement_count": 0, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_entity_label_counts": {}, + "replacement_missing_final_value_count": 0, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_label_counts": {}, + "replacement_synthetic_original_collision_value_count": 0, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "shell__native-local__r000", + "text_length_tokens": 750, + "final_entity_count": 8, + "replacement_count": 8, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_entity_label_counts": {}, + "replacement_missing_final_value_count": 0, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_label_counts": {}, + "replacement_synthetic_original_collision_value_count": 0, + "original_value_leak_count": 1, + "original_value_leak_label_counts": {"api_key": 1}, + "run_tags": { + "suite_id": "suite", + "workload_id": "shell", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", + "experimental_replacement_strategy": "local_structured_substitute", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "shell__native-local__r000", + }, + }, + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "repetition": 0, + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "seed_entity_count": 13, + "seed_validation_candidate_count": 13, + "augmented_entity_count": 1, + "augmented_new_final_value_count": 1, + "final_entity_count": 14, + "final_source_counts": {"detector": 11, "augmenter": 3}, + "final_entity_signature_hashes": ["bio-hash-a", "bio-hash-b"], + "final_entity_signature_labels": {"bio-hash-a": "person", "bio-hash-b": "city"}, + "final_entity_signature_details": { + "bio-hash-a": { + "label": "person", + "source": "detector", + "row_index": 0, + "start_position": 0, + "end_position": 5, + "value_hash": "hash-person", + "value_length": 5, + } + }, + "final_entity_signature_count": 2, + }, + { + "suite_id": "suite", + "workload_id": "shell", + "config_id": "native-local", + "repetition": 0, + "case_id": "shell__native-local__r000", + "run_id": "shell__native-local__r000", + "workflow_name": "native-single-pass", + "seed_entity_count": 8, + "seed_validation_candidate_count": 0, + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "final_entity_count": 8, + "final_source_counts": {"augmenter": 8}, + "final_entity_signature_hashes": ["shell-hash-a"], + "final_entity_signature_labels": {"shell-hash-a": "api_key"}, + "final_entity_signature_details": { + "shell-hash-a": { + "label": "api_key", + "source": "native", + "row_index": 0, + "start_position": 12, + "end_position": 32, + "value_hash": "hash-secret", + "value_length": 20, + } + }, + "final_entity_signature_count": 1, + }, + ], + ) + traces_dir = benchmark_dir / "traces" + traces_dir.mkdir() + _write_jsonl( + traces_dir / "bio__default__r000.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "model_alias": "local-nemotron-json", + "status": "error", + "error_type": "SyncClientUnavailableError", + "is_async": False, + "messages": [{"role": "user", "content": "Alice has sk-test"}], + "response": "Alice still has sk-test", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "dd_message_trace", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "model_alias": "local-nemotron-json", + "status": "success", + "is_async": True, + "messages": [{"role": "user", "content": "sk-test"}], + "response": "Alice", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.case_count == 2 + assert result.group_count == 2 + assert result.model_usage_count == 2 + assert result.model_usage_group_count == 2 + cases = {row.case_id: row for row in result.cases} + assert cases["bio__default__r000"].workload_category == "synthetic_biography" + assert cases["bio__default__r000"].entity_label_set_id == "agent" + assert cases["bio__default__r000"].entity_label_count == 4 + assert cases["bio__default__r000"].gliner_threshold == pytest.approx(0.3) + assert cases["bio__default__r000"].experimental_replacement_strategy == "default" + assert cases["bio__default__r000"].observed_total_requests == 4 + assert cases["bio__default__r000"].observed_successful_requests == 3 + assert cases["bio__default__r000"].observed_input_tokens == 5000 + assert cases["bio__default__r000"].observed_output_tokens == 1000 + assert cases["bio__default__r000"].observed_total_tokens == 6000 + assert cases["bio__default__r000"].observed_failed_request_rate == pytest.approx(1 / 4) + assert cases["bio__default__r000"].dd_trace_record_count == 2 + assert cases["bio__default__r000"].dd_trace_error_count == 1 + assert cases["bio__default__r000"].dd_trace_sync_client_unavailable_count == 1 + assert cases["bio__default__r000"].observed_bridge_fallback_requests == 1 + assert cases["bio__default__r000"].observed_non_bridge_total_requests == 3 + assert cases["bio__default__r000"].observed_non_bridge_failed_requests == 0 + assert cases["bio__default__r000"].observed_non_bridge_failed_request_rate == 0 + assert cases["bio__default__r000"].record_count == 2 + assert cases["bio__default__r000"].input_text_tokens_total == 1500 + assert cases["bio__default__r000"].records_per_pipeline_sec == pytest.approx(0.2) + assert cases["bio__default__r000"].records_per_ndd_sec == pytest.approx(2 / 8.5) + assert cases["bio__default__r000"].input_text_tokens_per_pipeline_sec == 150 + assert cases["bio__default__r000"].input_text_tokens_per_ndd_sec == pytest.approx(1500 / 8.5) + assert cases["bio__default__r000"].topology_endpoint_count == 2 + assert cases["bio__default__r000"].topology_gpu_count == 4 + assert cases["bio__default__r000"].topology_tensor_parallelism == 2 + assert cases["bio__default__r000"].input_text_tokens_per_endpoint_sec == 75 + assert cases["bio__default__r000"].input_text_tokens_per_gpu_sec == 37.5 + assert cases["bio__default__r000"].empty_detection_count == 1 + assert cases["bio__default__r000"].empty_detection_rate == 0.5 + assert cases["bio__default__r000"].empty_detection_with_ground_truth_count == 1 + assert cases["bio__default__r000"].empty_detection_with_ground_truth_rate == 0.5 + assert cases["bio__default__r000"].ground_truth_record_count == 2 + assert cases["bio__default__r000"].ground_truth_entity_count == 22 + assert cases["bio__default__r000"].entity_true_positive_count == 10 + assert cases["bio__default__r000"].entity_false_positive_count == 4 + assert cases["bio__default__r000"].entity_false_negative_count == 12 + assert cases["bio__default__r000"].entity_precision == pytest.approx(10 / 14) + assert cases["bio__default__r000"].entity_recall == pytest.approx(10 / 22) + assert cases["bio__default__r000"].entity_relaxed_gt_found_count == 15 + assert cases["bio__default__r000"].entity_relaxed_detected_tp_count == 14 + assert cases["bio__default__r000"].entity_relaxed_label_compatible_gt_found_count == 13 + assert cases["bio__default__r000"].entity_relaxed_label_compatible_detected_tp_count == 12 + assert cases["bio__default__r000"].entity_relaxed_precision == 1.0 + assert cases["bio__default__r000"].entity_relaxed_recall == pytest.approx(15 / 22) + assert cases["bio__default__r000"].entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) + assert cases["bio__default__r000"].entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) + assert cases["bio__default__r000"].validation_max_entities_per_call == 10 + assert cases["bio__default__r000"].original_value_leak_count == 0 + assert cases["bio__default__r000"].original_value_leak_record_count == 0 + assert cases["bio__default__r000"].original_value_leak_label_counts == {} + assert cases["bio__default__r000"].replacement_missing_final_entity_count == 2 + assert cases["bio__default__r000"].replacement_missing_final_entity_label_counts == {"date": 2} + assert cases["bio__default__r000"].replacement_missing_final_value_count == 1 + assert cases["bio__default__r000"].replacement_synthetic_original_collision_count == 1 + assert cases["bio__default__r000"].replacement_synthetic_original_collision_label_counts == {"date": 1} + assert cases["bio__default__r000"].replacement_synthetic_original_collision_value_count == 1 + assert cases["bio__default__r000"].seed_validation_candidate_count == 13 + assert cases["bio__default__r000"].estimated_seed_validation_chunk_count == 2 + assert cases["bio__default__r000"].augmented_new_final_value_count == 1 + assert cases["bio__default__r000"].artifact_final_detector_entity_count == 11 + assert cases["bio__default__r000"].artifact_final_augmenter_entity_count == 3 + assert cases["bio__default__r000"].artifact_final_entity_signature_hashes == ["bio-hash-a", "bio-hash-b"] + assert cases["bio__default__r000"].artifact_final_entity_signature_labels == { + "bio-hash-a": "person", + "bio-hash-b": "city", + } + assert cases["bio__default__r000"].artifact_final_entity_signature_details == { + "bio-hash-a": { + "label": "person", + "source": "detector", + "row_index": 0, + "start_position": 0, + "end_position": 5, + "value_length": 5, + } + } + assert cases["bio__default__r000"].artifact_final_entity_signature_count == 2 + assert cases["shell__native-local__r000"].observed_total_requests == 0 + assert cases["shell__native-local__r000"].experimental_replacement_strategy == "local_structured_substitute" + assert cases["shell__native-local__r000"].observed_failed_request_rate is None + assert cases["shell__native-local__r000"].observed_bridge_fallback_requests is None + assert cases["shell__native-local__r000"].observed_non_bridge_failed_requests is None + assert cases["shell__native-local__r000"].final_entity_count == 8 + assert cases["shell__native-local__r000"].replacement_missing_final_entity_count == 0 + assert cases["shell__native-local__r000"].replacement_missing_final_entity_label_counts == {} + assert cases["shell__native-local__r000"].replacement_missing_final_value_count == 0 + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_count == 0 + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_label_counts == {} + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_value_count == 0 + assert cases["shell__native-local__r000"].original_value_leak_count == 1 + assert cases["shell__native-local__r000"].original_value_leak_record_count == 1 + assert cases["shell__native-local__r000"].original_value_leak_label_counts == {"api_key": 1} + assert cases["shell__native-local__r000"].artifact_final_augmenter_entity_count == 8 + assert cases["shell__native-local__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] + assert cases["shell__native-local__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} + assert ( + cases["shell__native-local__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" + ) + model_rows = {row.model_name: row for row in result.model_usage} + assert model_rows["nvidia/gliner-pii"].observed_failed_requests == 0 + assert model_rows["nvidia/gliner-pii"].observed_failed_request_rate == 0 + assert model_rows["nvidia/gliner-pii"].observed_total_tokens == 1100 + assert model_rows["nvidia/nemotron-3-super"].model_alias == "local-nemotron-json" + assert model_rows["nvidia/nemotron-3-super"].experimental_replacement_strategy == "default" + assert model_rows["nvidia/nemotron-3-super"].model_provider_name == "local-vllm" + assert model_rows["nvidia/nemotron-3-super"].observed_failed_requests == 1 + assert model_rows["nvidia/nemotron-3-super"].observed_failed_request_rate == pytest.approx(1 / 3) + assert model_rows["nvidia/nemotron-3-super"].observed_total_tokens == 4900 + model_groups = {(row.model_alias, row.model_name): row for row in result.model_usage_groups} + nemotron_group = model_groups[("local-nemotron-json", "nvidia/nemotron-3-super")] + assert nemotron_group.model_provider_name == "local-vllm" + assert nemotron_group.sum_observed_failed_requests == 1 + assert nemotron_group.observed_failed_request_rate == pytest.approx(1 / 3) + assert nemotron_group.median_observed_total_requests == 3 + bio_group = next(group for group in result.groups if group.workload_id == "bio") + assert bio_group.workload_category == "synthetic_biography" + assert bio_group.entity_label_set_id == "agent" + assert bio_group.entity_label_count == 4 + assert bio_group.gliner_threshold == pytest.approx(0.3) + assert bio_group.experimental_replacement_strategy == "default" + assert bio_group.median_observed_bridge_fallback_requests == 1 + assert bio_group.median_observed_non_bridge_total_requests == 3 + assert bio_group.median_observed_non_bridge_failed_requests == 0 + assert bio_group.median_observed_non_bridge_failed_request_rate == 0 + assert bio_group.median_replacement_missing_final_entity_count == 2 + assert bio_group.median_replacement_missing_final_value_count == 1 + assert bio_group.replacement_missing_final_entity_label_counts == {"date": 2} + assert bio_group.median_replacement_synthetic_original_collision_count == 1 + assert bio_group.median_replacement_synthetic_original_collision_value_count == 1 + assert bio_group.replacement_synthetic_original_collision_label_counts == {"date": 1} + assert bio_group.total_record_count == 2 + assert bio_group.total_input_text_tokens == 1500 + assert bio_group.median_input_text_tokens_per_pipeline_sec == 150 + assert bio_group.median_input_text_tokens_per_endpoint_sec == 75 + assert bio_group.median_input_text_tokens_per_gpu_sec == 37.5 + assert bio_group.total_empty_detection_count == 1 + assert bio_group.empty_detection_rate == 0.5 + assert bio_group.total_empty_detection_with_ground_truth_count == 1 + assert bio_group.empty_detection_with_ground_truth_rate == 0.5 + assert bio_group.total_ground_truth_record_count == 2 + assert bio_group.sum_ground_truth_entity_count == 22 + assert bio_group.sum_entity_true_positive_count == 10 + assert bio_group.sum_entity_false_positive_count == 4 + assert bio_group.sum_entity_false_negative_count == 12 + assert bio_group.micro_entity_precision == pytest.approx(10 / 14) + assert bio_group.micro_entity_recall == pytest.approx(10 / 22) + assert bio_group.sum_entity_relaxed_gt_found_count == 15 + assert bio_group.sum_entity_relaxed_detected_tp_count == 14 + assert bio_group.sum_entity_relaxed_label_compatible_gt_found_count == 13 + assert bio_group.sum_entity_relaxed_label_compatible_detected_tp_count == 12 + assert bio_group.micro_entity_relaxed_precision == 1.0 + assert bio_group.micro_entity_relaxed_recall == pytest.approx(15 / 22) + assert bio_group.micro_entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) + assert bio_group.micro_entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) + shell_group = next(group for group in result.groups if group.workload_id == "shell") + assert shell_group.experimental_replacement_strategy == "local_structured_substitute" + assert shell_group.sum_original_value_leak_count == 1 + assert shell_group.leaking_case_count == 1 + assert shell_group.median_original_value_leak_count == 1 + + serialized = result.model_dump_json() + assert "Alice" not in serialized + assert "sk-test" not in serialized + + +def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_model_workflow", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "model_workflow", + "run_id": "bio__native__r000", + "workflow_name": "entity-detection-native-single-pass", + "elapsed_sec": 0.25, + "observed_total_requests": 3, + "observed_successful_requests": 3, + "observed_failed_requests": 0, + "observed_input_tokens": 30, + "observed_output_tokens": 12, + "observed_total_tokens": 42, + "model_usage": { + "native-direct": { + "model_alias": "native-direct", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 3, + "failed_requests": 0, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + }, + } + }, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "native", + "experimental_detection_strategy": "native_single_pass", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__native__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__native__r000", + "final_entity_count": 2, + "replacement_count": 2, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "native", + "experimental_detection_strategy": "native_single_pass", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__native__r000", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.case_count == 1 + case = result.cases[0] + assert case.observed_total_requests == 3 + assert case.observed_total_tokens == 42 + assert case.observed_failed_request_rate == 0 + assert result.model_usage_count == 1 + model_row = result.model_usage[0] + assert model_row.workflow_name == "entity-detection-native-single-pass" + assert model_row.model_alias == "native-direct" + assert model_row.model_name == "nvidia/nemotron-3-super" + assert model_row.observed_total_tokens == 42 + assert result.groups[0].median_observed_total_requests == 3 + assert result.model_usage_groups[0].sum_observed_total_tokens == 42 + + +def test_analyze_benchmark_output_accepts_detection_artifact_override(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_override", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r000", + }, + } + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "final_entity_signature_hashes": ["stale-hash"], + "final_entity_signature_count": 1, + } + ], + ) + refreshed_artifacts = tmp_path / "refreshed-detection-artifacts.jsonl" + _write_jsonl( + refreshed_artifacts, + [ + { + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "final_entity_signature_hashes": ["fresh-hash-a", "fresh-hash-b"], + "final_entity_signature_labels": {"fresh-hash-a": "person", "fresh-hash-b": "email"}, + "final_entity_signature_count": 2, + } + ], + ) + + default_result = tool.analyze_benchmark_output(benchmark_dir) + override_result = tool.analyze_benchmark_output(benchmark_dir, detection_artifacts=refreshed_artifacts) + + assert default_result.cases[0].artifact_final_entity_signature_hashes == ["stale-hash"] + assert override_result.detection_artifacts_path == str(refreshed_artifacts) + assert override_result.cases[0].artifact_final_entity_signature_hashes == ["fresh-hash-a", "fresh-hash-b"] + assert override_result.cases[0].artifact_final_entity_signature_labels == { + "fresh-hash-a": "person", + "fresh-hash-b": "email", + } + + +def test_analyze_benchmark_output_requires_detection_artifact_override_path(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_override_missing", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "run_tags": {"case_id": "bio__default__r000"}, + } + ], + ) + + with pytest.raises(ValueError, match="input path does not exist"): + tool.analyze_benchmark_output(benchmark_dir, detection_artifacts=tmp_path / "missing.jsonl") + + +def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_export", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + result = tool.BenchmarkOutputAnalysis( + benchmark_dir=str(tmp_path / "benchmark"), + cases=[ + tool.CaseAnalysisRow( + suite_id="suite", + workload_id="shell", + config_id="native", + experimental_detection_strategy="native_single_pass", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + repetition=0, + case_id="shell__native__r000", + run_id="shell__native__r000", + final_entity_count=8, + ) + ], + groups=[ + tool.GroupAnalysisRow( + workload_id="shell", + config_id="native", + experimental_detection_strategy="native_single_pass", + experimental_replacement_strategy="local_structured_substitute", + case_count=1, + median_final_entity_count=8, + median_observed_successful_requests=0, + median_observed_input_tokens=0, + median_observed_output_tokens=0, + median_observed_failed_request_rate=0, + median_artifact_final_entity_count=8, + ) + ], + model_usage=[ + tool.ModelUsageAnalysisRow( + workload_id="shell", + config_id="native", + experimental_detection_strategy="native_single_pass", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + case_id="shell__native__r000", + run_id="shell__native__r000", + workflow_name="entity-detection", + model_name="nvidia/gliner-pii", + observed_total_requests=1, + observed_successful_requests=1, + observed_total_tokens=1200, + ) + ], + model_usage_groups=[ + tool.ModelUsageGroupAnalysisRow( + workload_id="shell", + config_id="native", + experimental_detection_strategy="native_single_pass", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + workflow_name="entity-detection", + model_name="nvidia/gliner-pii", + case_count=1, + workflow_count=1, + sum_observed_total_requests=1, + sum_observed_successful_requests=1, + sum_observed_total_tokens=1200, + median_observed_total_requests=1, + median_observed_total_tokens=1200, + ) + ], + ) + + output_dir = tmp_path / "tables" + tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) + + assert pd.read_csv(output_dir / "case_analysis.csv")["case_id"].tolist() == ["shell__native__r000"] + assert pd.read_csv(output_dir / "case_analysis.csv")["experimental_replacement_strategy"].tolist() == [ + "local_structured_substitute" + ] + assert pd.read_csv(output_dir / "group_analysis.csv")["case_count"].tolist() == [1] + assert pd.read_csv(output_dir / "model_analysis.csv")["model_name"].tolist() == ["nvidia/gliner-pii"] + assert pd.read_csv(output_dir / "model_group_analysis.csv")["workflow_count"].tolist() == [1] + assert (output_dir / "manifest.json").exists() + + +def test_analyze_benchmark_output_preserves_zero_entity_cases(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_zero", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "empty__redact__r000", + "final_entity_count": 0, + "replacement_count": 0, + "run_tags": { + "workload_id": "empty", + "config_id": "redact", + "experimental_detection_strategy": "default", + "case_id": "empty__redact__r000", + }, + } + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.cases[0].final_entity_count == 0 + assert result.groups[0].median_final_entity_count == 0 + + +def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_replacement_strategy_groups", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "secrets__candidate__r000", + "final_entity_count": 4, + "run_tags": { + "workload_id": "secrets", + "config_id": "candidate", + "experimental_detection_strategy": "native_single_pass", + "experimental_replacement_strategy": "default", + "case_id": "secrets__candidate__r000", + }, + }, + { + "record_type": "record", + "run_id": "secrets__candidate__r001", + "final_entity_count": 4, + "run_tags": { + "workload_id": "secrets", + "config_id": "candidate", + "experimental_detection_strategy": "native_single_pass", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "secrets__candidate__r001", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.group_count == 2 + assert {group.experimental_replacement_strategy for group in result.groups} == { + "default", + "local_structured_substitute", + } + + +def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_failures", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "stage", + "run_id": "shell__candidate__r000", + "stage": "Anonymizer._run_internal", + "status": "completed", + "elapsed_sec": 1.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "repetition": 0, + "case_id": "shell__candidate__r000", + }, + }, + { + "record_type": "ndd_workflow", + "run_id": "shell__candidate__r001", + "workflow_name": "entity-detection", + "status": "error", + "elapsed_sec": 0.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "repetition": 1, + "case_id": "shell__candidate__r001", + }, + }, + { + "record_type": "stage", + "run_id": "shell__candidate__r001", + "stage": "Anonymizer._run_internal", + "status": "error", + "elapsed_sec": 0.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "repetition": 1, + "case_id": "shell__candidate__r001", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + cases = {row.case_id: row for row in result.cases} + assert cases["shell__candidate__r000"].case_failed is False + assert cases["shell__candidate__r000"].error_stage_count == 0 + assert cases["shell__candidate__r000"].error_ndd_workflow_count == 0 + assert cases["shell__candidate__r001"].case_failed is True + assert cases["shell__candidate__r001"].error_stage_count == 1 + assert cases["shell__candidate__r001"].error_ndd_workflow_count == 1 + assert result.groups[0].failed_case_count == 1 + assert result.groups[0].failed_case_rate == pytest.approx(0.5) + assert result.groups[0].error_stage_count == 1 + assert result.groups[0].error_ndd_workflow_count == 1 + assert "failed_cases=1/2" in tool.render_result(result, json_output=False) + + +def test_analyze_benchmark_output_groups_artifact_contribution_metrics(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_group", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "final_entity_count": 10, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r001", + "final_entity_count": 14, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r001", + }, + }, + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "workload_id": "bio", + "config_id": "default", + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "seed_entity_count": 9, + "seed_validation_candidate_count": 9, + "augmented_entity_count": 4, + "augmented_new_final_value_count": 1, + "final_entity_count": 11, + "final_source_counts": {"detector": 10, "augmenter": 1}, + "final_entity_signature_hashes": ["a", "b"], + "final_entity_signature_count": 2, + }, + { + "workload_id": "bio", + "config_id": "default", + "case_id": "bio__default__r001", + "run_id": "bio__default__r001", + "seed_entity_count": 13, + "seed_validation_candidate_count": 13, + "augmented_entity_count": 8, + "augmented_new_final_value_count": 3, + "final_entity_count": 15, + "final_source_counts": {"detector": 12, "augmenter": 3}, + "final_entity_signature_hashes": ["a", "b", "c", "d"], + "final_entity_signature_count": 4, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + group = result.groups[0] + assert group.median_final_entity_count == 12 + assert group.median_observed_successful_requests == 0 + assert group.median_observed_input_tokens == 0 + assert group.median_observed_output_tokens == 0 + assert group.median_observed_failed_request_rate is None + assert group.median_seed_entity_count == 11 + assert group.median_seed_validation_candidate_count == 11 + assert group.median_augmented_entity_count == 6 + assert group.median_augmented_new_final_value_count == 2 + assert group.median_artifact_final_entity_count == 13 + assert group.median_artifact_final_detector_entity_count == 11 + assert group.median_artifact_final_augmenter_entity_count == 2 + assert group.median_artifact_final_entity_signature_count == 3 diff --git a/tests/tools/test_compare_strategy_pairs.py b/tests/tools/test_compare_strategy_pairs.py new file mode 100644 index 00000000..ad0ab436 --- /dev/null +++ b/tests/tools/test_compare_strategy_pairs.py @@ -0,0 +1,1473 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> None: + tool = load_tool("measurement_compare_strategy_pairs", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py") + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "shell-no-augment", + "experimental_detection_strategy": "no_augment", + "experimental_replacement_strategy": "default", + "case_id": "shell__base", + "pipeline_elapsed_sec": 4.5, + "observed_total_requests": 3, + "observed_total_tokens": 2875, + "observed_failed_requests": 1, + "final_entity_count": 4, + "seed_validation_candidate_count": 5, + "augmented_entity_count": 2, + "augmented_new_final_value_count": 1, + "artifact_final_detector_entity_count": 4, + }, + { + "workload_id": "shell-1", + "config_id": "shell-native", + "experimental_detection_strategy": "native_candidate_validate_no_augment", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 0.8, + "observed_total_requests": 1, + "observed_total_tokens": 101, + "observed_failed_requests": 0, + "final_entity_count": 8, + "seed_validation_candidate_count": 0, + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "artifact_final_augmenter_entity_count": 8, + }, + { + "workload_id": "legal-1", + "config_id": "legal-no-augment", + "experimental_detection_strategy": "no_augment", + "experimental_replacement_strategy": "default", + "case_id": "legal__base", + "pipeline_elapsed_sec": 21, + "observed_total_requests": 3, + "observed_total_tokens": 8847, + "observed_failed_requests": 1, + "final_entity_count": 26, + "seed_validation_candidate_count": 40, + "artifact_final_detector_entity_count": 26, + }, + { + "workload_id": "legal-1", + "config_id": "legal-native", + "experimental_detection_strategy": "native_candidate_validate_no_augment", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "legal__candidate", + "pipeline_elapsed_sec": 20.9, + "observed_total_requests": 3, + "observed_total_tokens": 8847, + "observed_failed_requests": 1, + "final_entity_count": 26, + "seed_validation_candidate_count": 40, + "artifact_final_detector_entity_count": 26, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_strategy="no_augment", + candidate_strategy="native_candidate_validate_no_augment", + ) + + by_workload = {row.workload_id: row for row in rows} + shell = by_workload["shell-1"] + assert shell.baseline_replacement_strategy == "default" + assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.final_entity_count_delta == 4 + assert shell.observed_total_requests_delta == -2 + assert shell.observed_total_tokens_delta == -2774 + assert shell.seed_validation_candidate_count_delta == -5 + assert shell.baseline_augmented_entity_count == 2 + assert shell.candidate_augmented_entity_count == 0 + assert shell.augmented_entity_count_delta == -2 + assert shell.baseline_augmented_new_final_value_count == 1 + assert shell.candidate_augmented_new_final_value_count == 0 + assert shell.augmented_new_final_value_count_delta == -1 + assert shell.candidate_augmenter_entity_count == 8 + assert shell.safety_verdict == "review" + assert shell.performance_verdict == "improved" + assert shell.candidate_verdict == "review" + assert shell.flags == ["no_candidate_detector_entities"] + + legal = by_workload["legal-1"] + assert legal.baseline_replacement_strategy == "default" + assert legal.candidate_replacement_strategy == "local_structured_substitute" + assert legal.final_entity_count_delta == 0 + assert legal.observed_total_tokens_delta == 0 + assert legal.candidate_detector_entity_count == 26 + assert legal.safety_verdict == "pass" + assert legal.performance_verdict == "improved" + assert legal.candidate_verdict == "candidate_viable" + assert legal.flags == [] + + +def test_compare_case_analysis_rejects_ambiguous_strategy_selector() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_ambiguous", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "base-a", + "experimental_detection_strategy": "no_augment", + "case_id": "a", + }, + { + "workload_id": "shell-1", + "config_id": "base-b", + "experimental_detection_strategy": "no_augment", + "case_id": "b", + }, + { + "workload_id": "shell-1", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "case_id": "c", + }, + ] + ) + + with pytest.raises(ValueError, match="baseline selector matched multiple configs"): + tool.compare_case_analysis(table, baseline_strategy="no_augment", candidate_strategy="detector_only") + + +def test_compare_case_analysis_rejects_candidate_synthetic_original_collisions() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_replacement_collisions", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-1", + "config_id": "baseline", + "experimental_detection_strategy": "default", + "case_id": "base", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 4, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_value_count": 0, + }, + { + "workload_id": "legal-1", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "cand", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 2, + "observed_total_tokens": 1000, + "final_entity_count": 4, + "replacement_synthetic_original_collision_count": 1, + "replacement_synthetic_original_collision_value_count": 1, + "replacement_synthetic_original_collision_label_counts.date": 1, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_replacement_synthetic_original_collision_count == 1 + assert row.replacement_synthetic_original_collision_count_delta == 1 + assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + assert "candidate_replacement_synthetic_original_collision" in row.flags + assert row.value_protection_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_compare_case_analysis_rejects_candidate_missing_replacement_map_entries() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_missing_replacement_map_entries", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "baseline", + "experimental_detection_strategy": "default", + "case_id": "base", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 4, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_value_count": 0, + }, + { + "workload_id": "structured-identifiers", + "config_id": "candidate", + "experimental_detection_strategy": "default", + "case_id": "cand", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 2, + "observed_total_tokens": 1000, + "final_entity_count": 4, + "replacement_missing_final_entity_count": 1, + "replacement_missing_final_value_count": 1, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_replacement_missing_final_entity_count == 1 + assert row.replacement_missing_final_entity_count_delta == 1 + assert "candidate_replacement_missing_final_entity" in row.flags + assert row.value_protection_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_compare_case_tables_allows_candidate_from_separate_run() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_cross_run", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + baseline = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "legal-no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "legal__base", + "observed_total_tokens": 55790, + "final_entity_count": 193, + "artifact_final_detector_entity_count": 172, + } + ] + ) + candidate = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "legal-native-validate", + "experimental_detection_strategy": "detector_native_validate_no_augment", + "case_id": "legal__candidate", + "observed_total_tokens": 55805, + "final_entity_count": 193, + "artifact_final_detector_entity_count": 172, + } + ] + ) + + rows = tool.compare_case_tables( + baseline, + candidate, + baseline_strategy="no_augment", + candidate_strategy="detector_native_validate_no_augment", + ) + + assert len(rows) == 1 + assert rows[0].workload_id == "legal-5" + assert rows[0].baseline_config_id == "legal-no-augment" + assert rows[0].candidate_config_id == "legal-native-validate" + assert rows[0].observed_total_tokens_delta == 15 + assert rows[0].flags == ["token_increase"] + + +def test_compare_case_analysis_preserves_augmentation_contribution_deltas() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_augmentation", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-2", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "augmented_entity_count": 8, + "augmented_new_final_value_count": 3, + "artifact_final_augmenter_entity_count": 3, + "artifact_final_entity_signature_hashes": ["a", "b", "c"], + }, + { + "workload_id": "legal-2", + "config_id": "no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "no-augment-r0", + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "artifact_final_augmenter_entity_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="no-augment") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_augmented_entity_count == 8 + assert row.candidate_augmented_entity_count == 0 + assert row.augmented_entity_count_delta == -8 + assert row.baseline_augmented_new_final_value_count == 3 + assert row.candidate_augmented_new_final_value_count == 0 + assert row.augmented_new_final_value_count_delta == -3 + assert row.baseline_augmenter_entity_count == 3 + assert row.candidate_augmenter_entity_count == 0 + + +def test_compare_case_analysis_review_gates_detector_only_candidate_shell_case() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_detector_only", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-1", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "bio-1", + "config_id": "detector-only", + "experimental_detection_strategy": "detector_only", + "case_id": "bio__detector", + "pipeline_elapsed_sec": 5, + "observed_total_requests": 1, + "observed_total_tokens": 1000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="detector_only") + + assert len(rows) == 1 + row = rows[0] + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["candidate_skips_llm_validation"] + + +def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_detector_only", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "shell__default", + "pipeline_elapsed_sec": 8, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-1", + "config_id": "detector-only", + "experimental_detection_strategy": "detector_only", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 1, + "observed_total_requests": 1, + "observed_total_tokens": 200, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_strategy="default", + candidate_strategy="detector_only", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["candidate_skips_llm_validation"] + + +def test_compare_case_analysis_review_gates_non_detector_sources_when_signatures_match() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_non_detector_sources", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "native-source-default", + "experimental_detection_strategy": "default", + "case_id": "shell__default", + "pipeline_elapsed_sec": 21.4, + "observed_total_requests": 8, + "observed_total_tokens": 9854, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + { + "workload_id": "shell-1", + "config_id": "native-source-candidate", + "experimental_detection_strategy": "native_single_pass", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 0.001, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 2, + "artifact_final_augmenter_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_config="native-source-default", + candidate_config="native-source-candidate", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.shared_final_entity_signature_label_counts == {"api_key": 1, "password": 1} + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "pass" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["no_candidate_detector_entities"] + + +def test_compare_case_analysis_flags_signature_loss_even_when_counts_match() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_signature_loss", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "final_entity_count": 3, + "artifact_final_entity_signature_hashes": '["a","b","c"]', + "artifact_final_entity_signature_labels": '{"a":"first_name","b":"city","c":"first_name"}', + }, + { + "workload_id": "bio-5", + "config_id": "no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "bio__no_augment", + "final_entity_count": 3, + "artifact_final_entity_signature_hashes": ["a", "b", "d"], + "artifact_final_entity_signature_labels": {"a": "first_name", "b": "city", "d": "last_name"}, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_config="default", + candidate_config="no-augment", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.final_entity_count_delta == 0 + assert row.baseline_only_final_entity_signature_count == 1 + assert row.candidate_only_final_entity_signature_count == 1 + assert row.shared_final_entity_signature_count == 2 + assert row.baseline_only_final_entity_signature_label_counts == {"first_name": 1} + assert row.candidate_only_final_entity_signature_label_counts == {"last_name": 1} + assert row.shared_final_entity_signature_label_counts == {"city": 1, "first_name": 1} + assert row.safety_verdict == "fail" + assert row.performance_verdict == "unknown" + assert row.candidate_verdict == "reject" + assert row.flags == ["entity_signature_loss"] + + +def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_coverage", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["api-token", "pin"], + "artifact_final_entity_signature_labels": {"api-token": "api_key", "pin": "pin"}, + "artifact_final_entity_signature_details": { + "api-token": { + "label": "api_key", + "source": "augmenter", + "row_index": 0, + "start_position": 42, + "end_position": 66, + "value_hash": "token-hash", + "value_length": 24, + }, + "pin": { + "label": "pin", + "source": "detector", + "row_index": 0, + "start_position": 90, + "end_position": 95, + "value_hash": "pin-hash", + "value_length": 5, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 2, + "artifact_final_augmenter_entity_count": 2, + "artifact_final_entity_signature_hashes": ["cookie", "pin"], + "artifact_final_entity_signature_labels": {"cookie": "http_cookie", "pin": "pin"}, + "artifact_final_entity_signature_details": { + "cookie": { + "label": "http_cookie", + "source": "native", + "row_index": 0, + "start_position": 30, + "end_position": 80, + "value_hash": "cookie-hash", + "value_length": 50, + }, + "pin": { + "label": "pin", + "source": "native", + "row_index": 0, + "start_position": 90, + "end_position": 95, + "value_hash": "pin-hash", + "value_length": 5, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 1 + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_covered_signature_label_counts == {"api_key": 1} + assert row.baseline_only_candidate_uncovered_signature_label_counts == {} + assert "entity_signature_loss" not in row.flags + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_review_gates_covered_label_mismatch() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_label_mismatch", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-row", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 1000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["dob"], + "artifact_final_entity_signature_labels": {"dob": "date_of_birth"}, + "artifact_final_entity_signature_details": { + "dob": { + "label": "date_of_birth", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 35, + "value_hash": "date-hash", + "value_length": 15, + }, + }, + }, + { + "workload_id": "legal-row", + "config_id": "native-validation", + "experimental_detection_strategy": "detector_native_validate_no_augment", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "observed_total_requests": 2, + "observed_total_tokens": 300, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["date"], + "artifact_final_entity_signature_labels": {"date": "date"}, + "artifact_final_entity_signature_details": { + "date": { + "label": "date", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 35, + "value_hash": "date-hash", + "value_length": 15, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-validation") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_label_mismatch_signature_count == 1 + assert row.baseline_only_candidate_label_mismatch_signature_label_counts == {"date_of_birth": 1} + assert "entity_signature_loss" not in row.flags + assert "covered_label_mismatch" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_treats_high_overlap_candidate_span_as_covered() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_overlap", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-assignment"], + "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, + "artifact_final_entity_signature_details": { + "token-assignment": { + "label": "http_cookie", + "source": "augmenter", + "row_index": 0, + "start_position": 20, + "end_position": 68, + "value_hash": "token-assignment-hash", + "value_length": 48, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 1, + "artifact_final_augmenter_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-value"], + "artifact_final_entity_signature_labels": {"token-value": "api_key"}, + "artifact_final_entity_signature_details": { + "token-value": { + "label": "api_key", + "source": "native", + "row_index": 0, + "start_position": 26, + "end_position": 69, + "value_hash": "token-value-hash", + "value_length": 43, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 1 + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_overlapping_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_overlapping_signature_label_counts == {"http_cookie": 1} + assert row.baseline_only_candidate_uncovered_signature_label_counts == {} + assert "entity_signature_loss" not in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_treats_small_assignment_prefix_gap_as_boundary_overlap() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_boundary", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["login-assignment"], + "artifact_final_entity_signature_labels": {"login-assignment": "unique_id"}, + "artifact_final_entity_signature_details": { + "login-assignment": { + "label": "unique_id", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 38, + "value_hash": "login-assignment-hash", + "value_length": 18, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 1, + "artifact_final_augmenter_entity_count": 1, + "artifact_final_entity_signature_hashes": ["login-value"], + "artifact_final_entity_signature_labels": {"login-value": "user_name"}, + "artifact_final_entity_signature_details": { + "login-value": { + "label": "user_name", + "source": "native", + "row_index": 0, + "start_position": 26, + "end_position": 38, + "value_hash": "login-value-hash", + "value_length": 12, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_overlapping_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert "entity_signature_loss" not in row.flags + assert "span_boundary_mismatch" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_flags_replacement_only_detection_instability() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_replacement_only_detection_instability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "dd-substitute", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-assignment"], + "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, + "artifact_final_entity_signature_details": { + "token-assignment": { + "label": "http_cookie", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 70, + "value_hash": "token-assignment-hash", + "value_length": 50, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "local-substitute", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 7, + "observed_total_requests": 3, + "observed_total_tokens": 3000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-value"], + "artifact_final_entity_signature_labels": {"token-value": "api_key"}, + "artifact_final_entity_signature_details": { + "token-value": { + "label": "api_key", + "source": "detector", + "row_index": 0, + "start_position": 26, + "end_position": 70, + "value_hash": "token-value-hash", + "value_length": 44, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="dd-substitute", candidate_config="local-substitute") + + assert len(rows) == 1 + row = rows[0] + assert "covered_label_mismatch" in row.flags + assert "replacement_only_detection_instability" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_rejects_candidate_original_value_leaks() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_original_value_leak", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-secrets", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "structured__default", + "pipeline_elapsed_sec": 10, + "observed_total_tokens": 1000, + "final_entity_count": 2, + "original_value_leak_count": 0, + "original_value_leak_record_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + { + "workload_id": "structured-secrets", + "config_id": "candidate", + "experimental_detection_strategy": "native_single_pass", + "case_id": "structured__candidate", + "pipeline_elapsed_sec": 1, + "observed_total_tokens": 0, + "final_entity_count": 2, + "original_value_leak_count": 2, + "original_value_leak_record_count": 1, + "original_value_leak_label_counts.api_key": 1, + "original_value_leak_label_counts.password": 1, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_original_value_leak_count == 2 + assert row.candidate_original_value_leak_record_count == 1 + assert row.original_value_leak_count_delta == 2 + assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + assert row.flags == ["candidate_original_value_leak"] + + +def test_compare_case_analysis_marks_clean_but_expensive_candidate_for_review() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_verdicts", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "legal__default", + "pipeline_elapsed_sec": 20, + "observed_total_tokens": 1000, + "final_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "legal__candidate", + "pipeline_elapsed_sec": 25, + "observed_total_tokens": 1200, + "final_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + assert rows[0].safety_verdict == "pass" + assert rows[0].performance_verdict == "regressed" + assert rows[0].candidate_verdict == "review" + assert rows[0].flags == ["token_increase"] + + +def test_compare_case_analysis_separates_bridge_fallbacks_from_provider_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_trace_adjusted", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_failed_requests": 1, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 3, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 12, + "artifact_final_detector_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "bio-5", + "config_id": "windowed", + "experimental_detection_strategy": "windowed", + "case_id": "bio__windowed", + "pipeline_elapsed_sec": 15, + "observed_total_requests": 6, + "observed_failed_requests": 2, + "observed_bridge_fallback_requests": 2, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 12, + "artifact_final_detector_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="windowed") + + assert len(rows) == 1 + row = rows[0] + assert row.observed_failed_requests_delta == 1 + assert row.observed_bridge_fallback_requests_delta == 1 + assert row.observed_non_bridge_failed_requests_delta == 0 + assert "bridge_fallback_increase" in row.flags + assert "failed_request_increase" not in row.flags + + +def test_compare_case_analysis_flags_non_bridge_provider_failure_increase() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_non_bridge_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "legal__default", + "observed_total_requests": 5, + "observed_failed_requests": 1, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 10, + "artifact_final_detector_entity_count": 10, + "artifact_final_entity_signature_hashes": ["a"], + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "legal__candidate", + "observed_total_requests": 5, + "observed_failed_requests": 2, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 1, + "final_entity_count": 10, + "artifact_final_detector_entity_count": 10, + "artifact_final_entity_signature_hashes": ["a"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.observed_bridge_fallback_requests_delta == 0 + assert row.observed_non_bridge_failed_requests_delta == 1 + assert "bridge_fallback_increase" not in row.flags + assert "failed_request_increase" in row.flags + assert row.safety_verdict == "review" + assert row.performance_verdict == "unchanged" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_rejects_candidate_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_case_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 8, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r1", + "pipeline_elapsed_sec": 8, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "experimental_detection_strategy": "detector_only", + "case_id": "candidate-r1", + "pipeline_elapsed_sec": 0.2, + "case_failed": True, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 0 + assert row.candidate_failed_case_count == 1 + assert row.failed_case_count_delta == 1 + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + assert "candidate_case_failures" in row.flags + + +def test_compare_case_analysis_counts_model_workflow_errors_as_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_model_workflow_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-5", + "config_id": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 8, + "error_stage_count": 0, + "error_ndd_workflow_count": 0, + "error_model_workflow_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "error_stage_count": 0, + "error_ndd_workflow_count": 0, + "error_model_workflow_count": 1, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 0 + assert row.candidate_failed_case_count == 1 + assert row.safety_verdict == "fail" + assert row.candidate_verdict == "reject" + assert "candidate_case_failures" in row.flags + + +def test_compare_case_analysis_review_gates_baseline_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_baseline_case_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "baseline", + "case_id": "baseline-r0", + "pipeline_elapsed_sec": 30, + "case_failed": True, + "artifact_final_entity_signature_hashes": [], + }, + { + "workload_id": "bio-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 20, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 1 + assert row.candidate_failed_case_count == 0 + assert row.failed_case_count_delta == -1 + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["baseline_case_failures"] + + +def test_compare_case_analysis_flags_repeated_signature_instability() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_stability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r1", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r1", + "artifact_final_entity_signature_hashes": ["a"], + "artifact_final_entity_signature_labels": {"a": "person"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.baseline_stable_final_entity_signature_count == 2 + assert row.candidate_stable_final_entity_signature_count == 1 + assert row.stable_final_entity_signature_count_delta == -1 + assert row.baseline_stable_candidate_unstable_final_entity_signature_count == 1 + assert row.baseline_stable_candidate_unstable_final_entity_signature_label_counts == {"date": 1} + assert row.safety_verdict == "fail" + assert row.candidate_verdict == "reject" + assert row.flags == ["stable_entity_signature_loss"] + + +def test_compare_case_analysis_does_not_infer_stability_loss_from_single_candidate_run() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_single_candidate_stability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r1", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.baseline_stable_final_entity_signature_count is None + assert row.candidate_stable_final_entity_signature_count is None + assert row.stable_final_entity_signature_count_delta is None + assert row.baseline_stable_candidate_unstable_final_entity_signature_count is None + assert row.safety_verdict == "pass" + assert "stable_entity_signature_loss" not in row.flags + + +def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_export", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" + ) + rows = [ + tool.ComparisonRow( + workload_id="shell-1", + baseline_config_id="base", + candidate_config_id="candidate", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="review", + signature_parity_verdict="pass", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + baseline_final_entity_count=4, + candidate_final_entity_count=8, + final_entity_count_delta=4, + flags=["candidate_skips_llm_validation"], + ) + ] + + output = tmp_path / "comparison.csv" + tool.write_comparisons(rows, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["workload_id"].tolist() == ["shell-1"] + assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["final_entity_count_delta"].tolist() == [4] + assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] + + +def test_compare_strategy_pairs_summarizes_candidate_verdicts() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_summary", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" + ) + rows = [ + tool.ComparisonRow( + workload_id="legal-1", + baseline_config_id="legal-default", + candidate_config_id="legal-no-augment", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="pass", + signature_parity_verdict="pass", + safety_verdict="pass", + performance_verdict="improved", + candidate_verdict="candidate_viable", + ), + tool.ComparisonRow( + workload_id="bio-1", + baseline_config_id="bio-default", + candidate_config_id="bio-no-augment", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="fail", + signature_parity_verdict="fail", + safety_verdict="fail", + performance_verdict="improved", + candidate_verdict="reject", + baseline_only_final_entity_signature_label_counts={"first_name": 2}, + ), + tool.ComparisonRow( + workload_id="shell-1", + baseline_config_id="shell-default", + candidate_config_id="shell-detector-only", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="review", + signature_parity_verdict="pass", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + ), + ] + + summary = tool.summarize_comparisons(rows) + + assert summary.comparison_count == 3 + assert summary.value_protection_verdict_counts == {"fail": 1, "pass": 1, "review": 1} + assert summary.signature_parity_verdict_counts == {"fail": 1, "pass": 2, "review": 0} + assert summary.safety_verdict_counts == {"fail": 1, "pass": 1, "review": 1} + assert summary.performance_verdict_counts["improved"] == 3 + assert summary.candidate_verdict_counts == {"candidate_viable": 1, "reject": 1, "review": 1} + assert summary.candidate_viable_workloads == ["legal-1"] + assert summary.review_workloads == ["shell-1"] + assert summary.rejected_workloads == ["bio-1"] + + rendered = tool.render_result( + tool.ComparisonResult( + input_path="case_analysis.csv", + baseline_selector="strategy:default", + candidate_selector="strategy:no_augment", + summary=summary, + comparisons=rows, + ), + json_output=False, + ) + assert "Compared 3 workload(s): viable=1, review=1, reject=1" in rendered + assert ( + "- bio-1: verdict=reject (safety=fail, value_protection=fail, signature_parity=fail, performance=improved)" + ) in rendered + assert "elapsed unknown->unknown" in rendered + assert "lost_labels=first_name:2" in rendered diff --git a/tests/tools/test_dd_parser_compat.py b/tests/tools/test_dd_parser_compat.py new file mode 100644 index 00000000..b9b0241c --- /dev/null +++ b/tests/tools/test_dd_parser_compat.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +from data_designer.engine.models.recipes import response_recipes as recipes +from pydantic import BaseModel + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +class TinyPayload(BaseModel): + value: str + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_raw_json_compat_accepts_pydantic_raw_json_and_restores_parser() -> None: + tool = load_tool("measurement_dd_parser_compat_pydantic", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + original = recipes.PydanticResponseRecipe._build_parser_fn + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.PydanticResponseRecipe(TinyPayload) + parsed = recipe._build_parser_fn()('{"value": "ok"}') + + assert parsed.value == "ok" + assert recipes.PydanticResponseRecipe._build_parser_fn is original + + +def test_raw_json_compat_accepts_pydantic_json_after_reasoning_prefix() -> None: + tool = load_tool( + "measurement_dd_parser_compat_pydantic_reasoning", REPO_ROOT / "tools/measurement/dd_parser_compat.py" + ) + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.PydanticResponseRecipe(TinyPayload) + parsed = recipe.parse('reasoning text\n\n\n\n{"value": "ok"}') + + assert parsed.value == "ok" + + +def test_raw_json_compat_accepts_structured_raw_json_and_restores_parser() -> None: + tool = load_tool("measurement_dd_parser_compat_structured", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + original = recipes.StructuredResponseRecipe._build_parser_fn + schema = { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + } + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.StructuredResponseRecipe(schema) + parsed = recipe._build_parser_fn()('{"value": "ok"}') + + assert parsed == {"value": "ok"} + assert recipes.StructuredResponseRecipe._build_parser_fn is original + + +def test_raw_json_compat_uses_outermost_embedded_json_object() -> None: + tool = load_tool("measurement_dd_parser_compat_outer_json", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + } + }, + "required": ["items"], + } + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.StructuredResponseRecipe(schema) + parsed = recipe.parse('text before {"items": [{"value": "first"}, {"value": "last"}]}') + + assert parsed == {"items": [{"value": "first"}, {"value": "last"}]} diff --git a/tests/tools/test_dd_trace_analysis.py b/tests/tools/test_dd_trace_analysis.py new file mode 100644 index 00000000..82bb5457 --- /dev/null +++ b/tests/tools/test_dd_trace_analysis.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_dd_traces_classifies_response_shape_without_raw_content(tmp_path: Path) -> None: + tool = load_tool("measurement_dd_trace_analysis", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") + trace_path = tmp_path / "trace.jsonl" + _write_jsonl( + trace_path, + [ + { + "record_type": "dd_message_trace", + "run_id": "case-1", + "run_tags": { + "suite_id": "suite", + "workload_id": "legal", + "config_id": "default", + "case_id": "case-1", + }, + "workflow_name": "entity-detection", + "model_alias": "validator", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "status": "completed", + "error_type": None, + "elapsed_sec": 3.0, + "messages": [{"role": "user", "content": [{"type": "text", "text": "secret prompt text"}]}], + "response": {"content": '{"decisions": []}', "reasoning_content": None, "tool_calls": []}, + "usage": {"input_tokens": 11, "output_tokens": 7, "total_tokens": 18}, + }, + { + "record_type": "dd_message_trace", + "run_id": "case-1", + "run_tags": { + "suite_id": "suite", + "workload_id": "legal", + "config_id": "default", + "case_id": "case-1", + }, + "workflow_name": "entity-detection", + "model_alias": "augmenter", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "status": "completed", + "error_type": None, + "elapsed_sec": 4.0, + "messages": [{"role": "user", "content": [{"type": "text", "text": "another secret prompt"}]}], + "response": { + "content": 'reasoning\n\n```json\n{"entities": []}\n```', + "reasoning_content": None, + "tool_calls": [], + }, + "usage": {"input_tokens": 13, "output_tokens": 17, "total_tokens": 30}, + }, + { + "record_type": "other", + "response": {"content": "ignored raw text"}, + }, + ], + ) + + result = tool.analyze_trace_path(trace_path) + + assert result.trace_record_count == 2 + rows = {row.model_alias: row for row in result.rows} + assert rows["validator"].response_shape == "raw_json" + assert rows["validator"].response_has_embedded_json is True + assert rows["validator"].prompt_chars == len("secret prompt text") + assert rows["augmenter"].response_shape == "fenced_json" + assert rows["augmenter"].response_has_thinking is True + assert rows["augmenter"].response_chars > 0 + serialized = result.model_dump_json() + assert "secret prompt text" not in serialized + assert "another secret prompt" not in serialized + assert "decisions" not in serialized + assert "entities" not in serialized + + group = next(row for row in result.groups if row.response_shape == "fenced_json") + assert group.model_alias == "augmenter" + assert group.model_name == "nvidia/nemotron-3-super" + assert group.trace_record_count == 1 + assert group.sum_total_tokens == 30 + assert group.error_count == 0 + + +def test_analyze_dd_traces_reads_trace_directory_and_exports_tables(tmp_path: Path) -> None: + tool = load_tool("measurement_dd_trace_analysis_export", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") + trace_dir = tmp_path / "traces" + trace_dir.mkdir() + _write_jsonl( + trace_dir / "a.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "case-a", + "workflow_name": "entity-detection", + "model_name": "nvidia/nemotron-3-super", + "status": "error", + "error_type": "ParserException", + "elapsed_sec": 1.0, + "response": None, + "usage": None, + } + ], + ) + _write_jsonl( + trace_dir / "b.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "case-b", + "workflow_name": "entity-detection", + "model_name": "nvidia/gliner-pii", + "status": "completed", + "elapsed_sec": 0.2, + "response": {"content": "plain text", "reasoning_content": None, "tool_calls": []}, + "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + } + ], + ) + + result = tool.analyze_trace_path(trace_dir) + output_dir = tmp_path / "analysis" + tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) + + assert result.trace_record_count == 2 + rows = {row.run_id: row for row in result.rows} + assert rows["case-a"].status == "error" + assert rows["case-a"].response_shape == "none" + assert rows["case-b"].response_shape == "text" + assert pd.read_csv(output_dir / "trace_analysis.csv")["run_id"].tolist() == ["case-a", "case-b"] + assert pd.read_csv(output_dir / "trace_group_analysis.csv")["trace_record_count"].sum() == 2 + assert (output_dir / "manifest.json").exists() diff --git a/tests/tools/test_detection_artifact_analysis.py b/tests/tools/test_detection_artifact_analysis.py new file mode 100644 index 00000000..f99fffbd --- /dev/null +++ b/tests/tools/test_detection_artifact_analysis.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _entity(value: str, label: str, start: int, end: int, *, source: str = "detector") -> dict[str, object]: + return { + "id": f"{label}_{start}_{end}", + "value": value, + "label": label, + "start_position": start, + "end_position": end, + "score": 1.0, + "source": source, + } + + +def _write_artifact(root: Path, workflow: str, rows: list[dict[str, object]]) -> None: + parquet_dir = root / workflow / "parquet-files" + parquet_dir.mkdir(parents=True) + pd.DataFrame(rows).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + + +def test_detection_artifact_analysis_reports_augmentation_contribution(tmp_path: Path) -> None: + tool = load_tool( + "measurement_detection_artifact_analysis", + REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py", + ) + artifact_root = tmp_path / "artifacts" + _write_artifact( + artifact_root, + "entity-detection", + [ + { + "_seed_entities_json": json.dumps([_entity("Alice", "first_name", 0, 5)]), + "_seed_validation_candidates": json.dumps( + {"candidates": [{"id": "first_name_0_5", "value": "Alice", "label": "first_name"}]} + ), + "_augmented_entities": json.dumps( + { + "entities": [ + {"value": "Alice", "label": "first_name", "reason": "duplicate"}, + {"value": "12 February 1980", "label": "api_key", "reason": "date mislabeled"}, + ] + } + ), + "_detected_entities": json.dumps( + { + "entities": [ + _entity("Alice", "first_name", 0, 5), + _entity("12 February 1980", "api_key", 20, 36, source="augmenter"), + ] + } + ), + "_validation_candidates": json.dumps( + { + "candidates": [ + {"id": "first_name_0_5", "value": "Alice", "label": "first_name"}, + {"id": "api_key_20_36", "value": "12 February 1980", "label": "api_key"}, + ] + } + ), + } + ], + ) + + result = tool.analyze_artifacts(artifact_root) + + assert len(result.rows) == 1 + row = result.rows[0] + assert row.seed_entity_count == 1 + assert row.seed_validation_candidate_count == 1 + assert row.merged_validation_candidate_count == 2 + assert row.augmented_entity_count == 2 + assert row.augmented_duplicate_seed_value_count == 1 + assert row.augmented_new_value_count == 1 + assert row.augmented_new_final_value_count == 1 + assert row.final_entity_count == 2 + assert row.weak_api_key_shape_count == 1 + assert row.weak_api_key_shape_label_counts == {"api_key": 1} + assert row.final_entity_signature_count == 2 + assert len(row.final_entity_signature_hashes) == 2 + assert set(row.final_entity_signature_labels) == set(row.final_entity_signature_hashes) + assert sorted(row.final_entity_signature_labels.values()) == ["api_key", "first_name"] + assert set(row.final_entity_signature_details) == set(row.final_entity_signature_hashes) + first_name_detail = next( + detail for detail in row.final_entity_signature_details.values() if detail["label"] == "first_name" + ) + assert first_name_detail["source"] == "detector" + assert first_name_detail["row_index"] == 0 + assert first_name_detail["start_position"] == 0 + assert first_name_detail["end_position"] == 5 + assert first_name_detail["value_length"] == 5 + assert "value_hash" not in first_name_detail + + serialized = row.model_dump_json() + assert "Alice" not in serialized + assert "12 February" not in serialized + + +def test_detection_artifact_analysis_handles_no_augment_rows(tmp_path: Path) -> None: + tool = load_tool( + "measurement_detection_artifact_analysis_no_augment", + REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py", + ) + artifact_root = tmp_path / "artifacts" + _write_artifact( + artifact_root, + "entity-detection-no-augment", + [ + { + "_seed_entities_json": json.dumps([_entity("Aydin", "city", 12, 17)]), + "_seed_validation_candidates": json.dumps( + {"candidates": [{"id": "city_12_17", "value": "Aydin", "label": "city"}]} + ), + "_augmented_entities": json.dumps({"entities": []}), + "_detected_entities": json.dumps({"entities": [_entity("Aydin", "city", 12, 17)]}), + } + ], + ) + + result = tool.analyze_artifacts(artifact_root) + + assert len(result.rows) == 1 + row = result.rows[0] + assert row.workflow_name == "entity-detection-no-augment" + assert row.seed_entity_count == 1 + assert row.seed_validation_candidate_count == 1 + assert row.merged_validation_candidate_count == 0 + assert row.augmented_entity_count == 0 + assert row.augmented_new_value_count == 0 + assert row.augmented_new_final_value_count == 0 + assert row.final_entity_count == 1 + assert row.final_source_counts == {"detector": 1} + assert row.final_entity_signature_count == 1 + assert row.final_entity_signature_hashes == sorted(row.final_entity_signature_hashes) + assert list(row.final_entity_signature_labels.values()) == ["city"] diff --git a/tests/tools/test_detection_strategies.py b/tests/tools/test_detection_strategies.py new file mode 100644 index 00000000..5ef9affd --- /dev/null +++ b/tests/tools/test_detection_strategies.py @@ -0,0 +1,1093 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +import threading +import time +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import Mock + +import pandas as pd + +from anonymizer.engine.constants import ( + COL_DETECTED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, +) +from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow +from anonymizer.engine.ndd.model_loader import load_default_model_selection +from anonymizer.engine.schemas import EntitiesSchema +from anonymizer.measurement import MeasurementCollector, measurement_session + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_native_candidate_validate_no_augment_strategy_skips_data_designer_and_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_native_candidate_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SequencedClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs = [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + ] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=self.outputs.pop(0), + elapsed_sec=0.1, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + adapter = Mock() + client = SequencedClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "direct_seed"), + ] + assert len(client.prompts) == 2 + assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-candidate-validate-no-augment" + assert record["observed_total_requests"] == 2 + assert record["observed_total_tokens"] == 28 + + +def test_detector_native_validate_no_augment_strategy_reuses_detector_seed_and_direct_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_called_once() + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( + "entity-detection-detector-native-validate-no-augment-seed" + ) + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + assert len(client.prompts) == 1 + assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + + +def test_detector_native_validate_no_augment_ignores_invalid_reclass_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate_invalid_label", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class ValidationClient: + def complete(self, request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content=( + '{"decisions": [' + '{"id": "first_name_0_5", "decision": "reclass", ' + '"proposed_label": "drop", "reason": "invalid label"}' + "]}" + ), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=ValidationClient(), + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + + +def test_detector_native_validate_native_augment_uses_direct_validation_and_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_augment", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class DirectClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + if "Find additional sensitive entities" in request.prompt: + return SimpleNamespace( + content='{"entities": [{"value": "NVIDIA", "label": "organization_name"}]}', + elapsed_sec=0.2, + usage={"prompt_tokens": 50, "completion_tokens": 6, "total_tokens": 56}, + ) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + client = DirectClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_called_once() + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( + "entity-detection-detector-native-validate-native-augment-seed" + ) + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ("organization_name", "NVIDIA", "augmenter"), + ] + assert len(client.prompts) == 2 + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + assert records[0]["workflow_name"] == "entity-detection-detector-native-validate-native-augment" + assert records[0]["observed_total_requests"] == 2 + assert records[0]["observed_total_tokens"] == 104 + + +def test_detector_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate_parallel", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + def seed_row(text: str, value: str) -> dict[str, object]: + row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": f"first_name_0_{len(value)}", + "value": value, + "label": "first_name", + "start_position": 0, + "end_position": len(value), + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(row) + return row + + class ValidationClient: + def __init__(self) -> None: + self.lock = threading.Lock() + self.active_count = 0 + self.max_active_count = 0 + + def complete(self, request): # type: ignore[no-untyped-def] + with self.lock: + self.active_count += 1 + self.max_active_count = max(self.max_active_count, self.active_count) + time.sleep(0.05) + with self.lock: + self.active_count -= 1 + candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" + return SimpleNamespace( + content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + seed_df = pd.DataFrame( + [ + seed_row("Alice works at NVIDIA.", "Alice"), + seed_row("Bob works at NVIDIA.", "Bob"), + ], + index=[10, 4], + ) + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=seed_df, failed_records=[]) + client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + assert client.max_active_count > 1 + assert list(result.dataframe.index) == [10, 4] + entities_by_row = [ + [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] + for raw in result.dataframe[COL_DETECTED_ENTITIES] + ] + assert entities_by_row == [ + [("first_name", "Alice", "detector")], + [("first_name", "Bob", "detector")], + ] + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert [record["observed_total_requests"] for record in records] == [1, 1] + assert [record["observed_total_tokens"] for record in records] == [48, 48] + record = records[0] + assert record["workflow_name"] == "entity-detection-detector-native-validate-no-augment" + assert record["observed_total_requests"] == 1 + assert record["observed_total_tokens"] == 48 + + +def test_gliner_native_validate_no_augment_strategy_bypasses_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + assert request.text == text + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + seed_client = GlinerSeedClient() + validation_client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=validation_client, + gliner_seed_client=seed_client, + native_runtime=tool.NativeDetectionRuntime( + model="test/native", + provider="test-native-provider", + gliner_model="test/gliner", + gliner_provider="test-gliner-provider", + ), + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + assert len(validation_client.prompts) == 1 + assert all("Find additional sensitive entities" not in prompt for prompt in validation_client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-gliner-native-validate-no-augment" + assert record["observed_total_requests"] == 2 + assert record["observed_total_tokens"] == 73 + assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] + assert record["model_usage"]["gliner-direct"]["model_name"] == "test/gliner" + assert record["model_usage"]["gliner-direct"]["model_provider_name"] == "test-gliner-provider" + assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 + assert record["model_usage"]["native-direct"]["model_name"] == "test/native" + assert record["model_usage"]["native-direct"]["model_provider_name"] == "test-native-provider" + assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 48 + + +def test_gliner_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_parallel", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + value = str(request.text).split()[0] + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": value, + "label": "first_name", + "start": 0, + "end": len(value), + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" + return SimpleNamespace( + content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + collector = MeasurementCollector(record_hash_key="test-key") + dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]) + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=ValidationClient(), + gliner_seed_client=GlinerSeedClient(), + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + dataframe, + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + assert list(result.dataframe.index) == [10, 4] + entities_by_row = [ + [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] + for raw in result.dataframe[COL_DETECTED_ENTITIES] + ] + assert entities_by_row == [ + [("first_name", "Alice", "detector")], + [("first_name", "Bob", "detector")], + ] + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert [record["observed_total_requests"] for record in records] == [2, 2] + assert [record["observed_total_tokens"] for record in records] == [73, 73] + + +def test_gliner_native_validate_no_augment_parallel_rows_keep_failed_records() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_parallel_failure", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + if str(request.text).startswith("Broken"): + raise RuntimeError("seed unavailable") + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Broken row"]}, index=[0, 1]) + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=ValidationClient(), + gliner_seed_client=GlinerSeedClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + dataframe, + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name"], + ) + + assert list(result.dataframe.index) == [0] + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("first_name", "Alice")] + assert [(failed.record_id, failed.step) for failed in result.failed_records] == [ + ("1", "entity-detection-gliner-native-validate-no-augment") + ] + assert "seed unavailable" in result.failed_records[0].reason + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + assert records[0]["observed_total_requests"] == 2 + + +def test_gliner_native_validate_native_augment_strategy_bypasses_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_augment", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + assert request.text == text + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class NativeClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs = [ + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=self.outputs.pop(0), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + seed_client = GlinerSeedClient() + native_client = NativeClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + native_client=native_client, + gliner_seed_client=seed_client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ("organization_name", "NVIDIA", "augmenter"), + ] + assert any("Find additional sensitive entities" in prompt for prompt in native_client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-gliner-native-validate-native-augment" + assert record["observed_total_requests"] == 3 + assert record["observed_total_tokens"] == 121 + assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] + assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 + assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 96 + + +def test_native_single_pass_strategy_runs_one_direct_call_without_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + {"value": "Alice", "label": "first_name", "start": 0, "end": 5}, + {"value": "NVIDIA", "label": "organization_name", "start": 15, "end": 21}, + ] + } + ), + elapsed_sec=0.1, + usage={}, + ) + + adapter = Mock() + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "direct_single_pass"), + ("organization_name", "NVIDIA", "direct_single_pass"), + ] + assert len(client.prompts) == 1 + assert '"start"' in client.prompts[0] + assert '"end"' in client.prompts[0] + + +def test_native_single_pass_recall_strategy_uses_label_examples() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_recall", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={}, + ) + + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass_recall, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person", "email"], + ) + + assert len(client.prompts) == 1 + assert "- person" in client.prompts[0] + assert "- email:" in client.prompts[0] + assert "Bias toward high recall" in client.prompts[0] + + +def test_native_single_pass_values_strategy_uses_value_only_prompt() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_values", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person"}]}', + elapsed_sec=0.1, + usage={}, + ) + + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass_values, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice met Alice."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.value, entity.start_position, entity.end_position) for entity in entities] == [ + ("Alice", 0, 5), + ("Alice", 10, 15), + ] + assert len(client.prompts) == 1 + assert '"start"' not in client.prompts[0] + assert '"end"' not in client.prompts[0] + assert '{"entities": [{"value": "exact substring", "label": "one_allowed_label"' in client.prompts[0] + + +def test_native_single_pass_strategy_records_direct_model_usage() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_usage", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "first_name", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 20, "completion_tokens": 7, "total_tokens": 27}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=SinglePassClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-single-pass" + assert record["observed_total_requests"] == 1 + assert record["observed_successful_requests"] == 1 + assert record["observed_failed_requests"] == 0 + assert record["observed_input_tokens"] == 20 + assert record["observed_output_tokens"] == 7 + assert record["observed_total_tokens"] == 27 + + +def test_native_single_pass_strategy_uses_only_native_spans() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_native_spans", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice logged in.\nPassword: SuperSecret123!\n" + + class SinglePassClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={}, + ) + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=SinglePassClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person", "password"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("person", "Alice", "direct_single_pass"), + ] + + +def test_native_single_pass_strategy_records_parser_errors_as_failures() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_parser_error", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class InvalidJsonClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content="not json", + elapsed_sec=0.1, + usage={"prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=InvalidJsonClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + + assert result.dataframe.empty + assert len(result.failed_records) == 1 + assert result.failed_records[0].step == "entity-detection-native-single-pass" + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-single-pass" + assert record["status"] == "error" + assert record["failed_record_count"] == 1 + assert record["observed_successful_requests"] == 1 + assert record["observed_failed_requests"] == 0 + + +def test_detector_only_strategy_finalizes_gliner_spans_without_validation_or_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_only", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + text = "Alice met Alice at the lab." + start = text.index("Alice") + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: + assert [column.name for column in columns] == [ + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_DETECTED_ENTITIES, + ] + assert kwargs["workflow_name"] == "entity-detection-detector-only" + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": start, + "end": start + len("Alice"), + "score": 0.99, + } + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = json.loads(row[COL_SEED_ENTITIES_JSON]) + assert [(entity["label"], entity["value"]) for entity in seed_entities] == [("person", "Alice")] + row = columns[3].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.detector_only): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("person", "Alice"), ("person", "Alice")] + assert "Alice" in result.dataframe[COL_TAGGED_TEXT].iloc[0] + + +def test_compact_validation_strategy_disables_full_text_single_chunk_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_compact_validation", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.compact_validation): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at Acme."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert EntityDetectionWorkflow.detect_and_validate_entities is original + assert calls[0]["validation_single_chunk_full_text"] is False + + +def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: + tool = load_tool( + "measurement_detection_strategies_prose_augment_focus", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + before = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.prose_augment_focus): + inside = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + + after = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + assert "Contextual prose recall focus" not in before + assert "Contextual prose recall focus" in inside + assert "organization and institution names" in inside + assert after == before diff --git a/tests/tools/test_direct_detection_probe.py b/tests/tools/test_direct_detection_probe.py new file mode 100644 index 00000000..c37960d6 --- /dev/null +++ b/tests/tools/test_direct_detection_probe.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +class FakeClient: + def complete(self, request): # type: ignore[no-untyped-def] + module = sys.modules["measurement_direct_detection_probe_case"] + return module.DirectCompletion( + content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + elapsed_sec=1.25, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + +def test_direct_detection_case_uses_direct_client_and_canonical_artifacts() -> None: + tool = load_tool( + "measurement_direct_detection_probe_case", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + result = tool.run_direct_detection_case( + tool.DirectDetectionRequest( + case_id="case-1", + text="Alice met Alice.", + labels=["first_name"], + row_index=0, + prompt_mode=tool.PromptMode.compact, + ), + client=FakeClient(), + ) + + assert result.status == tool.CaseStatus.completed + assert result.elapsed_sec == 1.25 + assert result.usage == {"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14} + assert result.allowed_suggestion_count == 1 + assert result.final_entity_count == 2 + assert result.final_entity_signature_count == 2 + assert result.final_label_counts == {"first_name": 2} + assert result.artifact.final_source_counts == {"direct_llm": 2} + + +def test_finalize_suggestions_filters_labels_and_deduplicates_signature_hashes() -> None: + tool = load_tool( + "measurement_direct_detection_probe_finalize", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + artifact = tool.finalize_direct_suggestions( + text="Alice works at NVIDIA.", + suggestions=[ + {"value": "Alice", "label": "first_name"}, + {"value": "NVIDIA", "label": "organization_name"}, + {"value": "works", "label": "unsupported"}, + ], + labels=["first_name", "organization_name"], + row_index=0, + workflow_name="direct-detection", + ) + + assert artifact.final_entity_count == 2 + assert artifact.final_label_counts == {"first_name": 1, "organization_name": 1} + assert artifact.weak_api_key_shape_count == 0 + assert set(artifact.final_entity_signature_labels.values()) == {"first_name", "organization_name"} + + +def test_signature_comparison_counts_shared_baseline_only_and_direct_only_labels() -> None: + tool = load_tool( + "measurement_direct_detection_probe_comparison", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + comparison = tool.compare_signature_sets( + baseline_hashes={"shared", "baseline-only"}, + baseline_labels={"shared": "person", "baseline-only": "date"}, + direct_hashes={"shared", "direct-only"}, + direct_labels={"shared": "person", "direct-only": "city"}, + ) + + assert comparison.shared_final_entity_signature_count == 1 + assert comparison.baseline_only_final_entity_signature_count == 1 + assert comparison.direct_only_final_entity_signature_count == 1 + assert comparison.baseline_only_label_counts == {"date": 1} + assert comparison.direct_only_label_counts == {"city": 1} + + +def test_baseline_comparison_skips_rows_without_signature_hashes() -> None: + tool = load_tool( + "measurement_direct_detection_probe_missing_signatures", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + class LocalFakeClient: + def complete(self, request): # type: ignore[no-untyped-def] + return tool.DirectCompletion( + content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + elapsed_sec=1.25, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + case = tool.run_direct_detection_case( + tool.DirectDetectionRequest( + case_id="case-1", + text="Alice met Alice.", + labels=["first_name"], + row_index=0, + prompt_mode=tool.PromptMode.compact, + ), + client=LocalFakeClient(), + ) + + compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 2}) + + assert compared.comparison is None + + +def test_baseline_artifact_reader_rejects_duplicate_row_indexes(tmp_path: Path) -> None: + tool = load_tool( + "measurement_direct_detection_probe_duplicate_baseline", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + baseline_path = tmp_path / "baseline.jsonl" + baseline_path.write_text('{"row_index": 0}\n{"row_index": 0}\n', encoding="utf-8") + + with pytest.raises(ValueError, match="multiple rows for row_index=0"): + tool._read_baseline_artifacts(baseline_path) diff --git a/tests/tools/test_extract_signature_deltas.py b/tests/tools/test_extract_signature_deltas.py new file mode 100644 index 00000000..922d73eb --- /dev/null +++ b/tests/tools/test_extract_signature_deltas.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _entity(value: str, label: str, start: int, end: int, source: str = "detector") -> dict[str, object]: + return {"value": value, "label": label, "start_position": start, "end_position": end, "source": source} + + +def _write_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: + parquet_file = root / "entity-detection" / "parquet-files" / "batch_00000.parquet" + parquet_file.parent.mkdir(parents=True) + pd.DataFrame([{COL_TEXT: text, COL_DETECTED_ENTITIES: {"entities": entities}}]).to_parquet(parquet_file) + artifact_path = root / "detection-artifacts.jsonl" + row = tool.build_detection_artifact_row_from_entities( + workflow_name="entity-detection", + batch_file="entity-detection/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) + return artifact_path + + +def _write_seed_only_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: + parquet_file = root / "entity-detection-seed" / "parquet-files" / "batch_00000.parquet" + parquet_file.parent.mkdir(parents=True) + pd.DataFrame([{COL_TEXT: text}]).to_parquet(parquet_file) + artifact_path = root / "detection-artifacts.jsonl" + row = tool.build_detection_artifact_row_from_entities( + workflow_name="entity-detection-seed", + batch_file="entity-detection-seed/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) + return artifact_path + + +def _case_metadata() -> dict[str, object]: + return { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "config", + "case_id": "bio__config__r000", + "run_id": "bio__config__r000", + "repetition": 0, + } + + +def test_extract_signature_deltas_masks_candidate_only_context(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_artifact_case(baseline_root, analyzer, [_entity("Alice", "person", 0, 5)], "Alice met NASA") + candidate = _write_artifact_case( + candidate_root, + analyzer, + [_entity("Alice", "person", 0, 5), _entity("NASA", "organization_name", 10, 14, "augmenter")], + "Alice met NASA", + ) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.side == "candidate_only" + assert row.label == "organization_name" + assert row.source == "augmenter" + assert row.resolution == "parquet" + assert "NASA" not in row.masked_context + assert "[ORGANIZATION_NAME:" in row.masked_context + + +def test_extract_signature_deltas_recovers_artifact_detail_context(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_context_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas_context", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_artifact_case(baseline_root, analyzer, [], "The applicant was born in 1990.") + candidate = _write_artifact_case(candidate_root, analyzer, [], "The applicant was born in 1990.") + detail_entity = analyzer.EntitySchema( + value="1990", label="date_of_birth", start_position=26, end_position=30, source="rule" + ) + detail_row = analyzer.build_detection_artifact_row_from_entities( + workflow_name="entity-detection", + batch_file="entity-detection/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[detail_entity], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **detail_row}], sep=".").to_json(candidate, orient="records", lines=True) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.label == "date_of_birth" + assert row.source == "rule" + assert row.resolution == "artifact_details" + assert "1990" not in row.masked_context + assert "[DATE_OF_BIRTH:" in row.masked_context + + +def test_extract_signature_deltas_uses_signature_details_when_final_parquet_is_unavailable(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_detail_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas_details", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_seed_only_artifact_case(baseline_root, analyzer, [], "Alice met NASA") + candidate = _write_seed_only_artifact_case( + candidate_root, + analyzer, + [_entity("NASA", "organization_name", 10, 14, "detector")], + "Alice met NASA", + ) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.label == "organization_name" + assert row.source == "detector" + assert row.resolution == "artifact_details" + assert row.start_position == 10 + assert row.end_position == 14 + assert "NASA" not in row.masked_context + assert "[ORGANIZATION_NAME:" in row.masked_context diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py new file mode 100644 index 00000000..99642535 --- /dev/null +++ b/tests/tools/test_measurement_tools.py @@ -0,0 +1,1357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path +from types import ModuleType +from typing import Any + +import pandas as pd +import pytest +import yaml +from pydantic import ValidationError + +from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT +from anonymizer.engine.constants import COL_FINAL_ENTITIES + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[str, Any]: + return { + "base_dir": tmp_path, + "workloads": {workload.id: workload for workload in spec.workloads}, + "configs": {config.id: config for config in spec.configs}, + "raw_dir": tmp_path / "raw", + "dd_trace": tool.DDTraceMode.none, + "trace_dir": tmp_path / "traces", + "dd_parser_compat": spec.dd_parser_compat, + "artifact_path": tmp_path / "artifacts", + } + + +def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: + tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") + dataframe = pd.DataFrame( + [ + {"record_type": "run", "run_id": "case-a", "run_tags": {"suite_id": "suite-a"}}, + {"record_type": "stage", "run_id": "case-a", "stage": "detect", "metrics": {"rows": 2}}, + ] + ) + + result = tool.export_tables( + dataframe, + input_path=tmp_path / "measurements.jsonl", + output_dir=tmp_path / "tables", + export_format=tool.ExportFormat.csv, + overwrite=False, + ) + + assert result.total_rows == 2 + assert {table.record_type for table in result.tables} == {"run", "stage"} + assert (tmp_path / "tables/run.csv").exists() + assert (tmp_path / "tables/stage.csv").exists() + assert (tmp_path / "tables/manifest.json").exists() + + +def test_benchmark_exports_detection_artifact_analysis(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_analysis", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + parquet_dir = artifact_root / "entity-detection" / "parquet-files" + parquet_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": '[{"value":"Alice","label":"first_name","start_position":0,"end_position":5}]', + "_augmented_entities": '{"entities":[{"value":"Alice","label":"first_name"}]}', + "_detected_entities": ( + '{"entities":[{"value":"Alice","label":"first_name",' + '"start_position":0,"end_position":5,"source":"detector"}]}' + ), + } + ] + ).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + output_path = tmp_path / "detection-artifacts.jsonl" + + result_path = tool.export_detection_artifact_analysis(artifact_root, output_path) + + assert result_path == output_path + rows = [pd.read_json(output_path, lines=True).iloc[0].to_dict()] + assert rows[0]["augmented_duplicate_seed_value_count"] == 1 + assert "Alice" not in output_path.read_text(encoding="utf-8") + + +def test_benchmark_detection_artifact_analysis_ignores_stale_artifacts(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_delta", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + stale_dir = artifact_root / "entity-detection-old" / "parquet-files" + stale_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": '{"entities":[]}', + } + ] + ).to_parquet(stale_dir / "batch_00000.parquet", index=False) + snapshot = tool.snapshot_detection_artifacts(artifact_root) + fresh_dir = artifact_root / "entity-detection-new" / "parquet-files" + fresh_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": ( + '{"entities":[{"value":"sk-test-AAAAAAAAAAAAAAAAAAAAAAAA",' + '"label":"api_key","start_position":0,"end_position":32,"source":"augmenter"}]}' + ), + } + ] + ).to_parquet(fresh_dir / "batch_00000.parquet", index=False) + output_path = tmp_path / "detection-artifacts.jsonl" + + result_path = tool.export_detection_artifact_analysis( + artifact_root, + output_path, + artifact_snapshot=snapshot, + ) + + assert result_path == output_path + rows = pd.read_json(output_path, lines=True) + assert rows["workflow_name"].tolist() == ["entity-detection-new"] + assert rows["final_entity_count"].tolist() == [1] + + +def test_benchmark_case_detection_artifact_analysis_adds_case_metadata(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_case", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + parquet_dir = artifact_root / "entity-detection" / "parquet-files" + parquet_dir.mkdir(parents=True) + snapshot = tool.snapshot_detection_artifacts(artifact_root) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": ( + '{"entities":[{"value":"sk-test-AAAAAAAAAAAAAAAAAAAAAAAA",' + '"label":"api_key","start_position":0,"end_position":32,"source":"detector"}]}' + ), + } + ] + ).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + case = tool.BenchmarkCase( + suite_id="suite-a", + workload_id="shell", + config_id="rules", + repetition=2, + case_id="shell__rules__r002", + ) + output_path = tmp_path / "raw" / "shell__rules__r002.detection-artifacts.jsonl" + + result_path = tool.export_case_detection_artifact_analysis( + artifact_root, + output_path, + case=case, + artifact_snapshot=snapshot, + ) + + assert result_path == output_path + rows = pd.read_json(output_path, lines=True) + assert rows["suite_id"].tolist() == ["suite-a"] + assert rows["workload_id"].tolist() == ["shell"] + assert rows["config_id"].tolist() == ["rules"] + assert rows["case_id"].tolist() == ["shell__rules__r002"] + assert rows["run_id"].tolist() == ["shell__rules__r002"] + assert rows["repetition"].tolist() == [2] + assert "sk-test" not in output_path.read_text(encoding="utf-8") + + +def _stale_detection_artifact_payload() -> dict[str, Any]: + return { + "workflow_name": "entity-detection", + "batch_file": "entity-detection/parquet-files/batch_00000.parquet", + "row_index": 0, + "seed_entity_count": 1, + "seed_validation_candidate_count": 1, + "merged_validation_candidate_count": 1, + "augmented_entity_count": 0, + "final_entity_count": 1, + "augmented_duplicate_seed_value_count": 0, + "augmented_new_value_count": 0, + "augmented_new_final_value_count": 0, + "weak_api_key_shape_count": 0, + "final_entity_signature_count": 1, + "final_entity_signature_hashes": ["stale"], + "final_entity_signature_labels": {"stale": "person"}, + "weak_api_key_shape_label_counts": {}, + "final_label_counts": {"person": 1}, + "final_source_counts": {"detector": 1}, + } + + +def _final_trace_dataframe_with_rule_entity() -> pd.DataFrame: + return pd.DataFrame( + { + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": "Alice", + "label": "person", + "start_position": 0, + "end_position": 5, + "source": "detector", + }, + { + "value": "1990", + "label": "date_of_birth", + "start_position": 25, + "end_position": 29, + "source": "rule", + }, + ] + } + ] + } + ) + + +def test_benchmark_patches_detection_artifacts_from_final_trace_dataframe(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_patch_result_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + output_path = tmp_path / "raw" / "case.detection-artifacts.jsonl" + tool.write_detection_artifact_payloads([_stale_detection_artifact_payload()], output_path) + + result = tool.patch_case_detection_artifacts_from_trace_dataframe( + output_path, + _final_trace_dataframe_with_rule_entity(), + ) + + assert result == output_path + text = output_path.read_text(encoding="utf-8") + assert "Alice" not in text + assert "1990" not in text + row = json.loads(text) + assert row["final_entity_count"] == 2 + assert row["final_entity_signature_count"] == 2 + assert row["final_label_counts.person"] == 1 + assert row["final_label_counts.date_of_birth"] == 1 + assert row["final_source_counts.detector"] == 1 + assert row["final_source_counts.rule"] == 1 + assert any(key.startswith("final_entity_signature_details.") for key in row) + assert "final_entity_signature_labels.stale" not in row + + +def test_run_suite_records_detection_artifact_analysis_path( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_run_suite_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec = tool.BenchmarkSpec( + suite_id="artifact-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + output_dir = tmp_path / "output" + output_dir.mkdir() + artifact_path = output_dir / "artifacts" + analysis_path = output_dir / "detection-artifacts.jsonl" + + class FakeAnonymizer: + def __init__(self, **kwargs: Any) -> None: + assert kwargs["artifact_path"] == artifact_path + + def fake_run_case(case: Any, *_args: Any, **_kwargs: Any) -> Any: + assert _kwargs["export_detection_artifacts"] is True + raw_path = output_dir / "raw" / f"{case.case_id}.jsonl" + raw_path.parent.mkdir() + raw_path.write_text('{"record_type":"run","run_id":"case"}\n', encoding="utf-8") + artifact_output_path = output_dir / "raw" / f"{case.case_id}.detection-artifacts.jsonl" + artifact_output_path.write_text( + '{"case_id":"input__redact__r000","workflow_name":"entity-detection"}\n', + encoding="utf-8", + ) + return case.model_copy( + update={ + "status": tool.CaseStatus.completed, + "measurement_path": str(raw_path), + "detection_artifact_path": str(artifact_output_path), + } + ) + + monkeypatch.setattr(tool, "Anonymizer", FakeAnonymizer) + monkeypatch.setattr(tool, "_run_case", fake_run_case) + monkeypatch.setattr(tool, "export_measurement_tables", lambda *_args: output_dir / "tables") + + result = tool.run_suite( + spec, + spec_path=tmp_path / "suite.yaml", + output_dir=output_dir, + export=True, + fail_fast=False, + dd_trace=tool.DDTraceMode.none, + trace_dir=None, + ) + + assert result.detection_artifact_analysis_path == str(analysis_path) + assert analysis_path.read_text(encoding="utf-8") == ( + '{"case_id":"input__redact__r000","workflow_name":"entity-detection"}\n' + ) + summary = (output_dir / "summary.json").read_text(encoding="utf-8") + assert "detection_artifact_analysis_path" in summary + assert "detection_artifact_path" in summary + + +def test_run_suite_skips_detection_artifact_analysis_when_export_disabled( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_run_suite_no_export_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + spec = tool.BenchmarkSpec( + suite_id="artifact-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + output_dir = tmp_path / "output" + output_dir.mkdir() + + class FakeAnonymizer: + def __init__(self, **_kwargs: Any) -> None: + pass + + def fake_run_case(case: Any, *_args: Any, **_kwargs: Any) -> Any: + assert _kwargs["export_detection_artifacts"] is False + raw_path = output_dir / "raw" / f"{case.case_id}.jsonl" + raw_path.parent.mkdir() + raw_path.write_text('{"record_type":"run","run_id":"case"}\n', encoding="utf-8") + return case.model_copy(update={"status": tool.CaseStatus.completed, "measurement_path": str(raw_path)}) + + monkeypatch.setattr(tool, "Anonymizer", FakeAnonymizer) + monkeypatch.setattr(tool, "_run_case", fake_run_case) + monkeypatch.setattr( + tool, + "combine_detection_artifact_analysis", + lambda *_args: pytest.fail("artifact analysis should not be combined"), + ) + + result = tool.run_suite( + spec, + spec_path=tmp_path / "suite.yaml", + output_dir=output_dir, + export=False, + fail_fast=False, + dd_trace=tool.DDTraceMode.none, + trace_dir=None, + ) + + assert result.detection_artifact_analysis_path is None + assert not (output_dir / "detection-artifacts.jsonl").exists() + + +def test_benchmark_case_retries_transient_errors_and_records_attempts( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_case_retry_success", REPO_ROOT / "tools/measurement/run_benchmarks.py") + attempts: list[Path] = [] + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=1, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + pd.DataFrame({"text": ["Alice"]}).to_csv(tmp_path / "input.csv", index=False) + + def fake_execute_case(*_args: Any, raw_path: Path, **_kwargs: Any) -> Any: + attempts.append(raw_path) + if len(attempts) == 1: + raise RuntimeError("transient provider health check failure") + raw_path.parent.mkdir(parents=True, exist_ok=True) + raw_path.write_text('{"record_type":"run"}\n', encoding="utf-8") + return tool._CaseExecution( + input_data=tool.AnonymizerInput(source=str(tmp_path / "input.csv"), text_column="text") + ) + + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + + result = tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=False, + export_detection_artifacts=False, + ) + + assert result.status == tool.CaseStatus.completed + assert result.attempt_count == 2 + assert result.attempt_errors == ["transient provider health check failure"] + assert attempts == [tmp_path / "raw" / "input__redact__r000.jsonl"] * 2 + + +def test_benchmark_case_records_persistent_retry_failures( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_case_retry_failure", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=1, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + + attempts = 0 + errors: list[str] = [] + + def fake_execute_case(*_args: Any, **_kwargs: Any) -> Any: + nonlocal attempts + attempts += 1 + raise RuntimeError(f"provider unavailable #{attempts}") + + def capture_error(case: Any, *, error: Exception, **kwargs: Any) -> Any: + errors.append(str(error)) + return original_run_case_error(case, error=error, **kwargs) + + original_run_case_error = tool._run_case_error + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + monkeypatch.setattr(tool, "_run_case_error", capture_error) + + result = tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=False, + export_detection_artifacts=False, + ) + + assert result.status == tool.CaseStatus.error + assert result.error == "provider unavailable #2" + assert result.attempt_count == 2 + assert result.attempt_errors == ["provider unavailable #1", "provider unavailable #2"] + assert errors == ["provider unavailable #2"] + assert attempts == 2 + + +def test_benchmark_case_fail_fast_skips_retries( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_case_retry_fail_fast", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=3, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + attempts = 0 + + def fake_execute_case(*_args: Any, **_kwargs: Any) -> Any: + nonlocal attempts + attempts += 1 + raise RuntimeError("fail fast") + + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + + with pytest.raises(RuntimeError, match="fail fast"): + tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=True, + export_detection_artifacts=False, + ) + + assert attempts == 1 + + +def test_combine_detection_artifact_analysis_separates_jsonl_chunks(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_combine_artifact_newlines", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + first = tmp_path / "first.jsonl" + second = tmp_path / "second.jsonl" + first.write_text('{"case_id":"one"}', encoding="utf-8") + second.write_text('{"case_id":"two"}\n', encoding="utf-8") + destination = tmp_path / "combined.jsonl" + cases = [ + tool.BenchmarkCase( + suite_id="suite", + workload_id="input", + config_id="first", + repetition=0, + case_id="first", + detection_artifact_path=str(first), + ), + tool.BenchmarkCase( + suite_id="suite", + workload_id="input", + config_id="second", + repetition=0, + case_id="second", + detection_artifact_path=str(second), + ), + ] + + result = tool.combine_detection_artifact_analysis(cases, destination) + + assert result == destination + assert destination.read_text(encoding="utf-8") == '{"case_id":"one"}\n{"case_id":"two"}\n' + + +def test_benchmark_spec_validates_matrix_references(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-suite +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +matrix: + - workload: missing + config: redact +""", + encoding="utf-8", + ) + + with pytest.raises(ValidationError, match="unknown workload"): + tool.load_spec(spec_path) + + +def test_benchmark_partial_rewrite_goal_uses_public_defaults() -> None: + tool = load_tool("measurement_benchmark_tool_defaults", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + rewrite = tool.build_rewrite(tool.RewriteSpec(protect="Direct payroll identifiers")) + + assert rewrite.privacy_goal.protect == "Direct payroll identifiers" + assert rewrite.privacy_goal.preserve == DEFAULT_PRESERVE_TEXT + + +def test_benchmark_output_dir_requires_overwrite_for_existing_files(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_output", REPO_ROOT / "tools/measurement/run_benchmarks.py") + output_dir = tmp_path / "benchmark-output" + output_dir.mkdir() + existing = output_dir / "summary.json" + existing.write_text("{}", encoding="utf-8") + + with pytest.raises(ValueError, match="not empty"): + tool.prepare_output_dir(output_dir, overwrite=False, dry_run=False) + + tool.prepare_output_dir(output_dir, overwrite=True, dry_run=False) + + assert (output_dir / "raw").is_dir() + assert not existing.exists() + + +def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_dry_run", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: smoke-suite +workloads: + - id: biography + source: biographies.csv +configs: + - id: redact + replace: redact +matrix: + - workload: biography + config: redact + repetitions: 2 +""", + encoding="utf-8", + ) + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "biographies.csv", index=False) + output_dir = tmp_path / "dry-run-output" + + result = tool.run_or_plan( + spec_path, + output=output_dir, + overwrite=False, + dry_run=True, + export=False, + fail_fast=False, + ) + + assert len(result.cases) == 2 + assert result.table_dir is None + assert {case.status for case in result.cases} == {tool.CaseStatus.planned} + assert not output_dir.exists() + + +def test_benchmark_preflight_rejects_missing_text_column(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_preflight_input", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"body": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-input-suite +workloads: + - id: biography + source: input.csv + text_column: text +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="workload 'biography' text_column 'text' not found"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_build_input_materializes_sliced_csv_workload(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_sliced_input", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"id": ["a", "b", "c", "d"], "text": ["row-a", "row-b", "row-c", "row-d"]}).to_csv( + input_path, index=False + ) + workload = tool.WorkloadSpec( + id="slice", + source="input.csv", + text_column="text", + id_column="id", + row_offset=1, + row_limit=2, + ) + + anonymizer_input = tool.build_input( + workload, + tmp_path, + slice_dir=tmp_path / "slices", + case_id="slice__redact__r000", + ) + + assert anonymizer_input.text_column == "text" + assert anonymizer_input.id_column == "id" + assert Path(anonymizer_input.source) != input_path + sliced = pd.read_csv(anonymizer_input.source) + assert sliced.to_dict("records") == [ + {"id": "b", "text": "row-b"}, + {"id": "c", "text": "row-c"}, + ] + + +def test_benchmark_preflight_rejects_sliced_remote_workload(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_sliced_remote", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-remote-slice +workloads: + - id: remote + source: s3://bucket/input.csv + row_limit: 2 +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="row slicing requires a local workload source"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_preflight_models", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-model-suite +model_configs: | + selected_models: + detection: + entity_detector: detector + entity_validator: [validator] + entity_augmenter: augmenter + replace: + replacement_generator: missing-replacer + model_configs: + - alias: detector + model: test/detector + - alias: validator + model: test/validator + - alias: augmenter + model: test/augmenter +workloads: + - id: biography + source: input.csv +configs: + - id: substitute + replace: substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="missing-replacer"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_local_structured_substitute_for_contextual_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_local_substitute_contextual_labels", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: local-substitute-suite +workloads: + - id: input + source: input.csv +configs: + - id: local-substitute + detect: + entity_labels: [api_key, person] + replace: substitute + experimental_replacement_strategy: local_structured_substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="unsupported local structured substitute labels: person"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_local_structured_substitute_supported_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_local_substitute_supported_labels", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: local-substitute-suite +workloads: + - id: input + source: input.csv +configs: + - id: local-substitute + detect: + entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] + replace: substitute + experimental_replacement_strategy: local_structured_substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_example_suites_are_portable() -> None: + example_paths = sorted((REPO_ROOT / "tools/measurement/examples").glob("*.yaml")) + assert example_paths + + machine_specific_fragments = ( + "/root/", + "/Users/", + "/stable-cache/", + "gpu-dev-pod", + "serve-svc", + ) + path_fields = {"source", "model_configs", "model_providers", "artifact_path"} + + def walk(value: Any) -> Iterator[tuple[str, Any]]: + if isinstance(value, dict): + for key, item in value.items(): + yield str(key), item + yield from walk(item) + elif isinstance(value, list): + for item in value: + yield from walk(item) + + for example_path in example_paths: + payload = yaml.safe_load(example_path.read_text(encoding="utf-8")) + assert isinstance(payload, dict) + + for key, value in walk(payload): + if isinstance(value, str): + assert not any(fragment in value for fragment in machine_specific_fragments), ( + f"{example_path} contains machine-specific value for {key}: {value}" + ) + if key in path_fields: + assert not Path(value).is_absolute(), f"{example_path} uses absolute path for {key}: {value}" + if key in {"endpoint", "gliner_endpoint"}: + raise AssertionError(f"{example_path} should use endpoint_env for {key}, not a literal endpoint") + + +def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + provider_path = tmp_path / "providers.yaml" + provider_path.write_text("not_providers: []\n", encoding="utf-8") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-provider-suite +model_providers: providers.yaml +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="providers"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_provider_path", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + provider_path = tmp_path / "providers.yaml" + provider_path.write_text( + """ +providers: + - name: test-provider + endpoint: https://example.com/v1 + provider_type: openai + api_key: TEST_API_KEY +""", + encoding="utf-8", + ) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: provider-path-suite +model_providers: providers.yaml +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_native_strategy_without_runtime( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_required", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="native_runtime.runtime_id"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_native_runtime_config_override_merges_suite_defaults() -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_merge", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + config = tool.ConfigSpec( + id="native-single-pass", + experimental_detection_strategy="native_single_pass", + native_runtime=tool.NativeRuntimeSpec(model="config-model", max_workers=2), + replace="redact", + ) + spec = tool.BenchmarkSpec( + suite_id="native-runtime-suite", + native_runtime=tool.NativeRuntimeSpec( + runtime_id="suite-runtime", + endpoint="http://suite-endpoint/v1", + model="suite-model", + provider="suite-provider", + max_workers=4, + ), + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[config], + ) + + runtime = tool._native_detection_runtime(spec, config) + + assert runtime.endpoint == "http://suite-endpoint/v1" + assert runtime.model == "config-model" + assert runtime.provider == "suite-provider" + assert runtime.max_workers == 2 + + +def test_benchmark_preflight_rejects_native_strategy_without_endpoint_or_model( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_endpoint_required", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +native_runtime: + runtime_id: native-runtime +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="native_runtime.endpoint"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_native_runtime_resolves_endpoint_and_model_from_env( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_env", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", "http://runtime-from-env/v1") + monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_MODEL", "env-model") + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +native_runtime: + runtime_id: env-runtime +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + runtime = tool._native_detection_runtime(spec, spec.configs[0]) + tags = tool._run_tags( + tool.BenchmarkCase( + suite_id="native-runtime-suite", + workload_id="input", + config_id="native-single-pass", + repetition=0, + case_id="input__native-single-pass__r000", + ), + spec, + ) + + assert runtime.endpoint == "http://runtime-from-env/v1" + assert runtime.model == "env-model" + assert tags["native_runtime_id"] == "env-runtime" + assert "native_endpoint" not in tags + assert tags["native_endpoint_env"] == "ANONYMIZER_BENCH_NATIVE_ENDPOINT" + assert tags["native_model"] == "env-model" + + +def test_benchmark_preflight_skips_inactive_native_configs(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_inactive_native_runtime", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: inactive-native-suite +workloads: + - id: input + source: input.csv +configs: + - id: redact + replace: redact + - id: inactive-native + experimental_detection_strategy: native_single_pass + replace: redact +matrix: + - workload: input + config: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_case_passes_dd_trace_config_to_measurement_session( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_dd_trace", REPO_ROOT / "tools/measurement/run_benchmarks.py") + captured: list[Any] = [] + + @contextmanager + def fake_measurement_session(config: Any) -> Iterator[None]: + captured.append(config) + yield None + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> None: + assert config.replace is not None + assert data.text_column == "text" + + monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) + + spec = tool.BenchmarkSpec( + suite_id="trace-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "input.csv", index=False) + case = tool.BenchmarkCase( + suite_id="trace-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + trace_path = tmp_path / "traces" / "input__redact__r000.jsonl" + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=tmp_path / "raw" / "input__redact__r000.jsonl", + trace_path=trace_path, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.all_messages, + dd_parser_compat=tool.DDParserCompatMode.none, + ) + + assert len(captured) == 1 + assert captured[0].dd_trace == "all_messages" + assert captured[0].dd_trace_path == trace_path + assert captured[0].streaming is True + assert captured[0].keep_records is False + + +def test_benchmark_config_accepts_experimental_detection_strategy() -> None: + tool = load_tool( + "measurement_benchmark_tool_detection_strategy_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + + detector_only = tool.ConfigSpec( + id="detector-only", + replace="redact", + experimental_detection_strategy="detector_only", + ) + + assert detector_only.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.detector_only + anonymizer_config = tool.build_anonymizer_config(detector_only) + assert not hasattr(anonymizer_config.detect, "experimental_detection_strategy") + + native_candidate_validate = tool.ConfigSpec( + id="native-candidate-validate", + replace="redact", + experimental_detection_strategy="native_candidate_validate_no_augment", + ) + + assert ( + native_candidate_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment + ) + + detector_native_validate = tool.ConfigSpec( + id="detector-native-validate", + replace="redact", + experimental_detection_strategy="detector_native_validate_no_augment", + ) + + assert ( + detector_native_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment + ) + + detector_native_augment = tool.ConfigSpec( + id="detector-native-augment", + replace="redact", + experimental_detection_strategy="detector_native_validate_native_augment", + ) + + assert ( + detector_native_augment.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment + ) + + gliner_native_validate = tool.ConfigSpec( + id="gliner-native-validate", + replace="redact", + experimental_detection_strategy="gliner_native_validate_no_augment", + ) + + assert ( + gliner_native_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment + ) + + gliner_native_augment = tool.ConfigSpec( + id="gliner-native-augment", + replace="redact", + experimental_detection_strategy="gliner_native_validate_native_augment", + ) + + assert ( + gliner_native_augment.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment + ) + + native_single_pass = tool.ConfigSpec( + id="native-single-pass", + replace="redact", + experimental_detection_strategy="native_single_pass", + ) + + assert native_single_pass.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.native_single_pass + + native_single_pass_recall = tool.ConfigSpec( + id="native-single-pass-recall", + replace="redact", + experimental_detection_strategy="native_single_pass_recall", + ) + + assert ( + native_single_pass_recall.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_recall + ) + + native_single_pass_values = tool.ConfigSpec( + id="native-single-pass-values", + replace="redact", + experimental_detection_strategy="native_single_pass_values", + ) + + assert ( + native_single_pass_values.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_values + ) + + native_single_pass_values_recall = tool.ConfigSpec( + id="native-single-pass-values-recall", + replace="redact", + experimental_detection_strategy="native_single_pass_values_recall", + ) + + assert ( + native_single_pass_values_recall.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_values_recall + ) + + +def test_benchmark_spec_accepts_dd_parser_compat() -> None: + tool = load_tool( + "measurement_benchmark_tool_dd_parser_compat_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + + spec = tool.BenchmarkSpec( + suite_id="raw-json-suite", + dd_parser_compat="raw_json", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + + assert spec.dd_parser_compat == tool.DDParserCompatMode.raw_json + + +def test_benchmark_case_enters_experimental_detection_strategy_context( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_detection_strategy_case", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + captured_measurements: list[Any] = [] + captured_parser_compat: list[Any] = [] + captured_strategies: list[Any] = [] + captured_context_kwargs: list[dict[str, Any]] = [] + + @contextmanager + def fake_measurement_session(config: Any) -> Iterator[None]: + captured_measurements.append(config) + yield None + + @contextmanager + def fake_detection_strategy_context(strategy: Any, **kwargs: Any) -> Iterator[None]: + captured_strategies.append(strategy) + captured_context_kwargs.append(kwargs) + yield None + + @contextmanager + def fake_dd_parser_compat_context(mode: Any) -> Iterator[None]: + captured_parser_compat.append(mode) + yield None + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> None: + assert config.replace is not None + assert data.text_column == "text" + + monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) + monkeypatch.setattr(tool, "dd_parser_compat_context", fake_dd_parser_compat_context) + monkeypatch.setattr(tool, "experimental_detection_strategy_context", fake_detection_strategy_context) + + spec = tool.BenchmarkSpec( + suite_id="native-suite", + dd_parser_compat="raw_json", + native_runtime=tool.NativeRuntimeSpec( + runtime_id="native-test", + endpoint="http://runtime.example/v1", + model="test-model", + ), + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[ + tool.ConfigSpec( + id="native-single-pass-redact", + replace="redact", + experimental_detection_strategy="native_single_pass", + ) + ], + ) + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(tmp_path / "input.csv", index=False) + case = tool.BenchmarkCase( + suite_id="native-suite", + workload_id="input", + config_id="native-single-pass-redact", + repetition=0, + case_id="input__native-single-pass-redact__r000", + ) + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=tmp_path / "raw" / "input__native-single-pass-redact__r000.jsonl", + trace_path=None, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.none, + dd_parser_compat=spec.dd_parser_compat, + ) + + assert captured_parser_compat == [tool.DDParserCompatMode.raw_json] + assert captured_strategies == [tool.ExperimentalDetectionStrategy.native_single_pass] + assert captured_context_kwargs[0]["native_runtime"].endpoint == "http://runtime.example/v1" + assert captured_context_kwargs[0]["native_runtime"].model == "test-model" + assert captured_measurements[0].run_tags["experimental_detection_strategy"] == "native_single_pass" + assert captured_measurements[0].run_tags["native_runtime_id"] == "native-test" + assert captured_measurements[0].run_tags["dd_parser_compat"] == "raw_json" diff --git a/tests/tools/test_replacement_strategies.py b/tests/tools/test_replacement_strategies.py new file mode 100644 index 00000000..3b1ce86b --- /dev/null +++ b/tests/tools/test_replacement_strategies.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import Mock + +import pandas as pd + +from anonymizer.config.replace_strategies import Substitute +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_FINAL_ENTITIES, COL_REPLACED_TEXT, COL_TEXT +from anonymizer.engine.ndd.model_loader import load_default_model_selection +from anonymizer.engine.replace.llm_replace_workflow import LlmReplaceWorkflow +from anonymizer.engine.replace.replace_runner import ReplacementWorkflow + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_local_structured_substitute_context_bypasses_dd_and_restores_method() -> None: + tool = load_tool( + "measurement_replacement_strategies_local_substitute", + REPO_ROOT / "tools/measurement/replacement_strategies.py", + ) + original_method = LlmReplaceWorkflow.generate_map_only + secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + adapter = Mock() + runner = ReplacementWorkflow(llm_workflow=LlmReplaceWorkflow(adapter=adapter)) + dataframe = pd.DataFrame( + { + COL_TEXT: [f"export API_KEY={secret}"], + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": secret, + "label": "api_key", + "start_position": len("export API_KEY="), + "end_position": len("export API_KEY=") + len(secret), + } + ] + } + ], + COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": secret, "labels": ["api_key"]}]}], + } + ) + + with tool.experimental_replacement_strategy_context( + tool.ExperimentalReplacementStrategy.local_structured_substitute + ): + result = runner.run( + dataframe, + replace_method=Substitute(), + model_configs=[], + selected_models=load_default_model_selection().replace, + ) + + adapter.run_workflow.assert_not_called() + assert LlmReplaceWorkflow.generate_map_only is original_method + replaced = result.dataframe[COL_REPLACED_TEXT].iloc[0] + assert secret not in replaced + assert "sk-test-" in replaced diff --git a/tests/tools/test_replay_replacement_strategies.py b/tests/tools/test_replay_replacement_strategies.py new file mode 100644 index 00000000..482d1919 --- /dev/null +++ b/tests/tools/test_replay_replacement_strategies.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pandas as pd + +from anonymizer.engine.constants import ( + COL_FINAL_ENTITIES, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, + COL_TEXT, +) + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_replacement_row_metrics_counts_sanitized_replay_failures() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_metrics", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + row = pd.Series( + { + COL_FINAL_ENTITIES: { + "entities": [ + {"value": "secret-a", "label": "api_key", "start_position": 0, "end_position": 8}, + {"value": "secret-b", "label": "password", "start_position": 9, "end_position": 17}, + {"value": "1234", "label": "pin", "start_position": 18, "end_position": 22}, + ] + }, + COL_REPLACEMENT_MAP: { + "replacements": [ + {"original": "secret-a", "label": "api_key", "synthetic": "secret-b"}, + {"original": "secret-b", "label": "password", "synthetic": "Synthetic!123"}, + ] + }, + COL_REPLACED_TEXT: "the leaked pin is 1234", + } + ) + + metrics = tool.replacement_row_metrics(row) + + assert metrics.final_entity_count == 3 + assert metrics.replacement_count == 2 + assert metrics.missing_count == 1 + assert metrics.missing_labels == {"pin": 1} + assert metrics.collision_count == 1 + assert metrics.collision_labels == {"api_key": 1} + assert metrics.leak_count == 1 + assert metrics.leak_labels == {"pin": 1} + + +def test_summarize_replacement_dataframe_counts_sources_and_totals() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_summary", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + dataframe = pd.DataFrame( + [ + { + COL_FINAL_ENTITIES: { + "entities": [ + {"value": "alice@example.com", "label": "email", "start_position": 0, "end_position": 17} + ] + }, + COL_REPLACEMENT_MAP: { + "replacements": [ + { + "original": "alice@example.com", + "label": "email", + "synthetic": "user-123@example.invalid", + } + ] + }, + COL_REPLACED_TEXT: "user-123@example.invalid", + COL_REPLACEMENT_MAP_SOURCE: "local_structured", + } + ] + ) + + summary = tool.summarize_replacement_dataframe( + dataframe, + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + elapsed_sec=0.01, + ) + + assert summary.status == tool.ReplayStatus.completed + assert summary.final_entity_count == 1 + assert summary.replacement_count == 1 + assert summary.missing_count == 0 + assert summary.collision_count == 0 + assert summary.leak_count == 0 + assert summary.replacement_map_sources == {"local_structured": 1} + + +def test_aggregate_strategy_summaries_sums_repeated_backend_runs() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_aggregate", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + + summary = tool.aggregate_strategy_summaries( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + summaries=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=1.5, + row_count=5, + final_entity_count=10, + replacement_count=9, + missing_count=1, + leak_count=1, + missing_labels={"api_key": 1}, + leak_labels={"api_key": 1}, + replacement_map_sources={"llm": 5}, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.error, + elapsed_sec=2.5, + row_count=5, + final_entity_count=10, + replacement_count=10, + collision_count=1, + collision_labels={"password": 1}, + replacement_map_sources={"llm": 4}, + error="provider failed", + ), + ], + ) + + assert summary.status == tool.ReplayStatus.error + assert summary.repetition_count == 2 + assert summary.elapsed_sec == 4.0 + assert summary.row_count == 10 + assert summary.final_entity_count == 20 + assert summary.replacement_count == 19 + assert summary.missing_count == 1 + assert summary.collision_count == 1 + assert summary.leak_count == 1 + assert summary.missing_labels == {"api_key": 1} + assert summary.collision_labels == {"password": 1} + assert summary.leak_labels == {"api_key": 1} + assert summary.replacement_map_sources == {"llm": 9} + assert summary.error == "provider failed" + + +def test_replay_comparison_row_marks_fast_complete_local_replay_viable() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_row", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/structured_identifiers.csv", + text_column="text", + labels=["api_key"], + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=8.8, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=6.2, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result) + + assert row.workload_id == "structured_identifiers" + assert row.baseline_replacement_strategy == "default" + assert row.candidate_replacement_strategy == "local_structured_substitute" + assert row.pipeline_elapsed_sec_delta < 0 + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "pass" + assert row.safety_verdict == "pass" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "candidate_viable" + assert row.flags == [] + + +def test_replay_comparison_row_rejects_missing_local_replacements() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_missing", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/structured_identifiers.csv", + text_column="text", + labels=["api_key"], + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=8.8, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=6.2, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=1, + final_entity_count=2, + replacement_count=1, + missing_count=1, + ), + ], + ) + + row = tool.replay_comparison_row(result) + + assert row.candidate_replacement_missing_final_entity_count == 1 + assert "candidate_replacement_missing_final_entity" in row.flags + assert "replacement_count_loss" in row.flags + assert row.value_protection_verdict == "fail" + assert row.signature_parity_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_replay_comparison_row_reviews_duplicate_local_synthetic_values() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_duplicates", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/biographies.csv", + text_column="biography", + labels=["organization_name", "religious_belief"], + nrows=5, + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=18.8, + detected_final_entity_count=10, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=7.9, + row_count=5, + final_entity_count=10, + replacement_count=10, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=5, + final_entity_count=10, + replacement_count=10, + duplicate_synthetic_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") + + assert row.candidate_duplicate_synthetic_replacement_count == 2 + assert "candidate_duplicate_synthetic_replacement" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_replay_comparison_row_reviews_candidate_replacement_count_gain_over_flawed_baseline() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_baseline_flaw", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/biographies.csv", + text_column="biography", + labels=["organization_name"], + nrows=5, + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=12.9, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=7.6, + row_count=5, + final_entity_count=2, + replacement_count=1, + missing_count=1, + leak_count=1, + missing_labels={"organization_name": 1}, + leak_labels={"organization_name": 1}, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=5, + final_entity_count=2, + replacement_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") + + assert row.replacement_count_delta == 1 + assert row.replacement_missing_final_entity_count_delta == -1 + assert row.baseline_replacement_missing_final_entity_label_counts == {"organization_name": 1} + assert row.candidate_replacement_missing_final_entity_label_counts == {} + assert row.baseline_original_value_leak_label_counts == {"organization_name": 1} + assert row.candidate_original_value_leak_label_counts == {} + assert "baseline_replacement_missing_final_entity" in row.flags + assert "baseline_original_value_leak" in row.flags + assert "candidate_covers_baseline_replacement_missing_final_entity" in row.flags + assert "candidate_covers_baseline_original_value_leak" in row.flags + assert "replacement_count_delta" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_strip_replacement_columns_removes_prior_strategy_outputs() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_strip", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + dataframe = pd.DataFrame( + { + "text": ["hello"], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACEMENT_MAP_SOURCE: ["redact"], + COL_REPLACED_TEXT: ["hello"], + } + ) + + stripped = tool.strip_replacement_columns(dataframe) + + assert list(stripped.columns) == ["text"] + + +def test_build_replay_dataframe_uses_preview_when_nrows_is_set(tmp_path: Path) -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_nrows", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + trace_dataframe = pd.DataFrame( + { + COL_TEXT: ["born on 1988-11-21"], + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": "1988-11-21", + "label": "date_of_birth", + "start_position": 8, + "end_position": 18, + } + ] + } + ], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACED_TEXT: ["born on [REDACTED_DATE_OF_BIRTH]"], + } + ) + + class StubAnonymizer: + preview_num_records: int | None = None + + def run(self, **_kwargs: object) -> object: + raise AssertionError("run should not be used for row-limited replay") + + def preview(self, **kwargs: object) -> object: + self.preview_num_records = kwargs["num_records"] # type: ignore[assignment] + return SimpleNamespace(trace_dataframe=trace_dataframe, resolved_text_column="text") + + anonymizer = StubAnonymizer() + source = tmp_path / "multiline.csv" + source.write_text('text\n"born on 1988-11-21"\n', encoding="utf-8") + + _elapsed, replay_df = tool.build_replay_dataframe( + anonymizer, + source=source, + text_column="text", + labels=["date_of_birth"], + nrows=1, + dd_parser_compat=tool.DDParserCompatMode.none, + ) + + assert anonymizer.preview_num_records == 1 + assert COL_REPLACEMENT_MAP not in replay_df.columns + assert COL_REPLACED_TEXT not in replay_df.columns + assert replay_df[COL_FINAL_ENTITIES].iloc[0]["entities"][0]["value"] == "1988-11-21" diff --git a/tests/tools/test_screen_strategy_comparisons.py b/tests/tools/test_screen_strategy_comparisons.py new file mode 100644 index 00000000..8b2fa237 --- /dev/null +++ b/tests/tools/test_screen_strategy_comparisons.py @@ -0,0 +1,974 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + analysis_dir = tmp_path / "analysis" + analysis_dir.mkdir() + pd.DataFrame( + [ + { + "workload_id": "shell-3", + "baseline_config_id": "default", + "candidate_config_id": "detector-only", + "baseline_strategy": "default", + "candidate_strategy": "detector_only", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "baseline_case_count": 3, + "candidate_case_count": 3, + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.0, + "observed_total_requests_delta": -12, + "observed_total_tokens_delta": -11000, + "augmented_entity_count_delta": -3, + "augmented_new_final_value_count_delta": 0, + "baseline_only_final_entity_signature_count": 0, + "baseline_stable_candidate_unstable_final_entity_signature_count": 0, + "baseline_stable_final_entity_signature_count": 12, + "candidate_stable_final_entity_signature_count": 12, + "shared_stable_final_entity_signature_count": 12, + "flags": '["candidate_skips_llm_validation"]', + }, + { + "workload_id": "legal-2", + "baseline_config_id": "default", + "candidate_config_id": "no-augment", + "baseline_strategy": "default", + "candidate_strategy": "no_augment", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -10.0, + "observed_total_requests_delta": -2, + "observed_total_tokens_delta": -10000, + "augmented_entity_count_delta": -5.5, + "augmented_new_final_value_count_delta": -2, + "baseline_only_final_entity_signature_count": 2, + "baseline_stable_candidate_unstable_final_entity_signature_count": 1, + "baseline_only_final_entity_signature_label_counts.first_name": 2, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, + "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', + }, + ] + ).to_csv(analysis_dir / "default-vs-candidates.csv", index=False) + pd.DataFrame([{"workload_id": "shell-3", "config_id": "default"}]).to_csv( + analysis_dir / "group_analysis.csv", + index=False, + ) + pd.DataFrame( + [ + { + "source_path": "analysis/default-vs-candidates.csv", + "workload_id": "shell-3", + "baseline_config_id": "default", + "candidate_config_id": "detector-only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + } + ] + ).to_csv(analysis_dir / "strategy-screen.csv", index=False) + (analysis_dir / "empty.csv").write_text("", encoding="utf-8") + + result = tool.screen_comparison_paths([analysis_dir]) + + assert result.scanned_file_count == 4 + assert result.comparison_file_count == 1 + assert result.row_count == 2 + assert result.summary.candidate_verdict_counts == {"reject": 1, "review": 1} + assert result.summary.review_count == 1 + assert result.summary.reject_count == 1 + assert result.summary.viable_count == 0 + legal = next(row for row in result.rows if row.workload_id == "legal-2") + assert legal.workload_family == "legal" + assert legal.flags == ["entity_signature_loss", "stable_entity_signature_loss"] + assert legal.baseline_only_label_counts == {"first_name": 2} + assert legal.stable_lost_label_counts == {"date": 1} + assert legal.augmented_new_final_value_count_delta == -2 + shell = next(row for row in result.rows if row.workload_id == "shell-3") + assert shell.baseline_replacement_strategy == "default" + assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.baseline_case_count == 3 + assert shell.candidate_case_count == 3 + assert shell.shared_stable_final_entity_signature_count == 12 + detector_local = next( + group + for group in result.groups + if group.group_key == "strategy:detector_only|replacement:local_structured_substitute" + ) + assert detector_local.candidate_replacement_strategy == "local_structured_substitute" + assert detector_local.row_count == 1 + no_augment = next(group for group in result.groups if group.group_key == "strategy:no_augment") + assert no_augment.row_count == 1 + assert no_augment.reject_count == 1 + assert no_augment.has_conflicting_verdicts is False + + +def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_export", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + rows = [ + tool.ScreenRow( + source_path="analysis/default-vs-detector-only.csv", + workload_id="shell", + baseline_config_id="default", + candidate_config_id="detector-only", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + flags=["candidate_skips_llm_validation"], + ) + ] + + output = tmp_path / "screen.csv" + tool.write_rows(rows, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["workload_id"].tolist() == ["shell"] + assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] + + +def test_screen_strategy_comparisons_dedupes_exact_rows(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_dedupe", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + comparison = pd.DataFrame( + [ + { + "workload_id": "bio", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -5.0, + } + ] + ) + comparison.to_csv(tmp_path / "a.csv", index=False) + comparison.to_csv(tmp_path / "b.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.scanned_file_count == 2 + assert result.comparison_file_count == 2 + assert result.row_count == 1 + assert result.duplicate_row_count == 1 + assert result.summary.viable_count == 1 + + +def test_screen_strategy_comparisons_surfaces_evidence_level_counts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_evidence_level", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default", + "candidate_config_id": "local-substitute", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + }, + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default", + "candidate_config_id": "local-substitute-legacy", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "shared_stable_final_entity_signature_count": 17, + }, + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} + assert {row.evidence_level for row in result.rows} == {"split_verdicts", "stable_signatures"} + group = result.groups[0] + assert group.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} + assert group.split_verdict_candidate_verdict_counts == {"review": 1} + rendered = tool.render_result(result, json_output=False, limit=10) + assert "evidence=split_verdicts" in rendered + assert "evidence_counts=split_verdicts:1,stable_signatures:1" in rendered + + +def test_screen_strategy_comparisons_surfaces_candidate_original_value_leaks(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_original_value_leaks", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-secrets", + "baseline_config_id": "default", + "candidate_config_id": "native-single-pass", + "candidate_strategy": "native_single_pass", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "candidate_original_value_leak_count": 2, + "candidate_original_value_leak_record_count": 1, + "original_value_leak_count_delta": 2, + "candidate_original_value_leak_label_counts.api_key": 1, + "candidate_original_value_leak_label_counts.password": 1, + "candidate_replacement_synthetic_original_collision_count": 1, + "candidate_replacement_synthetic_original_collision_value_count": 1, + "replacement_synthetic_original_collision_count_delta": 1, + "candidate_replacement_synthetic_original_collision_label_counts.date": 1, + "flags": ('["candidate_original_value_leak", "candidate_replacement_synthetic_original_collision"]'), + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.reject_count == 1 + row = result.rows[0] + assert row.candidate_original_value_leak_count == 2 + assert row.candidate_original_value_leak_record_count == 1 + assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert row.candidate_replacement_synthetic_original_collision_count == 1 + assert row.candidate_replacement_synthetic_original_collision_value_count == 1 + assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + group = result.groups[0] + assert group.sum_candidate_original_value_leak_count == 2 + assert group.sum_candidate_original_value_leak_record_count == 1 + assert group.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert group.sum_candidate_replacement_synthetic_original_collision_count == 1 + assert group.sum_candidate_replacement_synthetic_original_collision_value_count == 1 + assert group.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + rendered = tool.render_result(result, json_output=False, limit=10) + assert "candidate_original_value_leaks=2.0" in rendered + assert "candidate_replacement_collisions=1.0" in rendered + assert "leak_labels=api_key:1,password:1" in rendered + assert "collision_labels=date:1" in rendered + + +def test_screen_strategy_comparisons_surfaces_label_policy_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_label_policy_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "legal-r1", + "baseline_config_id": "legal-default", + "candidate_config_id": "legal-detector-native-validate", + "candidate_strategy": "detector_native_validate_no_augment", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -26.5, + "observed_total_requests_delta": -1, + "observed_total_tokens_delta": -5366, + "flags": '["covered_label_mismatch"]', + "baseline_only_candidate_label_mismatch_signature_label_counts.date_of_birth": 1, + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.value_protection_verdict_counts == {"pass": 1} + assert result.summary.signature_parity_verdict_counts == {"review": 1} + row = result.rows[0] + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.label_mismatch_label_counts == {"date_of_birth": 1} + group = result.groups[0] + assert group.value_protection_verdict_counts == {"pass": 1} + assert group.signature_parity_verdict_counts == {"review": 1} + assert group.label_mismatch_label_counts == {"date_of_birth": 1} + assert group.recommendation == "label_policy_review" + rendered = tool.render_result(result, json_output=False, limit=10) + assert "value_protection=pass" in rendered + assert "signature_parity=review" in rendered + assert "label_mismatch=date_of_birth:1" in rendered + + +def test_screen_strategy_comparisons_surfaces_reliability_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_reliability_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default-substitute", + "candidate_config_id": "local-substitute", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "pass", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "flags": '["failed_request_increase"]', + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + group = result.groups[0] + assert group.flag_counts == {"failed_request_increase": 1} + assert group.recommendation == "reliability_review" + + +def test_screen_strategy_comparisons_surfaces_replacement_replay_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_replacement_replay_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default-substitute", + "candidate_config_id": "local-substitute", + "baseline_strategy": "default", + "candidate_strategy": "default", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "flags": '["covered_label_mismatch", "replacement_only_detection_instability"]', + "baseline_only_candidate_label_mismatch_signature_label_counts.api_key": 1, + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + group = result.groups[0] + assert group.flag_counts == { + "covered_label_mismatch": 1, + "replacement_only_detection_instability": 1, + } + assert group.recommendation == "replacement_replay_review" + + +def test_screen_strategy_comparisons_surfaces_baseline_defect_improvement_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_baseline_defect_improvement", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "biography-supported-structured", + "baseline_config_id": "dd-substitute", + "candidate_config_id": "local-substitute", + "baseline_strategy": "default", + "candidate_strategy": "default", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "baseline_replacement_missing_final_entity_count": 2, + "candidate_replacement_missing_final_entity_count": 0, + "baseline_original_value_leak_count": 2, + "candidate_original_value_leak_count": 0, + "baseline_duplicate_synthetic_replacement_count": 1, + "candidate_duplicate_synthetic_replacement_count": 0, + "baseline_replacement_missing_final_entity_label_counts.organization_name": 1, + "baseline_replacement_missing_final_entity_label_counts.religious_belief": 1, + "baseline_original_value_leak_label_counts.organization_name": 2, + "flags": ( + '["baseline_replacement_missing_final_entity", "baseline_original_value_leak", ' + '"candidate_covers_baseline_replacement_missing_final_entity", ' + '"candidate_covers_baseline_original_value_leak", "replacement_count_delta"]' + ), + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + row = result.rows[0] + assert row.baseline_replacement_missing_final_entity_count == 2 + assert row.candidate_replacement_missing_final_entity_count == 0 + assert row.baseline_original_value_leak_count == 2 + assert row.candidate_original_value_leak_count == 0 + assert row.baseline_duplicate_synthetic_replacement_count == 1 + assert row.candidate_duplicate_synthetic_replacement_count == 0 + assert row.baseline_replacement_missing_final_entity_label_counts == { + "organization_name": 1, + "religious_belief": 1, + } + assert row.baseline_original_value_leak_label_counts == {"organization_name": 2} + group = result.groups[0] + assert group.baseline_defect_improvement_count == 1 + assert group.sum_baseline_replacement_missing_final_entity_count == 2 + assert group.sum_candidate_replacement_missing_final_entity_count == 0 + assert group.sum_baseline_original_value_leak_count == 2 + assert group.sum_candidate_original_value_leak_count == 0 + assert group.sum_baseline_duplicate_synthetic_replacement_count == 1 + assert group.sum_candidate_duplicate_synthetic_replacement_count == 0 + assert group.baseline_replacement_missing_final_entity_label_counts == { + "organization_name": 1, + "religious_belief": 1, + } + assert group.baseline_original_value_leak_label_counts == {"organization_name": 2} + assert group.recommendation == "candidate_covers_baseline_defects" + rendered = tool.render_result(result, json_output=False, limit=10) + assert "baseline_defect_improvements=1" in rendered + assert "baseline_missing_replacements=2.0" in rendered + assert "candidate_missing_replacements=0.0" in rendered + assert "baseline_missing_labels=organization_name:1,religious_belief:1" in rendered + + +def test_screen_strategy_comparisons_groups_default_detection_by_replacement_strategy() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_replacement_group", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + row = tool.ScreenRow( + source_path="comparison.csv", + workload_id="structured-identifiers", + baseline_config_id="substitute-dd", + candidate_config_id="substitute-local", + candidate_strategy="default", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + safety_verdict="pass", + performance_verdict="improved", + candidate_verdict="candidate_viable", + ) + + assert tool.group_base_for_row(row, config_aliases={}) == "replacement:local_structured_substitute" + + +def test_screen_strategy_comparisons_keeps_generic_review_without_leak_metrics() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_generic_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + group = tool.ScreenGroup( + group_key="strategy:detector_only", + candidate_strategy="detector_only", + row_count=1, + review_count=1, + performance_verdict_counts={"improved": 1}, + flag_counts={"candidate_skips_llm_validation": 1}, + ) + + assert tool.group_recommendation(group) == "review_only" + + +def test_screen_strategy_comparisons_filters_source_paths(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_source_filters", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + current_dir = tmp_path / "analysis-current-csv" + current_dir.mkdir() + stale_dir = tmp_path / "analysis" + stale_dir.mkdir() + pd.DataFrame( + [ + { + "workload_id": "bio-current", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + } + ] + ).to_csv(current_dir / "current-comparison.csv", index=False) + pd.DataFrame( + [ + { + "workload_id": "bio-stale", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "fail", + "performance_verdict": "regressed", + "candidate_verdict": "reject", + } + ] + ).to_csv(stale_dir / "stale-comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], source_includes=["analysis-current-csv"]) + + assert result.scanned_file_count == 1 + assert result.comparison_file_count == 1 + assert [row.workload_id for row in result.rows] == ["bio-current"] + + result = tool.screen_comparison_paths( + [tmp_path], + source_includes=["analysis"], + source_excludes=["analysis-current-csv"], + ) + + assert result.scanned_file_count == 1 + assert result.comparison_file_count == 1 + assert [row.workload_id for row in result.rows] == ["bio-stale"] + + +def test_screen_strategy_comparisons_groups_candidate_strategy_conflicts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-small", + "baseline_config_id": "default", + "candidate_config_id": "no-augment-small", + "candidate_strategy": "no_augment", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -7.0, + "observed_total_requests_delta": -2, + "observed_total_tokens_delta": -8000, + "baseline_case_count": 3, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 20, + "flags": "[]", + }, + { + "workload_id": "legal-offset", + "baseline_config_id": "default", + "candidate_config_id": "no-augment-offset", + "candidate_strategy": "no_augment", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -10.0, + "observed_total_requests_delta": -1, + "observed_total_tokens_delta": -10000, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 14, + "baseline_only_final_entity_signature_label_counts.first_name": 2, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, + "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', + }, + { + "workload_id": "shell", + "baseline_config_id": "default", + "candidate_config_id": "detector-only", + "candidate_strategy": "detector_only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.9, + "observed_total_tokens_delta": -11000, + "flags": '["candidate_skips_llm_validation"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == ["strategy:detector_only", "strategy:no_augment"] + no_augment = groups["strategy:no_augment"] + assert no_augment.row_count == 2 + assert no_augment.viable_count == 1 + assert no_augment.reject_count == 1 + assert no_augment.has_conflicting_verdicts is True + assert no_augment.recommendation == "conflicting_evidence" + assert no_augment.best_pipeline_elapsed_sec_delta_pct == -10.0 + assert no_augment.best_observed_total_tokens_delta == -10000 + assert no_augment.best_observed_total_requests_delta == -2 + assert no_augment.worst_pipeline_elapsed_sec_delta_pct == -7.0 + assert no_augment.worst_observed_total_tokens_delta == -8000 + assert no_augment.worst_observed_total_requests_delta == -1 + assert no_augment.min_baseline_case_count == 2 + assert no_augment.min_candidate_case_count == 2 + assert no_augment.min_shared_stable_final_entity_signature_count == 14 + assert no_augment.flag_counts == {"entity_signature_loss": 1, "stable_entity_signature_loss": 1} + assert no_augment.baseline_only_label_counts == {"first_name": 2} + assert no_augment.stable_lost_label_counts == {"date": 1} + + +def test_screen_strategy_comparisons_groups_default_strategy_by_config(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_default_config_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "biographies-r5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -6.0, + "observed_total_tokens_delta": -325, + "flags": "[]", + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-hybrid", + "candidate_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "regressed", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": 10.0, + "observed_total_tokens_delta": 100, + "baseline_only_final_entity_signature_label_counts.university": 1, + "flags": '["entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == [ + "config:biography-augment-temp07|family:biographies", + "config:biography-hybrid|family:biographies", + ] + assert groups["config:biography-augment-temp07|family:biographies"].recommendation == "single_slice_viable" + assert groups["config:biography-hybrid|family:biographies"].recommendation == "reject" + + +def test_screen_strategy_comparisons_groups_default_strategy_by_config_alias(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_config_alias_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "biographies-r2", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-hybrid-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -10.4, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 48, + "flags": "[]", + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "mixed", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": 16.0, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 116, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.university": 1, + "flags": '["stable_entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths( + [tmp_path], + group_by=tool.GroupBy.strategy_workload_family, + config_aliases={ + "biography-hybrid-augment-temp07": "biography-temp07-routing", + "biography-augment-temp07": "biography-temp07-routing", + }, + ) + + groups = {group.group_key: group for group in result.groups} + group = groups["alias:biography-temp07-routing|family:biographies"] + assert group.row_count == 2 + assert group.candidate_config_ids == ["biography-augment-temp07", "biography-hybrid-augment-temp07"] + assert group.viable_count == 1 + assert group.reject_count == 1 + assert group.recommendation == "conflicting_evidence" + assert group.min_baseline_case_count == 2 + assert group.min_candidate_case_count == 2 + assert group.min_shared_stable_final_entity_signature_count == 48 + assert group.stable_lost_label_counts == {"university": 1} + + +def test_screen_strategy_comparisons_can_group_by_strategy_and_workload_family(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_family_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-secrets-3", + "baseline_config_id": "default", + "candidate_config_id": "detector-only-shell", + "candidate_strategy": "detector_only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.9, + "observed_total_tokens_delta": -11000, + "flags": '["candidate_skips_llm_validation"]', + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "default", + "candidate_config_id": "detector-only-bio", + "candidate_strategy": "detector_only", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -90.0, + "observed_total_tokens_delta": -17000, + "baseline_only_final_entity_signature_label_counts.first_name": 4, + "flags": '["entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == ["strategy:detector_only|family:shell-secrets", "strategy:detector_only|family:biographies"] + assert groups["strategy:detector_only|family:shell-secrets"].recommendation == "review_only" + assert groups["strategy:detector_only|family:shell-secrets"].workload_families == ["shell-secrets"] + assert groups["strategy:detector_only|family:biographies"].recommendation == "reject" + assert groups["strategy:detector_only|family:biographies"].baseline_only_label_counts == {"first_name": 4} + + +def test_workload_family_normalizes_slice_and_offset_suffixes() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_workload_family", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + + assert tool.workload_family("legal-r2-offset1") == "legal" + assert tool.workload_family("legal-slice-2") == "legal" + assert tool.workload_family("biographies-r5-offset5") == "biographies" + assert tool.workload_family("shell-secrets-3") == "shell-secrets" + assert tool.workload_family("support-ticket") == "support-ticket" + + +def test_screen_strategy_comparisons_writes_group_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_group_export", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + groups = [ + tool.ScreenGroup( + group_key="strategy:no_augment", + candidate_strategy="no_augment", + row_count=2, + viable_count=1, + reject_count=1, + has_conflicting_verdicts=True, + performance_verdict_counts={"improved": 1, "regressed": 1}, + flag_counts={"entity_signature_loss": 1}, + ) + ] + + output = tmp_path / "groups.csv" + tool.write_groups(groups, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["group_key"].tolist() == ["strategy:no_augment"] + assert exported["has_conflicting_verdicts"].tolist() == [True] + assert exported["recommendation"].tolist() == ["conflicting_evidence"] + assert exported["performance_verdict_counts"].tolist() == ['{"improved": 1, "regressed": 1}'] + assert exported["flag_counts"].tolist() == ['{"entity_signature_loss": 1}'] + + +def test_screen_strategy_group_recommendations() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_recommendations", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="viable", row_count=1, viable_count=1), + ) + == "single_slice_viable" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="repeated-viable", row_count=2, viable_count=2), + ) + == "candidate_family_viable" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="promising", row_count=2, viable_count=1, review_count=1), + ) + == "promising_needs_review" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="weak-promising", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"signature_counts": 1, "stable_signatures": 1}, + ), + ) + == "needs_split_verdict_rerun" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="partial-split-review", + row_count=2, + review_count=2, + evidence_level_counts={"split_verdicts": 1, "stable_signatures": 1}, + split_verdict_candidate_verdict_counts={"review": 1}, + ), + ) + == "needs_split_verdict_rerun" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="split-review-with-legacy-viable", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"signature_counts": 1, "split_verdicts": 1}, + split_verdict_candidate_verdict_counts={"review": 1}, + ), + ) + == "needs_viable_split_verdict" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="split-viable-with-review", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"split_verdicts": 2}, + split_verdict_candidate_verdict_counts={"candidate_viable": 1, "review": 1}, + ), + ) + == "promising_needs_review" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review", + row_count=2, + review_count=2, + performance_verdict_counts={"improved": 2}, + ), + ) + == "review_only" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-mixed", + row_count=2, + review_count=2, + performance_verdict_counts={"mixed": 2}, + ), + ) + == "review_mixed_performance" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-improved-and-mixed", + row_count=2, + review_count=2, + performance_verdict_counts={"improved": 1, "mixed": 1}, + ), + ) + == "review_mixed_performance" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-regressed", + row_count=2, + review_count=2, + performance_verdict_counts={"regressed": 2}, + ), + ) + == "no_performance_win" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="reject", row_count=2, reject_count=2), + ) + == "reject" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="conflict", row_count=2, viable_count=1, reject_count=1), + ) + == "conflicting_evidence" + ) diff --git a/tests/tools/test_staged_detection_output_analysis.py b/tests/tools/test_staged_detection_output_analysis.py new file mode 100644 index 00000000..4c791501 --- /dev/null +++ b/tests/tools/test_staged_detection_output_analysis.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + output_dir = tmp_path / "staged" + output_dir.mkdir() + _write_jsonl( + output_dir / "staged-detection-cases.jsonl", + [ + { + "record_type": "staged_detection_case", + "case_id": "shell-row-0", + "row_index": 0, + "seed_source": "gliner", + "status": "completed", + "elapsed_sec": 0.002, + "model_elapsed_sec": 0.0, + "model_phase_count": 0, + "model_request_count": 0, + "final_entity_count": 5, + "final_entity_signature_count": 5, + "final_label_counts": {"api_key": 2, "email": 1, "password": 1, "url": 1}, + "total_usage": {}, + "comparison": { + "baseline_final_entity_signature_count": 5, + "shared_final_entity_signature_count": 5, + "baseline_only_final_entity_signature_count": 0, + "direct_only_final_entity_signature_count": 0, + "baseline_only_label_counts": {}, + "direct_only_label_counts": {}, + }, + }, + { + "record_type": "staged_detection_case", + "case_id": "bio-row-0", + "row_index": 0, + "seed_source": "direct_llm", + "status": "completed", + "elapsed_sec": 10.0, + "model_elapsed_sec": 9.5, + "model_phase_count": 3, + "model_request_count": 3, + "final_entity_count": 3, + "final_entity_signature_count": 3, + "final_label_counts": {"person": 2, "api_key": 1}, + "total_usage": {"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, + "comparison": { + "baseline_final_entity_signature_count": 4, + "shared_final_entity_signature_count": 2, + "baseline_only_final_entity_signature_count": 2, + "direct_only_final_entity_signature_count": 1, + "baseline_only_label_counts": {"city": 1, "person": 1}, + "direct_only_label_counts": {"api_key": 1}, + }, + }, + { + "record_type": "staged_detection_case", + "case_id": "bio-row-1", + "row_index": 1, + "seed_source": "direct_llm", + "status": "error", + "elapsed_sec": 1.0, + "model_elapsed_sec": 0.8, + "model_phase_count": 1, + "model_request_count": 1, + "final_entity_count": 0, + "final_entity_signature_count": 0, + "total_usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}, + "comparison": None, + "error": "provider failed", + }, + ], + ) + + result = tool.analyze_staged_detection_output(output_dir) + + assert result.case_count == 3 + assert result.group_count == 2 + groups = {row.seed_source: row for row in result.groups} + assert groups["gliner"].case_count == 1 + assert groups["gliner"].completed_case_count == 1 + assert groups["gliner"].model_elapsed_sec_sum == 0.0 + assert groups["gliner"].model_request_count_sum == 0 + assert groups["gliner"].baseline_shared_signature_rate == 1.0 + assert groups["direct_llm"].case_count == 2 + assert groups["direct_llm"].completed_case_count == 1 + assert groups["direct_llm"].error_case_count == 1 + assert groups["direct_llm"].elapsed_sec_sum == pytest.approx(11.0) + assert groups["direct_llm"].model_elapsed_sec_sum == pytest.approx(10.3) + assert groups["direct_llm"].model_request_count_sum == 4 + assert groups["direct_llm"].total_tokens_sum == 132 + assert groups["direct_llm"].baseline_final_entity_signature_count_sum == 4 + assert groups["direct_llm"].shared_final_entity_signature_count_sum == 2 + assert groups["direct_llm"].baseline_only_final_entity_signature_count_sum == 2 + assert groups["direct_llm"].direct_only_final_entity_signature_count_sum == 1 + assert groups["direct_llm"].baseline_shared_signature_rate == pytest.approx(0.5) + + label_deltas = {(row.seed_source, row.delta_type, row.label): row.count for row in result.label_deltas} + assert label_deltas == { + ("direct_llm", "baseline_only", "city"): 1, + ("direct_llm", "baseline_only", "person"): 1, + ("direct_llm", "direct_only", "api_key"): 1, + } + + +def test_staged_detection_output_analysis_writes_csv_tables(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis_export", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + cases_path = tmp_path / "cases.jsonl" + _write_jsonl( + cases_path, + [ + { + "case_id": "case-0", + "row_index": 0, + "seed_source": "gliner", + "status": "completed", + "elapsed_sec": 0.01, + "model_elapsed_sec": 0.0, + "model_request_count": 0, + "total_usage": {}, + } + ], + ) + + result = tool.analyze_staged_detection_output(cases_path) + export = tool.write_analysis_tables(result, tmp_path / "analysis", tool.ExportFormat.csv) + + assert Path(export.manifest_path).exists() + assert [table.table for table in export.tables] == [ + "case_analysis", + "group_analysis", + "label_delta_analysis", + ] + case_table = pd.read_csv(tmp_path / "analysis" / "case_analysis.csv") + assert case_table.loc[0, "case_id"] == "case-0" + assert case_table.loc[0, "model_request_count"] == 0 + + +def test_staged_detection_output_analysis_rejects_missing_case_file(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis_missing_input", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + + with pytest.raises(ValueError, match="input path does not exist"): + tool.analyze_staged_detection_output(tmp_path / "missing") diff --git a/tests/tools/test_staged_detection_probe.py b/tests/tools/test_staged_detection_probe.py new file mode 100644 index 00000000..29987f90 --- /dev/null +++ b/tests/tools/test_staged_detection_probe.py @@ -0,0 +1,579 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +from anonymizer.engine.schemas import ValidationCandidateSchema + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +class SequencedClient: + def __init__(self, tool: ModuleType, outputs: list[str]) -> None: + self._tool = tool + self._outputs = list(outputs) + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + content = self._outputs.pop(0) + return self._tool.DirectCompletion( + content=content, + elapsed_sec=0.5, + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + + +class StaticSeedClient: + def __init__(self, tool: ModuleType, content: str) -> None: + self._tool = tool + self.content = content + self.requests = [] + + def detect(self, request): # type: ignore[no-untyped-def] + self.requests.append(request) + return self._tool.DirectCompletion( + content=self.content, + elapsed_sec=0.25, + usage={"prompt_tokens": 1, "completion_tokens": 3, "total_tokens": 4}, + ) + + +def test_staged_detection_reuses_validation_and_augmentation_flow() -> None: + tool = load_tool( + "measurement_staged_detection_probe_case", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_suggestion_count == 1 + assert result.seed_entity_count == 1 + assert result.validation_decision_count == 1 + assert result.augmented_suggestion_count == 1 + assert result.final_entity_count == 2 + assert result.final_label_counts == {"first_name": 1, "organization_name": 1} + assert result.artifact.final_source_counts == {"direct_seed": 1, "augmenter": 1} + assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=True) + assert result.phase_skip_reasons == tool.PhaseSkipReasons() + assert result.model_phase_count == 3 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) + assert result.model_request_count == 3 + assert result.total_usage == {"prompt_tokens": 30, "completion_tokens": 15, "total_tokens": 45} + assert len(client.prompts) == 3 + + +def test_staged_detection_can_disable_augmentation_phase() -> None: + tool = load_tool( + "measurement_staged_detection_probe_disable_augmentation", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_label_counts == {"first_name": 1} + assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=False) + assert result.phase_skip_reasons == tool.PhaseSkipReasons(augmentation="disabled") + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=0) + assert result.model_request_count == 2 + assert len(client.prompts) == 2 + + +def test_staged_detection_execution_exposes_output_row_without_serializing_it() -> None: + tool = load_tool( + "measurement_staged_detection_probe_execution", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": []}', + ], + ) + + execution = tool.execute_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=client, + ) + + assert execution.case.status == tool.CaseStatus.completed + assert execution.row[tool.COL_DETECTED_ENTITIES]["entities"][0]["value"] == "Alice" + assert "row" not in execution.case.model_dump() + + +def test_staged_detection_validation_drop_removes_seed_entity() -> None: + tool = load_tool( + "measurement_staged_detection_probe_drop", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "name", "label": "first_name", "reason": "placeholder"}]}', + '{"decisions": [{"id": "first_name_3_7", "decision": "drop", "reason": "placeholder"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="my name is hidden", + labels=["first_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_entity_count == 1 + assert result.validation_decision_count == 1 + assert result.final_entity_count == 0 + assert result.final_label_counts == {} + + +def test_staged_detection_invalid_reclass_label_keeps_seed_label() -> None: + tool = load_tool( + "measurement_staged_detection_probe_invalid_reclass_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + ( + '{"decisions": [' + '{"id": "first_name_0_5", "decision": "reclass", ' + '"proposed_label": "drop", "reason": "invalid label"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"first_name": 1} + + +def test_staged_detection_discards_invalid_augmentation_labels() -> None: + tool = load_tool( + "measurement_staged_detection_probe_invalid_augmentation_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "drop", "reason": "invalid label"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.augmented_suggestion_count == 0 + assert result.final_entity_count == 1 + assert result.final_label_counts == {"first_name": 1} + + +def test_staged_detection_preserves_more_specific_seed_label_on_native_reclass() -> None: + tool = load_tool( + "measurement_staged_detection_probe_specific_label_reclass", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "birth date"}]}', + ( + '{"decisions": [' + '{"id": "date_of_birth_26_41", "decision": "reclass", ' + '"proposed_label": "date", "reason": "date expression"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="The applicant was born on 23 October 1992.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date_of_birth": 1} + + +def test_staged_detection_allows_date_of_birth_reclass_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_generic_date_reclass", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "date"}]}', + ( + '{"decisions": [' + '{"id": "date_of_birth_3_18", "decision": "reclass", ' + '"proposed_label": "date", "reason": "filing date"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_demotes_native_birth_date_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_generic_date_birth_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date", "reason": "date"}]}', + ( + '{"decisions": [' + '{"id": "date_3_18", "decision": "reclass", ' + '"proposed_label": "date_of_birth", "reason": "ambiguous date"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_can_seed_from_direct_gliner_payload_without_llm_seed_prompt() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + '{"entities": [{"text": "Alice", "label": "first_name", "start": 0, "end": 5, "score": 0.99}]}', + ) + llm_client = SequencedClient( + tool, + [ + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"first_name": 1, "organization_name": 1} + assert result.total_usage == {"prompt_tokens": 21, "completion_tokens": 13, "total_tokens": 34} + assert len(seed_client.requests) == 1 + assert len(llm_client.prompts) == 2 + + +def test_staged_detection_promotes_gliner_date_seed_in_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_birth_date_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 26, "end": 41, "score": 0.99}]}'), + ) + llm_client = SequencedClient( + tool, + ['{"decisions": [{"id": "date_of_birth_26_41", "decision": "keep", "reason": "birth date"}]}'], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="The applicant was born on 23 October 1992.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"date_of_birth": 1} + + +def test_staged_detection_keeps_gliner_date_seed_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_generic_date_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 3, "end": 18, "score": 0.99}]}'), + ) + llm_client = SequencedClient( + tool, + ['{"decisions": [{"id": "date_3_18", "decision": "keep", "reason": "filing date"}]}'], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_validation_prompt_preserves_degree_label_guidance() -> None: + tool = load_tool( + "measurement_staged_detection_probe_degree_guidance", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + request = tool.StagedDetectionRequest( + case_id="case-1", + text="He earned his Bachelor of Science in physics.", + labels=["degree", "education_level", "field_of_study"], + ) + candidates = tool.ValidationCandidatesSchema( + candidates=[ + ValidationCandidateSchema( + id="education_level_14_33", + value="Bachelor of Science", + label="education_level", + context_before="He earned his ", + context_after=" in physics.", + ) + ] + ) + + prompt = tool._validation_prompt(request, candidates) + + assert "degree" in prompt + assert "education_level" in prompt + assert "Bachelor of Science" in prompt + assert "Prefer degree for credential names" in prompt + + +def test_staged_detection_augmentation_prompt_discourages_grouped_person_and_surname_spans() -> None: + tool = load_tool( + "measurement_staged_detection_probe_augmentation_guidance", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + request = tool.StagedDetectionRequest( + case_id="case-1", + text="Her parents, Mark and Linda, live near West Baker Drive and Baker's grocery.", + labels=["first_name", "last_name", "organization_name", "place_name", "company_name"], + ) + + prompt = tool._augmentation_prompt(request, {tool.COL_SEED_ENTITIES_JSON: "[]"}) + + assert "split personal names connected by 'and'" in prompt + assert "Do not label a list of people as organization_name" in prompt + assert "also return the surname substring as last_name" in prompt + + +def test_staged_detection_baseline_comparison_skips_rows_without_signature_hashes() -> None: + tool = load_tool( + "measurement_staged_detection_probe_missing_baseline_signatures", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "person name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "person name"}]}', + '{"entities": []}', + ], + ) + case = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=llm_client, + ) + + compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 1}) + + assert compared.comparison is None + + +def test_staged_detection_can_chunk_validation_into_local_excerpts() -> None: + tool = load_tool( + "measurement_staged_detection_probe_chunked_validation", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + ( + '{"entities": [' + '{"value": "Alice", "label": "first_name", "reason": "name"},' + '{"value": "Paris", "label": "city", "reason": "city"}' + "]}" + ), + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"decisions": [{"id": "city_61_66", "decision": "keep", "reason": "city"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works in a very long remote biography before moving to Paris.", + labels=["first_name", "city"], + row_index=0, + ), + client=client, + validation_prompt_mode=tool.ValidationPromptMode.chunked_excerpt, + validation_max_entities_per_call=1, + validation_excerpt_window_chars=8, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_label_counts == {"city": 1, "first_name": 1} + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=2, augmentation=1) + assert result.model_phase_count == 3 + assert result.model_request_count == 4 + assert len(client.prompts) == 4 + assert "Alice" in client.prompts[1] + assert "Paris" not in client.prompts[1] + assert "Paris" in client.prompts[2] diff --git a/tools/measurement/README.md b/tools/measurement/README.md new file mode 100644 index 00000000..69a017af --- /dev/null +++ b/tools/measurement/README.md @@ -0,0 +1,511 @@ + + + +# Measurement tools + +This directory contains developer tools for measuring Anonymizer runs, exporting +measurement JSONL to tables, and comparing benchmark strategies. Run the tools +inside the project environment, either with an activated venv or through +`uv run`. + +```bash +uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables +``` + +By default, `export_measurements.py` writes Parquet files plus +`manifest.json`: + +- `run.parquet` +- `stage.parquet` +- `record.parquet` +- `ndd_workflow.parquet` when DataDesigner adapter records are present +- `model_workflow.parquet` when direct model workflow records are present + +Use `--format csv` or `--format jsonl` for non-Parquet output, and +`--overwrite` to replace existing output files. + +## Benchmark runner + +`run_benchmarks.py` runs repeatable Anonymizer workloads and writes the same +measurement JSONL format, one raw file per benchmark case plus a combined +`measurements.jsonl`. + +```bash +uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark-runs/suite +uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json +uv run python tools/measurement/run_benchmarks.py suite.yaml \ + --output benchmark-runs/suite \ + --dd-trace last-message +``` + +The repo-data smoke suite can be run with DataDesigner traces enabled: + +```bash +bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh +``` + +The script writes to `/tmp/anonymizer-repo-data-smoke-dd-traces` by default. +Pass a different output directory as the first argument, or set +`DD_TRACE_MODE=all-messages` when full chat history is needed: + +```bash +DD_TRACE_MODE=all-messages \ + bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh \ + /tmp/anonymizer-repo-data-smoke-dd-traces-full +``` + +Benchmark suites are YAML files with three parts: + +- `workloads`: input datasets and text-column metadata. +- `configs`: Anonymizer replace or rewrite configurations. +- `matrix`: optional workload/config pairs and repetition counts. When omitted, + every workload is crossed with every config once. + +Example: + +```yaml +suite_id: biography-smoke +model_configs: ./model-configs.yaml +model_providers: ./providers.yaml +dd_parser_compat: none +case_retries: 1 +case_retry_backoff_sec: 10 +workloads: + - id: biographies + source: ./data/biographies.csv + text_column: text + row_limit: 25 + - id: support + source: ./data/support.csv + text_column: body + id_column: ticket_id + row_offset: 100 + row_limit: 50 +configs: + - id: redact-default + replace: redact + - id: hash-agent-labels + detect: + entity_labels: [person, email, api_key, password] + replace: + strategy: hash + digest_length: 12 + - id: rewrite-low-risk + rewrite: + risk_tolerance: low + max_repair_iterations: 1 +matrix: + - workload: biographies + config: redact-default + repetitions: 3 + - workload: support + config: hash-agent-labels +``` + +Use `row_limit` and `row_offset` to create cheap, repeatable slices of a local +CSV or Parquet workload. The runner materializes a per-case sliced input under +`raw/inputs/` before calling Anonymizer, so each case keeps a stable input file +even when the matrix has multiple configs or repetitions. Slicing is rejected +for URL-like sources because the runner cannot safely materialize a local +subset without downloading the whole dataset first. + +Relative paths in suite files are resolved from the suite file's directory. +The runner refuses to write into a non-empty output directory unless +`--overwrite` is set. By default it also exports Parquet tables into `tables/`; +pass `--no-export` when only raw measurement JSONL is needed. + +Before starting a real run, the benchmark runner performs cheap preflight +checks: suite/config parsing, local dataset existence, CSV/Parquet text-column +metadata, provider YAML shape, native runtime requirements, and active +model-alias references. `--dry-run` runs those same checks, expands the planned +matrix, and skips output-dir writes and model work. + +Use `case_retries` and `case_retry_backoff_sec` for long-running suites on +shared model endpoints. Retries are disabled by default. When enabled, a failed +case is retried with the same `case_id` and output paths; the final case still +records `attempt_count` and `attempt_errors` in `summary.json`. `--fail-fast` +remains fail-fast and bypasses retries. + +## Benchmark-only detection strategies + +Configs may set `experimental_detection_strategy` for benchmark-only pipeline +probes. These values are not public `Detect` config fields, and they should not +be treated as safe defaults across arbitrary data. + +```yaml +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +``` + +Supported values: + +- `default`: run the normal Anonymizer detection pipeline. +- `no_augment`: run GLiNER detection and validation, but skip LLM augmentation. +- `detector_only`: run only GLiNER detection and local finalization. This skips + LLM validation and LLM augmentation. +- `native_candidate_validate_no_augment`: run a benchmark-only native staged + detector without DataDesigner using direct OpenAI-compatible calls for seed + extraction and validation, then skip augmentation. +- `detector_native_validate_no_augment`: run the normal GLiNER detector seed + through Anonymizer/DataDesigner, then bypass DataDesigner validation and + augmentation with direct OpenAI-compatible validation calls. +- `detector_native_validate_native_augment`: run the normal GLiNER detector seed + through Anonymizer/DataDesigner, then bypass DataDesigner validation and + augmentation with direct OpenAI-compatible validation and augmentation calls. +- `gliner_native_validate_no_augment`: run a direct hosted-GLiNER seed without + DataDesigner, validate those detector candidates with direct + OpenAI-compatible calls, and skip augmentation. +- `gliner_native_validate_native_augment`: run a direct hosted-GLiNER seed + without DataDesigner, validate those detector candidates with direct + OpenAI-compatible calls, then run direct native augmentation. +- `native_single_pass`: run a benchmark-only native detector without + DataDesigner using one direct OpenAI-compatible provider call per row. The + model must return exact values plus `start`/`end` offsets; local code + validates offsets, resolves overlaps, and records parser/runtime failures as + `model_workflow` errors. +- `native_single_pass_recall`: the same one-call native detector with a + recall-oriented prompt that includes Anonymizer's label examples and stronger + high-recall guidance. +- `native_single_pass_values`: the same one-call native detector, but with the + value-only prompt shape from `direct_detection_probe.py`. The model returns + exact values and labels only; local code resolves every occurrence of each + returned value into spans. +- `native_single_pass_values_recall`: the value-only one-call detector with the + recall-oriented prompt from `direct_detection_probe.py`. + +Native benchmark strategies require an explicit runtime. Set top-level +`native_runtime.endpoint` and `native_runtime.model`, set the standard +`ANONYMIZER_BENCH_NATIVE_ENDPOINT` and `ANONYMIZER_BENCH_NATIVE_MODEL` +environment variables, or override runtime fields per config with +`configs[].native_runtime`. GLiNER-seeded native strategies also require +`native_runtime.gliner_endpoint` and `native_runtime.gliner_model`, or the +standard `ANONYMIZER_BENCH_GLINER_ENDPOINT` and +`ANONYMIZER_BENCH_GLINER_MODEL` environment variables. The runner records +runtime id, alias, provider, model, and env-variable names as run tags; raw +endpoint URLs are not emitted into measurement tables. + +```yaml +native_runtime: + runtime_id: local-vllm-json + endpoint_env: ANONYMIZER_BENCH_NATIVE_ENDPOINT + model_env: ANONYMIZER_BENCH_NATIVE_MODEL + provider: local-vllm + alias: native-direct +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +``` + +Use `detector_only` only as a lower-bound ablation. It skips the LLM validation +pass that drops false positives and reclassifies ambiguous spans. A faster run +that loses baseline signatures is a rejection. + +Use staged native strategies when the question is "can direct provider calls +replace part of DataDesigner orchestration?" They still need repeated signature, +leak, label-mismatch, parser, and reliability gates before any workload-specific +promotion. + +Use one-call native strategies for the more aggressive "collapse detection to +one call" experiment. They are often faster when the prompt works, but they are +more parser- and recall-sensitive. Any malformed JSON response becomes a failed +case in analysis, and any missed baseline signature should be treated as a +rejection rather than a latency win. + +## Replacement strategy probes + +Replacement-map generation has a separate benchmark-only knob: + +```yaml +configs: + - id: structured-local-substitute + experimental_replacement_strategy: local_structured_substitute + detect: + entity_labels: [api_key, email, password, url] + replace: + strategy: substitute +``` + +Supported replacement strategy values: + +- `default`: run normal replacement behavior. `Substitute` uses the configured + `replacement_generator` role through DataDesigner. +- `local_structured_substitute`: for `replace: substitute`, build deterministic + synthetic replacement maps locally for supported structured labels. Text + replacement still uses Anonymizer's normal replacement-map application code. + +`local_structured_substitute` requires `replace: substitute` and explicit +`detect.entity_labels`. Every label must be one of the structured substitute +labels: `api_key`, `date_of_birth`, `email`, `organization_name`, `password`, +`http_cookie`, `pin`, `religious_belief`, `street_address`, `unique_id`, `url`, +or `user_name`. The preflight rejects contextual labels such as `person`. This +is deliberate. The local substitute map generator does not understand names, +social relations, cultural consistency, or prose semantics; use the default +DataDesigner-backed `Substitute` path for those workloads. + +## DataDesigner traces + +For debugging DataDesigner calls, pass `--dd-trace last-message` or +`--dd-trace all-messages`. Trace records are written separately from sanitized +measurements, under `traces/{case_id}.jsonl` by default. Use `--trace-dir` to +choose another directory. `last-message` stores only the final prompt message +for each DataDesigner model call; `all-messages` stores the full message list. + +DataDesigner traces may contain raw input text, prompts, model outputs, entity +values, replacement values, secrets, and PII. Treat them as debug artifacts: +keep them out of shared benchmark bundles unless they have been reviewed or +redacted. + +Summarize traced calls without copying raw prompts or responses into analysis +output: + +```bash +uv run python tools/measurement/analyze_dd_traces.py \ + benchmark-runs/suite-id/traces \ + --output benchmark-runs/suite-id/trace-analysis \ + --format csv +``` + +This writes `trace_analysis.*` and `trace_group_analysis.*`. The row table +captures run tags, workflow/model metadata, status, elapsed time, prompt and +response lengths, token counts, and response-shape flags such as `raw_json`, +`fenced_json`, `embedded_json`, `text`, and `none`. + +## Direct probes + +`direct_detection_probe.py` calls a local OpenAI-compatible endpoint directly +for a small slice of records. It is useful for prompt experiments before adding +a benchmark strategy. + +```bash +uv run python tools/measurement/direct_detection_probe.py \ + docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column text \ + --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ + --model nvidia/nemotron-3-super \ + --labels person,email,api_key,password \ + --row-limit 5 \ + --output /tmp/direct-detection-probe +``` + +`staged_detection_probe.py` runs a no-DataDesigner staged detector outside the +main benchmark harness. It can compare seed extraction, validation, and +augmentation boundaries before integrating a strategy into `run_benchmarks.py`. + +```bash +uv run python tools/measurement/staged_detection_probe.py \ + docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column text \ + --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ + --model nvidia/nemotron-3-super \ + --labels person,email,api_key,password \ + --row-limit 5 \ + --output /tmp/staged-detection-probe +``` + +Useful staged options: + +- `--seed-source direct-llm`: use direct LLM seed extraction. +- `--seed-source gliner`: use direct hosted GLiNER seeding. +- `--skip-augmentation`: disable augmentation for any seed source. This is an + ablation for measuring how much recall the augmentation phase carries. +- `--validation-prompt-mode chunked-excerpt`: split seed validation candidates + into chunks of `--validation-max-entities-per-call` and send each chunk with a + tagged local excerpt bounded by `--validation-excerpt-window-chars`. + +The staged tool writes `staged-detection-cases.jsonl`, +`staged-detection-artifacts.jsonl`, and `summary.json`. Case rows include +per-phase usage for seed extraction, validation, and augmentation, true case +wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, +`phase_model_work`, `phase_skip_reasons`, `phase_model_requests`, +`model_phase_count`, `model_request_count`, total usage, and optional baseline +signature deltas. Treat `summary.json` as a sensitive debug artifact because it +records the resolved endpoint/model runtime used for the probe. + +Summarize staged probe outputs: + +```bash +uv run python tools/measurement/analyze_staged_detection_output.py \ + /tmp/staged-detection-probe \ + --output /tmp/staged-detection-probe/analysis \ + --format csv +``` + +The analyzer accepts either the staged output directory or the +`staged-detection-cases.jsonl` file directly. It writes per-case, per-seed +source group, and label-delta tables. Use `group_analysis.csv` for latency, +token, request, and signature-overlap totals; use `label_delta_analysis.csv` to +see which labels account for baseline-only misses or direct-only additions. The +analysis tables omit raw text and raw entity values. + +## Benchmark analysis + +`analyze_benchmark_output.py` joins `measurements.jsonl`, optional +DataDesigner traces, and detection artifact sidecars into richer case/group +tables: + +```bash +uv run python tools/measurement/analyze_benchmark_output.py \ + benchmark-runs/suite-id \ + --output benchmark-runs/suite-id/analysis \ + --format csv +``` + +Important outputs: + +- `case_analysis.*`: one row per benchmark case. +- `group_analysis.*`: median and aggregate metrics grouped by workload/config. +- `model_usage.*`: one row per measured model usage entry. +- `model_usage_group_analysis.*`: model usage rolled up by workflow/model. + +Use `--detection-artifacts` to provide an explicit detection artifact JSONL +sidecar. Otherwise, the analyzer reads `detection-artifacts.jsonl` in the +benchmark directory when present. + +`compare_strategy_pairs.py` compares baseline/candidate case rows: + +```bash +uv run python tools/measurement/compare_strategy_pairs.py \ + benchmark-runs/suite-id/analysis/case_analysis.csv \ + --baseline-config default \ + --candidate-config native-single-pass \ + --output benchmark-runs/suite-id/analysis/default-vs-native-single-pass.csv +``` + +When one CSV does not contain both arms, pass `--candidate-case-analysis`: + +```bash +uv run python tools/measurement/compare_strategy_pairs.py \ + baseline/analysis/case_analysis.csv \ + --candidate-case-analysis candidate/analysis/case_analysis.csv \ + --baseline-strategy default \ + --candidate-strategy detector_native_validate_no_augment \ + --output comparison.csv +``` + +`screen_strategy_comparisons.py` screens many comparison CSVs: + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ \ + --output benchmark-runs/strategy-screen.csv +``` + +Use `--group-by strategy_workload_family` when the same candidate behaves +differently across workload families. Use `--config-aliases aliases.json` to +group related config IDs, such as temperature or validation-window variants of +the same strategy. + +## Pandas patterns + +Analysis tables are regular CSV/Parquet files. A typical local workflow: + +```python +import pandas as pd + +cases = pd.read_parquet("benchmark-runs/suite/analysis/case_analysis.parquet") +groups = pd.read_parquet("benchmark-runs/suite/analysis/group_analysis.parquet") + +cols = [ + "workload_id", + "config_id", + "experimental_detection_strategy", + "median_pipeline_elapsed_sec", + "median_observed_total_requests", + "median_observed_total_tokens", + "median_artifact_final_entity_signature_count", +] +print(groups[cols].sort_values(["workload_id", "median_pipeline_elapsed_sec"])) + +failures = cases[ + (cases["case_failed"]) | + (cases["observed_failed_requests"] > 0) | + (cases["dd_trace_error_count"] > 0) +] +print(failures[["case_id", "config_id", "observed_failed_requests", "dd_trace_error_count"]]) +``` + +Compare a candidate against a baseline: + +```python +comparison = pd.read_csv("benchmark-runs/suite/analysis/default-vs-native.csv") +candidate_rows = comparison[ + ["workload_id", "candidate_verdict", "safety_verdict", "performance_verdict", "flags"] +] +print(candidate_rows) +``` + +Find candidate-specific misses: + +```python +loss_cols = [ + column for column in comparison.columns + if column.startswith("baseline_only_final_entity_signature_label_counts.") +] +print(comparison[["workload_id", *loss_cols]].fillna(0)) +``` + +## Metric interpretation + +Use metrics as signals, not as a single score. + +Latency and throughput: + +- `elapsed_sec`: wall time for a measured stage or workflow. +- `pipeline_elapsed_sec`: end-to-end Anonymizer wall time for a case. +- `records_per_pipeline_sec`: completed input records per pipeline second. +- `input_text_tokens_per_pipeline_sec`: input text tokens processed per + pipeline second. + +Model work: + +- `observed_total_requests`: measured model requests from DataDesigner or direct + model workflow records. +- `observed_total_tokens`: measured input plus output tokens. +- `observed_failed_requests`: provider-level failed requests. +- `observed_bridge_fallback_requests`: sync-client fallback requests recorded + from DataDesigner traces. +- `observed_non_bridge_failed_requests`: failed requests after subtracting + sync-client bridge fallbacks. Prefer this field when judging endpoint + reliability from trace-enabled runs. + +Detection artifacts: + +- `seed_entity_count`: detector or direct-seed candidate count before + validation. +- `seed_validation_candidate_count`: candidates sent to validation. +- `estimated_seed_validation_chunk_count`: estimated validator chunks from the + active validation chunk size. +- `augmented_entity_count`: augmenter suggestions. +- `augmented_new_final_value_count`: augmenter suggestions that add values not + already present in the seed/final set. +- `artifact_final_detector_entity_count` and + `artifact_final_augmenter_entity_count`: final entity source counts derived + from detection artifact sidecars. +- `artifact_final_entity_signature_count` and + `artifact_final_entity_signature_hashes`: opaque final-span signatures derived + from detection artifacts. These do not include raw entity values. + +Safety and replacement: + +- `original_value_leak_count`: count of protected original values still present + in replaced output. +- `replacement_missing_final_entity_count`: final entity occurrences whose + original value has no replacement-map entry. +- `replacement_missing_final_value_count`: unique final entity values with no + replacement-map entry. +- `replacement_synthetic_original_collision_count`: final entity occurrences + whose original value was reused as a synthetic replacement value elsewhere in + the same record. +- `baseline_only_candidate_covered_signature_count`, + `baseline_only_candidate_overlapping_signature_count`, and + `baseline_only_candidate_uncovered_signature_count`: comparison-only fields + from `compare_strategy_pairs.py`. These split exact signature deltas into + covered, boundary-overlap, and uncovered losses. +- `candidate_verdict`: `candidate_viable`, `review`, or `reject`. + +Treat `candidate_viable` as a promotion candidate, not as an automatic default. +It means the sampled comparison passed the current gates and improved at least +one performance metric without regressing another. Re-run candidates on the +target workload family, with repetitions, before changing production defaults. diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py new file mode 100644 index 00000000..056c2287 --- /dev/null +++ b/tools/measurement/analyze_benchmark_output.py @@ -0,0 +1,1552 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze joined benchmark measurements and detection artifact sidecars. + +Usage: + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --output analysis + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --detection-artifacts current.jsonl + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --json +""" + +from __future__ import annotations + +import json +import logging +import math +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any, cast + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.benchmark_output") + +_SYNC_CLIENT_UNAVAILABLE_ERROR = "SyncClientUnavailableError" +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class CaseAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + workload_category: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + entity_label_set_id: str | None = None + entity_label_count: int | None = None + gliner_threshold: float | None = None + repetition: int | None = None + case_id: str + run_id: str + case_failed: bool = False + error_stage_count: int = 0 + error_ndd_workflow_count: int = 0 + error_model_workflow_count: int = 0 + pipeline_elapsed_sec: float | None = None + ndd_workflow_count: int = 0 + ndd_elapsed_sec_total: float = 0.0 + observed_total_requests: int = 0 + observed_successful_requests: int = 0 + observed_input_tokens: int = 0 + observed_output_tokens: int = 0 + observed_total_tokens: int = 0 + observed_failed_requests: int = 0 + observed_failed_request_rate: float | None = None + dd_trace_record_count: int = 0 + dd_trace_error_count: int = 0 + dd_trace_sync_client_unavailable_count: int = 0 + observed_bridge_fallback_requests: int | None = None + observed_non_bridge_total_requests: int | None = None + observed_non_bridge_failed_requests: int | None = None + observed_non_bridge_failed_request_rate: float | None = None + record_count: int = 0 + input_text_tokens_total: int | None = None + records_per_pipeline_sec: float | None = None + records_per_ndd_sec: float | None = None + input_text_tokens_per_pipeline_sec: float | None = None + input_text_tokens_per_ndd_sec: float | None = None + topology_endpoint_count: float | None = None + topology_gpu_count: float | None = None + topology_tensor_parallelism: float | None = None + topology_shard_count: float | None = None + input_text_tokens_per_endpoint_sec: float | None = None + input_text_tokens_per_gpu_sec: float | None = None + final_entity_count: float | None = None + empty_detection_count: int = 0 + empty_detection_rate: float | None = None + empty_detection_with_ground_truth_count: int = 0 + empty_detection_with_ground_truth_rate: float | None = None + ground_truth_record_count: int = 0 + ground_truth_entity_count: float | None = None + entity_true_positive_count: float | None = None + entity_false_positive_count: float | None = None + entity_false_negative_count: float | None = None + entity_precision: float | None = None + entity_recall: float | None = None + entity_f1: float | None = None + entity_relaxed_gt_found_count: float | None = None + entity_relaxed_detected_tp_count: float | None = None + entity_relaxed_label_compatible_gt_found_count: float | None = None + entity_relaxed_label_compatible_detected_tp_count: float | None = None + entity_relaxed_precision: float | None = None + entity_relaxed_recall: float | None = None + entity_relaxed_f1: float | None = None + entity_relaxed_label_compatible_precision: float | None = None + entity_relaxed_label_compatible_recall: float | None = None + entity_relaxed_label_compatible_f1: float | None = None + replacement_count: float | None = None + replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + replacement_missing_final_value_count: float | None = None + replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + replacement_synthetic_original_collision_value_count: float | None = None + original_value_leak_count: float | None = None + original_value_leak_record_count: int = 0 + original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + validation_max_entities_per_call: int | None = None + detection_artifact_rows: int = 0 + seed_entity_count: float | None = None + seed_validation_candidate_count: float | None = None + estimated_seed_validation_chunk_count: float | None = None + augmented_entity_count: float | None = None + augmented_new_final_value_count: float | None = None + artifact_final_entity_count: float | None = None + artifact_final_detector_entity_count: float | None = None + artifact_final_augmenter_entity_count: float | None = None + artifact_final_entity_signature_count: float | None = None + artifact_final_entity_signature_hashes: list[str] = Field(default_factory=list) + artifact_final_entity_signature_labels: dict[str, str] = Field(default_factory=dict) + artifact_final_entity_signature_details: dict[str, dict[str, Any]] = Field(default_factory=dict) + + +class GroupAnalysisRow(BaseModel): + workload_id: str | None = None + workload_category: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + entity_label_set_id: str | None = None + entity_label_count: int | None = None + gliner_threshold: float | None = None + case_count: int + failed_case_count: int = 0 + failed_case_rate: float | None = None + error_stage_count: int = 0 + error_ndd_workflow_count: int = 0 + error_model_workflow_count: int = 0 + median_pipeline_elapsed_sec: float | None = None + median_ndd_elapsed_sec_total: float | None = None + median_observed_total_requests: float | None = None + median_observed_successful_requests: float | None = None + median_observed_input_tokens: float | None = None + median_observed_output_tokens: float | None = None + median_observed_total_tokens: float | None = None + median_observed_failed_requests: float | None = None + median_observed_failed_request_rate: float | None = None + median_observed_bridge_fallback_requests: float | None = None + median_observed_non_bridge_total_requests: float | None = None + median_observed_non_bridge_failed_requests: float | None = None + median_observed_non_bridge_failed_request_rate: float | None = None + total_record_count: int = 0 + median_record_count: float | None = None + total_input_text_tokens: int | None = None + median_input_text_tokens_total: float | None = None + median_records_per_pipeline_sec: float | None = None + median_records_per_ndd_sec: float | None = None + median_input_text_tokens_per_pipeline_sec: float | None = None + median_input_text_tokens_per_ndd_sec: float | None = None + median_topology_endpoint_count: float | None = None + median_topology_gpu_count: float | None = None + median_topology_tensor_parallelism: float | None = None + median_topology_shard_count: float | None = None + median_input_text_tokens_per_endpoint_sec: float | None = None + median_input_text_tokens_per_gpu_sec: float | None = None + median_final_entity_count: float | None = None + total_empty_detection_count: int = 0 + empty_detection_rate: float | None = None + total_empty_detection_with_ground_truth_count: int = 0 + empty_detection_with_ground_truth_rate: float | None = None + total_ground_truth_record_count: int = 0 + sum_ground_truth_entity_count: float | None = None + sum_entity_true_positive_count: float | None = None + sum_entity_false_positive_count: float | None = None + sum_entity_false_negative_count: float | None = None + micro_entity_precision: float | None = None + micro_entity_recall: float | None = None + micro_entity_f1: float | None = None + sum_entity_relaxed_gt_found_count: float | None = None + sum_entity_relaxed_detected_tp_count: float | None = None + sum_entity_relaxed_label_compatible_gt_found_count: float | None = None + sum_entity_relaxed_label_compatible_detected_tp_count: float | None = None + micro_entity_relaxed_precision: float | None = None + micro_entity_relaxed_recall: float | None = None + micro_entity_relaxed_f1: float | None = None + micro_entity_relaxed_label_compatible_precision: float | None = None + micro_entity_relaxed_label_compatible_recall: float | None = None + micro_entity_relaxed_label_compatible_f1: float | None = None + median_entity_relaxed_f1: float | None = None + median_entity_relaxed_label_compatible_f1: float | None = None + median_replacement_missing_final_entity_count: float | None = None + median_replacement_missing_final_value_count: float | None = None + replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + median_replacement_synthetic_original_collision_count: float | None = None + median_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + sum_original_value_leak_count: float | None = None + leaking_case_count: int = 0 + median_original_value_leak_count: float | None = None + median_seed_entity_count: float | None = None + median_seed_validation_candidate_count: float | None = None + median_estimated_seed_validation_chunk_count: float | None = None + median_augmented_entity_count: float | None = None + median_augmented_new_final_value_count: float | None = None + median_artifact_final_entity_count: float | None = None + median_artifact_final_detector_entity_count: float | None = None + median_artifact_final_augmenter_entity_count: float | None = None + median_artifact_final_entity_signature_count: float | None = None + + +class ModelUsageAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + repetition: int | None = None + case_id: str + run_id: str + workflow_name: str | None = None + model_alias: str | None = None + model_name: str + model_provider_name: str | None = None + ndd_elapsed_sec: float | None = None + observed_total_requests: int = 0 + observed_successful_requests: int = 0 + observed_failed_requests: int = 0 + observed_input_tokens: int = 0 + observed_output_tokens: int = 0 + observed_total_tokens: int = 0 + observed_reasoning_tokens: int | None = None + observed_failed_request_rate: float | None = None + + +class ModelUsageGroupAnalysisRow(BaseModel): + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + workflow_name: str | None = None + model_alias: str | None = None + model_name: str + model_provider_name: str | None = None + case_count: int + workflow_count: int + sum_observed_total_requests: int = 0 + sum_observed_successful_requests: int = 0 + sum_observed_failed_requests: int = 0 + sum_observed_input_tokens: int = 0 + sum_observed_output_tokens: int = 0 + sum_observed_total_tokens: int = 0 + sum_observed_reasoning_tokens: int | None = None + observed_failed_request_rate: float | None = None + median_observed_total_requests: float | None = None + median_observed_failed_requests: float | None = None + median_observed_total_tokens: float | None = None + + +class BenchmarkOutputAnalysis(BaseModel): + benchmark_dir: str + detection_artifacts_path: str | None = None + cases: list[CaseAnalysisRow] = Field(default_factory=list) + groups: list[GroupAnalysisRow] = Field(default_factory=list) + model_usage: list[ModelUsageAnalysisRow] = Field(default_factory=list) + model_usage_groups: list[ModelUsageGroupAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def case_count(self) -> int: + return len(self.cases) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + @computed_field + @property + def model_usage_count(self) -> int: + return len(self.model_usage) + + @computed_field + @property + def model_usage_group_count(self) -> int: + return len(self.model_usage_groups) + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_benchmark_output( + benchmark_dir: Path, + *, + detection_artifacts: Path | None = None, +) -> BenchmarkOutputAnalysis: + measurements = read_jsonl_table(benchmark_dir / "measurements.jsonl", required=True) + artifacts_path = detection_artifacts or benchmark_dir / "detection-artifacts.jsonl" + artifacts = read_jsonl_table(artifacts_path, required=detection_artifacts is not None) + traces = read_trace_summary_table(benchmark_dir / "traces") + cases = [ + _build_case_row(case_id, measurements, artifacts, traces) + for case_id in _case_ids(measurements, artifacts, traces) + ] + model_usage = build_model_usage_rows(measurements) + return BenchmarkOutputAnalysis( + benchmark_dir=str(benchmark_dir), + detection_artifacts_path=str(artifacts_path) if not artifacts.empty else None, + cases=cases, + groups=build_group_rows(cases), + model_usage=model_usage, + model_usage_groups=build_model_usage_group_rows(model_usage), + ) + + +def read_jsonl_table(path: Path, *, required: bool) -> pd.DataFrame: + if not path.exists(): + if required: + raise ValueError(f"input path does not exist: {path}") + return pd.DataFrame() + if path.is_dir(): + raise ValueError(f"input path is a directory: {path}") + raw = pd.read_json(path, lines=True) + if raw.empty: + return raw + return pd.json_normalize(raw.to_dict("records"), sep=".") + + +def read_trace_summary_table(trace_path: Path) -> pd.DataFrame: + """Read DD trace files into a sanitized table with no prompt/response text.""" + if not trace_path.exists(): + return pd.DataFrame() + if trace_path.is_file(): + paths = [trace_path] + elif trace_path.is_dir(): + paths = sorted(trace_path.rglob("*.jsonl")) + else: + raise ValueError(f"trace path is not a file or directory: {trace_path}") + + rows: list[dict[str, Any]] = [] + for path in paths: + for line in path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + record = json.loads(line) + if not isinstance(record, dict) or record.get("record_type") != "dd_message_trace": + continue + run_tags = record.get("run_tags") if isinstance(record.get("run_tags"), dict) else {} + rows.append( + { + "record_type": "dd_message_trace", + "run_id": record.get("run_id"), + "run_tags.case_id": run_tags.get("case_id"), + "run_tags.workload_id": run_tags.get("workload_id"), + "run_tags.config_id": run_tags.get("config_id"), + "run_tags.experimental_detection_strategy": run_tags.get("experimental_detection_strategy"), + "run_tags.experimental_replacement_strategy": run_tags.get("experimental_replacement_strategy"), + "run_tags.dd_parser_compat": run_tags.get("dd_parser_compat"), + "run_tags.repetition": run_tags.get("repetition"), + "workflow_name": record.get("workflow_name"), + "model_alias": record.get("model_alias"), + "status": record.get("status"), + "error_type": record.get("error_type"), + "is_async": record.get("is_async"), + } + ) + return pd.DataFrame(rows) + + +def _case_ids(*frames: pd.DataFrame) -> list[str]: + values: set[str] = set() + for dataframe in frames: + for column in ("run_tags.case_id", "case_id", "run_id"): + if column in dataframe.columns: + values.update(str(value) for value in dataframe[column].dropna().tolist()) + return sorted(values) + + +def _build_case_row( + case_id: str, + measurements: pd.DataFrame, + artifacts: pd.DataFrame, + traces: pd.DataFrame, +) -> CaseAnalysisRow: + measurement_rows = _rows_for_case(measurements, case_id) + artifact_rows = _rows_for_case(artifacts, case_id) + trace_rows = _rows_for_case(traces, case_id) + record_rows = _records_of_type(measurement_rows, "record") + ndd_rows = _records_of_type(measurement_rows, "ndd_workflow") + model_rows = _model_workflow_rows(measurement_rows) + stage_rows = _records_of_type(measurement_rows, "stage") + pipeline_rows = _pipeline_stage_rows(measurement_rows) + validation_max_entities_per_call = _first_int([measurement_rows], ["detect.validation_max_entities_per_call"]) + request_metrics = _case_request_metrics(model_rows) + pipeline_elapsed_sec = _sum_or_none(pipeline_rows, "elapsed_sec") + ndd_elapsed_sec_total = _sum_or_zero(ndd_rows, "elapsed_sec") + record_count = len(record_rows) + input_text_tokens_total = _sum_int_or_none(record_rows, "text_length_tokens") + records_per_pipeline_sec = _safe_rate(record_count, pipeline_elapsed_sec) + records_per_ndd_sec = _safe_rate(record_count, ndd_elapsed_sec_total) + input_text_tokens_per_pipeline_sec = _safe_rate(input_text_tokens_total, pipeline_elapsed_sec) + input_text_tokens_per_ndd_sec = _safe_rate(input_text_tokens_total, ndd_elapsed_sec_total) + final_entity_count = _coalesce_number( + _sum_or_none(record_rows, "final_entity_count"), + _sum_or_none(artifact_rows, "final_entity_count"), + ) + return CaseAnalysisRow( + suite_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.suite_id", "suite_id"]), + workload_id=_first_value( + [measurement_rows, artifact_rows, trace_rows], ["run_tags.workload_id", "workload_id"] + ), + workload_category=_first_value( + [measurement_rows, artifact_rows, trace_rows], + ["run_tags.workload_category", "run_tags.dataset_category", "workload_category", "dataset_category"], + ), + config_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.config_id", "config_id"]), + experimental_detection_strategy=_first_value([measurement_rows], ["run_tags.experimental_detection_strategy"]), + experimental_replacement_strategy=_first_value( + [measurement_rows, trace_rows], + ["run_tags.experimental_replacement_strategy"], + ), + dd_parser_compat=_first_value([measurement_rows], ["run_tags.dd_parser_compat"]), + entity_label_set_id=_first_value( + [measurement_rows], + [ + "run_tags.entity_label_set_id", + "run_tags.entity_label_set", + "run_tags.label_set", + "detect.entity_label_source", + ], + ), + entity_label_count=_first_int([measurement_rows], ["run_tags.entity_label_count", "detect.entity_label_count"]), + gliner_threshold=_first_float([measurement_rows], ["run_tags.gliner_threshold", "detect.gliner_threshold"]), + repetition=_first_int([measurement_rows, artifact_rows, trace_rows], ["run_tags.repetition", "repetition"]), + case_id=case_id, + run_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_id"]) or case_id, + **_case_failure_metrics(stage_rows=stage_rows, ndd_rows=ndd_rows, model_rows=model_rows), + pipeline_elapsed_sec=pipeline_elapsed_sec, + ndd_workflow_count=len(ndd_rows), + ndd_elapsed_sec_total=ndd_elapsed_sec_total, + **request_metrics, + **_case_trace_metrics(trace_rows, request_metrics=request_metrics), + record_count=record_count, + input_text_tokens_total=input_text_tokens_total, + records_per_pipeline_sec=records_per_pipeline_sec, + records_per_ndd_sec=records_per_ndd_sec, + input_text_tokens_per_pipeline_sec=input_text_tokens_per_pipeline_sec, + input_text_tokens_per_ndd_sec=input_text_tokens_per_ndd_sec, + **_case_topology_metrics( + measurement_rows, + input_text_tokens_per_pipeline_sec=input_text_tokens_per_pipeline_sec, + ), + final_entity_count=final_entity_count, + **_case_empty_detection_metrics(record_rows, record_count=record_count), + **_case_ground_truth_metrics(record_rows, final_entity_count=final_entity_count), + replacement_count=_sum_or_none(record_rows, "replacement_count"), + replacement_missing_final_entity_count=_sum_or_none(record_rows, "replacement_missing_final_entity_count"), + replacement_missing_final_entity_label_counts=_sum_prefixed_ints( + record_rows, + "replacement_missing_final_entity_label_counts.", + ), + replacement_missing_final_value_count=_sum_or_none(record_rows, "replacement_missing_final_value_count"), + replacement_synthetic_original_collision_count=_sum_or_none( + record_rows, + "replacement_synthetic_original_collision_count", + ), + replacement_synthetic_original_collision_label_counts=_sum_prefixed_ints( + record_rows, + "replacement_synthetic_original_collision_label_counts.", + ), + replacement_synthetic_original_collision_value_count=_sum_or_none( + record_rows, + "replacement_synthetic_original_collision_value_count", + ), + original_value_leak_count=_sum_or_none(record_rows, "original_value_leak_count"), + original_value_leak_record_count=_positive_count(record_rows, "original_value_leak_count"), + original_value_leak_label_counts=_sum_prefixed_ints(record_rows, "original_value_leak_label_counts."), + validation_max_entities_per_call=validation_max_entities_per_call, + **_case_artifact_metrics( + artifact_rows, + validation_max_entities_per_call=validation_max_entities_per_call, + ), + ) + + +def _case_request_metrics(model_rows: pd.DataFrame) -> dict[str, int | float | None]: + observed_total_requests = int(_sum_or_zero(model_rows, "observed_total_requests")) + observed_failed_requests = int(_sum_or_zero(model_rows, "observed_failed_requests")) + return { + "observed_total_requests": observed_total_requests, + "observed_successful_requests": int(_sum_or_zero(model_rows, "observed_successful_requests")), + "observed_input_tokens": int(_sum_or_zero(model_rows, "observed_input_tokens")), + "observed_output_tokens": int(_sum_or_zero(model_rows, "observed_output_tokens")), + "observed_total_tokens": int(_sum_or_zero(model_rows, "observed_total_tokens")), + "observed_failed_requests": observed_failed_requests, + "observed_failed_request_rate": _request_failure_rate( + failed=observed_failed_requests, + total=observed_total_requests, + ), + } + + +def _case_trace_metrics( + trace_rows: pd.DataFrame, + *, + request_metrics: dict[str, int | float | None], +) -> dict[str, int | float | None]: + trace_record_count = len(trace_rows) + if trace_record_count == 0: + return { + "dd_trace_record_count": 0, + "dd_trace_error_count": 0, + "dd_trace_sync_client_unavailable_count": 0, + "observed_bridge_fallback_requests": None, + "observed_non_bridge_total_requests": None, + "observed_non_bridge_failed_requests": None, + "observed_non_bridge_failed_request_rate": None, + } + + status = trace_rows["status"].astype(str) if "status" in trace_rows.columns else pd.Series(dtype=str) + error_type = trace_rows["error_type"].astype(str) if "error_type" in trace_rows.columns else pd.Series(dtype=str) + error_count = int((status == "error").sum()) + bridge_fallbacks = int(((status == "error") & (error_type == _SYNC_CLIENT_UNAVAILABLE_ERROR)).sum()) + observed_total = int(request_metrics["observed_total_requests"] or 0) + observed_failed = int(request_metrics["observed_failed_requests"] or 0) + non_bridge_total = max(observed_total - bridge_fallbacks, 0) + non_bridge_failed = max(observed_failed - bridge_fallbacks, 0) + return { + "dd_trace_record_count": trace_record_count, + "dd_trace_error_count": error_count, + "dd_trace_sync_client_unavailable_count": bridge_fallbacks, + "observed_bridge_fallback_requests": bridge_fallbacks, + "observed_non_bridge_total_requests": non_bridge_total, + "observed_non_bridge_failed_requests": non_bridge_failed, + "observed_non_bridge_failed_request_rate": _request_failure_rate( + failed=non_bridge_failed, + total=non_bridge_total, + ), + } + + +def _case_failure_metrics( + *, + stage_rows: pd.DataFrame, + ndd_rows: pd.DataFrame, + model_rows: pd.DataFrame, +) -> dict[str, bool | int]: + error_stage_count = _error_status_count(stage_rows) + error_ndd_workflow_count = _error_status_count(ndd_rows) + error_model_workflow_count = _error_status_count(model_rows) + return { + "case_failed": error_stage_count > 0 or error_ndd_workflow_count > 0 or error_model_workflow_count > 0, + "error_stage_count": error_stage_count, + "error_ndd_workflow_count": error_ndd_workflow_count, + "error_model_workflow_count": error_model_workflow_count, + } + + +def _case_topology_metrics( + measurement_rows: pd.DataFrame, + *, + input_text_tokens_per_pipeline_sec: float | None, +) -> dict[str, float | None]: + endpoint_count = _first_float( + [measurement_rows], + [ + "run_tags.topology_endpoint_count", + "run_tags.endpoint_count", + "run_tags.n_endpoints", + "run_tags.n_llm_endpoints", + ], + ) + gpu_count = _first_float( + [measurement_rows], + [ + "run_tags.topology_gpu_count", + "run_tags.gpu_count", + "run_tags.n_gpus", + "run_tags.n_llm_gpus", + ], + ) + tensor_parallelism = _first_float( + [measurement_rows], + [ + "run_tags.topology_tensor_parallelism", + "run_tags.tensor_parallelism", + "run_tags.gpus_per_endpoint", + "run_tags.tp", + ], + ) + shard_count = _first_float( + [measurement_rows], + ["run_tags.topology_shard_count", "run_tags.shard_count", "run_tags.n_shards"], + ) + return { + "topology_endpoint_count": endpoint_count, + "topology_gpu_count": gpu_count, + "topology_tensor_parallelism": tensor_parallelism, + "topology_shard_count": shard_count, + "input_text_tokens_per_endpoint_sec": _safe_ratio(input_text_tokens_per_pipeline_sec, endpoint_count), + "input_text_tokens_per_gpu_sec": _safe_ratio(input_text_tokens_per_pipeline_sec, gpu_count), + } + + +def _case_empty_detection_metrics(record_rows: pd.DataFrame, *, record_count: int) -> dict[str, int | float | None]: + empty_detection_count = _zero_count(record_rows, "final_entity_count") + ground_truth_record_count = _non_null_count(record_rows, "ground_truth_entity_count") + empty_detection_with_gt_count = _zero_with_positive_count( + record_rows, + zero_column="final_entity_count", + positive_column="ground_truth_entity_count", + ) + return { + "empty_detection_count": empty_detection_count, + "empty_detection_rate": _safe_ratio(empty_detection_count, record_count), + "empty_detection_with_ground_truth_count": empty_detection_with_gt_count, + "empty_detection_with_ground_truth_rate": _safe_ratio( + empty_detection_with_gt_count, + ground_truth_record_count, + ), + "ground_truth_record_count": ground_truth_record_count, + } + + +def _case_ground_truth_metrics( + record_rows: pd.DataFrame, + *, + final_entity_count: float | None, +) -> dict[str, float | None]: + ground_truth_entity_count = _sum_or_none(record_rows, "ground_truth_entity_count") + true_positive = _sum_or_none(record_rows, "entity_true_positive_count") + false_positive = _sum_or_none(record_rows, "entity_false_positive_count") + false_negative = _sum_or_none(record_rows, "entity_false_negative_count") + relaxed_gt_found = _sum_or_none(record_rows, "entity_relaxed_gt_found_count") + relaxed_detected_tp = _sum_or_none(record_rows, "entity_relaxed_detected_tp_count") + label_compatible_gt_found = _sum_or_none(record_rows, "entity_relaxed_label_compatible_gt_found_count") + label_compatible_detected_tp = _sum_or_none( + record_rows, + "entity_relaxed_label_compatible_detected_tp_count", + ) + strict_precision = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_positive)) + strict_recall = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_negative)) + relaxed_precision = _safe_ratio(relaxed_detected_tp, final_entity_count) + relaxed_recall = _safe_ratio(relaxed_gt_found, ground_truth_entity_count) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, final_entity_count) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, ground_truth_entity_count) + return { + "ground_truth_entity_count": ground_truth_entity_count, + "entity_true_positive_count": true_positive, + "entity_false_positive_count": false_positive, + "entity_false_negative_count": false_negative, + "entity_precision": strict_precision, + "entity_recall": strict_recall, + "entity_f1": _f1(strict_precision, strict_recall), + "entity_relaxed_gt_found_count": relaxed_gt_found, + "entity_relaxed_detected_tp_count": relaxed_detected_tp, + "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, + "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, + "entity_relaxed_precision": relaxed_precision, + "entity_relaxed_recall": relaxed_recall, + "entity_relaxed_f1": _f1(relaxed_precision, relaxed_recall), + "entity_relaxed_label_compatible_precision": label_compatible_precision, + "entity_relaxed_label_compatible_recall": label_compatible_recall, + "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), + } + + +def _error_status_count(rows: pd.DataFrame) -> int: + if "status" not in rows.columns: + return 0 + statuses = rows["status"].astype(str).str.lower() + return int(statuses.isin({"error", "failed"}).sum()) + + +def _case_artifact_metrics( + artifact_rows: pd.DataFrame, + *, + validation_max_entities_per_call: int | None, +) -> dict[str, int | float | list[str] | dict[str, str] | None]: + signature_hashes = _artifact_signature_hashes(artifact_rows) + return { + "detection_artifact_rows": len(artifact_rows), + "seed_entity_count": _sum_or_none(artifact_rows, "seed_entity_count"), + "seed_validation_candidate_count": _sum_or_none(artifact_rows, "seed_validation_candidate_count"), + "estimated_seed_validation_chunk_count": _estimated_validation_chunk_count( + artifact_rows, + validation_max_entities_per_call=validation_max_entities_per_call, + ), + "augmented_entity_count": _sum_or_none(artifact_rows, "augmented_entity_count"), + "augmented_new_final_value_count": _sum_or_none(artifact_rows, "augmented_new_final_value_count"), + "artifact_final_entity_count": _sum_or_none(artifact_rows, "final_entity_count"), + "artifact_final_detector_entity_count": _sum_or_none(artifact_rows, "final_source_counts.detector"), + "artifact_final_augmenter_entity_count": _sum_or_none(artifact_rows, "final_source_counts.augmenter"), + "artifact_final_entity_signature_count": _signature_count(artifact_rows, signature_hashes=signature_hashes), + "artifact_final_entity_signature_hashes": signature_hashes, + "artifact_final_entity_signature_labels": _artifact_signature_labels(artifact_rows), + "artifact_final_entity_signature_details": _artifact_signature_details(artifact_rows), + } + + +def _rows_for_case(dataframe: pd.DataFrame, case_id: str) -> pd.DataFrame: + if dataframe.empty: + return dataframe + masks = [ + dataframe[column].astype(str) == case_id + for column in ("run_tags.case_id", "case_id", "run_id") + if column in dataframe.columns + ] + if not masks: + return dataframe.iloc[0:0] + mask = masks[0] + for next_mask in masks[1:]: + mask = mask | next_mask + return dataframe[mask] + + +def _records_of_type(dataframe: pd.DataFrame, record_type: str) -> pd.DataFrame: + if "record_type" not in dataframe.columns: + return dataframe.iloc[0:0] + return dataframe[dataframe["record_type"] == record_type] + + +def _records_of_types(dataframe: pd.DataFrame, record_types: set[str]) -> pd.DataFrame: + if "record_type" not in dataframe.columns: + return dataframe.iloc[0:0] + return dataframe[dataframe["record_type"].isin(record_types)] + + +def _model_workflow_rows(dataframe: pd.DataFrame) -> pd.DataFrame: + return _records_of_types(dataframe, {"ndd_workflow", "model_workflow"}) + + +def _pipeline_stage_rows(dataframe: pd.DataFrame) -> pd.DataFrame: + stages = _records_of_type(dataframe, "stage") + if "stage" not in stages.columns: + return stages.iloc[0:0] + return stages[stages["stage"] == "Anonymizer._run_internal"] + + +_MODEL_USAGE_SUFFIXES = { + ".request_usage.total_requests": "observed_total_requests", + ".request_usage.successful_requests": "observed_successful_requests", + ".request_usage.failed_requests": "observed_failed_requests", + ".token_usage.input_tokens": "observed_input_tokens", + ".token_usage.output_tokens": "observed_output_tokens", + ".token_usage.total_tokens": "observed_total_tokens", + ".token_usage.reasoning_tokens": "observed_reasoning_tokens", +} + +_MODEL_USAGE_METADATA_SUFFIXES = { + ".model_alias": "model_alias", + ".model_name": "model_name", + ".model_provider_name": "model_provider_name", +} + + +def build_model_usage_rows(measurements: pd.DataFrame) -> list[ModelUsageAnalysisRow]: + model_rows = _model_workflow_rows(measurements) + if model_rows.empty: + return [] + model_usage_keys = _model_usage_keys(model_rows.columns) + rows: list[ModelUsageAnalysisRow] = [] + for _, measurement in model_rows.iterrows(): + data = measurement.to_dict() + case_id = _string_from_row(data, ["run_tags.case_id", "run_id"]) + run_id = _string_from_row(data, ["run_id", "run_tags.case_id"]) + if case_id is None or run_id is None: + continue + for model_usage_key in model_usage_keys: + usage = _model_usage_metrics(data, model_usage_key) + if not _has_observed_model_usage(usage): + continue + metadata = _model_usage_metadata(data, model_usage_key) + rows.append( + ModelUsageAnalysisRow( + suite_id=_string_from_row(data, ["run_tags.suite_id"]), + workload_id=_string_from_row(data, ["run_tags.workload_id"]), + config_id=_string_from_row(data, ["run_tags.config_id"]), + experimental_detection_strategy=_string_from_row( + data, ["run_tags.experimental_detection_strategy"] + ), + experimental_replacement_strategy=_string_from_row( + data, ["run_tags.experimental_replacement_strategy"] + ), + dd_parser_compat=_string_from_row(data, ["run_tags.dd_parser_compat"]), + repetition=_int_from_row(data, ["run_tags.repetition"]), + case_id=case_id, + run_id=run_id, + workflow_name=_string_from_row(data, ["workflow_name"]), + model_alias=metadata.get("model_alias"), + model_name=metadata.get("model_name") or model_usage_key, + model_provider_name=metadata.get("model_provider_name"), + ndd_elapsed_sec=_float_from_row(data, ["elapsed_sec"]), + **usage, + ) + ) + return rows + + +def _model_usage_keys(columns: pd.Index) -> list[str]: + keys: set[str] = set() + for column in columns: + parsed = _model_usage_column_parts(str(column)) + if parsed is not None: + keys.add(parsed[0]) + return sorted(keys) + + +def _model_usage_column_parts(column: str) -> tuple[str, str] | None: + prefix = "model_usage." + if not column.startswith(prefix): + return None + for suffix, metric in {**_MODEL_USAGE_SUFFIXES, **_MODEL_USAGE_METADATA_SUFFIXES}.items(): + if column.endswith(suffix): + return column[len(prefix) : -len(suffix)], metric + return None + + +def _model_usage_metrics(data: dict[str, Any], model_usage_key: str) -> dict[str, int | float | None]: + values: dict[str, int | float | None] = { + "observed_total_requests": 0, + "observed_successful_requests": 0, + "observed_failed_requests": 0, + "observed_input_tokens": 0, + "observed_output_tokens": 0, + "observed_total_tokens": 0, + "observed_reasoning_tokens": None, + "observed_failed_request_rate": None, + } + for suffix, metric in _MODEL_USAGE_SUFFIXES.items(): + value = data.get(f"model_usage.{model_usage_key}{suffix}") + if value is None or pd.isna(value): + continue + values[metric] = _coerce_int(value) + values["observed_failed_request_rate"] = _request_failure_rate( + failed=values["observed_failed_requests"], + total=values["observed_total_requests"], + ) + return values + + +def _model_usage_metadata(data: dict[str, Any], model_usage_key: str) -> dict[str, str | None]: + values: dict[str, str | None] = { + "model_alias": None, + "model_name": None, + "model_provider_name": None, + } + for suffix, field_name in _MODEL_USAGE_METADATA_SUFFIXES.items(): + value = data.get(f"model_usage.{model_usage_key}{suffix}") + if value is None or pd.isna(value): + continue + values[field_name] = str(value) + return values + + +def _has_observed_model_usage(usage: dict[str, int | float | None]) -> bool: + return any(value not in (None, 0) for value in usage.values()) + + +def _string_from_row(data: dict[str, Any], columns: list[str]) -> str | None: + for column in columns: + value = data.get(column) + if value is not None and not pd.isna(value): + return str(value) + return None + + +def _int_from_row(data: dict[str, Any], columns: list[str]) -> int | None: + value = _string_from_row(data, columns) + return int(float(value)) if value is not None else None + + +def _float_from_row(data: dict[str, Any], columns: list[str]) -> float | None: + value = _string_from_row(data, columns) + return float(value) if value is not None else None + + +def _coerce_int(value: Any) -> int: + return int(float(value)) + + +def _first_value(frames: list[pd.DataFrame], columns: list[str]) -> str | None: + for frame in frames: + for column in columns: + if column not in frame.columns: + continue + values = frame[column].dropna() + if not values.empty: + return str(values.iloc[0]) + return None + + +def _first_int(frames: list[pd.DataFrame], columns: list[str]) -> int | None: + value = _first_value(frames, columns) + return int(float(value)) if value is not None else None + + +def _first_float(frames: list[pd.DataFrame], columns: list[str]) -> float | None: + value = _first_value(frames, columns) + return float(value) if value is not None else None + + +def _coalesce_number(*values: float | None) -> float | None: + for value in values: + if value is not None: + return value + return None + + +def _artifact_signature_hashes(artifact_rows: pd.DataFrame) -> list[str]: + if "final_entity_signature_hashes" not in artifact_rows.columns: + return [] + values: set[str] = set() + for raw in artifact_rows["final_entity_signature_hashes"].dropna(): + values.update(_coerce_string_list(raw)) + return sorted(values) + + +def _artifact_signature_labels(artifact_rows: pd.DataFrame) -> dict[str, str]: + labels: dict[str, str] = {} + if "final_entity_signature_labels" in artifact_rows.columns: + for raw in artifact_rows["final_entity_signature_labels"].dropna(): + labels.update(_coerce_string_dict(raw)) + for column in artifact_rows.columns: + prefix = "final_entity_signature_labels." + if not column.startswith(prefix): + continue + signature_hash = column.removeprefix(prefix) + for value in artifact_rows[column].dropna(): + labels[signature_hash] = str(value) + return dict(sorted(labels.items())) + + +def _artifact_signature_details(artifact_rows: pd.DataFrame) -> dict[str, dict[str, Any]]: + details: dict[str, dict[str, Any]] = {} + if "final_entity_signature_details" in artifact_rows.columns: + for raw in artifact_rows["final_entity_signature_details"].dropna(): + details.update(_coerce_detail_map(raw)) + prefix = "final_entity_signature_details." + for column in artifact_rows.columns: + if not column.startswith(prefix): + continue + remainder = column.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue + for value in artifact_rows[column].dropna(): + details.setdefault(signature_hash, {})[field] = _json_scalar(value) + return dict(sorted(details.items())) + + +def _coerce_detail_map(raw: object) -> dict[str, dict[str, Any]]: + if isinstance(raw, str): + try: + raw = json.loads(raw) + except json.JSONDecodeError: + return {} + if not isinstance(raw, dict): + return {} + details: dict[str, dict[str, Any]] = {} + for signature_hash, value in raw.items(): + if isinstance(value, dict): + details[str(signature_hash)] = { + str(key): _json_scalar(item) for key, item in value.items() if str(key) in _SIGNATURE_DETAIL_FIELDS + } + return details + + +def _json_scalar(value: object) -> Any: + if hasattr(value, "item"): + try: + return value.item() + except ValueError: + return value + return value + + +def _coerce_string_list(raw: object) -> list[str]: + if isinstance(raw, list): + return [str(item) for item in raw] + return [] + + +def _coerce_string_dict(raw: object) -> dict[str, str]: + if isinstance(raw, dict): + return {str(key): str(value) for key, value in raw.items()} + return {} + + +def _signature_count(artifact_rows: pd.DataFrame, *, signature_hashes: list[str]) -> float | None: + if signature_hashes: + return float(len(signature_hashes)) + return _sum_or_none(artifact_rows, "final_entity_signature_count") + + +def _sum_or_zero(dataframe: pd.DataFrame, column: str) -> float: + if column not in dataframe.columns: + return 0.0 + return float(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) + + +def _sum_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + if column not in dataframe.columns: + return None + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.sum()) + + +def _positive_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + values = pd.to_numeric(dataframe[column], errors="coerce").fillna(0) + return int((values > 0).sum()) + + +def _zero_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + return int((values == 0).sum()) + + +def _non_null_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(pd.to_numeric(dataframe[column], errors="coerce").notna().sum()) + + +def _zero_with_positive_count(dataframe: pd.DataFrame, *, zero_column: str, positive_column: str) -> int: + if zero_column not in dataframe.columns or positive_column not in dataframe.columns: + return 0 + zero_values = pd.to_numeric(dataframe[zero_column], errors="coerce") + positive_values = pd.to_numeric(dataframe[positive_column], errors="coerce") + return int(((zero_values == 0) & (positive_values > 0)).sum()) + + +def _sum_prefixed_ints(dataframe: pd.DataFrame, prefix: str) -> dict[str, int]: + totals: dict[str, int] = {} + base_column = prefix.removesuffix(".") + if base_column in dataframe.columns: + for value in dataframe[base_column].dropna().tolist(): + for key, count in _coerce_count_mapping(value).items(): + totals[key] = totals.get(key, 0) + count + for column in sorted(col for col in dataframe.columns if col.startswith(prefix)): + value = _sum_int_or_zero(dataframe, column) + if value: + totals[column.removeprefix(prefix)] = value + return totals + + +def _coerce_count_mapping(value: object) -> dict[str, int]: + payload = value + if isinstance(value, str): + try: + payload = json.loads(value) + except json.JSONDecodeError: + return {} + if not isinstance(payload, dict): + return {} + counts: dict[str, int] = {} + for key, count in payload.items(): + numeric = pd.to_numeric(pd.Series([count]), errors="coerce").dropna() + if not numeric.empty and numeric.iloc[0]: + counts[str(key)] = int(numeric.iloc[0]) + return counts + + +def build_group_rows(cases: list[CaseAnalysisRow]) -> list[GroupAnalysisRow]: + if not cases: + return [] + table = pd.DataFrame([case.model_dump() for case in cases]) + rows: list[GroupAnalysisRow] = [] + group_columns = [ + "workload_id", + "workload_category", + "config_id", + "experimental_detection_strategy", + "experimental_replacement_strategy", + "entity_label_set_id", + "entity_label_count", + "gliner_threshold", + ] + for keys, group in table.groupby(group_columns, dropna=False): + rows.append(_build_group_row(keys, group)) + return rows + + +def build_model_usage_group_rows(model_usage: list[ModelUsageAnalysisRow]) -> list[ModelUsageGroupAnalysisRow]: + if not model_usage: + return [] + table = pd.DataFrame([row.model_dump() for row in model_usage]) + rows: list[ModelUsageGroupAnalysisRow] = [] + group_columns = [ + "workload_id", + "config_id", + "experimental_detection_strategy", + "experimental_replacement_strategy", + "dd_parser_compat", + "workflow_name", + "model_alias", + "model_name", + "model_provider_name", + ] + for keys, group in table.groupby(group_columns, dropna=False): + rows.append(_build_model_usage_group_row(keys, group)) + return rows + + +def _build_model_usage_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> ModelUsageGroupAnalysisRow: + ( + workload_id, + config_id, + detection_strategy, + replacement_strategy, + dd_parser_compat, + workflow_name, + model_alias, + model_name, + provider_name, + ) = keys + reasoning_sum = _sum_int_or_none(group, "observed_reasoning_tokens") + total_requests = _sum_int_or_zero(group, "observed_total_requests") + failed_requests = _sum_int_or_zero(group, "observed_failed_requests") + return ModelUsageGroupAnalysisRow( + workload_id=_none_if_nan(workload_id), + config_id=_none_if_nan(config_id), + experimental_detection_strategy=_none_if_nan(detection_strategy), + experimental_replacement_strategy=_none_if_nan(replacement_strategy), + dd_parser_compat=_none_if_nan(dd_parser_compat), + workflow_name=_none_if_nan(workflow_name), + model_alias=_none_if_nan(model_alias), + model_name=str(model_name), + model_provider_name=_none_if_nan(provider_name), + case_count=int(group["case_id"].nunique()), + workflow_count=len(group), + sum_observed_total_requests=total_requests, + sum_observed_successful_requests=_sum_int_or_zero(group, "observed_successful_requests"), + sum_observed_failed_requests=failed_requests, + sum_observed_input_tokens=_sum_int_or_zero(group, "observed_input_tokens"), + sum_observed_output_tokens=_sum_int_or_zero(group, "observed_output_tokens"), + sum_observed_total_tokens=_sum_int_or_zero(group, "observed_total_tokens"), + sum_observed_reasoning_tokens=reasoning_sum, + observed_failed_request_rate=_request_failure_rate(failed=failed_requests, total=total_requests), + median_observed_total_requests=_median_or_none(group, "observed_total_requests"), + median_observed_failed_requests=_median_or_none(group, "observed_failed_requests"), + median_observed_total_tokens=_median_or_none(group, "observed_total_tokens"), + ) + + +def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysisRow: + ( + workload_id, + workload_category, + config_id, + detection_strategy, + replacement_strategy, + entity_label_set_id, + entity_label_count, + gliner_threshold, + ) = keys + case_count = int(group["case_id"].nunique()) + failed_case_count = _sum_bool_or_zero(group, "case_failed") + total_record_count = _sum_int_or_zero(group, "record_count") + total_input_text_tokens = _sum_int_or_none(group, "input_text_tokens_total") + total_empty_detection_count = _sum_int_or_zero(group, "empty_detection_count") + total_ground_truth_record_count = _sum_int_or_zero(group, "ground_truth_record_count") + total_empty_detection_with_gt_count = _sum_int_or_zero(group, "empty_detection_with_ground_truth_count") + final_entity_count = _sum_or_none(group, "final_entity_count") + ground_truth_entity_count = _sum_or_none(group, "ground_truth_entity_count") + true_positive = _sum_or_none(group, "entity_true_positive_count") + false_positive = _sum_or_none(group, "entity_false_positive_count") + false_negative = _sum_or_none(group, "entity_false_negative_count") + strict_precision = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_positive)) + strict_recall = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_negative)) + relaxed_gt_found = _sum_or_none(group, "entity_relaxed_gt_found_count") + relaxed_detected_tp = _sum_or_none(group, "entity_relaxed_detected_tp_count") + label_compatible_gt_found = _sum_or_none(group, "entity_relaxed_label_compatible_gt_found_count") + label_compatible_detected_tp = _sum_or_none(group, "entity_relaxed_label_compatible_detected_tp_count") + relaxed_precision = _safe_ratio(relaxed_detected_tp, final_entity_count) + relaxed_recall = _safe_ratio(relaxed_gt_found, ground_truth_entity_count) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, final_entity_count) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, ground_truth_entity_count) + return GroupAnalysisRow( + workload_id=_none_if_nan(workload_id), + workload_category=_none_if_nan(workload_category), + config_id=_none_if_nan(config_id), + experimental_detection_strategy=_none_if_nan(detection_strategy), + experimental_replacement_strategy=_none_if_nan(replacement_strategy), + entity_label_set_id=_none_if_nan(entity_label_set_id), + entity_label_count=_int_if_not_nan(entity_label_count), + gliner_threshold=_float_if_not_nan(gliner_threshold), + case_count=case_count, + failed_case_count=failed_case_count, + failed_case_rate=_request_failure_rate(failed=failed_case_count, total=case_count), + error_stage_count=_sum_int_or_zero(group, "error_stage_count"), + error_ndd_workflow_count=_sum_int_or_zero(group, "error_ndd_workflow_count"), + error_model_workflow_count=_sum_int_or_zero(group, "error_model_workflow_count"), + median_pipeline_elapsed_sec=_median_or_none(group, "pipeline_elapsed_sec"), + median_ndd_elapsed_sec_total=_median_or_none(group, "ndd_elapsed_sec_total"), + median_observed_total_requests=_median_or_none(group, "observed_total_requests"), + median_observed_successful_requests=_median_or_none(group, "observed_successful_requests"), + median_observed_input_tokens=_median_or_none(group, "observed_input_tokens"), + median_observed_output_tokens=_median_or_none(group, "observed_output_tokens"), + median_observed_total_tokens=_median_or_none(group, "observed_total_tokens"), + median_observed_failed_requests=_median_or_none(group, "observed_failed_requests"), + median_observed_failed_request_rate=_median_or_none(group, "observed_failed_request_rate"), + median_observed_bridge_fallback_requests=_median_or_none(group, "observed_bridge_fallback_requests"), + median_observed_non_bridge_total_requests=_median_or_none(group, "observed_non_bridge_total_requests"), + median_observed_non_bridge_failed_requests=_median_or_none(group, "observed_non_bridge_failed_requests"), + median_observed_non_bridge_failed_request_rate=_median_or_none( + group, + "observed_non_bridge_failed_request_rate", + ), + total_record_count=total_record_count, + median_record_count=_median_or_none(group, "record_count"), + total_input_text_tokens=total_input_text_tokens, + median_input_text_tokens_total=_median_or_none(group, "input_text_tokens_total"), + median_records_per_pipeline_sec=_median_or_none(group, "records_per_pipeline_sec"), + median_records_per_ndd_sec=_median_or_none(group, "records_per_ndd_sec"), + median_input_text_tokens_per_pipeline_sec=_median_or_none(group, "input_text_tokens_per_pipeline_sec"), + median_input_text_tokens_per_ndd_sec=_median_or_none(group, "input_text_tokens_per_ndd_sec"), + median_topology_endpoint_count=_median_or_none(group, "topology_endpoint_count"), + median_topology_gpu_count=_median_or_none(group, "topology_gpu_count"), + median_topology_tensor_parallelism=_median_or_none(group, "topology_tensor_parallelism"), + median_topology_shard_count=_median_or_none(group, "topology_shard_count"), + median_input_text_tokens_per_endpoint_sec=_median_or_none(group, "input_text_tokens_per_endpoint_sec"), + median_input_text_tokens_per_gpu_sec=_median_or_none(group, "input_text_tokens_per_gpu_sec"), + median_final_entity_count=_median_or_none(group, "final_entity_count"), + total_empty_detection_count=total_empty_detection_count, + empty_detection_rate=_safe_ratio(total_empty_detection_count, total_record_count), + total_empty_detection_with_ground_truth_count=total_empty_detection_with_gt_count, + empty_detection_with_ground_truth_rate=_safe_ratio( + total_empty_detection_with_gt_count, + total_ground_truth_record_count, + ), + total_ground_truth_record_count=total_ground_truth_record_count, + sum_ground_truth_entity_count=ground_truth_entity_count, + sum_entity_true_positive_count=true_positive, + sum_entity_false_positive_count=false_positive, + sum_entity_false_negative_count=false_negative, + micro_entity_precision=strict_precision, + micro_entity_recall=strict_recall, + micro_entity_f1=_f1(strict_precision, strict_recall), + sum_entity_relaxed_gt_found_count=relaxed_gt_found, + sum_entity_relaxed_detected_tp_count=relaxed_detected_tp, + sum_entity_relaxed_label_compatible_gt_found_count=label_compatible_gt_found, + sum_entity_relaxed_label_compatible_detected_tp_count=label_compatible_detected_tp, + micro_entity_relaxed_precision=relaxed_precision, + micro_entity_relaxed_recall=relaxed_recall, + micro_entity_relaxed_f1=_f1(relaxed_precision, relaxed_recall), + micro_entity_relaxed_label_compatible_precision=label_compatible_precision, + micro_entity_relaxed_label_compatible_recall=label_compatible_recall, + micro_entity_relaxed_label_compatible_f1=_f1(label_compatible_precision, label_compatible_recall), + median_entity_relaxed_f1=_median_or_none(group, "entity_relaxed_f1"), + median_entity_relaxed_label_compatible_f1=_median_or_none( + group, + "entity_relaxed_label_compatible_f1", + ), + median_replacement_missing_final_entity_count=_median_or_none( + group, + "replacement_missing_final_entity_count", + ), + median_replacement_missing_final_value_count=_median_or_none(group, "replacement_missing_final_value_count"), + replacement_missing_final_entity_label_counts=_sum_prefixed_ints( + group, + "replacement_missing_final_entity_label_counts.", + ), + median_replacement_synthetic_original_collision_count=_median_or_none( + group, + "replacement_synthetic_original_collision_count", + ), + median_replacement_synthetic_original_collision_value_count=_median_or_none( + group, + "replacement_synthetic_original_collision_value_count", + ), + replacement_synthetic_original_collision_label_counts=_sum_prefixed_ints( + group, + "replacement_synthetic_original_collision_label_counts.", + ), + sum_original_value_leak_count=_sum_or_none(group, "original_value_leak_count"), + leaking_case_count=_positive_count(group, "original_value_leak_count"), + median_original_value_leak_count=_median_or_none(group, "original_value_leak_count"), + median_seed_entity_count=_median_or_none(group, "seed_entity_count"), + median_seed_validation_candidate_count=_median_or_none(group, "seed_validation_candidate_count"), + median_estimated_seed_validation_chunk_count=_median_or_none(group, "estimated_seed_validation_chunk_count"), + median_augmented_entity_count=_median_or_none(group, "augmented_entity_count"), + median_augmented_new_final_value_count=_median_or_none(group, "augmented_new_final_value_count"), + median_artifact_final_entity_count=_median_or_none(group, "artifact_final_entity_count"), + median_artifact_final_detector_entity_count=_median_or_none(group, "artifact_final_detector_entity_count"), + median_artifact_final_augmenter_entity_count=_median_or_none(group, "artifact_final_augmenter_entity_count"), + median_artifact_final_entity_signature_count=_median_or_none(group, "artifact_final_entity_signature_count"), + ) + + +def _none_if_nan(value: object) -> str | None: + if pd.isna(value): + return None + return str(value) + + +def _int_if_not_nan(value: object) -> int | None: + if pd.isna(value): + return None + return int(float(cast(Any, value))) + + +def _float_if_not_nan(value: object) -> float | None: + if pd.isna(value): + return None + return float(cast(Any, value)) + + +def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.median()) + + +def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: + return int(_sum_or_zero(dataframe, column)) + + +def _sum_bool_or_zero(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(dataframe[column].fillna(False).astype(bool).sum()) + + +def _sum_int_or_none(dataframe: pd.DataFrame, column: str) -> int | None: + value = _sum_or_none(dataframe, column) + return int(value) if value is not None else None + + +def _request_failure_rate(*, failed: object, total: object) -> float | None: + failed_value = _optional_number(failed) + total_value = _optional_number(total) + if failed_value is None or total_value is None or total_value <= 0: + return None + return failed_value / total_value + + +def _safe_rate(numerator: object, elapsed_sec: object) -> float | None: + numerator_value = _optional_number(numerator) + elapsed_value = _optional_number(elapsed_sec) + if numerator_value is None or elapsed_value is None or elapsed_value <= 0: + return None + return numerator_value / elapsed_value + + +def _safe_ratio(numerator: object, denominator: object) -> float | None: + numerator_value = _optional_number(numerator) + denominator_value = _optional_number(denominator) + if numerator_value is None or denominator_value is None or denominator_value <= 0: + return None + return numerator_value / denominator_value + + +def _sum_optional_numbers(*values: object) -> float | None: + numeric_values = [_optional_number(value) for value in values] + if any(value is None for value in numeric_values): + return None + return sum(cast(float, value) for value in numeric_values) + + +def _f1(precision: float | None, recall: float | None) -> float | None: + if precision is None or recall is None or precision + recall == 0: + return None + return 2 * precision * recall / (precision + recall) + + +def _optional_number(value: object) -> float | None: + if value is None or pd.isna(value): + return None + return float(value) + + +def _estimated_validation_chunk_count( + artifact_rows: pd.DataFrame, + *, + validation_max_entities_per_call: int | None, +) -> float | None: + if validation_max_entities_per_call is None or validation_max_entities_per_call <= 0: + return None + if "seed_validation_candidate_count" not in artifact_rows.columns: + return None + counts = pd.to_numeric(artifact_rows["seed_validation_candidate_count"], errors="coerce").dropna() + if counts.empty: + return None + return float(sum(math.ceil(count / validation_max_entities_per_call) for count in counts if count > 0)) + + +def write_analysis_tables( + result: BenchmarkOutputAnalysis, + output_dir: Path, + export_format: ExportFormat, +) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows( + result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, CaseAnalysisRow + ), + _write_model_rows( + result.groups, output_dir / f"group_analysis.{export_format.value}", export_format, GroupAnalysisRow + ), + _write_model_rows( + result.model_usage, + output_dir / f"model_analysis.{export_format.value}", + export_format, + ModelUsageAnalysisRow, + ), + _write_model_rows( + result.model_usage_groups, + output_dir / f"model_group_analysis.{export_format.value}", + export_format, + ModelUsageGroupAnalysisRow, + ), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows( + rows: list[BaseModel], + path: Path, + export_format: ExportFormat, + row_model: type[BaseModel], +) -> TableSummary: + if rows: + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + else: + table = pd.DataFrame(columns=list(row_model.model_fields)) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def render_result(result: BenchmarkOutputAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Analyzed {result.case_count} case(s) across {result.group_count} group(s); " + f"model rows={result.model_usage_count}" + ] + for group in result.groups: + label = ( + f"{group.workload_id}/{group.config_id}/" + f"{group.experimental_detection_strategy}/{group.experimental_replacement_strategy}" + ) + lines.append( + f"- {label}: cases={group.case_count}, median_entities={group.median_final_entity_count}, " + f"failed_cases={group.failed_case_count}/{group.case_count}, " + f"median_requests={group.median_observed_total_requests}, median_tokens={group.median_observed_total_tokens}, " + f"median_input_tok_s={group.median_input_text_tokens_per_pipeline_sec}, " + f"micro_relaxed_f1={group.micro_entity_relaxed_f1}, " + f"empty_with_gt={group.total_empty_detection_with_ground_truth_count}, " + f"median_failed_request_rate={group.median_observed_failed_request_rate}, " + f"median_aug_new_final={group.median_augmented_new_final_value_count}" + ) + return "\n".join(lines) + + +@app.default +def main( + benchmark_dir: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + detection_artifacts: Annotated[Path | None, cyclopts.Parameter("--detection-artifacts")] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_benchmark_output(benchmark_dir, detection_artifacts=detection_artifacts) + if output is not None: + write_analysis_tables(result, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_dd_traces.py b/tools/measurement/analyze_dd_traces.py new file mode 100644 index 00000000..157beec9 --- /dev/null +++ b/tools/measurement/analyze_dd_traces.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze DataDesigner message traces without emitting raw prompt or response text. + +Usage: + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --output analysis --format csv + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --json +""" + +from __future__ import annotations + +import json +import logging +import re +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.dd_traces") + +JSON_FENCE_RE = re.compile(r"```\s*json\b", re.IGNORECASE) + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class TraceAnalysisRow(BaseModel): + trace_file: str + line_number: int + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + case_id: str | None = None + run_id: str + workflow_name: str | None = None + model_alias: str | None = None + model_name: str | None = None + model_provider_name: str | None = None + status: str | None = None + error_type: str | None = None + elapsed_sec: float | None = None + message_count: int = 0 + prompt_chars: int = 0 + response_shape: str + response_chars: int = 0 + response_has_thinking: bool = False + response_has_json_fence: bool = False + response_has_embedded_json: bool = False + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +class TraceGroupAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + workflow_name: str | None = None + model_alias: str | None = None + model_name: str | None = None + model_provider_name: str | None = None + status: str | None = None + error_type: str | None = None + response_shape: str + trace_record_count: int + error_count: int = 0 + thinking_count: int = 0 + json_fence_count: int = 0 + embedded_json_count: int = 0 + median_elapsed_sec: float | None = None + median_prompt_chars: float | None = None + median_response_chars: float | None = None + sum_input_tokens: int = 0 + sum_output_tokens: int = 0 + sum_total_tokens: int = 0 + + +class TraceAnalysis(BaseModel): + trace_path: str + rows: list[TraceAnalysisRow] = Field(default_factory=list) + groups: list[TraceGroupAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def trace_record_count(self) -> int: + return len(self.rows) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_trace_path(trace_path: Path) -> TraceAnalysis: + rows = [ + _trace_row(record, trace_file=path, line_number=line_number) + for path in _iter_trace_files(trace_path) + for line_number, record in _iter_trace_records(path) + ] + return TraceAnalysis( + trace_path=str(trace_path), + rows=rows, + groups=build_trace_group_rows(rows), + ) + + +def _iter_trace_files(trace_path: Path) -> list[Path]: + if not trace_path.exists(): + raise ValueError(f"trace path does not exist: {trace_path}") + if trace_path.is_file(): + return [trace_path] + if not trace_path.is_dir(): + raise ValueError(f"trace path is not a file or directory: {trace_path}") + return sorted(trace_path.rglob("*.jsonl")) + + +def _iter_trace_records(path: Path) -> list[tuple[int, dict[str, Any]]]: + records: list[tuple[int, dict[str, Any]]] = [] + for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): + if not line.strip(): + continue + payload = json.loads(line) + if isinstance(payload, dict) and payload.get("record_type") == "dd_message_trace": + records.append((line_number, payload)) + return records + + +def _trace_row(record: dict[str, Any], *, trace_file: Path, line_number: int) -> TraceAnalysisRow: + response = record.get("response") if isinstance(record.get("response"), dict) else None + response_content = _string_or_none(response.get("content")) if response is not None else None + reasoning_content = _string_or_none(response.get("reasoning_content")) if response is not None else None + run_tags = record.get("run_tags") if isinstance(record.get("run_tags"), dict) else {} + usage = record.get("usage") if isinstance(record.get("usage"), dict) else {} + return TraceAnalysisRow( + trace_file=str(trace_file), + line_number=line_number, + suite_id=_string_or_none(run_tags.get("suite_id")), + workload_id=_string_or_none(run_tags.get("workload_id")), + config_id=_string_or_none(run_tags.get("config_id")), + case_id=_string_or_none(run_tags.get("case_id")), + run_id=str(record.get("run_id") or ""), + workflow_name=_string_or_none(record.get("workflow_name")), + model_alias=_string_or_none(record.get("model_alias")), + model_name=_string_or_none(record.get("model_name")), + model_provider_name=_string_or_none(record.get("model_provider_name")), + status=_string_or_none(record.get("status")), + error_type=_string_or_none(record.get("error_type")), + elapsed_sec=_float_or_none(record.get("elapsed_sec")), + message_count=_message_count(record.get("messages")), + prompt_chars=_message_text_chars(record.get("messages")), + response_shape=_response_shape(response_content), + response_chars=len(response_content or ""), + response_has_thinking=_has_thinking(response_content, reasoning_content), + response_has_json_fence=_has_json_fence(response_content), + response_has_embedded_json=_has_embedded_json(response_content), + input_tokens=_int_or_none(usage.get("input_tokens")), + output_tokens=_int_or_none(usage.get("output_tokens")), + total_tokens=_int_or_none(usage.get("total_tokens")), + ) + + +def build_trace_group_rows(rows: list[TraceAnalysisRow]) -> list[TraceGroupAnalysisRow]: + if not rows: + return [] + table = pd.DataFrame([row.model_dump() for row in rows]) + group_columns = [ + "suite_id", + "workload_id", + "config_id", + "workflow_name", + "model_alias", + "model_name", + "model_provider_name", + "status", + "error_type", + "response_shape", + ] + return [ + _build_trace_group_row(keys, group) for keys, group in table.groupby(group_columns, dropna=False, sort=True) + ] + + +def _build_trace_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> TraceGroupAnalysisRow: + ( + suite_id, + workload_id, + config_id, + workflow_name, + model_alias, + model_name, + provider_name, + status, + error_type, + response_shape, + ) = keys + return TraceGroupAnalysisRow( + suite_id=_none_if_nan(suite_id), + workload_id=_none_if_nan(workload_id), + config_id=_none_if_nan(config_id), + workflow_name=_none_if_nan(workflow_name), + model_alias=_none_if_nan(model_alias), + model_name=_none_if_nan(model_name), + model_provider_name=_none_if_nan(provider_name), + status=_none_if_nan(status), + error_type=_none_if_nan(error_type), + response_shape=str(response_shape), + trace_record_count=len(group), + error_count=int((group["status"] == "error").sum()), + thinking_count=int(group["response_has_thinking"].sum()), + json_fence_count=int(group["response_has_json_fence"].sum()), + embedded_json_count=int(group["response_has_embedded_json"].sum()), + median_elapsed_sec=_median_or_none(group, "elapsed_sec"), + median_prompt_chars=_median_or_none(group, "prompt_chars"), + median_response_chars=_median_or_none(group, "response_chars"), + sum_input_tokens=_sum_int_or_zero(group, "input_tokens"), + sum_output_tokens=_sum_int_or_zero(group, "output_tokens"), + sum_total_tokens=_sum_int_or_zero(group, "total_tokens"), + ) + + +def _response_shape(content: str | None) -> str: + if not content: + return "none" + if _has_json_fence(content): + return "fenced_json" + stripped = content.strip() + try: + json.loads(stripped) + except json.JSONDecodeError: + return "embedded_json" if _has_embedded_json(content) else "text" + return "raw_json" + + +def _has_thinking(content: str | None, reasoning_content: str | None) -> bool: + return bool(reasoning_content) or "" in (content or "") + + +def _has_json_fence(content: str | None) -> bool: + return bool(content and JSON_FENCE_RE.search(content)) + + +def _has_embedded_json(content: str | None) -> bool: + if not content: + return False + decoder = json.JSONDecoder() + for start, char in enumerate(content): + if char not in "{[": + continue + try: + parsed, _ = decoder.raw_decode(content, start) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict | list): + return True + return False + + +def _message_count(messages: object) -> int: + return len(messages) if isinstance(messages, list) else 0 + + +def _message_text_chars(messages: object) -> int: + if not isinstance(messages, list): + return 0 + return sum(_message_content_chars(message) for message in messages) + + +def _message_content_chars(message: object) -> int: + if not isinstance(message, dict): + return 0 + return _content_chars(message.get("content")) + + +def _content_chars(content: object) -> int: + if isinstance(content, str): + return len(content) + if isinstance(content, list): + return sum(_content_item_chars(item) for item in content) + return 0 + + +def _content_item_chars(item: object) -> int: + if isinstance(item, str): + return len(item) + if isinstance(item, dict): + return len(str(item.get("text") or "")) + return 0 + + +def _string_or_none(value: object) -> str | None: + return str(value) if value is not None and not pd.isna(value) else None + + +def _float_or_none(value: object) -> float | None: + return float(value) if value is not None and not pd.isna(value) else None + + +def _int_or_none(value: object) -> int | None: + return int(value) if value is not None and not pd.isna(value) else None + + +def _none_if_nan(value: object) -> str | None: + if pd.isna(value): + return None + return str(value) + + +def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.median()) + + +def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) + + +def write_analysis_tables(result: TraceAnalysis, output_dir: Path, export_format: ExportFormat) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows(result.rows, output_dir / f"trace_analysis.{export_format.value}", export_format), + _write_model_rows(result.groups, output_dir / f"trace_group_analysis.{export_format.value}", export_format), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows(rows: list[BaseModel], path: Path, export_format: ExportFormat) -> TableSummary: + table = pd.DataFrame([row.model_dump() for row in rows]) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def render_result(result: TraceAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Analyzed {result.trace_record_count} trace record(s) across {result.group_count} group(s)"] + for group in result.groups: + model_label = group.model_alias or group.model_name + label = f"{group.workload_id}/{group.config_id}/{group.workflow_name}/{model_label}/{group.response_shape}" + lines.append( + f"- {label}: traces={group.trace_record_count}, errors={group.error_count}, " + f"thinking={group.thinking_count}, json_fence={group.json_fence_count}, " + f"tokens={group.sum_total_tokens}" + ) + return "\n".join(lines) + + +@app.default +def main( + trace_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_trace_path(trace_path) + if output is not None: + write_analysis_tables(result, output, format) + except (OSError, ValueError, json.JSONDecodeError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_detection_artifacts.py b/tools/measurement/analyze_detection_artifacts.py new file mode 100644 index 00000000..0c16da75 --- /dev/null +++ b/tools/measurement/analyze_detection_artifacts.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze detection artifacts for augmentation contribution and label-shape risks. + +Usage: + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts --output detection.jsonl + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts --json +""" + +from __future__ import annotations + +import ast +import hashlib +import json +import logging +import math +import re +import sys +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Iterable + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_VALIDATION_CANDIDATES, +) +from anonymizer.engine.schemas import ( + AugmentedEntitiesSchema, + EntitiesSchema, + EntitySchema, + ValidationCandidatesSchema, +) + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.detection_artifacts") + +API_KEY_PREFIX_RE = re.compile(r"^(sk-|sk_|sk-ant-|sk-proj-|ghp_|pat-|hf_|xox[a-z]-|ya29\.|aiza|akia|bearer\s+)", re.I) + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class DetectionArtifactRow(BaseModel): + workflow_name: str + batch_file: str + row_index: int + seed_entity_count: int + seed_validation_candidate_count: int + merged_validation_candidate_count: int + augmented_entity_count: int + final_entity_count: int + augmented_duplicate_seed_value_count: int + augmented_new_value_count: int + augmented_new_final_value_count: int + weak_api_key_shape_count: int + final_entity_signature_count: int + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_entity_signature_labels: dict[str, str] = Field(default_factory=dict) + final_entity_signature_details: dict[str, dict[str, object]] = Field(default_factory=dict) + weak_api_key_shape_label_counts: dict[str, int] = Field(default_factory=dict) + final_label_counts: dict[str, int] = Field(default_factory=dict) + final_source_counts: dict[str, int] = Field(default_factory=dict) + + +class DetectionArtifactAnalysis(BaseModel): + artifact_path: str + rows: list[DetectionArtifactRow] = Field(default_factory=list) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_artifacts( + artifact_path: Path, + *, + parquet_files: Iterable[Path] | None = None, +) -> DetectionArtifactAnalysis: + if not artifact_path.exists() or not artifact_path.is_dir(): + raise ValueError(f"artifact path is not a directory: {artifact_path}") + rows: list[DetectionArtifactRow] = [] + for parquet_file in parquet_files if parquet_files is not None else iter_detection_parquet_files(artifact_path): + rows.extend(_analyze_parquet_file(parquet_file, artifact_root=artifact_path)) + return DetectionArtifactAnalysis(artifact_path=str(artifact_path), rows=rows) + + +def iter_detection_parquet_files(artifact_path: Path) -> list[Path]: + files: list[Path] = [] + for workflow_dir in sorted(path for path in artifact_path.iterdir() if path.is_dir()): + if not workflow_dir.name.startswith("entity-detection"): + continue + files.extend(sorted((workflow_dir / "parquet-files").glob("*.parquet"))) + return files + + +def _analyze_parquet_file(parquet_file: Path, *, artifact_root: Path) -> list[DetectionArtifactRow]: + dataframe = pd.read_parquet(parquet_file) + workflow_name = parquet_file.parents[1].name + batch_file = str(parquet_file.relative_to(artifact_root)) + return [ + _analyze_dataframe_row(row, workflow_name=workflow_name, batch_file=batch_file, row_index=row_index) + for row_index, row in dataframe.iterrows() + ] + + +def _analyze_dataframe_row( + row: pd.Series, + *, + workflow_name: str, + batch_file: str, + row_index: int, +) -> DetectionArtifactRow: + seed_entities = _parse_entities(row.get(COL_SEED_ENTITIES_JSON)) + augmented_entities = _parse_augmented_entities(row.get(COL_AUGMENTED_ENTITIES)) + final_entities = _parse_entities(row.get(COL_DETECTED_ENTITIES)) + return build_detection_artifact_row_from_entities( + workflow_name=workflow_name, + batch_file=batch_file, + row_index=row_index, + seed_entities=seed_entities, + seed_validation_candidate_count=_parse_validation_candidate_count(row.get(COL_SEED_VALIDATION_CANDIDATES)), + merged_validation_candidate_count=_parse_validation_candidate_count(row.get(COL_VALIDATION_CANDIDATES)), + augmented_entities=augmented_entities, + final_entities=final_entities, + ) + + +def build_detection_artifact_row_from_entities( + *, + workflow_name: str, + batch_file: str, + row_index: int, + seed_entities: list[EntitySchema], + seed_validation_candidate_count: int, + merged_validation_candidate_count: int, + augmented_entities: list[EntitySchema], + final_entities: list[EntitySchema], +) -> DetectionArtifactRow: + seed_values = {_value_key(entity.value) for entity in seed_entities} + final_values = {_value_key(entity.value) for entity in final_entities} + augmented_new = [entity for entity in augmented_entities if _value_key(entity.value) not in seed_values] + weak_counts = _weak_api_key_shape_counts(final_entities) + final_entity_signatures = _entity_signature_hashes(final_entities, row_index=int(row_index)) + final_entity_signature_labels = _entity_signature_labels(final_entities, row_index=int(row_index)) + final_entity_signature_details = _entity_signature_details(final_entities, row_index=int(row_index)) + return DetectionArtifactRow( + workflow_name=workflow_name, + batch_file=batch_file, + row_index=int(row_index), + seed_entity_count=len(seed_entities), + seed_validation_candidate_count=seed_validation_candidate_count, + merged_validation_candidate_count=merged_validation_candidate_count, + augmented_entity_count=len(augmented_entities), + final_entity_count=len(final_entities), + augmented_duplicate_seed_value_count=len(augmented_entities) - len(augmented_new), + augmented_new_value_count=len(augmented_new), + augmented_new_final_value_count=sum(1 for entity in augmented_new if _value_key(entity.value) in final_values), + weak_api_key_shape_count=sum(weak_counts.values()), + final_entity_signature_count=len(final_entity_signatures), + final_entity_signature_hashes=final_entity_signatures, + final_entity_signature_labels=final_entity_signature_labels, + final_entity_signature_details=final_entity_signature_details, + weak_api_key_shape_label_counts=dict(weak_counts), + final_label_counts=_count_by(final_entities, "label"), + final_source_counts=_count_by(final_entities, "source"), + ) + + +def _parse_entities(raw: object) -> list[EntitySchema]: + values = _extract_payload_list(raw, key="entities") + parsed = EntitiesSchema.model_validate({"entities": values}) + return [entity for entity in parsed.entities if entity.value and entity.label] + + +def _parse_augmented_entities(raw: object) -> list[EntitySchema]: + values = _extract_payload_list(raw, key="entities") + parsed = AugmentedEntitiesSchema.model_validate({"entities": values}) + return [ + EntitySchema(value=entity.value, label=entity.label, source="augmenter") + for entity in parsed.entities + if entity.value and entity.label + ] + + +def _parse_validation_candidate_count(raw: object) -> int: + values = _extract_payload_list(raw, key="candidates") + parsed = ValidationCandidatesSchema.model_validate({"candidates": values}) + return len(parsed.candidates) + + +def _extract_payload_list(raw: object, *, key: str) -> list[object]: + payload = _coerce_payload(raw) + if isinstance(payload, dict): + return _coerce_list(payload.get(key)) + return _coerce_list(payload) + + +def _coerce_payload(raw: object) -> object: + if _is_missing(raw): + return {} + if hasattr(raw, "tolist"): + raw = raw.tolist() + if not isinstance(raw, str): + return raw + text = raw.strip() + if not text: + return {} + try: + return json.loads(text) + except json.JSONDecodeError: + try: + return ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + + +def _coerce_list(value: object) -> list[object]: + value = _coerce_payload(value) + if isinstance(value, list): + return value + return [] + + +def _is_missing(value: object) -> bool: + return value is None or (isinstance(value, float) and math.isnan(value)) + + +def _value_key(value: str) -> str: + return " ".join(value.casefold().split()) + + +def _entity_signature_hashes(entities: list[EntitySchema], *, row_index: int) -> list[str]: + signatures = {_entity_signature_hash(entity, row_index=row_index) for entity in entities} + return sorted(signatures) + + +def _entity_signature_labels(entities: list[EntitySchema], *, row_index: int) -> dict[str, str]: + labels = {_entity_signature_hash(entity, row_index=row_index): entity.label for entity in entities} + return dict(sorted(labels.items())) + + +def _entity_signature_details(entities: list[EntitySchema], *, row_index: int) -> dict[str, dict[str, object]]: + details = { + _entity_signature_hash(entity, row_index=row_index): { + "label": entity.label, + "source": entity.source, + "row_index": int(row_index), + "start_position": entity.start_position, + "end_position": entity.end_position, + "value_length": len(entity.value), + } + for entity in entities + } + return dict(sorted(details.items())) + + +def _entity_signature_hash(entity: EntitySchema, *, row_index: int) -> str: + payload = json.dumps( + { + "row": row_index, + "label": entity.label, + "start": entity.start_position, + "end": entity.end_position, + }, + ensure_ascii=True, + sort_keys=True, + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] + + +def _weak_api_key_shape_counts(entities: list[EntitySchema]) -> Counter[str]: + counts: Counter[str] = Counter() + for entity in entities: + if entity.label == "api_key" and not _looks_like_api_key(entity.value): + counts[entity.label] += 1 + return counts + + +def _looks_like_api_key(value: str) -> bool: + stripped = value.strip() + if API_KEY_PREFIX_RE.search(stripped): + return True + compact = re.sub(r"[\s'\";:,/]+", "", stripped) + if len(compact) < 20: + return False + return bool(re.search(r"[A-Za-z]", compact)) and bool(re.search(r"\d", compact)) + + +def _count_by(entities: list[EntitySchema], field: str) -> dict[str, int]: + counts = Counter(str(getattr(entity, field)) for entity in entities if getattr(entity, field)) + return dict(sorted(counts.items())) + + +def write_rows(rows: list[DetectionArtifactRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def render_result(result: DetectionArtifactAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + total_warnings = sum(row.weak_api_key_shape_count for row in result.rows) + workflows = Counter(row.workflow_name for row in result.rows) + lines = [f"Analyzed {len(result.rows)} detection artifact row(s) from {result.artifact_path}"] + for workflow_name, count in sorted(workflows.items()): + lines.append(f"- {workflow_name}: {count} row(s)") + lines.append(f"Weak api_key shape warnings: {total_warnings}") + return "\n".join(lines) + + +@app.default +def main( + artifact_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.jsonl, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_artifacts(artifact_path) + if output is not None: + write_rows(result.rows, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py new file mode 100644 index 00000000..737d439a --- /dev/null +++ b/tools/measurement/analyze_staged_detection_output.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze benchmark-only DD-free staged detection probe outputs. + +Usage: + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe/staged-detection-cases.jsonl + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --output analysis --format csv + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --json +""" + +from __future__ import annotations + +import json +import logging +import sys +from collections import Counter, defaultdict +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.staged_detection_output") + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain +class StagedCaseAnalysisRow(BaseModel): + source_path: str + case_id: str + row_index: int | None = None + seed_source: str | None = None + status: str | None = None + case_failed: bool = False + elapsed_sec: float | None = None + model_elapsed_sec: float | None = None + model_phase_count: int = 0 + model_request_count: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + seed_entity_count: int = 0 + validation_candidate_count: int = 0 + validation_decision_count: int = 0 + augmented_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_final_entity_signature_count: int | None = None + shared_final_entity_signature_count: int | None = None + baseline_only_final_entity_signature_count: int | None = None + direct_only_final_entity_signature_count: int | None = None + baseline_shared_signature_rate: float | None = None + baseline_loss_signature_rate: float | None = None + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + direct_only_label_counts: dict[str, int] = Field(default_factory=dict) + error: str | None = None + + +class StagedGroupAnalysisRow(BaseModel): + seed_source: str | None = None + case_count: int + completed_case_count: int = 0 + error_case_count: int = 0 + failed_case_rate: float | None = None + elapsed_sec_sum: float | None = None + elapsed_sec_mean: float | None = None + model_elapsed_sec_sum: float | None = None + model_elapsed_sec_mean: float | None = None + model_phase_count_sum: int = 0 + model_request_count_sum: int = 0 + prompt_tokens_sum: int = 0 + completion_tokens_sum: int = 0 + total_tokens_sum: int = 0 + final_entity_count_sum: int = 0 + final_entity_signature_count_sum: int = 0 + baseline_final_entity_signature_count_sum: int = 0 + shared_final_entity_signature_count_sum: int = 0 + baseline_only_final_entity_signature_count_sum: int = 0 + direct_only_final_entity_signature_count_sum: int = 0 + baseline_shared_signature_rate: float | None = None + baseline_loss_signature_rate: float | None = None + + +class LabelDeltaAnalysisRow(BaseModel): + seed_source: str | None = None + delta_type: str + label: str + count: int + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +class StagedDetectionOutputAnalysis(BaseModel): + source_path: str + cases: list[StagedCaseAnalysisRow] = Field(default_factory=list) + groups: list[StagedGroupAnalysisRow] = Field(default_factory=list) + label_deltas: list[LabelDeltaAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def case_count(self) -> int: + return len(self.cases) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + @computed_field + @property + def label_delta_count(self) -> int: + return len(self.label_deltas) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_staged_detection_output(input_path: Path) -> StagedDetectionOutputAnalysis: + case_path = resolve_case_path(input_path) + case_rows = [_build_case_row(row, source_path=case_path) for row in read_case_records(case_path)] + return StagedDetectionOutputAnalysis( + source_path=str(case_path), + cases=case_rows, + groups=build_group_rows(case_rows), + label_deltas=build_label_delta_rows(case_rows), + ) + + +def resolve_case_path(input_path: Path) -> Path: + if not input_path.exists(): + raise ValueError(f"input path does not exist: {input_path}") + if input_path.is_dir(): + input_path = input_path / "staged-detection-cases.jsonl" + if not input_path.exists(): + raise ValueError(f"staged detection case file does not exist: {input_path}") + if input_path.is_dir(): + raise ValueError(f"input path is a directory: {input_path}") + return input_path + + +def read_case_records(case_path: Path) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + for line in case_path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + record = json.loads(line) + if not isinstance(record, dict): + raise ValueError(f"JSONL row is not an object in {case_path}") + if record.get("record_type") in (None, "staged_detection_case"): + records.append(record) + return records + + +def _build_case_row(record: dict[str, Any], *, source_path: Path) -> StagedCaseAnalysisRow: + comparison = _dict_value(record.get("comparison")) + baseline_count = _optional_int(comparison.get("baseline_final_entity_signature_count")) + shared_count = _optional_int(comparison.get("shared_final_entity_signature_count")) + baseline_only_count = _optional_int(comparison.get("baseline_only_final_entity_signature_count")) + return StagedCaseAnalysisRow( + source_path=str(source_path), + case_id=str(record.get("case_id") or ""), + row_index=_optional_int(record.get("row_index")), + seed_source=_optional_str(record.get("seed_source")), + status=_optional_str(record.get("status")), + case_failed=str(record.get("status")).lower() == "error", + elapsed_sec=_optional_float(record.get("elapsed_sec")), + model_elapsed_sec=_optional_float(record.get("model_elapsed_sec")), + model_phase_count=_int_value(record.get("model_phase_count")), + model_request_count=_int_value(record.get("model_request_count")), + **_usage_fields(_dict_value(record.get("total_usage"))), + **_entity_count_fields(record), + baseline_final_entity_signature_count=baseline_count, + shared_final_entity_signature_count=shared_count, + baseline_only_final_entity_signature_count=baseline_only_count, + direct_only_final_entity_signature_count=_optional_int( + comparison.get("direct_only_final_entity_signature_count") + ), + baseline_shared_signature_rate=_rate(shared_count, baseline_count), + baseline_loss_signature_rate=_rate(baseline_only_count, baseline_count), + baseline_only_label_counts=_counter_dict(comparison.get("baseline_only_label_counts")), + direct_only_label_counts=_counter_dict(comparison.get("direct_only_label_counts")), + error=_optional_str(record.get("error")), + ) + + +def _usage_fields(usage: dict[str, Any]) -> dict[str, int]: + return { + "prompt_tokens": _int_value(usage.get("prompt_tokens")), + "completion_tokens": _int_value(usage.get("completion_tokens")), + "total_tokens": _int_value(usage.get("total_tokens")), + } + + +def _entity_count_fields(record: dict[str, Any]) -> dict[str, int | dict[str, int]]: + return { + "seed_entity_count": _int_value(record.get("seed_entity_count")), + "validation_candidate_count": _int_value(record.get("validation_candidate_count")), + "validation_decision_count": _int_value(record.get("validation_decision_count")), + "augmented_suggestion_count": _int_value(record.get("augmented_suggestion_count")), + "final_entity_count": _int_value(record.get("final_entity_count")), + "final_entity_signature_count": _int_value(record.get("final_entity_signature_count")), + "final_label_counts": _counter_dict(record.get("final_label_counts")), + } + + +def build_group_rows(cases: list[StagedCaseAnalysisRow]) -> list[StagedGroupAnalysisRow]: + groups: defaultdict[str | None, list[StagedCaseAnalysisRow]] = defaultdict(list) + for case in cases: + groups[case.seed_source].append(case) + return [_build_group_row(seed_source, rows) for seed_source, rows in sorted(groups.items(), key=_group_sort_key)] + + +def _build_group_row(seed_source: str | None, rows: list[StagedCaseAnalysisRow]) -> StagedGroupAnalysisRow: + case_count = len(rows) + error_count = sum(1 for row in rows if row.case_failed) + baseline_total = _sum_optional_int(rows, "baseline_final_entity_signature_count") + shared_total = _sum_optional_int(rows, "shared_final_entity_signature_count") + baseline_only_total = _sum_optional_int(rows, "baseline_only_final_entity_signature_count") + model_request_count = sum(row.model_request_count for row in rows) + return StagedGroupAnalysisRow( + seed_source=seed_source, + case_count=case_count, + completed_case_count=case_count - error_count, + error_case_count=error_count, + failed_case_rate=_rate(error_count, case_count), + elapsed_sec_sum=_sum_optional_float(rows, "elapsed_sec"), + elapsed_sec_mean=_mean_optional_float(rows, "elapsed_sec"), + model_elapsed_sec_sum=_sum_optional_float(rows, "model_elapsed_sec"), + model_elapsed_sec_mean=_mean_optional_float(rows, "model_elapsed_sec"), + model_phase_count_sum=sum(row.model_phase_count for row in rows), + model_request_count_sum=model_request_count, + prompt_tokens_sum=sum(row.prompt_tokens for row in rows), + completion_tokens_sum=sum(row.completion_tokens for row in rows), + total_tokens_sum=sum(row.total_tokens for row in rows), + final_entity_count_sum=sum(row.final_entity_count for row in rows), + final_entity_signature_count_sum=sum(row.final_entity_signature_count for row in rows), + baseline_final_entity_signature_count_sum=baseline_total, + shared_final_entity_signature_count_sum=shared_total, + baseline_only_final_entity_signature_count_sum=baseline_only_total, + direct_only_final_entity_signature_count_sum=_sum_optional_int( + rows, "direct_only_final_entity_signature_count" + ), + baseline_shared_signature_rate=_rate(shared_total, baseline_total), + baseline_loss_signature_rate=_rate(baseline_only_total, baseline_total), + ) + + +def _group_sort_key(item: tuple[str | None, list[StagedCaseAnalysisRow]]) -> str: + return item[0] or "" + + +def build_label_delta_rows(cases: list[StagedCaseAnalysisRow]) -> list[LabelDeltaAnalysisRow]: + counts: Counter[tuple[str | None, str, str]] = Counter() + for case in cases: + for label, count in case.baseline_only_label_counts.items(): + counts[(case.seed_source, "baseline_only", label)] += count + for label, count in case.direct_only_label_counts.items(): + counts[(case.seed_source, "direct_only", label)] += count + return [ + LabelDeltaAnalysisRow(seed_source=seed_source, delta_type=delta_type, label=label, count=count) + for (seed_source, delta_type, label), count in sorted(counts.items(), key=_label_delta_sort_key) + ] + + +def _label_delta_sort_key(item: tuple[tuple[str | None, str, str], int]) -> tuple[str, str, str]: + (seed_source, delta_type, label), _ = item + return (seed_source or "", delta_type, label) + + +def _sum_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: + values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] + return sum(values) if values else None + + +def _mean_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: + values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] + return sum(values) / len(values) if values else None + + +def _sum_optional_int(rows: list[StagedCaseAnalysisRow], field_name: str) -> int: + return sum(value for value in (_optional_int(getattr(row, field_name)) for row in rows) if value is not None) + + +def _rate(numerator: object, denominator: object) -> float | None: + numerator_value = _optional_float(numerator) + denominator_value = _optional_float(denominator) + if numerator_value is None or denominator_value is None or denominator_value <= 0: + return None + return numerator_value / denominator_value + + +def _dict_value(raw: object) -> dict[str, Any]: + return raw if isinstance(raw, dict) else {} + + +def _counter_dict(raw: object) -> dict[str, int]: + return {str(key): _int_value(value) for key, value in _dict_value(raw).items()} + + +def _optional_str(raw: object) -> str | None: + return str(raw) if raw is not None else None + + +def _optional_float(raw: object) -> float | None: + if raw is None: + return None + return float(raw) + + +def _optional_int(raw: object) -> int | None: + if raw is None: + return None + return int(float(raw)) + + +def _int_value(raw: object) -> int: + return _optional_int(raw) or 0 + + +def write_analysis_tables( + result: StagedDetectionOutputAnalysis, + output_dir: Path, + export_format: ExportFormat, +) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows( + result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, StagedCaseAnalysisRow + ), + _write_model_rows( + result.groups, + output_dir / f"group_analysis.{export_format.value}", + export_format, + StagedGroupAnalysisRow, + ), + _write_model_rows( + result.label_deltas, + output_dir / f"label_delta_analysis.{export_format.value}", + export_format, + LabelDeltaAnalysisRow, + ), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows( + rows: list[BaseModel], + path: Path, + export_format: ExportFormat, + row_model: type[BaseModel], +) -> TableSummary: + table = _rows_to_table(rows, row_model) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def _rows_to_table(rows: list[BaseModel], row_model: type[BaseModel]) -> pd.DataFrame: + if rows: + return pd.json_normalize([row.model_dump() for row in rows], sep=".") + return pd.DataFrame(columns=list(row_model.model_fields)) + + +def render_result(result: StagedDetectionOutputAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Analyzed {result.case_count} staged detection case(s) across {result.group_count} group(s)"] + lines.extend(_render_group_line(group, result.label_deltas) for group in result.groups) + return "\n".join(lines) + + +def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDeltaAnalysisRow]) -> str: + label = group.seed_source or "" + lost = _top_labels(label_deltas, seed_source=group.seed_source, delta_type="baseline_only") + return ( + f"- {label}: cases={group.case_count}, errors={group.error_case_count}, " + f"elapsed_sum={_fmt_float(group.elapsed_sec_sum)}s, " + f"model_elapsed_sum={_fmt_float(group.model_elapsed_sec_sum)}s, " + f"requests={group.model_request_count_sum}, tokens={group.total_tokens_sum}, " + f"shared={group.shared_final_entity_signature_count_sum}/" + f"{group.baseline_final_entity_signature_count_sum}, " + f"baseline_only={group.baseline_only_final_entity_signature_count_sum}, " + f"direct_only={group.direct_only_final_entity_signature_count_sum}, lost_labels={lost}" + ) + +def _top_labels(label_deltas: list[LabelDeltaAnalysisRow], *, seed_source: str | None, delta_type: str) -> str: + matches = [delta for delta in label_deltas if delta.seed_source == seed_source and delta.delta_type == delta_type] + if not matches: + return "{}" + items = sorted(matches, key=lambda delta: (-delta.count, delta.label))[:8] + return "{" + ", ".join(f"{item.label}:{item.count}" for item in items) + "}" + + +def _fmt_float(value: float | None) -> str: + return "n/a" if value is None else f"{value:.3f}" + + +@app.default +def main( + input_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_staged_detection_output(input_path) + if output is not None: + write_analysis_tables(result, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py new file mode 100644 index 00000000..bdbad0d5 --- /dev/null +++ b/tools/measurement/compare_strategy_pairs.py @@ -0,0 +1,1463 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Compare benchmark case-analysis rows for baseline/candidate strategy pairs. + +Usage: + uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.csv \ + --baseline-strategy default --candidate-strategy detector_native_validate_no_augment + uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.parquet \ + --baseline-config default --candidate-config no-augment --output comparisons.csv + uv run python tools/measurement/compare_strategy_pairs.py baseline/case_analysis.csv \ + --candidate-case-analysis candidate/case_analysis.csv \ + --baseline-strategy default --candidate-strategy native_single_pass +""" + +from __future__ import annotations + +import ast +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.strategy_pairs") + +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class SafetyVerdict(StrEnum): + passed = "pass" + review = "review" + fail = "fail" + + +class PerformanceVerdict(StrEnum): + improved = "improved" + mixed = "mixed" + regressed = "regressed" + unchanged = "unchanged" + unknown = "unknown" + + +class CandidateVerdict(StrEnum): + candidate_viable = "candidate_viable" + review = "review" + reject = "reject" + + +_log_format = LogFormat.plain +_MIN_CANDIDATE_OVERLAP_RATIO = 0.8 +_MAX_CANDIDATE_BOUNDARY_GAP_CHARS = 8 +_MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS = 8 +_BOUNDARY_NORMALIZED_BASELINE_LABELS = {"api_key", "http_cookie", "unique_id"} + + +class ComparisonSummary(BaseModel): + comparison_count: int = 0 + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + safety_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + candidate_viable_workloads: list[str] = Field(default_factory=list) + review_workloads: list[str] = Field(default_factory=list) + rejected_workloads: list[str] = Field(default_factory=list) + + +class ComparisonRow(BaseModel): + workload_id: str + baseline_config_id: str + candidate_config_id: str + baseline_strategy: str | None = None + candidate_strategy: str | None = None + baseline_replacement_strategy: str | None = None + candidate_replacement_strategy: str | None = None + baseline_case_count: int + candidate_case_count: int + baseline_failed_case_count: float | None = None + candidate_failed_case_count: float | None = None + failed_case_count_delta: float | None = None + baseline_failed_case_rate: float | None = None + candidate_failed_case_rate: float | None = None + failed_case_rate_delta: float | None = None + baseline_pipeline_elapsed_sec: float | None = None + candidate_pipeline_elapsed_sec: float | None = None + pipeline_elapsed_sec_delta: float | None = None + pipeline_elapsed_sec_delta_pct: float | None = None + baseline_observed_total_requests: float | None = None + candidate_observed_total_requests: float | None = None + observed_total_requests_delta: float | None = None + baseline_observed_total_tokens: float | None = None + candidate_observed_total_tokens: float | None = None + observed_total_tokens_delta: float | None = None + baseline_observed_failed_requests: float | None = None + candidate_observed_failed_requests: float | None = None + observed_failed_requests_delta: float | None = None + baseline_observed_bridge_fallback_requests: float | None = None + candidate_observed_bridge_fallback_requests: float | None = None + observed_bridge_fallback_requests_delta: float | None = None + baseline_observed_non_bridge_total_requests: float | None = None + candidate_observed_non_bridge_total_requests: float | None = None + observed_non_bridge_total_requests_delta: float | None = None + baseline_observed_non_bridge_failed_requests: float | None = None + candidate_observed_non_bridge_failed_requests: float | None = None + observed_non_bridge_failed_requests_delta: float | None = None + baseline_original_value_leak_count: float | None = None + candidate_original_value_leak_count: float | None = None + original_value_leak_count_delta: float | None = None + baseline_original_value_leak_record_count: float | None = None + candidate_original_value_leak_record_count: float | None = None + original_value_leak_record_count_delta: float | None = None + baseline_replacement_missing_final_entity_count: float | None = None + candidate_replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_count_delta: float | None = None + baseline_replacement_missing_final_value_count: float | None = None + candidate_replacement_missing_final_value_count: float | None = None + replacement_missing_final_value_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_count: float | None = None + candidate_replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_value_count: float | None = None + candidate_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_value_count_delta: float | None = None + value_protection_verdict: SafetyVerdict | None = None + signature_parity_verdict: SafetyVerdict | None = None + safety_verdict: SafetyVerdict | None = None + performance_verdict: PerformanceVerdict | None = None + candidate_verdict: CandidateVerdict | None = None + baseline_final_entity_count: float | None = None + candidate_final_entity_count: float | None = None + final_entity_count_delta: float | None = None + baseline_seed_validation_candidate_count: float | None = None + candidate_seed_validation_candidate_count: float | None = None + seed_validation_candidate_count_delta: float | None = None + baseline_augmented_entity_count: float | None = None + candidate_augmented_entity_count: float | None = None + augmented_entity_count_delta: float | None = None + baseline_augmented_new_final_value_count: float | None = None + candidate_augmented_new_final_value_count: float | None = None + augmented_new_final_value_count_delta: float | None = None + baseline_detector_entity_count: float | None = None + candidate_detector_entity_count: float | None = None + baseline_augmenter_entity_count: float | None = None + candidate_augmenter_entity_count: float | None = None + baseline_only_final_entity_signature_count: int | None = None + candidate_only_final_entity_signature_count: int | None = None + shared_final_entity_signature_count: int | None = None + baseline_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + shared_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_covered_signature_count: int | None = None + baseline_only_candidate_overlapping_signature_count: int | None = None + baseline_only_candidate_uncovered_signature_count: int | None = None + baseline_only_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_label_mismatch_signature_count: int | None = None + baseline_only_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_final_entity_signature_count: int | None = None + candidate_stable_final_entity_signature_count: int | None = None + stable_final_entity_signature_count_delta: int | None = None + baseline_stable_candidate_unstable_final_entity_signature_count: int | None = None + candidate_stable_baseline_unstable_final_entity_signature_count: int | None = None + shared_stable_final_entity_signature_count: int | None = None + baseline_stable_candidate_covered_signature_count: int | None = None + baseline_stable_candidate_overlapping_signature_count: int | None = None + baseline_stable_candidate_uncovered_signature_count: int | None = None + baseline_stable_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_label_mismatch_signature_count: int | None = None + baseline_stable_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_stable_baseline_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + shared_stable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + flags: list[str] = Field(default_factory=list) + + +class ComparisonResult(BaseModel): + input_path: str + candidate_input_path: str | None = None + baseline_selector: str + candidate_selector: str + summary: ComparisonSummary = Field(default_factory=ComparisonSummary) + comparisons: list[ComparisonRow] = Field(default_factory=list) + + @property + def comparison_count(self) -> int: + return len(self.comparisons) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def read_case_analysis(path: Path) -> pd.DataFrame: + if not path.exists() or path.is_dir(): + raise ValueError(f"case analysis path is not a file: {path}") + if path.suffix == ".parquet": + table = pd.read_parquet(path) + elif path.suffix == ".csv": + table = pd.read_csv(path) + elif path.suffix == ".jsonl": + table = pd.read_json(path, lines=True) + else: + raise ValueError(f"unsupported case analysis format: {path.suffix}") + _validate_case_analysis_columns(table) + return table + + +def _validate_case_analysis_columns(table: pd.DataFrame) -> None: + required = {"workload_id", "config_id", "case_id"} + missing = sorted(required - set(table.columns)) + if missing: + raise ValueError(f"case analysis is missing required column(s): {', '.join(missing)}") + + +def compare_case_analysis( + table: pd.DataFrame, + *, + baseline_config: str | None = None, + candidate_config: str | None = None, + baseline_strategy: str | None = None, + candidate_strategy: str | None = None, +) -> list[ComparisonRow]: + return compare_case_tables( + table, + table, + baseline_config=baseline_config, + candidate_config=candidate_config, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + ) + + +def compare_case_tables( + baseline_table: pd.DataFrame, + candidate_table: pd.DataFrame, + *, + baseline_config: str | None = None, + candidate_config: str | None = None, + baseline_strategy: str | None = None, + candidate_strategy: str | None = None, +) -> list[ComparisonRow]: + baseline = _select_rows( + baseline_table, config_id=baseline_config, strategy=baseline_strategy, selector_name="baseline" + ) + candidate = _select_rows( + candidate_table, config_id=candidate_config, strategy=candidate_strategy, selector_name="candidate" + ) + return [ + _compare_workload(workload_id, baseline, candidate) for workload_id in _common_workloads(baseline, candidate) + ] + + +def summarize_comparisons(rows: list[ComparisonRow]) -> ComparisonSummary: + return ComparisonSummary( + comparison_count=len(rows), + value_protection_verdict_counts=_verdict_counts(rows, "value_protection_verdict", list(SafetyVerdict)), + signature_parity_verdict_counts=_verdict_counts(rows, "signature_parity_verdict", list(SafetyVerdict)), + safety_verdict_counts=_verdict_counts(rows, "safety_verdict", list(SafetyVerdict)), + performance_verdict_counts=_verdict_counts(rows, "performance_verdict", list(PerformanceVerdict)), + candidate_verdict_counts=_verdict_counts(rows, "candidate_verdict", list(CandidateVerdict)), + candidate_viable_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.candidate_viable), + review_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.review), + rejected_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.reject), + ) + + +def _verdict_counts(rows: list[ComparisonRow], field: str, values: list[StrEnum]) -> dict[str, int]: + counts = {value.value: 0 for value in values} + for row in rows: + verdict = _verdict_value(getattr(row, field)) + if verdict is not None: + counts[verdict] = counts.get(verdict, 0) + 1 + return counts + + +def _workloads_by_candidate_verdict(rows: list[ComparisonRow], verdict: CandidateVerdict) -> list[str]: + return sorted(row.workload_id for row in rows if _verdict_value(row.candidate_verdict) == verdict.value) + + +def _verdict_value(value: object) -> str | None: + if value is None: + return None + if isinstance(value, StrEnum): + return value.value + return str(value) + + +def _select_rows( + table: pd.DataFrame, + *, + config_id: str | None, + strategy: str | None, + selector_name: str, +) -> pd.DataFrame: + if (config_id is None) == (strategy is None): + raise ValueError(f"{selector_name} selector must specify exactly one of config or strategy") + column, value = ("config_id", config_id) if config_id is not None else ("experimental_detection_strategy", strategy) + if column not in table.columns: + raise ValueError(f"case analysis is missing selector column: {column}") + selected = table[table[column].astype(str) == str(value)] + if selected.empty: + raise ValueError(f"{selector_name} selector matched no rows: {column}={value}") + _validate_unique_config_per_workload(selected, selector_name=selector_name) + return selected + + +def _validate_unique_config_per_workload(rows: pd.DataFrame, *, selector_name: str) -> None: + counts = rows.groupby("workload_id")["config_id"].nunique() + ambiguous = sorted(str(index) for index, count in counts.items() if count > 1) + if ambiguous: + raise ValueError(f"{selector_name} selector matched multiple configs for workload(s): {', '.join(ambiguous)}") + + +def _common_workloads(baseline: pd.DataFrame, candidate: pd.DataFrame) -> list[str]: + baseline_workloads = set(str(value) for value in baseline["workload_id"].dropna()) + candidate_workloads = set(str(value) for value in candidate["workload_id"].dropna()) + common = sorted(baseline_workloads & candidate_workloads) + if not common: + raise ValueError("baseline and candidate selectors have no workloads in common") + return common + + +def _compare_workload(workload_id: str, baseline: pd.DataFrame, candidate: pd.DataFrame) -> ComparisonRow: + baseline_summary = _summarize_selector(baseline[baseline["workload_id"].astype(str) == workload_id]) + candidate_summary = _summarize_selector(candidate[candidate["workload_id"].astype(str) == workload_id]) + return ComparisonRow( + workload_id=workload_id, + baseline_config_id=str(baseline_summary["config_id"]), + candidate_config_id=str(candidate_summary["config_id"]), + baseline_strategy=_optional_string(baseline_summary.get("experimental_detection_strategy")), + candidate_strategy=_optional_string(candidate_summary.get("experimental_detection_strategy")), + baseline_replacement_strategy=_optional_string(baseline_summary.get("experimental_replacement_strategy")), + candidate_replacement_strategy=_optional_string(candidate_summary.get("experimental_replacement_strategy")), + baseline_case_count=int(baseline_summary["case_count"]), + candidate_case_count=int(candidate_summary["case_count"]), + **_comparison_metrics(baseline_summary, candidate_summary), + ) + + +def _summarize_selector(rows: pd.DataFrame) -> dict[str, object]: + if rows.empty: + raise ValueError("selector has no rows for workload") + case_count = int(rows["case_id"].nunique()) + summary: dict[str, object] = { + "config_id": _single_string(rows, "config_id"), + "experimental_detection_strategy": _single_string(rows, "experimental_detection_strategy"), + "experimental_replacement_strategy": _single_string(rows, "experimental_replacement_strategy"), + "case_count": case_count, + } + for column in _NUMERIC_COLUMNS: + summary[column] = _median_or_none(rows, column) + summary["replacement_missing_final_entity_count"] = _sum_or_none(rows, "replacement_missing_final_entity_count") + summary["replacement_missing_final_value_count"] = _sum_or_none(rows, "replacement_missing_final_value_count") + summary["replacement_synthetic_original_collision_count"] = _sum_or_none( + rows, + "replacement_synthetic_original_collision_count", + ) + summary["replacement_synthetic_original_collision_value_count"] = _sum_or_none( + rows, + "replacement_synthetic_original_collision_value_count", + ) + summary["original_value_leak_count"] = _sum_or_none(rows, "original_value_leak_count") + summary["original_value_leak_record_count"] = _sum_or_none(rows, "original_value_leak_record_count") + summary["original_value_leak_label_counts"] = _sum_label_counts(rows, "original_value_leak_label_counts") + summary["replacement_synthetic_original_collision_label_counts"] = _sum_label_counts( + rows, + "replacement_synthetic_original_collision_label_counts", + ) + summary["failed_case_count"] = _failed_case_count(rows) + summary["failed_case_rate"] = _rate(summary["failed_case_count"], case_count) + summary["artifact_final_entity_signature_hashes"] = _signature_hashes(rows) + summary["stable_artifact_final_entity_signature_hashes"] = _stable_signature_hashes(rows) if case_count > 1 else [] + summary["artifact_final_entity_signature_labels"] = _signature_labels(rows) + summary["artifact_final_entity_signature_details"] = _signature_details(rows) + return summary + + +def _single_string(rows: pd.DataFrame, column: str) -> str | None: + if column not in rows.columns: + return None + values = sorted({str(value) for value in rows[column].dropna()}) + return values[0] if values else None + + +_NUMERIC_COLUMNS = [ + "pipeline_elapsed_sec", + "observed_total_requests", + "observed_total_tokens", + "observed_failed_requests", + "observed_bridge_fallback_requests", + "observed_non_bridge_total_requests", + "observed_non_bridge_failed_requests", + "replacement_missing_final_entity_count", + "replacement_missing_final_value_count", + "replacement_synthetic_original_collision_count", + "replacement_synthetic_original_collision_value_count", + "final_entity_count", + "seed_validation_candidate_count", + "augmented_entity_count", + "augmented_new_final_value_count", + "artifact_final_detector_entity_count", + "artifact_final_augmenter_entity_count", +] + + +def _median_or_none(rows: pd.DataFrame, column: str) -> float | None: + if column not in rows.columns: + return None + values = pd.to_numeric(rows[column], errors="coerce").dropna() + return float(values.median()) if not values.empty else None + + +def _sum_or_none(rows: pd.DataFrame, column: str) -> float | None: + if column not in rows.columns: + return None + values = pd.to_numeric(rows[column], errors="coerce").dropna() + return float(values.sum()) if not values.empty else None + + +def _comparison_metrics(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + metrics = _named_metric_deltas(baseline, candidate) + metrics.update(_source_counts(baseline, candidate)) + metrics.update(_entity_signature_deltas(baseline, candidate)) + metrics["flags"] = _comparison_flags( + metrics, + baseline_strategy=_optional_string(baseline.get("experimental_detection_strategy")), + candidate_strategy=_optional_string(candidate.get("experimental_detection_strategy")), + baseline_replacement_strategy=_optional_string(baseline.get("experimental_replacement_strategy")), + candidate_replacement_strategy=_optional_string(candidate.get("experimental_replacement_strategy")), + ) + metrics.update(_comparison_verdicts(metrics)) + return metrics + + +def _named_metric_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + return { + **_metric_delta("failed_case_count", baseline, candidate), + **_metric_delta("failed_case_rate", baseline, candidate), + **_metric_delta("pipeline_elapsed_sec", baseline, candidate, include_pct=True), + **_metric_delta("observed_total_requests", baseline, candidate), + **_metric_delta("observed_total_tokens", baseline, candidate), + **_metric_delta("observed_failed_requests", baseline, candidate), + **_metric_delta("observed_bridge_fallback_requests", baseline, candidate), + **_metric_delta("observed_non_bridge_total_requests", baseline, candidate), + **_metric_delta("observed_non_bridge_failed_requests", baseline, candidate), + **_metric_delta("original_value_leak_count", baseline, candidate), + **_metric_delta("original_value_leak_record_count", baseline, candidate), + **_metric_delta("replacement_missing_final_entity_count", baseline, candidate), + **_metric_delta("replacement_missing_final_value_count", baseline, candidate), + **_metric_delta("replacement_synthetic_original_collision_count", baseline, candidate), + **_metric_delta("replacement_synthetic_original_collision_value_count", baseline, candidate), + **_metric_delta("final_entity_count", baseline, candidate), + **_metric_delta("seed_validation_candidate_count", baseline, candidate), + **_metric_delta("augmented_entity_count", baseline, candidate), + **_metric_delta("augmented_new_final_value_count", baseline, candidate), + } + + +def _metric_delta( + name: str, + baseline: dict[str, object], + candidate: dict[str, object], + *, + include_pct: bool = False, +) -> dict[str, float | None]: + base = _optional_float(baseline.get(name)) + cand = _optional_float(candidate.get(name)) + values = {f"baseline_{name}": base, f"candidate_{name}": cand, f"{name}_delta": _delta(base, cand)} + if include_pct: + values[f"{name}_delta_pct"] = _delta_pct(base, cand) + return values + + +def _source_counts(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + return { + "baseline_detector_entity_count": _optional_float(baseline.get("artifact_final_detector_entity_count")), + "candidate_detector_entity_count": _optional_float(candidate.get("artifact_final_detector_entity_count")), + "baseline_augmenter_entity_count": _optional_float(baseline.get("artifact_final_augmenter_entity_count")), + "candidate_augmenter_entity_count": _optional_float(candidate.get("artifact_final_augmenter_entity_count")), + "baseline_original_value_leak_label_counts": _coerce_count_map( + baseline.get("original_value_leak_label_counts") + ), + "candidate_original_value_leak_label_counts": _coerce_count_map( + candidate.get("original_value_leak_label_counts") + ), + "baseline_replacement_synthetic_original_collision_label_counts": _coerce_count_map( + baseline.get("replacement_synthetic_original_collision_label_counts") + ), + "candidate_replacement_synthetic_original_collision_label_counts": _coerce_count_map( + candidate.get("replacement_synthetic_original_collision_label_counts") + ), + } + + +def _comparison_flags( + metrics: dict[str, object], + *, + baseline_strategy: str | None, + candidate_strategy: str | None, + baseline_replacement_strategy: str | None, + candidate_replacement_strategy: str | None, +) -> list[str]: + flags: list[str] = [] + _append_if_positive(flags, metrics, "baseline_failed_case_count", "baseline_case_failures") + _append_if_positive(flags, metrics, "candidate_failed_case_count", "candidate_case_failures") + _append_if_positive(flags, metrics, "baseline_original_value_leak_count", "baseline_original_value_leak") + _append_if_positive(flags, metrics, "candidate_original_value_leak_count", "candidate_original_value_leak") + _append_if_positive( + flags, + metrics, + "baseline_replacement_synthetic_original_collision_count", + "baseline_replacement_synthetic_original_collision", + ) + _append_if_positive( + flags, + metrics, + "candidate_replacement_synthetic_original_collision_count", + "candidate_replacement_synthetic_original_collision", + ) + _append_if_positive( + flags, + metrics, + "baseline_replacement_missing_final_entity_count", + "baseline_replacement_missing_final_entity", + ) + _append_if_positive( + flags, + metrics, + "candidate_replacement_missing_final_entity_count", + "candidate_replacement_missing_final_entity", + ) + _append_if_negative(flags, metrics, "final_entity_count_delta", "entity_count_loss") + _append_if_positive(flags, metrics, _signature_loss_metric(metrics), "entity_signature_loss") + _append_if_positive( + flags, + metrics, + "baseline_only_candidate_overlapping_signature_count", + "span_boundary_mismatch", + ) + _append_if_positive( + flags, + metrics, + "baseline_only_candidate_label_mismatch_signature_count", + "covered_label_mismatch", + ) + _append_if_positive( + flags, + metrics, + _stable_signature_loss_metric(metrics), + "stable_entity_signature_loss", + ) + failed_request_delta = ( + "observed_non_bridge_failed_requests_delta" + if _has_metric_pair(metrics, "observed_non_bridge_failed_requests") + else "observed_failed_requests_delta" + ) + _append_if_positive(flags, metrics, failed_request_delta, "failed_request_increase") + _append_if_positive(flags, metrics, "observed_bridge_fallback_requests_delta", "bridge_fallback_increase") + _append_if_positive(flags, metrics, "observed_total_tokens_delta", "token_increase") + _append_if_positive(flags, metrics, "observed_total_requests_delta", "request_increase") + if _candidate_lacks_detector_entities(metrics): + flags.append("no_candidate_detector_entities") + if candidate_strategy in _SKIPS_LLM_VALIDATION_STRATEGIES: + flags.append("candidate_skips_llm_validation") + if _replacement_only_detection_instability( + flags, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + baseline_replacement_strategy=baseline_replacement_strategy, + candidate_replacement_strategy=candidate_replacement_strategy, + ): + flags.append("replacement_only_detection_instability") + return flags + + +_DETECTION_INSTABILITY_FLAGS = { + "covered_label_mismatch", + "entity_count_loss", + "entity_signature_loss", + "span_boundary_mismatch", + "stable_entity_signature_loss", +} + + +def _replacement_only_detection_instability( + flags: list[str], + *, + baseline_strategy: str | None, + candidate_strategy: str | None, + baseline_replacement_strategy: str | None, + candidate_replacement_strategy: str | None, +) -> bool: + if baseline_strategy != candidate_strategy: + return False + if not baseline_replacement_strategy or not candidate_replacement_strategy: + return False + if baseline_replacement_strategy == candidate_replacement_strategy: + return False + return bool(set(flags) & _DETECTION_INSTABILITY_FLAGS) + + +def _signature_loss_metric(metrics: dict[str, object]) -> str: + if metrics.get("baseline_only_candidate_uncovered_signature_count") is not None: + return "baseline_only_candidate_uncovered_signature_count" + return "baseline_only_final_entity_signature_count" + + +def _stable_signature_loss_metric(metrics: dict[str, object]) -> str: + if metrics.get("baseline_stable_candidate_uncovered_signature_count") is not None: + return "baseline_stable_candidate_uncovered_signature_count" + return "baseline_stable_candidate_unstable_final_entity_signature_count" + + +_SKIPS_LLM_VALIDATION_STRATEGIES = {"detector_only"} + + +def _has_metric_pair(metrics: dict[str, object], name: str) -> bool: + return ( + _optional_float(metrics.get(f"baseline_{name}")) is not None + and _optional_float(metrics.get(f"candidate_{name}")) is not None + ) + + +def _comparison_verdicts(metrics: dict[str, object]) -> dict[str, str]: + value_protection_verdict = _value_protection_verdict(metrics) + signature_parity_verdict = _signature_parity_verdict(metrics) + safety_verdict = _safety_verdict(metrics) + performance_verdict = _performance_verdict(metrics) + return { + "value_protection_verdict": value_protection_verdict.value, + "signature_parity_verdict": signature_parity_verdict.value, + "safety_verdict": safety_verdict.value, + "performance_verdict": performance_verdict.value, + "candidate_verdict": _candidate_verdict(safety_verdict, performance_verdict).value, + } + + +def _value_protection_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_signature_loss", + "stable_entity_signature_loss", + }: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "baseline_case_failures", + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _signature_parity_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & {"candidate_case_failures", "entity_signature_loss", "stable_entity_signature_loss"}: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "span_boundary_mismatch", + "covered_label_mismatch", + "entity_count_loss", + "baseline_case_failures", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _safety_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_signature_loss", + "stable_entity_signature_loss", + }: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "no_candidate_detector_entities", + "candidate_skips_llm_validation", + "failed_request_increase", + "bridge_fallback_increase", + "span_boundary_mismatch", + "covered_label_mismatch", + }: + return SafetyVerdict.review + if flags & { + "baseline_case_failures", + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + "entity_count_loss", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _entity_count_loss_without_signature_artifacts(flags: set[str], metrics: dict[str, object]) -> bool: + return "entity_count_loss" in flags and metrics.get("baseline_only_final_entity_signature_count") is None + + +def _performance_verdict(metrics: dict[str, object]) -> PerformanceVerdict: + deltas = [ + _optional_float(metrics.get("pipeline_elapsed_sec_delta")), + _optional_float(metrics.get("observed_total_requests_delta")), + _optional_float(metrics.get("observed_total_tokens_delta")), + ] + known = [value for value in deltas if value is not None] + if not known: + return PerformanceVerdict.unknown + improved = any(value < 0 for value in known) + regressed = any(value > 0 for value in known) + if improved and regressed: + return PerformanceVerdict.mixed + if improved: + return PerformanceVerdict.improved + if regressed: + return PerformanceVerdict.regressed + return PerformanceVerdict.unchanged + + +def _candidate_verdict( + safety_verdict: SafetyVerdict, + performance_verdict: PerformanceVerdict, +) -> CandidateVerdict: + if safety_verdict == SafetyVerdict.fail: + return CandidateVerdict.reject + if safety_verdict == SafetyVerdict.passed and performance_verdict == PerformanceVerdict.improved: + return CandidateVerdict.candidate_viable + return CandidateVerdict.review + + +def _coerce_flag_list(value: object) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value] + return [] + + +def _candidate_lacks_detector_entities(metrics: dict[str, object]) -> bool: + final_count = _optional_float(metrics.get("candidate_final_entity_count")) + if final_count is None or final_count <= 0: + return False + detector_count = _optional_float(metrics.get("candidate_detector_entity_count")) + if detector_count is not None: + return detector_count == 0 + non_detector_count = _known_non_detector_candidate_count(metrics) + return non_detector_count is not None and non_detector_count >= final_count + + +def _known_non_detector_candidate_count(metrics: dict[str, object]) -> float | None: + known_counts = [ + _optional_float(metrics.get("candidate_augmenter_entity_count")), + ] + if all(value is None for value in known_counts): + return None + return sum(value or 0.0 for value in known_counts) + + +def _entity_signature_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + baseline_signatures = set(_coerce_signature_list(baseline.get("artifact_final_entity_signature_hashes"))) + candidate_signatures = set(_coerce_signature_list(candidate.get("artifact_final_entity_signature_hashes"))) + baseline_labels = _coerce_signature_labels(baseline.get("artifact_final_entity_signature_labels")) + candidate_labels = _coerce_signature_labels(candidate.get("artifact_final_entity_signature_labels")) + baseline_details = _coerce_signature_details(baseline.get("artifact_final_entity_signature_details")) + candidate_details = _coerce_signature_details(candidate.get("artifact_final_entity_signature_details")) + baseline_signatures.update(baseline_labels) + candidate_signatures.update(candidate_labels) + if not baseline_signatures and not candidate_signatures: + return { + "baseline_only_final_entity_signature_count": None, + "candidate_only_final_entity_signature_count": None, + "shared_final_entity_signature_count": None, + "baseline_only_final_entity_signature_label_counts": {}, + "candidate_only_final_entity_signature_label_counts": {}, + "shared_final_entity_signature_label_counts": {}, + "baseline_only_candidate_covered_signature_count": None, + "baseline_only_candidate_overlapping_signature_count": None, + "baseline_only_candidate_uncovered_signature_count": None, + "baseline_only_candidate_covered_signature_label_counts": {}, + "baseline_only_candidate_overlapping_signature_label_counts": {}, + "baseline_only_candidate_uncovered_signature_label_counts": {}, + "baseline_only_candidate_label_mismatch_signature_count": None, + "baseline_only_candidate_label_mismatch_signature_label_counts": {}, + } + baseline_only = baseline_signatures - candidate_signatures + candidate_only = candidate_signatures - baseline_signatures + shared = baseline_signatures & candidate_signatures + coverage = _candidate_span_coverage( + baseline_only, + baseline_details=baseline_details, + candidate_details=candidate_details, + baseline_labels=baseline_labels, + ) + return { + "baseline_only_final_entity_signature_count": len(baseline_only), + "candidate_only_final_entity_signature_count": len(candidate_only), + "shared_final_entity_signature_count": len(shared), + "baseline_only_final_entity_signature_label_counts": _signature_label_counts( + baseline_only, + baseline_labels, + ), + "candidate_only_final_entity_signature_label_counts": _signature_label_counts( + candidate_only, + candidate_labels, + ), + "shared_final_entity_signature_label_counts": _signature_label_counts(shared, baseline_labels), + **coverage, + **_stable_entity_signature_deltas( + baseline, + candidate, + baseline_labels, + candidate_labels, + baseline_details, + candidate_details, + ), + } + + +def _stable_entity_signature_deltas( + baseline: dict[str, object], + candidate: dict[str, object], + baseline_labels: dict[str, str], + candidate_labels: dict[str, str], + baseline_details: dict[str, dict[str, object]], + candidate_details: dict[str, dict[str, object]], +) -> dict[str, object]: + baseline_case_count = _optional_float(baseline.get("case_count")) + candidate_case_count = _optional_float(candidate.get("case_count")) + if ( + baseline_case_count is None + or candidate_case_count is None + or baseline_case_count < 2 + or candidate_case_count < 2 + ): + return _empty_stable_signature_deltas() + baseline_stable = set(_coerce_signature_list(baseline.get("stable_artifact_final_entity_signature_hashes"))) + candidate_stable = set(_coerce_signature_list(candidate.get("stable_artifact_final_entity_signature_hashes"))) + if not baseline_stable and not candidate_stable: + return _empty_stable_signature_deltas() + baseline_stable_candidate_unstable = baseline_stable - candidate_stable + candidate_stable_baseline_unstable = candidate_stable - baseline_stable + shared_stable = baseline_stable & candidate_stable + stable_candidate_details = { + signature: detail for signature, detail in candidate_details.items() if signature in candidate_stable + } + coverage = _candidate_span_coverage( + baseline_stable_candidate_unstable, + baseline_details=baseline_details, + candidate_details=stable_candidate_details, + baseline_labels=baseline_labels, + prefix="baseline_stable_candidate", + ) + return { + "baseline_stable_final_entity_signature_count": len(baseline_stable), + "candidate_stable_final_entity_signature_count": len(candidate_stable), + "stable_final_entity_signature_count_delta": len(candidate_stable) - len(baseline_stable), + "baseline_stable_candidate_unstable_final_entity_signature_count": len(baseline_stable_candidate_unstable), + "candidate_stable_baseline_unstable_final_entity_signature_count": len(candidate_stable_baseline_unstable), + "shared_stable_final_entity_signature_count": len(shared_stable), + "baseline_stable_candidate_unstable_final_entity_signature_label_counts": _signature_label_counts( + baseline_stable_candidate_unstable, + baseline_labels, + ), + "candidate_stable_baseline_unstable_final_entity_signature_label_counts": _signature_label_counts( + candidate_stable_baseline_unstable, + candidate_labels, + ), + "shared_stable_final_entity_signature_label_counts": _signature_label_counts(shared_stable, baseline_labels), + **coverage, + } + + +def _candidate_span_coverage( + baseline_signatures: set[str], + *, + baseline_details: dict[str, dict[str, object]], + candidate_details: dict[str, dict[str, object]], + baseline_labels: dict[str, str], + prefix: str = "baseline_only_candidate", +) -> dict[str, object]: + contained: set[str] = set() + overlapping: set[str] = set() + label_mismatch: set[str] = set() + for signature in baseline_signatures: + match, labels_mismatch = _candidate_span_match_kind(baseline_details.get(signature), candidate_details.values()) + if match == "contained": + contained.add(signature) + elif match == "overlapping": + overlapping.add(signature) + if labels_mismatch: + label_mismatch.add(signature) + covered = contained | overlapping + uncovered = baseline_signatures - covered + return { + f"{prefix}_covered_signature_count": len(covered), + f"{prefix}_overlapping_signature_count": len(overlapping), + f"{prefix}_uncovered_signature_count": len(uncovered), + f"{prefix}_covered_signature_label_counts": _signature_label_counts(covered, baseline_labels), + f"{prefix}_overlapping_signature_label_counts": _signature_label_counts(overlapping, baseline_labels), + f"{prefix}_uncovered_signature_label_counts": _signature_label_counts(uncovered, baseline_labels), + f"{prefix}_label_mismatch_signature_count": len(label_mismatch), + f"{prefix}_label_mismatch_signature_label_counts": _signature_label_counts(label_mismatch, baseline_labels), + } + + +def _candidate_span_match_kind( + baseline_detail: dict[str, object] | None, + candidate_details: object, +) -> tuple[str | None, bool]: + baseline_row = _optional_int(_detail_value(baseline_detail, "row_index")) + baseline_start = _optional_int(_detail_value(baseline_detail, "start_position")) + baseline_end = _optional_int(_detail_value(baseline_detail, "end_position")) + baseline_label = _optional_string(_detail_value(baseline_detail, "label")) + if baseline_row is None or baseline_start is None or baseline_end is None: + return None, False + baseline_length = baseline_end - baseline_start + if baseline_length <= 0: + return None, False + first_mismatched_match: tuple[str, bool] | None = None + for candidate_detail in candidate_details: + candidate_row = _optional_int(_detail_value(candidate_detail, "row_index")) + candidate_start = _optional_int(_detail_value(candidate_detail, "start_position")) + candidate_end = _optional_int(_detail_value(candidate_detail, "end_position")) + if candidate_row != baseline_row or candidate_start is None or candidate_end is None: + continue + candidate_label = _optional_string(_detail_value(candidate_detail, "label")) + labels_mismatch = _labels_mismatch(baseline_label, candidate_label) + if candidate_start <= baseline_start and candidate_end >= baseline_end: + if not labels_mismatch: + return "contained", False + first_mismatched_match = first_mismatched_match or ("contained", True) + continue + overlap = max(0, min(baseline_end, candidate_end) - max(baseline_start, candidate_start)) + if overlap / baseline_length >= _MIN_CANDIDATE_OVERLAP_RATIO: + if not labels_mismatch: + return "overlapping", False + first_mismatched_match = first_mismatched_match or ("overlapping", True) + continue + if _is_small_boundary_gap( + baseline_start=baseline_start, + baseline_end=baseline_end, + baseline_label=baseline_label, + candidate_start=candidate_start, + candidate_end=candidate_end, + overlap=overlap, + ): + if not labels_mismatch: + return "overlapping", False + first_mismatched_match = first_mismatched_match or ("overlapping", True) + return first_mismatched_match or (None, False) + + +def _labels_mismatch(baseline_label: str | None, candidate_label: str | None) -> bool: + return baseline_label is not None and candidate_label is not None and baseline_label != candidate_label + + +def _is_small_boundary_gap( + *, + baseline_start: int, + baseline_end: int, + baseline_label: str | None, + candidate_start: int, + candidate_end: int, + overlap: int, +) -> bool: + if baseline_label not in _BOUNDARY_NORMALIZED_BASELINE_LABELS: + return False + if overlap < _MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS: + return False + omitted_left = max(0, candidate_start - baseline_start) + omitted_right = max(0, baseline_end - candidate_end) + return omitted_left + omitted_right <= _MAX_CANDIDATE_BOUNDARY_GAP_CHARS + + +def _detail_value(detail: object, key: str) -> object: + if not isinstance(detail, dict): + return None + return detail.get(key) + + +def _empty_stable_signature_deltas() -> dict[str, object]: + return { + "baseline_stable_final_entity_signature_count": None, + "candidate_stable_final_entity_signature_count": None, + "stable_final_entity_signature_count_delta": None, + "baseline_stable_candidate_unstable_final_entity_signature_count": None, + "candidate_stable_baseline_unstable_final_entity_signature_count": None, + "shared_stable_final_entity_signature_count": None, + "baseline_stable_candidate_covered_signature_count": None, + "baseline_stable_candidate_overlapping_signature_count": None, + "baseline_stable_candidate_uncovered_signature_count": None, + "baseline_stable_candidate_covered_signature_label_counts": {}, + "baseline_stable_candidate_overlapping_signature_label_counts": {}, + "baseline_stable_candidate_uncovered_signature_label_counts": {}, + "baseline_stable_candidate_label_mismatch_signature_count": None, + "baseline_stable_candidate_label_mismatch_signature_label_counts": {}, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts": {}, + "candidate_stable_baseline_unstable_final_entity_signature_label_counts": {}, + "shared_stable_final_entity_signature_label_counts": {}, + } + + +def _signature_hashes(rows: pd.DataFrame) -> list[str]: + hashes: set[str] = set() + for signature_set in _signature_hash_sets(rows): + hashes.update(signature_set) + return sorted(hashes) + + +def _stable_signature_hashes(rows: pd.DataFrame) -> list[str]: + signature_sets = _signature_hash_sets(rows) + if not signature_sets: + return [] + return sorted(set.intersection(*signature_sets)) + + +def _signature_hash_sets(rows: pd.DataFrame) -> list[set[str]]: + if "artifact_final_entity_signature_hashes" not in rows.columns: + return [] + return [set(_coerce_signature_list(value)) for value in rows["artifact_final_entity_signature_hashes"].tolist()] + + +def _signature_labels(rows: pd.DataFrame) -> dict[str, str]: + labels: dict[str, str] = {} + if "artifact_final_entity_signature_labels" in rows.columns: + for value in rows["artifact_final_entity_signature_labels"].tolist(): + labels.update(_coerce_signature_labels(value)) + prefix = "artifact_final_entity_signature_labels." + for column in rows.columns: + if not column.startswith(prefix): + continue + signature_hash = column.removeprefix(prefix) + for value in rows[column].tolist(): + if not _is_missing_cell(value): + labels[signature_hash] = str(value) + return dict(sorted(labels.items())) + + +def _signature_details(rows: pd.DataFrame) -> dict[str, dict[str, object]]: + details: dict[str, dict[str, object]] = {} + if "artifact_final_entity_signature_details" in rows.columns: + for value in rows["artifact_final_entity_signature_details"].tolist(): + details.update(_coerce_signature_details(value)) + prefix = "artifact_final_entity_signature_details." + for column in rows.columns: + if not column.startswith(prefix): + continue + remainder = column.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue + for value in rows[column].tolist(): + if not _is_missing_cell(value): + details.setdefault(signature_hash, {})[field] = _json_scalar(value) + return dict(sorted(details.items())) + + +def _sum_label_counts(rows: pd.DataFrame, field: str) -> dict[str, int]: + counts: dict[str, int] = {} + if field in rows.columns: + for value in rows[field].tolist(): + _merge_count_map(counts, _coerce_count_map(value)) + prefix = f"{field}." + for column in rows.columns: + if not column.startswith(prefix): + continue + total = _sum_or_none(rows, column) + if total: + counts[column.removeprefix(prefix)] = counts.get(column.removeprefix(prefix), 0) + int(total) + return dict(sorted(counts.items())) + + +def _failed_case_count(rows: pd.DataFrame) -> int: + if "case_failed" in rows.columns: + return int(rows["case_failed"].map(_coerce_bool).sum()) + error_columns = [ + column + for column in ("error_stage_count", "error_ndd_workflow_count", "error_model_workflow_count") + if column in rows.columns + ] + if not error_columns: + return 0 + error_counts = rows[error_columns].apply(pd.to_numeric, errors="coerce").fillna(0).sum(axis=1) + return int((error_counts > 0).sum()) + + +def _coerce_bool(value: object) -> bool: + if _is_missing_cell(value): + return False + if isinstance(value, bool): + return value + if isinstance(value, int | float): + return value != 0 + text = str(value).strip().lower() + return text in {"1", "true", "t", "yes", "y"} + + +def _rate(count: object, total: object) -> float | None: + count_value = _optional_float(count) + total_value = _optional_float(total) + if count_value is None or total_value is None or total_value <= 0: + return None + return count_value / total_value + + +def _signature_label_counts(signatures: set[str], labels: dict[str, str]) -> dict[str, int]: + counts: dict[str, int] = {} + for signature_hash in signatures: + label = labels.get(signature_hash, "unknown") + counts[label] = counts.get(label, 0) + 1 + return dict(sorted(counts.items())) + + +def _coerce_signature_list(value: object) -> list[str]: + if _is_missing_cell(value): + return [] + if hasattr(value, "tolist"): + value = value.tolist() + if isinstance(value, list | tuple | set): + return [str(item) for item in value if not _is_missing_cell(item)] + if not isinstance(value, str): + return [] + text = value.strip() + if not text: + return [] + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return [] + return _coerce_signature_list(parsed) + + +def _coerce_signature_labels(value: object) -> dict[str, str]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + return {str(key): str(item) for key, item in value.items() if not _is_missing_cell(item)} + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_signature_labels(parsed) + + +def _coerce_signature_details(value: object) -> dict[str, dict[str, object]]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + details: dict[str, dict[str, object]] = {} + for signature_hash, raw_detail in value.items(): + if not isinstance(raw_detail, dict): + continue + details[str(signature_hash)] = { + str(key): _json_scalar(item) + for key, item in raw_detail.items() + if str(key) in _SIGNATURE_DETAIL_FIELDS and not _is_missing_cell(item) + } + return dict(sorted(details.items())) + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_signature_details(parsed) + + +def _json_scalar(value: object) -> object: + if hasattr(value, "item"): + try: + return value.item() + except ValueError: + return value + return value + + +def _coerce_count_map(value: object) -> dict[str, int]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + counts: dict[str, int] = {} + for key, item in value.items(): + if _is_missing_cell(item): + continue + count = _optional_float(item) + if count: + counts[str(key)] = int(count) + return dict(sorted(counts.items())) + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_count_map(parsed) + + +def _merge_count_map(target: dict[str, int], update: dict[str, int]) -> None: + for key, value in update.items(): + target[key] = target.get(key, 0) + value + + +def _is_missing_cell(value: object) -> bool: + return value is None or (isinstance(value, float) and pd.isna(value)) + + +def _append_if_negative(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: + value = _optional_float(metrics.get(field)) + if value is not None and value < 0: + flags.append(flag) + + +def _append_if_positive(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: + value = _optional_float(metrics.get(field)) + if value is not None and value > 0: + flags.append(flag) + + +def _optional_float(value: object) -> float | None: + if value is None or pd.isna(value): + return None + return float(value) + + +def _optional_int(value: object) -> int | None: + number = _optional_float(value) + return int(number) if number is not None else None + + +def _optional_string(value: object) -> str | None: + if value is None or pd.isna(value): + return None + return str(value) + + +def _delta(baseline: float | None, candidate: float | None) -> float | None: + return candidate - baseline if baseline is not None and candidate is not None else None + + +def _delta_pct(baseline: float | None, candidate: float | None) -> float | None: + if baseline is None or candidate is None or baseline == 0: + return None + return ((candidate - baseline) / baseline) * 100 + + +def write_comparisons(rows: list[ComparisonRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + table = _normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def _normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: + normalized = table.copy() + for column in normalized.columns: + if normalized[column].map(_is_nested_cell).any(): + normalized[column] = normalized[column].map(_json_cell) + return normalized + + +def _is_nested_cell(value: object) -> bool: + return isinstance(value, dict | list) + + +def _json_cell(value: object) -> object: + if not _is_nested_cell(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def render_result(result: ComparisonResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Compared {result.comparison_count} workload(s): " + f"viable={result.summary.candidate_verdict_counts.get(CandidateVerdict.candidate_viable.value, 0)}, " + f"review={result.summary.candidate_verdict_counts.get(CandidateVerdict.review.value, 0)}, " + f"reject={result.summary.candidate_verdict_counts.get(CandidateVerdict.reject.value, 0)}" + ] + for row in result.comparisons: + lines.append( + f"- {row.workload_id}: verdict={row.candidate_verdict or 'unknown'} " + f"(safety={row.safety_verdict or 'unknown'}, " + f"value_protection={row.value_protection_verdict or 'unknown'}, " + f"signature_parity={row.signature_parity_verdict or 'unknown'}, " + f"performance={row.performance_verdict or 'unknown'}), " + f"replacement={row.baseline_replacement_strategy or 'unknown'}" + f"->{row.candidate_replacement_strategy or 'unknown'}, " + f"elapsed {_number_pair(row.baseline_pipeline_elapsed_sec, row.candidate_pipeline_elapsed_sec, suffix='s')}, " + f"entities {row.baseline_final_entity_count}->{row.candidate_final_entity_count}, " + f"requests {_number_pair(row.baseline_observed_total_requests, row.candidate_observed_total_requests)}, " + f"tokens {_number_pair(row.baseline_observed_total_tokens, row.candidate_observed_total_tokens)}, " + f"original_value_leaks " + f"{_number_pair(row.baseline_original_value_leak_count, row.candidate_original_value_leak_count)}, " + f"lost_labels={_label_count_summary(row.baseline_only_final_entity_signature_label_counts)}, " + "covered_label_mismatch_labels=" + f"{_label_count_summary(row.baseline_only_candidate_label_mismatch_signature_label_counts)}, " + f"unstable_lost_labels={_label_count_summary(row.baseline_stable_candidate_unstable_final_entity_signature_label_counts)}, " + f"leak_labels={_label_count_summary(row.candidate_original_value_leak_label_counts)}, " + f"flags={','.join(row.flags) if row.flags else 'none'}" + ) + return "\n".join(lines) + + +def _number_pair(baseline: float | None, candidate: float | None, *, suffix: str = "") -> str: + return f"{_format_number(baseline, suffix=suffix)}->{_format_number(candidate, suffix=suffix)}" + + +def _format_number(value: float | None, *, suffix: str = "") -> str: + if value is None: + return "unknown" + return f"{value:.1f}{suffix}" + + +def _label_count_summary(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ",".join(f"{label}:{count}" for label, count in sorted(counts.items())) + + +def _selector_label(*, config: str | None, strategy: str | None) -> str: + if config is not None: + return f"config:{config}" + if strategy is not None: + return f"strategy:{strategy}" + return "" + + +@app.default +def main( + case_analysis: Path, + *, + candidate_case_analysis: Annotated[Path | None, cyclopts.Parameter("--candidate-case-analysis")] = None, + baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, + candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, + baseline_strategy: Annotated[str | None, cyclopts.Parameter("--baseline-strategy")] = None, + candidate_strategy: Annotated[str | None, cyclopts.Parameter("--candidate-strategy")] = None, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + baseline_table = read_case_analysis(case_analysis) + candidate_table = read_case_analysis(candidate_case_analysis) if candidate_case_analysis else baseline_table + comparisons = compare_case_tables( + baseline_table, + candidate_table, + baseline_config=baseline_config, + candidate_config=candidate_config, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + result = ComparisonResult( + input_path=str(case_analysis), + candidate_input_path=str(candidate_case_analysis) if candidate_case_analysis else None, + baseline_selector=_selector_label(config=baseline_config, strategy=baseline_strategy), + candidate_selector=_selector_label(config=candidate_config, strategy=candidate_strategy), + summary=summarize_comparisons(comparisons), + comparisons=comparisons, + ) + if output is not None: + write_comparisons(comparisons, output, format) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/dd_parser_compat.py b/tools/measurement/dd_parser_compat.py new file mode 100644 index 00000000..d3aa441c --- /dev/null +++ b/tools/measurement/dd_parser_compat.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark-only DataDesigner structured-output parser compatibility modes.""" + +from __future__ import annotations + +import json +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from enum import StrEnum +from typing import Any + +from data_designer.engine.models.parsers.errors import ParserException +from data_designer.engine.models.recipes import response_recipes as recipes +from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate + + +class DDParserCompatMode(StrEnum): + none = "none" + raw_json = "raw_json" + + +@contextmanager +def dd_parser_compat_context(mode: DDParserCompatMode) -> Iterator[None]: + """Temporarily patch DataDesigner parsers for benchmark-only compatibility.""" + if mode == DDParserCompatMode.none: + yield + return + if mode != DDParserCompatMode.raw_json: + raise ValueError(f"unsupported DataDesigner parser compatibility mode: {mode}") + + original_pydantic = recipes.PydanticResponseRecipe._build_parser_fn + original_structured = recipes.StructuredResponseRecipe._build_parser_fn + recipes.PydanticResponseRecipe._build_parser_fn = _tolerant_pydantic_builder(original_pydantic) # type: ignore[method-assign] + recipes.StructuredResponseRecipe._build_parser_fn = _tolerant_structured_builder(original_structured) # type: ignore[method-assign] + try: + yield + finally: + recipes.PydanticResponseRecipe._build_parser_fn = original_pydantic # type: ignore[method-assign] + recipes.StructuredResponseRecipe._build_parser_fn = original_structured # type: ignore[method-assign] + + +def _tolerant_pydantic_builder(original: Callable[..., Callable[[str], Any]]) -> Callable[..., Callable[[str], Any]]: + def build_parser(self: Any) -> Callable[[str], Any]: + base_parse = original(self) + + def parse(response: str) -> Any: + try: + return base_parse(response) + except ParserException as exc: + try: + return self.data_type.model_validate(_load_embedded_json(response)) + except Exception: + raise exc + + return parse + + return build_parser + + +def _tolerant_structured_builder( + original: Callable[..., Callable[[str], dict]], +) -> Callable[..., Callable[[str], dict]]: + def build_parser(self: Any) -> Callable[[str], dict]: + base_parse = original(self) + + def parse(response: str) -> dict: + try: + return base_parse(response) + except ParserException as exc: + try: + return validate(_load_embedded_json(response), **self._validate_args) + except (json.JSONDecodeError, JSONSchemaValidationError, TypeError, ValueError): + raise exc + + return parse + + return build_parser + + +def _load_embedded_json(response: str) -> Any: + """Return the largest JSON object/array embedded in a model response.""" + decoder = json.JSONDecoder() + stripped = response.strip() + try: + return json.loads(stripped) + except json.JSONDecodeError: + pass + + best: tuple[int, int, Any] | None = None + for start, char in enumerate(response): + if char not in "{[": + continue + try: + parsed, end = decoder.raw_decode(response, start) + except json.JSONDecodeError: + continue + if not isinstance(parsed, dict | list): + continue + # Prefer the candidate that consumes the most response text. That + # selects the outer response object instead of nested item objects. + if best is None or end > best[1] or (end == best[1] and start < best[0]): + best = (start, end, parsed) + + if best is None: + raise json.JSONDecodeError("No embedded JSON object or array found", response, 0) + return best[2] diff --git a/tools/measurement/detection_strategies.py b/tools/measurement/detection_strategies.py new file mode 100644 index 00000000..8c45c6ab --- /dev/null +++ b/tools/measurement/detection_strategies.py @@ -0,0 +1,1695 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Experimental detection strategies for benchmark-only performance probes.""" + +from __future__ import annotations + +import json +import time +from collections import Counter +from collections.abc import Callable, Iterator +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + +import pandas as pd +from data_designer.config import custom_column_generator +from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig +from data_designer.config.models import ModelConfig +from dd_parser_compat import _load_embedded_json +from direct_detection_probe import DirectDetectionRequest, DirectGenerationRequest, PromptMode, build_direct_prompt +from staged_detection_probe import ( + CaseStatus as StagedCaseStatus, +) +from staged_detection_probe import ( + DirectCompletion, + DirectDetectionClient, + GlinerSeedClient, + HttpxDirectDetectionClient, + SeedSource, + StagedDetectionCase, + StagedDetectionRequest, + StagedExecutionConfig, + ValidationPromptMode, + _run_augmentation_phase, + _run_validation_phase, + execute_staged_detection_case, + normalize_birth_context_final_entities, +) + +from anonymizer.config.models import DetectionModelSelection +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_MERGED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAGGED_TEXT, + COL_TEXT, + COL_VALIDATED_ENTITIES, + COL_VALIDATION_DECISIONS, + _jinja, +) +from anonymizer.engine.detection import detection_workflow as dw +from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams, make_chunked_validation_generator +from anonymizer.engine.detection.custom_columns import ( + apply_validation_and_finalize, + apply_validation_to_seed_entities, + enrich_validation_decisions, + merge_and_build_candidates, + parse_detected_entities, + prepare_validation_inputs, +) +from anonymizer.engine.detection.postprocess import ( + EntitySpan, + build_tagged_text, + expand_entity_occurrences, + resolve_overlaps, +) +from anonymizer.engine.ndd.adapter import FailedRecord +from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases +from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema, ValidationCandidatesSchema +from anonymizer.measurement import record_model_workflow + +_NATIVE_DIRECT_MODEL_ALIAS = "native-direct" +_GLINER_DIRECT_MODEL_ALIAS = "gliner-direct" +_DIRECT_MODEL_NAME = "configured-native-model" +_DIRECT_MODEL_PROVIDER = "configured-native-provider" +_DIRECT_MAX_TOKENS = 4096 +_DIRECT_TIMEOUT_SEC = 180.0 +_DIRECT_MAX_WORKERS = 4 + + +class ExperimentalDetectionStrategy(StrEnum): + default = "default" + prose_augment_focus = "prose_augment_focus" + compact_validation = "compact_validation" + no_augment = "no_augment" + detector_only = "detector_only" + native_candidate_validate_no_augment = "native_candidate_validate_no_augment" + detector_native_validate_no_augment = "detector_native_validate_no_augment" + detector_native_validate_native_augment = "detector_native_validate_native_augment" + gliner_native_validate_no_augment = "gliner_native_validate_no_augment" + gliner_native_validate_native_augment = "gliner_native_validate_native_augment" + native_single_pass = "native_single_pass" + native_single_pass_recall = "native_single_pass_recall" + native_single_pass_values = "native_single_pass_values" + native_single_pass_values_recall = "native_single_pass_values_recall" + + +_DetectAndValidate = Callable[..., dw.EntityDetectionResult] +_AugmentPrompt = Callable[..., str] +PROSE_AUGMENT_FOCUS_TEXT = """\ +Contextual prose recall focus: +- Re-scan untagged narrative prose for organization and institution names, named facilities, labs, research centers, street or place names, self-described beliefs, occupations, titles, family member names, and other quasi-identifiers that combine with already tagged entities. +- Prefer allowed labels when present, especially organization_name, company_name, place_name, religious_belief, political_view, occupation, first_name, last_name, city, state, university, degree, field_of_study, language, race_ethnicity, and age. +- Do not add generic common nouns, section headings, or labels outside the allowed list when strict labels are required. +""" + + +@dataclass(frozen=True) +class NativeDetectionRuntime: + """Runtime settings for benchmark-only native detection strategies.""" + + endpoint: str | None = None + model: str = _DIRECT_MODEL_NAME + provider: str = _DIRECT_MODEL_PROVIDER + alias: str = _NATIVE_DIRECT_MODEL_ALIAS + max_tokens: int = _DIRECT_MAX_TOKENS + timeout_sec: float = _DIRECT_TIMEOUT_SEC + gliner_endpoint: str | None = None + gliner_model: str = _DIRECT_MODEL_NAME + gliner_provider: str = _DIRECT_MODEL_PROVIDER + gliner_alias: str = _GLINER_DIRECT_MODEL_ALIAS + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = 0.3 + max_workers: int = _DIRECT_MAX_WORKERS + + +@dataclass(frozen=True) +class _NativeStagedTask: + ordinal: int + index: Any + row: pd.Series + + +@dataclass(frozen=True) +class _NativeStagedRunParams: + labels: list[str] + client: DirectDetectionClient + gliner_seed_client: GlinerSeedClient | None + runtime: NativeDetectionRuntime + data_summary: str | None + validation_prompt_mode: ValidationPromptMode + validation_max_entities_per_call: int + validation_excerpt_window_chars: int + seed_source: SeedSource + workflow_name: str + skip_augmentation: bool + + +@dataclass(frozen=True) +class _NativeStagedRowResult: + ordinal: int + index: Any + output_row: dict[str, Any] | None + failed_record: FailedRecord | None + case: StagedDetectionCase | None + + +@dataclass(frozen=True) +class _DetectorNativeValidationParams: + labels: list[str] + client: DirectDetectionClient + runtime: NativeDetectionRuntime + data_summary: str | None + validation_prompt_mode: ValidationPromptMode + validation_max_entities_per_call: int + validation_excerpt_window_chars: int + workflow_name: str + skip_augmentation: bool + + +@dataclass(frozen=True) +class _DetectorNativeValidationRowResult: + ordinal: int + index: Any + workflow_name: str + runtime: NativeDetectionRuntime + output_row: dict[str, Any] | None + failed_record: FailedRecord | None + completion: DirectCompletion | None + elapsed_sec: float + request_count: int + + +@contextmanager +def experimental_detection_strategy_context( + strategy: ExperimentalDetectionStrategy, + *, + native_client: DirectDetectionClient | None = None, + gliner_seed_client: GlinerSeedClient | None = None, + native_runtime: NativeDetectionRuntime | None = None, +) -> Iterator[None]: + """Temporarily apply a benchmark-only detection strategy.""" + if strategy == ExperimentalDetectionStrategy.default: + yield + return + + original_method = dw.EntityDetectionWorkflow.detect_and_validate_entities + original_augment_prompt = dw._get_augment_prompt + if strategy == ExperimentalDetectionStrategy.prose_augment_focus: + dw._get_augment_prompt = _make_prose_augment_prompt(original_augment_prompt) # type: ignore[assignment] + else: + dw.EntityDetectionWorkflow.detect_and_validate_entities = _method_for_strategy( # type: ignore[method-assign] + strategy, + original=original_method, + native_client=native_client, + gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime or NativeDetectionRuntime(), + ) + try: + yield + finally: + dw.EntityDetectionWorkflow.detect_and_validate_entities = original_method # type: ignore[method-assign] + dw._get_augment_prompt = original_augment_prompt # type: ignore[assignment] + + +def _make_prose_augment_prompt(original: _AugmentPrompt) -> _AugmentPrompt: + def get_augment_prompt(*, data_summary: str | None, labels: list[str], strict_labels: bool = False) -> str: + prompt = original(data_summary=data_summary, labels=labels, strict_labels=strict_labels) + return prompt.replace("Rules:\n", f"{PROSE_AUGMENT_FOCUS_TEXT}\nRules:\n", 1) + + return get_augment_prompt + + +def _method_for_strategy( + strategy: ExperimentalDetectionStrategy, + *, + original: _DetectAndValidate | None = None, + native_client: DirectDetectionClient | None = None, + gliner_seed_client: GlinerSeedClient | None = None, + native_runtime: NativeDetectionRuntime | None = None, +) -> _DetectAndValidate: + runtime = native_runtime or NativeDetectionRuntime() + if strategy == ExperimentalDetectionStrategy.compact_validation: + if original is None: + raise ValueError("compact_validation requires the original detection method") + return _make_default_compact_validation_method(original) + if strategy == ExperimentalDetectionStrategy.no_augment: + return _make_validated_no_augment_method() + if strategy == ExperimentalDetectionStrategy.detector_only: + return _detect_with_detector_only + if strategy == ExperimentalDetectionStrategy.native_candidate_validate_no_augment: + return _make_native_candidate_validate_no_augment_method(native_client=native_client, native_runtime=runtime) + if strategy == ExperimentalDetectionStrategy.detector_native_validate_no_augment: + return _make_detector_native_validate_no_augment_method(native_client=native_client, native_runtime=runtime) + if strategy == ExperimentalDetectionStrategy.detector_native_validate_native_augment: + return _make_detector_native_validate_native_augment_method(native_client=native_client, native_runtime=runtime) + if strategy == ExperimentalDetectionStrategy.gliner_native_validate_no_augment: + return _make_gliner_native_validate_no_augment_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + native_runtime=runtime, + ) + if strategy == ExperimentalDetectionStrategy.gliner_native_validate_native_augment: + return _make_gliner_native_validate_native_augment_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + native_runtime=runtime, + ) + if strategy == ExperimentalDetectionStrategy.native_single_pass: + return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime) + if strategy == ExperimentalDetectionStrategy.native_single_pass_recall: + return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime, recall_prompt=True) + if strategy == ExperimentalDetectionStrategy.native_single_pass_values: + return _make_native_single_pass_method( + native_client=native_client, + native_runtime=runtime, + value_only_prompt=True, + ) + if strategy == ExperimentalDetectionStrategy.native_single_pass_values_recall: + return _make_native_single_pass_method( + native_client=native_client, + native_runtime=runtime, + recall_prompt=True, + value_only_prompt=True, + ) + raise ValueError(f"unsupported experimental detection strategy: {strategy}") + + +def _make_default_compact_validation_method(original: _DetectAndValidate) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return original( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=False, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + + return detect_and_validate_entities + + +def _make_validated_no_augment_method() -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return _run_validated_no_augment_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + entity_labels=entity_labels, + data_summary=data_summary, + ) + + return detect_and_validate_entities + + +def _run_validated_no_augment_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + entity_labels: list[str] | None, + data_summary: str | None, +) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + detection_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_validated_no_augment_columns( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ), + workflow_name="entity-detection-no-augment", + preview_num_records=preview_num_records, + ) + return dw.EntityDetectionResult( + dataframe=detection_result.dataframe.copy(), + failed_records=detection_result.failed_records, + ) + + +def _validated_no_augment_columns( + *, + selected_models: DetectionModelSelection, + labels: list[str], + data_summary: str | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, +) -> list[LLMTextColumnConfig | CustomColumnConfig]: + validator_params = _validator_params( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ) + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, prompt=_jinja(COL_TEXT), model_alias=_detector_alias(selected_models) + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), + CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), + _validation_decisions_column(selected_models, validator_params), + CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), + CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=apply_validation_to_seed_entities), + CustomColumnConfig(name=COL_AUGMENTED_ENTITIES, generator_function=_empty_augmentation), + CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), + CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=apply_validation_and_finalize), + ] + + +def _validation_decisions_column( + selected_models: DetectionModelSelection, + validator_params: ChunkedValidationParams, +) -> CustomColumnConfig: + return CustomColumnConfig( + name=COL_VALIDATION_DECISIONS, + generator_function=make_chunked_validation_generator( + resolve_model_aliases("entity_validator", selected_models) + ), + generator_params=validator_params, + drop=True, + ) + + +def _validator_params( + *, + selected_models: DetectionModelSelection, + labels: list[str], + data_summary: str | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, +) -> ChunkedValidationParams: + validator_aliases = resolve_model_aliases("entity_validator", selected_models) + return ChunkedValidationParams( + pool=list(validator_aliases), + max_entities_per_call=validation_max_entities_per_call, + excerpt_window_chars=validation_excerpt_window_chars, + prompt_template=dw._get_validation_prompt(data_summary=data_summary, labels=labels), + ) + + +def _detector_alias(selected_models: DetectionModelSelection) -> str: + return resolve_model_alias("entity_detector", selected_models) + + +def _detect_with_detector_only( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, +) -> dw.EntityDetectionResult: + return _run_detector_only_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + entity_labels=entity_labels, + preview_num_records=preview_num_records, + workflow_name="entity-detection-detector-only", + ) + + +def _run_detector_only_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + entity_labels: list[str] | None, + preview_num_records: int | None, + workflow_name: str, +) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + detection_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_detector_only_columns(selected_models), + workflow_name=workflow_name, + preview_num_records=preview_num_records, + ) + return dw.EntityDetectionResult( + dataframe=detection_result.dataframe.copy(), + failed_records=detection_result.failed_records, + ) + + +def _detector_only_columns(selected_models: DetectionModelSelection) -> list[LLMTextColumnConfig | CustomColumnConfig]: + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, + prompt=_jinja(COL_TEXT), + model_alias=_detector_alias(selected_models), + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), + CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=_copy_seed_entities_json), + CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=_finalize_detector_only), + ] + + +@custom_column_generator(required_columns=[COL_SEED_ENTITIES]) +def _copy_seed_entities_json(row: dict[str, Any]) -> dict[str, Any]: + row[COL_SEED_ENTITIES_JSON] = json.dumps( + [span.as_dict() for span in _entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))] + ) + return row + + +@custom_column_generator( + required_columns=[COL_TEXT, COL_SEED_ENTITIES], + side_effect_columns=[COL_TAGGED_TEXT], +) +def _finalize_detector_only(row: dict[str, Any]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + spans = expand_entity_occurrences(text=text, entities=_entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))) + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump(mode="json") + row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) + return row + + +@custom_column_generator(required_columns=[COL_TEXT]) +def _empty_augmentation(row: dict[str, Any]) -> dict[str, Any]: + row[COL_AUGMENTED_ENTITIES] = AugmentedEntitiesSchema().model_dump(mode="json") + return row + + +def _entity_spans_from_payload(raw_payload: object) -> list[EntitySpan]: + return [ + EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + for entity in EntitiesSchema.from_raw(raw_payload).entities + ] + + +def _make_native_single_pass_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, + recall_prompt: bool = False, + value_only_prompt: bool = False, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + client = _native_client_or_default(native_client, native_runtime) + return _run_native_single_pass_detection( + dataframe, + labels=labels, + client=client, + runtime=native_runtime, + data_summary=data_summary, + preview_num_records=preview_num_records, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + + return detect_and_validate_entities + + +def _run_native_single_pass_detection( + dataframe: pd.DataFrame, + *, + labels: list[str], + client: DirectDetectionClient, + runtime: NativeDetectionRuntime, + data_summary: str | None, + preview_num_records: int | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> dw.EntityDetectionResult: + source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records: list[FailedRecord] = [] + + for index, row in source_df.iterrows(): + output_row, failed_record = _execute_native_single_pass_row( + row, + index=index, + labels=labels, + client=client, + runtime=runtime, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + if failed_record is not None: + failed_records.append(failed_record) + continue + if output_row is None: + failed_records.append(_native_single_pass_failed_record(index, error="native single-pass produced no row")) + continue + output_rows.append(output_row) + output_indices.append(index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), + failed_records=failed_records, + ) + + +def _execute_native_single_pass_row( + row: pd.Series, + *, + index: object, + labels: list[str], + client: DirectDetectionClient, + runtime: NativeDetectionRuntime, + data_summary: str | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> tuple[dict[str, Any] | None, FailedRecord | None]: + text = str(row.get(COL_TEXT, "")) + started = time.perf_counter() + try: + completion = _complete_native_single_pass( + text=text, + labels=labels, + client=client, + runtime=runtime, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + except Exception as exc: # noqa: BLE001 - benchmark experiment records case-local failures + _record_native_single_pass_request_error(elapsed_sec=time.perf_counter() - started, runtime=runtime) + return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") + try: + spans = _native_single_pass_spans(text, completion.content, labels=labels) + except Exception as exc: # noqa: BLE001 - parser fragility is a benchmark failure mode + _record_native_single_pass_completion( + completion, + status="error", + output_row_count=0, + failed_record_count=1, + runtime=runtime, + ) + return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") + _record_native_single_pass_completion( + completion, + status="completed", + output_row_count=1, + failed_record_count=0, + runtime=runtime, + ) + return _native_single_pass_result_row(row, spans=spans), None + + +def _complete_native_single_pass( + *, + text: str, + labels: list[str], + client: DirectDetectionClient, + runtime: NativeDetectionRuntime, + data_summary: str | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> Any: + return client.complete( + DirectGenerationRequest( + endpoint=runtime.endpoint or "", + model=runtime.model, + prompt=_native_single_pass_prompt( + text=text, + labels=labels, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ), + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, + ) + ) + + +def _native_single_pass_prompt( + *, + text: str, + labels: list[str], + data_summary: str | None, + recall_prompt: bool = False, + value_only_prompt: bool = False, +) -> str: + if value_only_prompt: + return build_direct_prompt( + DirectDetectionRequest( + case_id="native-single-pass-values", + text=text, + labels=labels, + prompt_mode=PromptMode.recall if recall_prompt else PromptMode.compact, + data_summary=data_summary, + ) + ) + + label_text = dw._format_label_examples(labels) if recall_prompt else ", ".join(labels) + recall_block = _native_single_pass_recall_block() if recall_prompt else "" + summary = f"\nData context: {data_summary}\n" if data_summary else "" + return f"""Extract privacy-sensitive entities from the input text in one pass. +{summary} +Use only these labels: +{label_text} + +Rules: +- Return exact substrings from the input text. +- Do not invent values. +- `start` and `end` must be zero-based Python character offsets, where text[start:end] equals `value`. +- Missing a sensitive value is worse than returning one extra plausible value. +- Skip generic nouns, syntax, and non-sensitive filler. +{recall_block}- Return only a JSON object with this shape: + {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "start": 0, "end": 0, "reason": "short reason"}}]}} + +Input text: +--- +{text} +---""" + + +def _native_single_pass_recall_block() -> str: + return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. +- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. +""" + + +def _native_single_pass_spans(text: str, content: str, *, labels: list[str]) -> list[EntitySpan]: + payload = _load_embedded_json(content) + if not isinstance(payload, dict): + raise ValueError("native single-pass response must be a JSON object") + entities = payload.get("entities") + if not isinstance(entities, list): + raise ValueError("native single-pass response must contain an entities list") + spans: list[EntitySpan] = [] + allowed = set(labels) + for item in entities: + if not isinstance(item, dict): + continue + value = str(item.get("value", "")).strip() + label = str(item.get("label", "")).strip() + if not value or label not in allowed: + continue + offset_span = _native_single_pass_offset_span(text, item, value=value, label=label) + if offset_span is not None: + spans.append(offset_span) + continue + spans.extend(_native_single_pass_value_spans(text, value=value, label=label)) + return resolve_overlaps(spans) + + +def _native_single_pass_offset_span( + text: str, + item: dict[str, Any], + *, + value: str, + label: str, +) -> EntitySpan | None: + start = _coerce_native_offset(item.get("start")) + end = _coerce_native_offset(item.get("end")) + if start is None or end is None or start < 0 or end <= start or end > len(text): + return None + if text[start:end] != value: + return None + return _native_single_pass_span(value=value, label=label, start=start, end=end) + + +def _coerce_native_offset(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, str) and value.strip().isdigit(): + return int(value.strip()) + return None + + +def _native_single_pass_value_spans(text: str, *, value: str, label: str) -> list[EntitySpan]: + spans: list[EntitySpan] = [] + start = 0 + while True: + match_start = text.find(value, start) + if match_start < 0: + return spans + match_end = match_start + len(value) + spans.append(_native_single_pass_span(value=value, label=label, start=match_start, end=match_end)) + start = match_end + + +def _native_single_pass_span(*, value: str, label: str, start: int, end: int) -> EntitySpan: + return EntitySpan( + entity_id=f"{label}_{start}_{end}", + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="direct_single_pass", + ) + + +def _native_single_pass_result_row(row: pd.Series, *, spans: list[EntitySpan]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + output_row = row.to_dict() + output_row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump( + mode="json" + ) + output_row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) + return output_row + + +def _native_single_pass_failed_record(index: object, *, error: str | None) -> FailedRecord: + return FailedRecord( + record_id=str(index), + step="entity-detection-native-single-pass", + reason=error or "native single-pass detection failed", + ) + + +def _record_native_single_pass_completion( + completion: Any, + *, + status: str, + output_row_count: int, + failed_record_count: int, + runtime: NativeDetectionRuntime, +) -> None: + record_model_workflow( + workflow_name="entity-detection-native-single-pass", + model_aliases=[runtime.alias], + input_row_count=1, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=float(getattr(completion, "elapsed_sec", 0.0) or 0.0), + status=status, + model_usage=_native_single_pass_model_usage( + successful_requests=1, + failed_requests=0, + usage=dict(getattr(completion, "usage", {}) or {}), + runtime=runtime, + ), + ) + + +def _record_native_single_pass_request_error(*, elapsed_sec: float, runtime: NativeDetectionRuntime) -> None: + record_model_workflow( + workflow_name="entity-detection-native-single-pass", + model_aliases=[runtime.alias], + input_row_count=1, + output_row_count=0, + failed_record_count=1, + elapsed_sec=elapsed_sec, + status="error", + model_usage=_native_single_pass_model_usage( + successful_requests=0, + failed_requests=1, + usage={}, + runtime=runtime, + ), + ) + + +def _native_single_pass_model_usage( + *, + successful_requests: int, + failed_requests: int, + usage: dict[str, int], + runtime: NativeDetectionRuntime, +) -> dict[str, dict[str, Any]]: + total_requests = successful_requests + failed_requests + return { + runtime.alias: { + "model_alias": runtime.alias, + "model_name": runtime.model, + "model_provider_name": runtime.provider, + "request_usage": { + "successful_requests": successful_requests, + "failed_requests": failed_requests, + "total_requests": total_requests, + }, + "token_usage": _native_token_usage(usage), + } + } + + +def _native_client_or_default( + native_client: DirectDetectionClient | None, + runtime: NativeDetectionRuntime, +) -> DirectDetectionClient: + if native_client is not None: + return native_client + _require_native_endpoint(runtime) + return HttpxDirectDetectionClient() + + +def _require_native_endpoint(runtime: NativeDetectionRuntime) -> None: + if not runtime.endpoint or not runtime.model: + raise ValueError( + "native detection strategies require configured native endpoint and model; " + "set native_runtime.endpoint and native_runtime.model in the benchmark suite" + ) + + +def _make_native_candidate_validate_no_augment_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=None, + native_runtime=native_runtime, + seed_source=SeedSource.direct_llm, + workflow_name="entity-detection-native-candidate-validate-no-augment", + skip_augmentation=True, + ) + + +def _make_gliner_native_validate_no_augment_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime, + seed_source=SeedSource.gliner, + workflow_name="entity-detection-gliner-native-validate-no-augment", + skip_augmentation=True, + ) + + +def _make_gliner_native_validate_native_augment_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime, + seed_source=SeedSource.gliner, + workflow_name="entity-detection-gliner-native-validate-native-augment", + skip_augmentation=False, + ) + + +def _make_detector_native_validate_no_augment_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: + return _make_detector_native_validate_method( + native_client=native_client, + native_runtime=native_runtime, + workflow_name="entity-detection-detector-native-validate-no-augment", + seed_workflow_name="entity-detection-detector-native-validate-no-augment-seed", + skip_augmentation=True, + ) + + +def _make_detector_native_validate_native_augment_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: + return _make_detector_native_validate_method( + native_client=native_client, + native_runtime=native_runtime, + workflow_name="entity-detection-detector-native-validate-native-augment", + seed_workflow_name="entity-detection-detector-native-validate-native-augment-seed", + skip_augmentation=False, + ) + + +def _make_detector_native_validate_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, + workflow_name: str, + seed_workflow_name: str, + skip_augmentation: bool, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + client = _native_client_or_default(native_client, native_runtime) + return _run_detector_native_validate_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + client=client, + runtime=native_runtime, + data_summary=data_summary, + workflow_name=workflow_name, + seed_workflow_name=seed_workflow_name, + skip_augmentation=skip_augmentation, + ) + + return detect_and_validate_entities + + +def _run_detector_native_validate_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + labels: list[str], + gliner_detection_threshold: float, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + validation_single_chunk_full_text: bool, + client: DirectDetectionClient, + runtime: NativeDetectionRuntime, + data_summary: str | None, + workflow_name: str, + seed_workflow_name: str, + skip_augmentation: bool, +) -> dw.EntityDetectionResult: + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + seed_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_detector_native_validate_seed_columns(selected_models), + workflow_name=seed_workflow_name, + preview_num_records=preview_num_records, + ) + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records = list(seed_result.failed_records) + validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) + tasks = [ + _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) + for ordinal, (index, row) in enumerate(seed_result.dataframe.iterrows()) + ] + params = _DetectorNativeValidationParams( + labels=labels, + client=client, + runtime=runtime, + data_summary=data_summary, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + + for result in _execute_detector_native_validate_tasks(tasks, params=params): + _record_detector_native_validation_result(result) + if result.failed_record is not None: + failed_records.append(result.failed_record) + continue + if result.output_row is None: + failed_records.append( + _native_failed_record( + result.index, + workflow_name=workflow_name, + error="native detector-seed validation produced no row", + ) + ) + continue + output_rows.append(result.output_row) + output_indices.append(result.index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe( + seed_result.dataframe, output_rows=output_rows, output_indices=output_indices + ), + failed_records=failed_records, + ) + + +def _execute_detector_native_validate_tasks( + tasks: list[_NativeStagedTask], + *, + params: _DetectorNativeValidationParams, +) -> list[_DetectorNativeValidationRowResult]: + if not tasks: + return [] + worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) + if worker_count == 1: + return [_execute_detector_native_validate_task(task, params=params) for task in tasks] + with ThreadPoolExecutor(max_workers=worker_count) as executor: + return list( + executor.map( + lambda task: _execute_detector_native_validate_task(task, params=params), + tasks, + ) + ) + + +def _execute_detector_native_validate_task( + task: _NativeStagedTask, + *, + params: _DetectorNativeValidationParams, +) -> _DetectorNativeValidationRowResult: + return _execute_detector_native_validate_row( + task.row, + index=task.index, + ordinal=task.ordinal, + labels=params.labels, + client=params.client, + runtime=params.runtime, + data_summary=params.data_summary, + validation_prompt_mode=params.validation_prompt_mode, + validation_max_entities_per_call=params.validation_max_entities_per_call, + validation_excerpt_window_chars=params.validation_excerpt_window_chars, + workflow_name=params.workflow_name, + skip_augmentation=params.skip_augmentation, + ) + + +def _detector_native_validate_seed_columns( + selected_models: DetectionModelSelection, +) -> list[LLMTextColumnConfig | CustomColumnConfig]: + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, + prompt=_jinja(COL_TEXT), + model_alias=_detector_alias(selected_models), + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), + CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), + ] + + +def _execute_detector_native_validate_row( + row: pd.Series, + *, + index: object, + ordinal: int, + labels: list[str], + client: DirectDetectionClient, + runtime: NativeDetectionRuntime, + data_summary: str | None, + validation_prompt_mode: ValidationPromptMode, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + workflow_name: str, + skip_augmentation: bool, +) -> _DetectorNativeValidationRowResult: + output_row = row.to_dict() + request = _native_staged_request(row, index=index, ordinal=ordinal, labels=labels, data_summary=data_summary) + config = StagedExecutionConfig( + endpoint=runtime.endpoint or "", + model=runtime.model, + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + skip_augmentation=skip_augmentation, + ) + request_count = _native_validation_request_count( + output_row, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + if not skip_augmentation: + request_count += 1 + started = time.perf_counter() + try: + validation_completion = _run_validation_phase(output_row, request, client, config) + augmentation_completion = _run_augmentation_phase(output_row, request, client, config) + completion = _combine_detector_native_completions([validation_completion, augmentation_completion]) + merge_and_build_candidates(output_row) + apply_validation_and_finalize(output_row) + normalize_birth_context_final_entities(output_row, allowed_labels=labels) + except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures + return _DetectorNativeValidationRowResult( + ordinal=ordinal, + index=index, + workflow_name=workflow_name, + runtime=runtime, + output_row=None, + failed_record=_native_failed_record( + index, + workflow_name=workflow_name, + error=f"{type(exc).__name__}: {exc}", + ), + completion=None, + elapsed_sec=time.perf_counter() - started, + request_count=request_count, + ) + return _DetectorNativeValidationRowResult( + ordinal=ordinal, + index=index, + workflow_name=workflow_name, + runtime=runtime, + output_row=output_row, + failed_record=None, + completion=completion, + elapsed_sec=time.perf_counter() - started, + request_count=request_count, + ) + + +def _native_validation_request_count( + row: dict[str, Any], + *, + validation_prompt_mode: ValidationPromptMode, + validation_max_entities_per_call: int, +) -> int: + candidate_count = len(ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})).candidates) + if candidate_count == 0: + return 0 + if validation_prompt_mode == ValidationPromptMode.full_text: + return 1 + return (candidate_count + validation_max_entities_per_call - 1) // validation_max_entities_per_call + + +def _combine_detector_native_completions(completions: list[DirectCompletion]) -> DirectCompletion: + return DirectCompletion( + content=completions[-1].content if completions else "", + elapsed_sec=sum(float(completion.elapsed_sec or 0.0) for completion in completions), + usage=_sum_usage_dicts([dict(completion.usage or {}) for completion in completions]), + ) + + +def _record_detector_native_validation_completion( + completion: DirectCompletion, + *, + request_count: int, + workflow_name: str, + runtime: NativeDetectionRuntime, +) -> None: + record_model_workflow( + workflow_name=workflow_name, + model_aliases=[runtime.alias], + input_row_count=1, + output_row_count=1, + failed_record_count=0, + elapsed_sec=float(completion.elapsed_sec or 0.0), + model_usage=_native_single_pass_model_usage( + successful_requests=request_count, + failed_requests=0, + usage=dict(completion.usage or {}), + runtime=runtime, + ), + ) + + +def _record_detector_native_validation_result(result: _DetectorNativeValidationRowResult) -> None: + if result.completion is not None: + _record_detector_native_validation_completion( + result.completion, + request_count=result.request_count, + workflow_name=result.workflow_name, + runtime=result.runtime, + ) + return + if result.failed_record is not None: + _record_detector_native_validation_error( + elapsed_sec=result.elapsed_sec, + request_count=result.request_count, + workflow_name=result.workflow_name, + runtime=result.runtime, + ) + + +def _record_detector_native_validation_error( + *, + elapsed_sec: float, + request_count: int, + workflow_name: str, + runtime: NativeDetectionRuntime, +) -> None: + record_model_workflow( + workflow_name=workflow_name, + model_aliases=[runtime.alias], + input_row_count=1, + output_row_count=0, + failed_record_count=1, + elapsed_sec=elapsed_sec, + status="error", + model_usage=_native_single_pass_model_usage( + successful_requests=0, + failed_requests=max(request_count, 1), + usage={}, + runtime=runtime, + ), + ) + + +def _make_native_staged_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, + seed_source: SeedSource, + workflow_name: str, + skip_augmentation: bool, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + _ = self, model_configs, selected_models, gliner_detection_threshold + labels = dw._resolve_detection_labels(entity_labels) + client = _native_client_or_default(native_client, native_runtime) + return _run_native_staged_detection( + dataframe, + labels=labels, + client=client, + gliner_seed_client=gliner_seed_client, + runtime=native_runtime, + data_summary=data_summary, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + seed_source=seed_source, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + + return detect_and_validate_entities + + +def _run_native_staged_detection( + dataframe: pd.DataFrame, + *, + labels: list[str], + client: DirectDetectionClient, + gliner_seed_client: GlinerSeedClient | None, + runtime: NativeDetectionRuntime, + data_summary: str | None, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + validation_single_chunk_full_text: bool, + seed_source: SeedSource, + workflow_name: str, + skip_augmentation: bool, +) -> dw.EntityDetectionResult: + source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records: list[FailedRecord] = [] + validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) + params = _NativeStagedRunParams( + labels=labels, + client=client, + gliner_seed_client=gliner_seed_client, + runtime=runtime, + data_summary=data_summary, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + seed_source=seed_source, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + tasks = [ + _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) + for ordinal, (index, row) in enumerate(source_df.iterrows()) + ] + + for result in _execute_native_staged_tasks(tasks, params=params): + if result.failed_record is not None: + failed_records.append(result.failed_record) + continue + if result.output_row is None: + failed_records.append( + _native_failed_record( + result.index, + workflow_name=workflow_name, + error="native staged detection produced no row", + ) + ) + continue + if result.case is not None: + _record_native_direct_usage(result.case, workflow_name=workflow_name, runtime=runtime) + output_rows.append(result.output_row) + output_indices.append(result.index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), + failed_records=failed_records, + ) + + +def _execute_native_staged_tasks( + tasks: list[_NativeStagedTask], + *, + params: _NativeStagedRunParams, +) -> list[_NativeStagedRowResult]: + if not tasks: + return [] + worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) + if worker_count == 1: + return [_execute_native_staged_task(task, params=params) for task in tasks] + with ThreadPoolExecutor(max_workers=worker_count) as executor: + return list(executor.map(lambda task: _execute_native_staged_task(task, params=params), tasks)) + + +def _native_staged_worker_count(task_count: int, *, runtime: NativeDetectionRuntime) -> int: + return max(1, min(task_count, runtime.max_workers)) + + +def _execute_native_staged_task( + task: _NativeStagedTask, + *, + params: _NativeStagedRunParams, +) -> _NativeStagedRowResult: + try: + request = _native_staged_request( + task.row, + index=task.index, + ordinal=task.ordinal, + labels=params.labels, + data_summary=params.data_summary, + ) + execution = execute_staged_detection_case( + request, + client=params.client, + seed_client=params.gliner_seed_client, + seed_source=params.seed_source, + endpoint=params.runtime.endpoint or "", + model=params.runtime.model, + gliner_endpoint=params.runtime.gliner_endpoint or "", + gliner_model=params.runtime.gliner_model, + gliner_api_key_env=params.runtime.gliner_api_key_env, + gliner_threshold=params.runtime.gliner_threshold, + max_tokens=params.runtime.max_tokens, + timeout_sec=params.runtime.timeout_sec, + skip_augmentation=params.skip_augmentation, + validation_prompt_mode=params.validation_prompt_mode, + validation_max_entities_per_call=params.validation_max_entities_per_call, + validation_excerpt_window_chars=params.validation_excerpt_window_chars, + ) + if execution.case.status != StagedCaseStatus.completed: + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=None, + failed_record=_native_failed_record( + task.index, + workflow_name=params.workflow_name, + error=execution.case.error, + ), + case=execution.case, + ) + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=_native_detection_result_row(task.row, execution_row=execution.row), + failed_record=None, + case=execution.case, + ) + except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=None, + failed_record=_native_failed_record( + task.index, + workflow_name=params.workflow_name, + error=f"{type(exc).__name__}: {exc}", + ), + case=None, + ) + + +def _native_validation_prompt_mode(validation_single_chunk_full_text: bool) -> ValidationPromptMode: + return ValidationPromptMode.full_text if validation_single_chunk_full_text else ValidationPromptMode.chunked_excerpt + + +def _native_failed_record(index: object, *, workflow_name: str, error: str | None) -> FailedRecord: + return FailedRecord( + record_id=str(index), + step=workflow_name, + reason=error or "native staged detection failed", + ) + + +def _record_native_direct_usage( + case: StagedDetectionCase, + *, + workflow_name: str, + runtime: NativeDetectionRuntime, +) -> None: + model_usage = _native_staged_model_usage(case, runtime=runtime) + if not model_usage: + return + record_model_workflow( + workflow_name=workflow_name, + model_aliases=sorted(model_usage), + input_row_count=1, + output_row_count=1 if case.status == StagedCaseStatus.completed else 0, + failed_record_count=0 if case.status == StagedCaseStatus.completed else 1, + elapsed_sec=case.model_elapsed_sec or case.elapsed_sec or 0.0, + model_usage=model_usage, + ) + + +def _native_staged_model_usage( + case: StagedDetectionCase, + *, + runtime: NativeDetectionRuntime, +) -> dict[str, dict[str, Any]]: + usage: dict[str, dict[str, Any]] = {} + native_requests = 0 + native_usage: list[dict[str, int]] = [] + + if case.phase_model_work.seed: + if case.seed_source == SeedSource.gliner: + usage[runtime.gliner_alias] = _direct_model_usage_entry( + alias=runtime.gliner_alias, + model_name=runtime.gliner_model, + provider_name=runtime.gliner_provider, + successful_requests=case.phase_model_requests.seed, + usage=case.phase_usage.seed, + ) + else: + native_requests += case.phase_model_requests.seed + native_usage.append(case.phase_usage.seed) + if case.phase_model_work.validation: + native_requests += case.phase_model_requests.validation + native_usage.append(case.phase_usage.validation) + if case.phase_model_work.augmentation: + native_requests += case.phase_model_requests.augmentation + native_usage.append(case.phase_usage.augmentation) + if native_requests: + usage[runtime.alias] = _direct_model_usage_entry( + alias=runtime.alias, + model_name=runtime.model, + provider_name=runtime.provider, + successful_requests=native_requests, + usage=_sum_usage_dicts(native_usage), + ) + return usage + + +def _direct_model_usage_entry( + *, + alias: str, + model_name: str, + provider_name: str, + successful_requests: int, + usage: dict[str, int], +) -> dict[str, Any]: + return { + "model_alias": alias, + "model_name": model_name, + "model_provider_name": provider_name, + "request_usage": { + "successful_requests": successful_requests, + "failed_requests": 0, + "total_requests": successful_requests, + }, + "token_usage": _native_token_usage(usage), + } + + +def _sum_usage_dicts(usages: list[dict[str, int]]) -> dict[str, int]: + totals: Counter[str] = Counter() + for usage in usages: + for key, value in usage.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def _native_token_usage(usage: dict[str, int]) -> dict[str, int]: + input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0)) + output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0)) + total_tokens = usage.get("total_tokens", input_tokens + output_tokens) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + + +def _native_staged_request( + row: pd.Series, + *, + index: object, + ordinal: int, + labels: list[str], + data_summary: str | None, +) -> StagedDetectionRequest: + return StagedDetectionRequest( + case_id=f"native-staged-{ordinal}", + text=str(row.get(COL_TEXT, "")), + labels=labels, + row_index=_safe_row_index(index, fallback=ordinal), + data_summary=data_summary, + ) + + +def _native_detection_result_row(row: pd.Series, *, execution_row: dict[str, Any]) -> dict[str, Any]: + output_row = row.to_dict() + output_row[COL_DETECTED_ENTITIES] = execution_row.get( + COL_DETECTED_ENTITIES, + EntitiesSchema().model_dump(mode="json"), + ) + output_row[COL_TAGGED_TEXT] = execution_row.get(COL_TAGGED_TEXT, str(row.get(COL_TEXT, ""))) + return output_row + + +def _safe_row_index(index: object, *, fallback: int) -> int: + try: + return int(index) # type: ignore[arg-type] + except (TypeError, ValueError): + return fallback + + +def _native_output_dataframe( + source_df: pd.DataFrame, + *, + output_rows: list[dict[str, Any]], + output_indices: list[Any], +) -> pd.DataFrame: + if output_rows: + return pd.DataFrame(output_rows, index=output_indices) + output = source_df.iloc[0:0].copy() + output[COL_DETECTED_ENTITIES] = pd.Series(dtype="object") + output[COL_TAGGED_TEXT] = pd.Series(dtype="object") + return output diff --git a/tools/measurement/direct_detection_probe.py b/tools/measurement/direct_detection_probe.py new file mode 100644 index 00000000..8659fec0 --- /dev/null +++ b/tools/measurement/direct_detection_probe.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run benchmark-only DD-free direct entity extraction probes. + +Usage: + uv run python tools/measurement/direct_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography --labels age,city,first_name,last_name,occupation \ + --output /tmp/direct-probe --overwrite +""" + +from __future__ import annotations + +import json +import logging +import os +import shutil +import sys +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any, Protocol + +import cyclopts +import httpx +import pandas as pd +from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities +from dd_parser_compat import _load_embedded_json +from pydantic import BaseModel, Field, ValidationError, model_validator + +from anonymizer.engine.detection.detection_workflow import _format_label_examples +from anonymizer.engine.detection.postprocess import EntitySpan, apply_augmented_entities, expand_entity_occurrences +from anonymizer.engine.schemas import EntitySchema + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.direct_detection_probe") + +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" +_UNCONFIGURED_MODEL = "configured-native-model" + + +class CaseStatus(StrEnum): + completed = "completed" + error = "error" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class PromptMode(StrEnum): + compact = "compact" + recall = "recall" + + +class DirectCompletion(BaseModel): + content: str + elapsed_sec: float + usage: dict[str, Any] = Field(default_factory=dict) + + +class DirectDetectionRequest(BaseModel): + case_id: str + text: str + labels: list[str] = Field(min_length=1) + row_index: int = 0 + prompt_mode: PromptMode = PromptMode.compact + data_summary: str | None = None + + @model_validator(mode="after") + def normalize_labels(self) -> DirectDetectionRequest: + self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) + if not self.labels: + raise ValueError("labels must contain at least one non-empty label") + return self + + +class DirectGenerationRequest(BaseModel): + endpoint: str + model: str + prompt: str + max_tokens: int = Field(gt=0) + temperature: float = 0.0 + top_p: float = 1.0 + timeout_sec: float = Field(gt=0) + json_mode: bool = True + disable_thinking: bool = True + + +class SignatureComparison(BaseModel): + baseline_final_entity_signature_count: int + shared_final_entity_signature_count: int + baseline_only_final_entity_signature_count: int + direct_only_final_entity_signature_count: int + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + direct_only_label_counts: dict[str, int] = Field(default_factory=dict) + + +class DirectDetectionCase(BaseModel): + case_id: str + row_index: int + status: CaseStatus + elapsed_sec: float | None = None + usage: dict[str, Any] = Field(default_factory=dict) + raw_suggestion_count: int = 0 + allowed_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_label_counts: dict[str, int] = Field(default_factory=dict) + comparison: SignatureComparison | None = None + artifact: DetectionArtifactRow | None = None + error: str | None = None + + +class DirectDetectionRun(BaseModel): + input_path: str + text_column: str + endpoint: str + model: str + prompt_mode: PromptMode + labels: list[str] + rows: list[DirectDetectionCase] = Field(default_factory=list) + + @property + def error_count(self) -> int: + return sum(1 for row in self.rows if row.status == CaseStatus.error) + + +class DirectDetectionClient(Protocol): + def complete(self, request: DirectGenerationRequest) -> DirectCompletion: ... + + +class HttpxDirectDetectionClient: + def complete(self, request: DirectGenerationRequest) -> DirectCompletion: + payload = build_chat_payload(request) + response = httpx.post( + f"{request.endpoint.rstrip('/')}/chat/completions", + json=payload, + timeout=request.timeout_sec, + ) + response.raise_for_status() + data = response.json() + return DirectCompletion( + content=str(data["choices"][0]["message"].get("content") or ""), + elapsed_sec=float(response.elapsed.total_seconds()), + usage=data.get("usage") or {}, + ) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + return + logger.error("bad_input error=%s", error) + + +def build_chat_payload(request: DirectGenerationRequest) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": request.model, + "messages": [ + {"role": "system", "content": "You are a precise entity extraction engine. Return JSON only."}, + {"role": "user", "content": request.prompt}, + ], + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": request.max_tokens, + } + if request.json_mode: + payload["response_format"] = {"type": "json_object"} + if request.disable_thinking: + payload["chat_template_kwargs"] = {"enable_thinking": False} + return payload + + +def build_direct_prompt(request: DirectDetectionRequest) -> str: + label_text = ( + _format_label_examples(request.labels) + if request.prompt_mode == PromptMode.recall + else ", ".join(request.labels) + ) + recall_block = _recall_block() if request.prompt_mode == PromptMode.recall else "" + summary = f"\nData context: {request.data_summary}\n" if request.data_summary else "" + return f"""Extract privacy-sensitive entities from the input text. +{summary} +Use only these labels: +{label_text} + +Rules: +- Return exact substrings from the input text. +- Do not invent values. +- Prefer specific labels from the allowed list. +- Skip generic nouns, syntax, and non-sensitive filler. +{recall_block}- Return only a JSON object with this shape: + {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} + +Input text: +--- +{request.text} +---""" + + +def _recall_block() -> str: + return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. +- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. +""" + + +def run_direct_detection_case( + request: DirectDetectionRequest, + *, + client: DirectDetectionClient, + endpoint: str | None = None, + model: str | None = None, + max_tokens: int = 4096, + timeout_sec: float = 180.0, +) -> DirectDetectionCase: + endpoint = endpoint or _UNCONFIGURED_ENDPOINT + model = model or _UNCONFIGURED_MODEL + try: + completion = client.complete( + DirectGenerationRequest( + endpoint=endpoint, + model=model, + prompt=build_direct_prompt(request), + max_tokens=max_tokens, + timeout_sec=timeout_sec, + ) + ) + suggestions = _extract_suggestions(completion.content) + allowed_suggestions = filter_direct_suggestions(suggestions, request.labels) + artifact = finalize_direct_suggestions( + text=request.text, + suggestions=allowed_suggestions, + labels=request.labels, + row_index=request.row_index, + workflow_name="direct-detection", + ) + return _completed_case(request, completion, suggestions, allowed_suggestions, artifact) + except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures + return DirectDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + status=CaseStatus.error, + error=f"{type(exc).__name__}: {exc}", + ) + + +def _extract_suggestions(content: str) -> list[dict[str, Any]]: + payload = _load_embedded_json(content) + if not isinstance(payload, dict): + return [] + suggestions = payload.get("entities") + return suggestions if isinstance(suggestions, list) else [] + + +def finalize_direct_suggestions( + *, + text: str, + suggestions: list[dict[str, Any]], + labels: list[str], + row_index: int, + workflow_name: str, +) -> DetectionArtifactRow: + cleaned = filter_direct_suggestions(suggestions, labels) + direct_spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) + entities = [ + _span_to_entity_schema(span, source="direct_llm") for span in expand_entity_occurrences(text, direct_spans) + ] + return build_detection_artifact_row_from_entities( + workflow_name=workflow_name, + batch_file="direct-detection", + row_index=row_index, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=entities, + final_entities=entities, + ) + + +def filter_direct_suggestions(suggestions: list[dict[str, Any]], labels: list[str]) -> list[dict[str, str]]: + allowed = set(labels) + cleaned = [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in suggestions + if isinstance(item, dict) + ] + return [item for item in cleaned if item["value"] and item["label"] in allowed] + + +def _span_to_entity_schema(span: EntitySpan, *, source: str) -> EntitySchema: + span_source = source if span.source == "augmenter" else span.source + return EntitySchema( + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=span_source, + ) + + +def _completed_case( + request: DirectDetectionRequest, + completion: DirectCompletion, + suggestions: list[dict[str, Any]], + allowed_suggestions: list[dict[str, str]], + artifact: DetectionArtifactRow, +) -> DirectDetectionCase: + return DirectDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + status=CaseStatus.completed, + elapsed_sec=completion.elapsed_sec, + usage=completion.usage, + raw_suggestion_count=len(suggestions), + allowed_suggestion_count=len(allowed_suggestions), + final_entity_count=artifact.final_entity_count, + final_entity_signature_count=artifact.final_entity_signature_count, + final_entity_signature_hashes=artifact.final_entity_signature_hashes, + final_label_counts=artifact.final_label_counts, + artifact=artifact, + ) + + +def compare_signature_sets( + *, + baseline_hashes: set[str], + baseline_labels: dict[str, str], + direct_hashes: set[str], + direct_labels: dict[str, str], +) -> SignatureComparison: + baseline_only = baseline_hashes - direct_hashes + direct_only = direct_hashes - baseline_hashes + return SignatureComparison( + baseline_final_entity_signature_count=len(baseline_hashes), + shared_final_entity_signature_count=len(baseline_hashes & direct_hashes), + baseline_only_final_entity_signature_count=len(baseline_only), + direct_only_final_entity_signature_count=len(direct_only), + baseline_only_label_counts=_label_counts(baseline_only, baseline_labels), + direct_only_label_counts=_label_counts(direct_only, direct_labels), + ) + + +def _label_counts(hashes: set[str], labels: dict[str, str]) -> dict[str, int]: + return dict(sorted(Counter(labels.get(item, "unknown") for item in hashes).items())) + + +def apply_baseline_comparisons( + cases: list[DirectDetectionCase], + baseline_artifacts: Path, +) -> list[DirectDetectionCase]: + baseline = _read_baseline_artifacts(baseline_artifacts) + compared: list[DirectDetectionCase] = [] + for case in cases: + if case.status != CaseStatus.completed or case.artifact is None: + compared.append(case) + continue + baseline_row = baseline.get(case.row_index) + if baseline_row is None: + compared.append(case) + continue + compared.append(_case_with_comparison(case, baseline_row)) + return compared + + +def _case_with_comparison(case: DirectDetectionCase, baseline_row: dict[str, Any]) -> DirectDetectionCase: + baseline_hashes = _baseline_signature_hashes(baseline_row) + if baseline_hashes is None: + return case + comparison = compare_signature_sets( + baseline_hashes=baseline_hashes, + baseline_labels=_signature_labels(baseline_row), + direct_hashes=set(case.artifact.final_entity_signature_hashes if case.artifact else []), + direct_labels=case.artifact.final_entity_signature_labels if case.artifact else {}, + ) + return case.model_copy(update={"comparison": comparison}) + + +def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: + hashes = row.get("final_entity_signature_hashes") + if not isinstance(hashes, list): + return None + return {str(item) for item in hashes} + + +def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: + baseline: dict[int, dict[str, Any]] = {} + with path.open(encoding="utf-8") as source: + for line in source: + if not line.strip(): + continue + row = json.loads(line) + row_index = int(row.get("row_index", 0)) + if row_index in baseline: + raise ValueError( + f"baseline artifacts has multiple rows for row_index={row_index}; " + "pass a per-case sidecar or pre-filter the artifact file" + ) + baseline[row_index] = row + return baseline + + +def _signature_labels(row: dict[str, Any]) -> dict[str, str]: + return { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") and value is not None + } + + +def run_probe( + input_path: Path, + *, + text_column: str, + labels: list[str], + output: Path | None = None, + overwrite: bool = False, + endpoint: str | None = None, + model: str | None = None, + limit: int = 1, + offset: int = 0, + prompt_mode: PromptMode = PromptMode.compact, + baseline_artifacts: Path | None = None, +) -> DirectDetectionRun: + endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) + model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) + requests = _load_requests( + input_path, text_column=text_column, labels=labels, limit=limit, offset=offset, prompt_mode=prompt_mode + ) + client = HttpxDirectDetectionClient() + cases = [run_direct_detection_case(request, client=client, endpoint=endpoint, model=model) for request in requests] + if baseline_artifacts is not None: + cases = apply_baseline_comparisons(cases, baseline_artifacts) + result = DirectDetectionRun( + input_path=str(input_path), + text_column=text_column, + endpoint=endpoint, + model=model, + prompt_mode=prompt_mode, + labels=labels, + rows=cases, + ) + if output is not None: + write_outputs(result, output, overwrite=overwrite) + return result + + +def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: + value = explicit or os.environ.get(env_var) + if not value: + raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") + return value + + +def _load_requests( + input_path: Path, + *, + text_column: str, + labels: list[str], + limit: int, + offset: int, + prompt_mode: PromptMode, +) -> list[DirectDetectionRequest]: + dataframe = pd.read_csv(input_path) + if text_column not in dataframe.columns: + raise ValueError(f"text column {text_column!r} not found in {input_path}") + selected = dataframe.iloc[offset : offset + limit] + return [ + DirectDetectionRequest( + case_id=f"{input_path.stem}-row-{int(index)}", + text=str(row[text_column]), + labels=labels, + row_index=int(index), + prompt_mode=prompt_mode, + ) + for index, row in selected.iterrows() + ] + + +def write_outputs(result: DirectDetectionRun, output_dir: Path, *, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise ValueError(f"output directory already exists: {output_dir}") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True) + _write_jsonl(output_dir / "direct-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) + _write_jsonl(output_dir / "direct-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) + (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def _case_payload(case: DirectDetectionCase) -> dict[str, Any]: + payload = case.model_dump(exclude={"artifact"}) + payload["record_type"] = "direct_detection_case" + return payload + + +def _artifact_payload(case: DirectDetectionCase) -> dict[str, Any]: + payload = case.artifact.model_dump() if case.artifact is not None else {} + payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "direct_detection_artifact"}) + return payload + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as target: + for row in rows: + target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") + + +def render_result(result: DirectDetectionRun, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = len(result.rows) - result.error_count + return f"Ran {completed}/{len(result.rows)} direct detection case(s); errors={result.error_count}" + + +def parse_labels(raw: str) -> list[str]: + return [label.strip() for label in raw.split(",") if label.strip()] + + +@app.default +def main( + input_path: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")], + labels: Annotated[str, cyclopts.Parameter("--labels")], + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, + model: Annotated[str | None, cyclopts.Parameter("--model")] = None, + limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, + offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, + prompt_mode: Annotated[PromptMode, cyclopts.Parameter("--prompt-mode")] = PromptMode.compact, + baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = run_probe( + input_path, + text_column=text_column, + labels=parse_labels(labels), + output=output, + overwrite=overwrite, + endpoint=endpoint, + model=model, + limit=limit, + offset=offset, + prompt_mode=prompt_mode, + baseline_artifacts=baseline_artifacts, + ) + except (ValueError, ValidationError, httpx.HTTPError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if result.error_count: + raise SystemExit(1) + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/examples/repo-data-smoke.yaml b/tools/measurement/examples/repo-data-smoke.yaml new file mode 100644 index 00000000..5009f054 --- /dev/null +++ b/tools/measurement/examples/repo-data-smoke.yaml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite_id: repo-data-smoke +workloads: + - id: biographies + source: ../../../docs/data/NVIDIA_synthetic_biographies.csv + text_column: biography + row_limit: 5 + - id: legal + source: ../../../docs/data/TAB_legal_sample25.csv + text_column: text + row_limit: 5 +configs: + - id: biographies-redact-default + replace: redact + - id: legal-hash-agent-labels + detect: + entity_labels: [person, email, api_key, password] + replace: + strategy: hash + digest_length: 12 +matrix: + - workload: biographies + config: biographies-redact-default + - workload: legal + config: legal-hash-agent-labels diff --git a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh new file mode 100644 index 00000000..9000f03f --- /dev/null +++ b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +output_dir="${1:-/tmp/anonymizer-repo-data-smoke-dd-traces}" +trace_mode="${DD_TRACE_MODE:-last-message}" +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +repo_root="$(cd "${script_dir}/../../.." && pwd)" +suite_file="${BENCHMARK_SUITE:-${script_dir}/repo-data-smoke.yaml}" + +cd "${repo_root}" + +uv run python tools/measurement/run_benchmarks.py \ + "${suite_file}" \ + --output "${output_dir}" \ + --overwrite \ + --dd-trace "${trace_mode}" \ + --trace-dir "${output_dir}/traces" diff --git a/tools/measurement/export_measurements.py b/tools/measurement/export_measurements.py new file mode 100755 index 00000000..796fcc40 --- /dev/null +++ b/tools/measurement/export_measurements.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Export Anonymizer measurement JSONL into per-record-type tables. + +Usage: + uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables + uv run python tools/measurement/export_measurements.py measurements.jsonl -o tables --format csv + uv run python tools/measurement/export_measurements.py measurements.jsonl -o tables --json +""" + +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.export") + +MANIFEST_FILENAME = "manifest.json" + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class TableSummary(BaseModel): + record_type: str + rows: int + columns: int + path: str + + +class ExportResult(BaseModel): + input_path: str + output_dir: str + format: ExportFormat + total_rows: int + tables: list[TableSummary] = Field(default_factory=list) + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def read_measurements(path: Path) -> pd.DataFrame: + if not path.exists(): + raise ValueError(f"input path does not exist: {path}") + if path.is_dir(): + raise ValueError(f"input path is a directory: {path}") + if path.suffix == ".json": + dataframe = pd.read_json(path) + else: + dataframe = pd.read_json(path, lines=True) + if "record_type" not in dataframe.columns: + raise ValueError("measurement input must contain a record_type field") + return dataframe + + +def normalize_table(rows: pd.DataFrame) -> pd.DataFrame: + relevant_rows = rows.dropna(axis="columns", how="all") + normalized = pd.json_normalize(relevant_rows.to_dict("records"), sep=".") + for column in normalized.columns: + if normalized[column].map(_is_nested_value).any(): + normalized[column] = normalized[column].map(_json_cell) + return normalized + + +def _is_nested_value(value: object) -> bool: + return isinstance(value, dict | list) + + +def _json_cell(value: object) -> object: + if not _is_nested_value(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def write_table(table: pd.DataFrame, path: Path, export_format: ExportFormat) -> None: + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + + +def export_tables( + dataframe: pd.DataFrame, + *, + input_path: Path, + output_dir: Path, + export_format: ExportFormat, + overwrite: bool, +) -> ExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _export_one_table(record_type, rows, output_dir=output_dir, export_format=export_format, overwrite=overwrite) + for record_type, rows in dataframe.groupby("record_type", sort=False) + ] + result = ExportResult( + input_path=str(input_path), + output_dir=str(output_dir), + format=export_format, + total_rows=len(dataframe), + tables=tables, + manifest_path=str(output_dir / MANIFEST_FILENAME), + ) + write_manifest(result, output_dir / MANIFEST_FILENAME, overwrite=overwrite) + return result + + +def _export_one_table( + record_type: str, + rows: pd.DataFrame, + *, + output_dir: Path, + export_format: ExportFormat, + overwrite: bool, +) -> TableSummary: + table = normalize_table(rows) + path = output_dir / f"{record_type}.{export_format.value}" + ensure_can_write(path, overwrite=overwrite) + write_table(table, path, export_format) + return TableSummary(record_type=record_type, rows=len(table), columns=len(table.columns), path=str(path)) + + +def write_manifest(result: ExportResult, path: Path, *, overwrite: bool) -> None: + ensure_can_write(path, overwrite=overwrite) + path.write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def ensure_can_write(path: Path, *, overwrite: bool) -> None: + if path.exists() and not overwrite: + raise ValueError(f"output already exists: {path}; pass --overwrite to replace it") + + +def render_result(result: ExportResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Wrote {len(result.tables)} table(s) from {result.total_rows} measurement record(s)"] + lines.append(f"Output: {result.output_dir}") + for table in result.tables: + lines.append(f"- {table.record_type}: {table.rows} rows, {table.columns} columns -> {table.path}") + lines.append(f"Manifest: {result.manifest_path}") + return "\n".join(lines) + + +@app.default +def main( + input_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + output_dir = output or input_path.with_suffix("").with_name(f"{input_path.stem}-tables") + try: + dataframe = read_measurements(input_path) + result = export_tables( + dataframe, + input_path=input_path, + output_dir=output_dir, + export_format=format, + overwrite=overwrite, + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py new file mode 100644 index 00000000..7732c147 --- /dev/null +++ b/tools/measurement/extract_signature_deltas.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Extract masked context for entity signature differences between benchmark artifacts. + +Usage: + uv run python tools/measurement/extract_signature_deltas.py \ + baseline/detection-artifacts.jsonl candidate/detection-artifacts.jsonl \ + --baseline-artifact-root baseline/artifacts --candidate-artifact-root candidate/artifacts \ + --baseline-config legal-default --candidate-config legal-rules-guardrail \ + --workload legal-r2 --output deltas.csv --format csv +""" + +from __future__ import annotations + +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from analyze_detection_artifacts import _entity_signature_hash +from pydantic import BaseModel, Field, ValidationError + +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT +from anonymizer.engine.schemas import EntitiesSchema, EntitySchema + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.signature_deltas") + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class DeltaSide(StrEnum): + baseline_only = "baseline_only" + candidate_only = "candidate_only" + + +class ContextResolution(StrEnum): + parquet = "parquet" + artifact_details = "artifact_details" + metadata_only = "metadata_only" + + +_log_format = LogFormat.plain +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} + + +class SignatureDeltaRow(BaseModel): + workload_id: str + row_index: int + side: DeltaSide + config_id: str + signature_hash: str + label: str | None = None + source: str | None = None + start_position: int | None = None + end_position: int | None = None + value_length: int | None = None + masked_context: str | None = None + resolution: ContextResolution = ContextResolution.metadata_only + batch_file: str | None = None + + +class SignatureDeltaResult(BaseModel): + baseline_artifacts: str + candidate_artifacts: str + workload_id: str | None = None + baseline_config: str | None = None + candidate_config: str | None = None + delta_count: int + rows: list[SignatureDeltaRow] = Field(default_factory=list) + + +class _ArtifactSide(BaseModel): + artifacts_path: str + artifact_root: str + config_id: str | None = None + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def extract_signature_deltas( + baseline_artifacts: Path, + candidate_artifacts: Path, + *, + baseline_artifact_root: Path, + candidate_artifact_root: Path, + baseline_config: str | None = None, + candidate_config: str | None = None, + workload: str | None = None, + context_window: int = 40, +) -> SignatureDeltaResult: + baseline = _select_artifact_rows( + _read_artifact_rows(baseline_artifacts), config_id=baseline_config, workload=workload + ) + candidate = _select_artifact_rows( + _read_artifact_rows(candidate_artifacts), + config_id=candidate_config, + workload=workload, + ) + rows = _compare_artifact_rows( + baseline, + candidate, + baseline_side=_artifact_side(baseline_artifacts, baseline_artifact_root, baseline_config), + candidate_side=_artifact_side(candidate_artifacts, candidate_artifact_root, candidate_config), + context_window=context_window, + ) + return SignatureDeltaResult( + baseline_artifacts=str(baseline_artifacts), + candidate_artifacts=str(candidate_artifacts), + workload_id=workload, + baseline_config=baseline_config, + candidate_config=candidate_config, + delta_count=len(rows), + rows=rows, + ) + + +def _artifact_side(artifacts_path: Path, artifact_root: Path, config_id: str | None) -> _ArtifactSide: + return _ArtifactSide( + artifacts_path=str(artifacts_path), + artifact_root=str(artifact_root), + config_id=config_id, + ) + + +def _read_artifact_rows(path: Path) -> list[dict[str, object]]: + if not path.exists() or path.is_dir(): + raise ValueError(f"detection artifact path is not a file: {path}") + with path.open(encoding="utf-8") as source: + return [json.loads(line) for line in source if line.strip()] + + +def _select_artifact_rows( + rows: list[dict[str, object]], + *, + config_id: str | None, + workload: str | None, +) -> dict[tuple[str, int], dict[str, object]]: + selected = [_row for _row in rows if _matches(_row, config_id=config_id, workload=workload)] + if not selected: + raise ValueError("artifact selector matched no rows") + return {_artifact_key(row): row for row in selected} + + +def _matches(row: dict[str, object], *, config_id: str | None, workload: str | None) -> bool: + if config_id is not None and str(row.get("config_id")) != config_id: + return False + if workload is not None and str(row.get("workload_id")) != workload: + return False + return True + + +def _artifact_key(row: dict[str, object]) -> tuple[str, int]: + return str(row.get("workload_id")), int(row.get("row_index", 0)) + + +def _compare_artifact_rows( + baseline: dict[tuple[str, int], dict[str, object]], + candidate: dict[tuple[str, int], dict[str, object]], + *, + baseline_side: _ArtifactSide, + candidate_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + rows: list[SignatureDeltaRow] = [] + for key in sorted(set(baseline) & set(candidate)): + rows.extend( + _row_signature_deltas(key, baseline[key], candidate[key], baseline_side, candidate_side, context_window) + ) + return rows + + +def _row_signature_deltas( + key: tuple[str, int], + baseline_row: dict[str, object], + candidate_row: dict[str, object], + baseline_side: _ArtifactSide, + candidate_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + baseline_signatures = _signature_set(baseline_row) + candidate_signatures = _signature_set(candidate_row) + return [ + *_delta_rows( + key, + baseline_row, + baseline_signatures - candidate_signatures, + DeltaSide.baseline_only, + baseline_side, + context_window, + ), + *_delta_rows( + key, + candidate_row, + candidate_signatures - baseline_signatures, + DeltaSide.candidate_only, + candidate_side, + context_window, + ), + ] + + +def _signature_set(row: dict[str, object]) -> set[str]: + raw = row.get("final_entity_signature_hashes", []) + if isinstance(raw, str): + raw = json.loads(raw) + return {str(item) for item in raw} if isinstance(raw, list) else set() + + +def _delta_rows( + key: tuple[str, int], + artifact_row: dict[str, object], + signatures: set[str], + side: DeltaSide, + artifact_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + labels = _signature_labels(artifact_row) + return [ + _contextual_delta_row(key, artifact_row, signature, labels.get(signature), side, artifact_side, context_window) + for signature in sorted(signatures) + ] + + +def _signature_labels(row: dict[str, object]) -> dict[str, str]: + labels = row.get("final_entity_signature_labels", {}) + if isinstance(labels, str): + labels = json.loads(labels) + flattened = { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") + } + return {**(labels if isinstance(labels, dict) else {}), **flattened} + + +def _contextual_delta_row( + key: tuple[str, int], + artifact_row: dict[str, object], + signature: str, + label: str | None, + side: DeltaSide, + artifact_side: _ArtifactSide, + context_window: int, +) -> SignatureDeltaRow: + context = _resolve_signature_context( + artifact_row, signature, label, Path(artifact_side.artifact_root), context_window + ) + return SignatureDeltaRow( + workload_id=key[0], + row_index=key[1], + side=side, + config_id=str(artifact_row.get("config_id") or artifact_side.config_id or ""), + signature_hash=signature, + label=label, + batch_file=_optional_string(artifact_row.get("batch_file")), + **context, + ) + + +def _resolve_signature_context( + artifact_row: dict[str, object], + signature: str, + label: str | None, + artifact_root: Path, + context_window: int, +) -> dict[str, object]: + parquet_context = _parquet_entity_context(artifact_row, signature, artifact_root, context_window) + if parquet_context is not None: + return parquet_context + detail_context = _artifact_detail_context(artifact_row, signature, label, artifact_root, context_window) + return detail_context or {"resolution": ContextResolution.metadata_only} + + +def _parquet_entity_context( + artifact_row: dict[str, object], + signature: str, + artifact_root: Path, + context_window: int, +) -> dict[str, object] | None: + record = _artifact_record(artifact_row, artifact_root) + if record is None: + return None + text, row_index, row = record + for entity in EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES)).entities: + if _entity_signature_hash(entity, row_index=row_index) == signature: + return _entity_context(entity, text, signature, context_window, ContextResolution.parquet) + return None + + +def _artifact_detail_context( + artifact_row: dict[str, object], + signature: str, + label: str | None, + artifact_root: Path, + context_window: int, +) -> dict[str, object] | None: + details = _signature_details(artifact_row).get(signature) + if details is None: + return None + start_position = _optional_int(details.get("start_position")) + end_position = _optional_int(details.get("end_position")) + resolved_label = _optional_string(details.get("label")) or label + if start_position is None or end_position is None or resolved_label is None: + return _metadata_context_from_details(details) + record = _artifact_record(artifact_row, artifact_root) + masked_context = None + if record is not None: + text, _row_index, _row = record + masked_context = _masked_context_from_details( + text, + label=resolved_label, + signature_hash=signature, + start_position=start_position, + end_position=end_position, + window=context_window, + ) + return { + "source": _optional_string(details.get("source")), + "start_position": start_position, + "end_position": end_position, + "value_length": _optional_int(details.get("value_length")), + "masked_context": masked_context, + "resolution": ContextResolution.artifact_details + if masked_context is not None + else ContextResolution.metadata_only, + } + + +def _signature_details(row: dict[str, object]) -> dict[str, dict[str, object]]: + details = _coerce_detail_map(row.get("final_entity_signature_details", {})) + prefix = "final_entity_signature_details." + for key, value in row.items(): + if not key.startswith(prefix): + continue + remainder = key.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue + details.setdefault(signature_hash, {})[field] = value + return details + + +def _coerce_detail_map(raw: object) -> dict[str, dict[str, object]]: + if isinstance(raw, str): + try: + raw = json.loads(raw) + except json.JSONDecodeError: + return {} + if not isinstance(raw, dict): + return {} + return { + str(key): {str(field): item for field, item in value.items() if str(field) in _SIGNATURE_DETAIL_FIELDS} + for key, value in raw.items() + if isinstance(value, dict) + } + + +def _metadata_context_from_details(details: dict[str, object]) -> dict[str, object]: + return { + "source": _optional_string(details.get("source")), + "start_position": _optional_int(details.get("start_position")), + "end_position": _optional_int(details.get("end_position")), + "value_length": _optional_int(details.get("value_length")), + "resolution": ContextResolution.metadata_only, + } + + +def _artifact_record(artifact_row: dict[str, object], artifact_root: Path) -> tuple[str, int, pd.Series] | None: + batch_file = _optional_string(artifact_row.get("batch_file")) + if batch_file is None: + return None + parquet_file = artifact_root / batch_file + if not parquet_file.exists(): + return None + row_index = int(artifact_row.get("row_index", 0)) + row = pd.read_parquet(parquet_file).iloc[row_index] + return str(row.get(COL_TEXT, "")), row_index, row + + +def _entity_context( + entity: EntitySchema, + text: str, + signature_hash: str, + context_window: int, + resolution: ContextResolution, +) -> dict[str, object]: + return { + "source": entity.source, + "start_position": entity.start_position, + "end_position": entity.end_position, + "value_length": len(entity.value), + "masked_context": _masked_context(text, entity, signature_hash, context_window), + "resolution": resolution, + } + + +def _masked_context(text: str, entity: EntitySchema, signature_hash: str, window: int) -> str: + before = text[max(0, entity.start_position - window) : entity.start_position] + after = text[entity.end_position : entity.end_position + window] + placeholder = f"[{entity.label.upper()}:{signature_hash}]" + return (before + placeholder + after).replace("\n", " ") + + +def _masked_context_from_details( + text: str, + *, + label: str, + signature_hash: str, + start_position: int, + end_position: int, + window: int, +) -> str: + before = text[max(0, start_position - window) : start_position] + after = text[end_position : end_position + window] + placeholder = f"[{label.upper()}:{signature_hash}]" + return (before + placeholder + after).replace("\n", " ") + + +def _optional_int(value: object) -> int | None: + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _optional_string(value: object) -> str | None: + if value is None: + return None + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + return str(value) + + +def write_rows(rows: list[SignatureDeltaRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def render_result(result: SignatureDeltaResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + counts = pd.Series([row.side.value for row in result.rows]).value_counts().to_dict() if result.rows else {} + return ( + f"Extracted {result.delta_count} signature delta(s)" + f"; baseline_only={counts.get('baseline_only', 0)}" + f"; candidate_only={counts.get('candidate_only', 0)}" + ) + + +@app.default +def main( + baseline_artifacts: Path, + candidate_artifacts: Path, + *, + baseline_artifact_root: Annotated[Path, cyclopts.Parameter("--baseline-artifact-root")], + candidate_artifact_root: Annotated[Path, cyclopts.Parameter("--candidate-artifact-root")], + baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, + candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, + workload: Annotated[str | None, cyclopts.Parameter("--workload")] = None, + context_window: Annotated[int, cyclopts.Parameter("--context-window")] = 40, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.jsonl, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = extract_signature_deltas( + baseline_artifacts, + candidate_artifacts, + baseline_artifact_root=baseline_artifact_root, + candidate_artifact_root=candidate_artifact_root, + baseline_config=baseline_config, + candidate_config=candidate_config, + workload=workload, + context_window=context_window, + ) + except (ValueError, ValidationError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + if output is not None: + write_rows(result.rows, output, format) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/replacement_strategies.py b/tools/measurement/replacement_strategies.py new file mode 100644 index 00000000..0089ac72 --- /dev/null +++ b/tools/measurement/replacement_strategies.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Experimental replacement strategies for benchmark-only performance probes.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from enum import StrEnum + +import pandas as pd +from data_designer.config.models import ModelConfig + +from anonymizer.config.models import ReplaceModelSelection +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE +from anonymizer.engine.replace import llm_replace_workflow as lrw +from anonymizer.engine.replace.structured_substitute import apply_structured_substitution_maps + + +class ExperimentalReplacementStrategy(StrEnum): + default = "default" + local_structured_substitute = "local_structured_substitute" + + +@contextmanager +def experimental_replacement_strategy_context(strategy: ExperimentalReplacementStrategy) -> Iterator[None]: + """Temporarily apply a benchmark-only replacement strategy.""" + if strategy == ExperimentalReplacementStrategy.default: + yield + return + + original_method = lrw.LlmReplaceWorkflow.generate_map_only + if strategy == ExperimentalReplacementStrategy.local_structured_substitute: + lrw.LlmReplaceWorkflow.generate_map_only = _local_structured_generate_map_only # type: ignore[method-assign] + else: + raise ValueError(f"unsupported experimental replacement strategy: {strategy}") + try: + yield + finally: + lrw.LlmReplaceWorkflow.generate_map_only = original_method # type: ignore[method-assign] + + +def _local_structured_generate_map_only( + self: lrw.LlmReplaceWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: ReplaceModelSelection, + instructions: str | None = None, + entities_column: str = COL_ENTITIES_BY_VALUE, + preview_num_records: int | None = None, +) -> lrw.LlmReplaceResult: + _ = self, model_configs, selected_models, instructions, preview_num_records + return lrw.LlmReplaceResult( + dataframe=apply_structured_substitution_maps(dataframe, entities_column=entities_column), + failed_records=[], + ) diff --git a/tools/measurement/replay_replacement_strategies.py b/tools/measurement/replay_replacement_strategies.py new file mode 100644 index 00000000..f5512aa5 --- /dev/null +++ b/tools/measurement/replay_replacement_strategies.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Replay substitute strategies on one fixed detection trace. + +Usage: + uv run python tools/measurement/replay_replacement_strategies.py data.csv \ + --text-column text --labels api_key,http_cookie,password,pin,unique_id,user_name \ + --model-configs /stable-cache/anonymizer/local-vllm-json-models.yaml \ + --model-providers /stable-cache/anonymizer/local-vllm-providers.yaml \ + --dd-parser-compat raw_json --json +""" + +from __future__ import annotations + +import json +import logging +import sys +import time +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context +from pydantic import BaseModel, Field, ValidationError +from replacement_strategies import ( + ExperimentalReplacementStrategy, + experimental_replacement_strategy_context, +) + +from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect +from anonymizer.config.replace_strategies import Redact, Substitute +from anonymizer.engine.constants import ( + COL_FINAL_ENTITIES, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, +) +from anonymizer.engine.schemas import EntitiesSchema, EntityReplacementMapSchema +from anonymizer.interface.anonymizer import Anonymizer, _unrename_output_columns +from anonymizer.measurement import _output_contains_original_value + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.replay_replacement_strategies") + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class ReplayStatus(StrEnum): + completed = "completed" + error = "error" + + +class ReplacementReplayStrategy(StrEnum): + dd_substitute = "dd_substitute" + local_structured_substitute = "local_structured_substitute" + + +class ReplacementRowMetrics(BaseModel): + final_entity_count: int = 0 + replacement_count: int = 0 + missing_count: int = 0 + collision_count: int = 0 + leak_count: int = 0 + duplicate_synthetic_count: int = 0 + missing_labels: dict[str, int] = Field(default_factory=dict) + collision_labels: dict[str, int] = Field(default_factory=dict) + leak_labels: dict[str, int] = Field(default_factory=dict) + + +class ReplacementReplaySummary(BaseModel): + strategy: ReplacementReplayStrategy + status: ReplayStatus + repetition_count: int = 1 + elapsed_sec: float | None = None + row_count: int = 0 + final_entity_count: int = 0 + replacement_count: int = 0 + missing_count: int = 0 + collision_count: int = 0 + leak_count: int = 0 + duplicate_synthetic_count: int = 0 + missing_labels: dict[str, int] = Field(default_factory=dict) + collision_labels: dict[str, int] = Field(default_factory=dict) + leak_labels: dict[str, int] = Field(default_factory=dict) + replacement_map_sources: dict[str, int] = Field(default_factory=dict) + error: str | None = None + + +class ReplacementReplayResult(BaseModel): + input_path: str + text_column: str + labels: list[str] + nrows: int | None = None + replacement_repetitions: int = 1 + dd_parser_compat: DDParserCompatMode + detect_elapsed_sec: float + detected_final_entity_count: int + strategies: list[ReplacementReplaySummary] + + +class ReplacementReplayComparisonRow(BaseModel): + workload_id: str + baseline_config_id: str = "dd_substitute_replay" + candidate_config_id: str = "local_structured_substitute_replay" + baseline_strategy: str = "default" + candidate_strategy: str = "default" + baseline_replacement_strategy: str = "default" + candidate_replacement_strategy: str = "local_structured_substitute" + baseline_case_count: int + candidate_case_count: int + baseline_pipeline_elapsed_sec: float | None = None + candidate_pipeline_elapsed_sec: float | None = None + pipeline_elapsed_sec_delta: float | None = None + pipeline_elapsed_sec_delta_pct: float | None = None + baseline_final_entity_count: int + candidate_final_entity_count: int + final_entity_count_delta: int + baseline_replacement_count: int + candidate_replacement_count: int + replacement_count_delta: int + baseline_replacement_missing_final_entity_count: int + candidate_replacement_missing_final_entity_count: int + replacement_missing_final_entity_count_delta: int + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_count: int + candidate_replacement_synthetic_original_collision_count: int + replacement_synthetic_original_collision_count_delta: int + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_duplicate_synthetic_replacement_count: int + candidate_duplicate_synthetic_replacement_count: int + duplicate_synthetic_replacement_count_delta: int + baseline_original_value_leak_count: int + candidate_original_value_leak_count: int + original_value_leak_count_delta: int + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + value_protection_verdict: str + signature_parity_verdict: str + safety_verdict: str + performance_verdict: str + candidate_verdict: str + flags: list[str] = Field(default_factory=list) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + else: + sys.stderr.write(f"ERROR: bad_input error={error}\n") + + +def parse_labels(raw: str) -> list[str]: + labels = [label.strip() for label in raw.split(",") if label.strip()] + if not labels: + raise ValueError("--labels must contain at least one non-empty label") + return list(dict.fromkeys(labels)) + + +def run_replacement_replay( + *, + source: Path, + text_column: str, + labels: list[str], + nrows: int | None, + model_configs: Path | None, + model_providers: Path | None, + artifact_path: Path, + dd_parser_compat: DDParserCompatMode, + replacement_repetitions: int, +) -> ReplacementReplayResult: + anonymizer = Anonymizer( + model_configs=model_configs, + model_providers=model_providers, + artifact_path=artifact_path, + ) + detect_elapsed, replay_df = build_replay_dataframe( + anonymizer, + source=source, + text_column=text_column, + labels=labels, + nrows=nrows, + dd_parser_compat=dd_parser_compat, + ) + strategies = [ + run_strategy_repetitions( + anonymizer, + replay_df, + ReplacementReplayStrategy.dd_substitute, + dd_parser_compat, + repetitions=replacement_repetitions, + ), + run_strategy_repetitions( + anonymizer, + replay_df, + ReplacementReplayStrategy.local_structured_substitute, + dd_parser_compat, + repetitions=replacement_repetitions, + ), + ] + return ReplacementReplayResult( + input_path=str(source), + text_column=text_column, + labels=labels, + nrows=nrows, + replacement_repetitions=replacement_repetitions, + dd_parser_compat=dd_parser_compat, + detect_elapsed_sec=detect_elapsed, + detected_final_entity_count=count_final_entities(replay_df), + strategies=strategies, + ) + + +def build_replay_dataframe( + anonymizer: Anonymizer, + *, + source: Path, + text_column: str, + labels: list[str], + nrows: int | None, + dd_parser_compat: DDParserCompatMode, +) -> tuple[float, pd.DataFrame]: + input_data = AnonymizerInput(source=str(source), text_column=text_column) + config = AnonymizerConfig(detect=Detect(entity_labels=labels), replace=Redact()) + with dd_parser_compat_context(dd_parser_compat): + start = time.perf_counter() + if nrows is None: + detected = anonymizer.run(config=config, data=input_data) + else: + detected = anonymizer.preview(config=config, data=input_data, num_records=nrows) + elapsed = time.perf_counter() - start + internal_df = _unrename_output_columns(detected.trace_dataframe, resolved_text_column=detected.resolved_text_column) + return elapsed, strip_replacement_columns(internal_df) + + +def strip_replacement_columns(dataframe: pd.DataFrame) -> pd.DataFrame: + return dataframe.drop(columns=[COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE, COL_REPLACED_TEXT], errors="ignore") + + +def run_strategy_repetitions( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, + *, + repetitions: int, +) -> ReplacementReplaySummary: + summaries = [run_strategy(anonymizer, dataframe, strategy, dd_parser_compat) for _ in range(repetitions)] + return aggregate_strategy_summaries(strategy=strategy, summaries=summaries) + + +def aggregate_strategy_summaries( + *, + strategy: ReplacementReplayStrategy, + summaries: list[ReplacementReplaySummary], +) -> ReplacementReplaySummary: + sources: Counter[str] = Counter() + missing_labels: Counter[str] = Counter() + collision_labels: Counter[str] = Counter() + leak_labels: Counter[str] = Counter() + errors: list[str] = [] + elapsed_values: list[float] = [] + status = ReplayStatus.completed + for summary in summaries: + sources.update(summary.replacement_map_sources) + missing_labels.update(summary.missing_labels) + collision_labels.update(summary.collision_labels) + leak_labels.update(summary.leak_labels) + if summary.elapsed_sec is not None: + elapsed_values.append(summary.elapsed_sec) + if summary.status == ReplayStatus.error: + status = ReplayStatus.error + if summary.error: + errors.append(summary.error) + return ReplacementReplaySummary( + strategy=strategy, + status=status, + repetition_count=len(summaries), + elapsed_sec=sum(elapsed_values) if elapsed_values else None, + row_count=sum(summary.row_count for summary in summaries), + final_entity_count=sum(summary.final_entity_count for summary in summaries), + replacement_count=sum(summary.replacement_count for summary in summaries), + missing_count=sum(summary.missing_count for summary in summaries), + collision_count=sum(summary.collision_count for summary in summaries), + leak_count=sum(summary.leak_count for summary in summaries), + duplicate_synthetic_count=sum(summary.duplicate_synthetic_count for summary in summaries), + missing_labels=dict(sorted(missing_labels.items())), + collision_labels=dict(sorted(collision_labels.items())), + leak_labels=dict(sorted(leak_labels.items())), + replacement_map_sources=dict(sorted(sources.items())), + error="; ".join(errors) if errors else None, + ) + + +def run_strategy( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, +) -> ReplacementReplaySummary: + try: + start = time.perf_counter() + result_df = execute_strategy(anonymizer, dataframe.copy(), strategy, dd_parser_compat) + elapsed = time.perf_counter() - start + except Exception as exc: + logger.exception("replacement replay strategy failed: %s", strategy) + return ReplacementReplaySummary(strategy=strategy, status=ReplayStatus.error, error=str(exc)) + return summarize_replacement_dataframe(result_df, strategy=strategy, elapsed_sec=elapsed) + + +def execute_strategy( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, +) -> pd.DataFrame: + if strategy == ReplacementReplayStrategy.local_structured_substitute: + with experimental_replacement_strategy_context(ExperimentalReplacementStrategy.local_structured_substitute): + return execute_substitute(anonymizer, dataframe) + with dd_parser_compat_context(dd_parser_compat): + return execute_substitute(anonymizer, dataframe) + + +def execute_substitute(anonymizer: Anonymizer, dataframe: pd.DataFrame) -> pd.DataFrame: + result = anonymizer._replace_runner.run( # benchmark probe against the configured runner + dataframe, + replace_method=Substitute(), + model_configs=anonymizer._model_configs, + selected_models=anonymizer._selected_models.replace, + ) + return result.dataframe + + +def summarize_replacement_dataframe( + dataframe: pd.DataFrame, + *, + strategy: ReplacementReplayStrategy, + elapsed_sec: float, +) -> ReplacementReplaySummary: + rows = [replacement_row_metrics(row) for _, row in dataframe.iterrows()] + return ReplacementReplaySummary( + strategy=strategy, + status=ReplayStatus.completed, + elapsed_sec=elapsed_sec, + row_count=len(rows), + final_entity_count=sum(row.final_entity_count for row in rows), + replacement_count=sum(row.replacement_count for row in rows), + missing_count=sum(row.missing_count for row in rows), + collision_count=sum(row.collision_count for row in rows), + leak_count=sum(row.leak_count for row in rows), + duplicate_synthetic_count=sum(row.duplicate_synthetic_count for row in rows), + missing_labels=sum_label_counts(row.missing_labels for row in rows), + collision_labels=sum_label_counts(row.collision_labels for row in rows), + leak_labels=sum_label_counts(row.leak_labels for row in rows), + replacement_map_sources=count_sources(dataframe), + ) + + +def replacement_row_metrics(row: Any) -> ReplacementRowMetrics: + entities = [entity.model_dump() for entity in EntitiesSchema.from_raw(row[COL_FINAL_ENTITIES]).entities] + replacements = parse_replacements(row[COL_REPLACEMENT_MAP]) + missing = missing_replacements(entities, replacements) + collisions = synthetic_original_collisions(entities, replacements) + leaks = leaked_entities(entities, str(row[COL_REPLACED_TEXT])) + synthetic_values = [item["synthetic"] for item in replacements if item.get("synthetic")] + return ReplacementRowMetrics( + final_entity_count=len(entities), + replacement_count=len(replacements), + missing_count=len(missing), + collision_count=len(collisions), + leak_count=len(leaks), + duplicate_synthetic_count=max(0, len(synthetic_values) - len(set(synthetic_values))), + missing_labels=count_labels(missing), + collision_labels=count_labels(collisions), + leak_labels=count_labels(leaks), + ) + + +def parse_replacements(raw: object) -> list[dict[str, str]]: + if hasattr(raw, "model_dump"): + raw = raw.model_dump(mode="python") + if isinstance(raw, str): + raw = json.loads(raw) + return [item.model_dump() for item in EntityReplacementMapSchema.model_validate(raw).replacements] + + +def missing_replacements(entities: list[dict[str, Any]], replacements: list[dict[str, str]]) -> list[dict[str, Any]]: + replacement_pairs = {(item["original"], item["label"]) for item in replacements} + return [entity for entity in entities if (entity.get("value"), entity.get("label")) not in replacement_pairs] + + +def synthetic_original_collisions( + entities: list[dict[str, Any]], + replacements: list[dict[str, str]], +) -> list[dict[str, str]]: + original_values = {entity["value"] for entity in entities if entity.get("value")} + return [item for item in replacements if item.get("synthetic") in original_values] + + +def leaked_entities(entities: list[dict[str, Any]], output_text: str) -> list[dict[str, Any]]: + return [ + entity + for entity in entities + if entity.get("value") and _output_contains_original_value(output_text, str(entity["value"])) + ] + + +def count_labels(items: list[dict[str, Any]]) -> dict[str, int]: + return dict(sorted(Counter(str(item.get("label") or "") for item in items if item.get("label")).items())) + + +def sum_label_counts(values: object) -> dict[str, int]: + counts: Counter[str] = Counter() + for mapping in values: + counts.update(mapping) + return dict(sorted(counts.items())) + + +def count_sources(dataframe: pd.DataFrame) -> dict[str, int]: + if COL_REPLACEMENT_MAP_SOURCE not in dataframe.columns: + return {} + return dict(sorted(Counter(str(source) for source in dataframe[COL_REPLACEMENT_MAP_SOURCE]).items())) + + +def count_final_entities(dataframe: pd.DataFrame) -> int: + return sum(len(EntitiesSchema.from_raw(raw).entities) for raw in dataframe[COL_FINAL_ENTITIES]) + + +def render_result(result: ReplacementReplayResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"input={result.input_path}", + f"labels={','.join(result.labels)}", + f"nrows={result.nrows if result.nrows is not None else 'all'}", + f"replacement_repetitions={result.replacement_repetitions}", + f"detected_final_entities={result.detected_final_entity_count}", + f"detection_elapsed_sec={result.detect_elapsed_sec:.3f}", + ] + lines.extend(render_strategy(summary) for summary in result.strategies) + return "\n".join(lines) + + +def render_strategy(summary: ReplacementReplaySummary) -> str: + if summary.status == ReplayStatus.error: + return f"{summary.strategy}: error={summary.error}" + return ( + f"{summary.strategy}: repetitions={summary.repetition_count} elapsed_sec={summary.elapsed_sec:.3f} " + f"replacements={summary.replacement_count} missing={summary.missing_count} " + f"leaks={summary.leak_count} collisions={summary.collision_count}" + ) + + +def replay_comparison_row( + result: ReplacementReplayResult, *, workload_id: str | None = None +) -> ReplacementReplayComparisonRow: + summaries = {summary.strategy: summary for summary in result.strategies} + baseline = summaries.get(ReplacementReplayStrategy.dd_substitute) + candidate = summaries.get(ReplacementReplayStrategy.local_structured_substitute) + if baseline is None or candidate is None: + raise ValueError( + "replacement replay result must include dd_substitute and local_structured_substitute summaries" + ) + flags = replay_comparison_flags(baseline, candidate) + value_protection = replay_value_protection_verdict(flags) + signature_parity = replay_signature_parity_verdict(baseline, candidate, flags) + safety = replay_safety_verdict(value_protection, signature_parity) + performance = replay_performance_verdict(baseline.elapsed_sec, candidate.elapsed_sec) + return ReplacementReplayComparisonRow( + workload_id=workload_id or Path(result.input_path).stem, + baseline_case_count=baseline.row_count, + candidate_case_count=candidate.row_count, + baseline_pipeline_elapsed_sec=baseline.elapsed_sec, + candidate_pipeline_elapsed_sec=candidate.elapsed_sec, + pipeline_elapsed_sec_delta=delta(baseline.elapsed_sec, candidate.elapsed_sec), + pipeline_elapsed_sec_delta_pct=delta_pct(baseline.elapsed_sec, candidate.elapsed_sec), + baseline_final_entity_count=baseline.final_entity_count, + candidate_final_entity_count=candidate.final_entity_count, + final_entity_count_delta=candidate.final_entity_count - baseline.final_entity_count, + baseline_replacement_count=baseline.replacement_count, + candidate_replacement_count=candidate.replacement_count, + replacement_count_delta=candidate.replacement_count - baseline.replacement_count, + baseline_replacement_missing_final_entity_count=baseline.missing_count, + candidate_replacement_missing_final_entity_count=candidate.missing_count, + replacement_missing_final_entity_count_delta=candidate.missing_count - baseline.missing_count, + baseline_replacement_missing_final_entity_label_counts=baseline.missing_labels, + candidate_replacement_missing_final_entity_label_counts=candidate.missing_labels, + baseline_replacement_synthetic_original_collision_count=baseline.collision_count, + candidate_replacement_synthetic_original_collision_count=candidate.collision_count, + replacement_synthetic_original_collision_count_delta=candidate.collision_count - baseline.collision_count, + baseline_replacement_synthetic_original_collision_label_counts=baseline.collision_labels, + candidate_replacement_synthetic_original_collision_label_counts=candidate.collision_labels, + baseline_duplicate_synthetic_replacement_count=baseline.duplicate_synthetic_count, + candidate_duplicate_synthetic_replacement_count=candidate.duplicate_synthetic_count, + duplicate_synthetic_replacement_count_delta=( + candidate.duplicate_synthetic_count - baseline.duplicate_synthetic_count + ), + baseline_original_value_leak_count=baseline.leak_count, + candidate_original_value_leak_count=candidate.leak_count, + original_value_leak_count_delta=candidate.leak_count - baseline.leak_count, + baseline_original_value_leak_label_counts=baseline.leak_labels, + candidate_original_value_leak_label_counts=candidate.leak_labels, + value_protection_verdict=value_protection, + signature_parity_verdict=signature_parity, + safety_verdict=safety, + performance_verdict=performance, + candidate_verdict=replay_candidate_verdict(safety, performance), + flags=flags, + ) + + +def replay_comparison_flags( + baseline: ReplacementReplaySummary, + candidate: ReplacementReplaySummary, +) -> list[str]: + flags: list[str] = [] + if baseline.status == ReplayStatus.error: + flags.append("baseline_case_failures") + if candidate.status == ReplayStatus.error: + flags.append("candidate_case_failures") + if baseline.missing_count: + flags.append("baseline_replacement_missing_final_entity") + if candidate.missing_count: + flags.append("candidate_replacement_missing_final_entity") + if baseline.missing_count and candidate.missing_count < baseline.missing_count: + flags.append("candidate_reduces_baseline_replacement_missing_final_entity") + if baseline.missing_count and candidate.missing_count == 0: + flags.append("candidate_covers_baseline_replacement_missing_final_entity") + if baseline.collision_count: + flags.append("baseline_replacement_synthetic_original_collision") + if candidate.collision_count: + flags.append("candidate_replacement_synthetic_original_collision") + if baseline.collision_count and candidate.collision_count < baseline.collision_count: + flags.append("candidate_reduces_baseline_replacement_synthetic_original_collision") + if baseline.collision_count and candidate.collision_count == 0: + flags.append("candidate_covers_baseline_replacement_synthetic_original_collision") + if baseline.duplicate_synthetic_count: + flags.append("baseline_duplicate_synthetic_replacement") + if candidate.duplicate_synthetic_count: + flags.append("candidate_duplicate_synthetic_replacement") + if baseline.leak_count: + flags.append("baseline_original_value_leak") + if candidate.leak_count: + flags.append("candidate_original_value_leak") + if baseline.leak_count and candidate.leak_count < baseline.leak_count: + flags.append("candidate_reduces_baseline_original_value_leak") + if baseline.leak_count and candidate.leak_count == 0: + flags.append("candidate_covers_baseline_original_value_leak") + if baseline.final_entity_count != candidate.final_entity_count: + flags.append( + "entity_count_loss" if candidate.final_entity_count < baseline.final_entity_count else "entity_count_delta" + ) + if baseline.replacement_count != candidate.replacement_count: + flags.append( + "replacement_count_loss" + if candidate.replacement_count < baseline.replacement_count + else "replacement_count_delta" + ) + return flags + + +def replay_value_protection_verdict(flags: list[str]) -> str: + flag_set = set(flags) + if flag_set & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_count_loss", + "replacement_count_loss", + }: + return "fail" + if "baseline_case_failures" in flag_set: + return "review" + baseline_defects = { + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + } + corrected_baseline_defects = { + "baseline_original_value_leak": "candidate_covers_baseline_original_value_leak", + "baseline_replacement_missing_final_entity": ("candidate_covers_baseline_replacement_missing_final_entity"), + "baseline_replacement_synthetic_original_collision": ( + "candidate_covers_baseline_replacement_synthetic_original_collision" + ), + } + for defect in baseline_defects & flag_set: + if corrected_baseline_defects[defect] not in flag_set: + return "review" + return "pass" + + +def replay_signature_parity_verdict( + baseline: ReplacementReplaySummary, + candidate: ReplacementReplaySummary, + flags: list[str], +) -> str: + if candidate.status == ReplayStatus.error: + return "fail" + if baseline.status == ReplayStatus.error: + return "review" + if candidate.final_entity_count < baseline.final_entity_count: + return "fail" + if candidate.final_entity_count != baseline.final_entity_count: + return "review" + if candidate.replacement_count < baseline.replacement_count: + return "fail" + if candidate.replacement_count != baseline.replacement_count: + return "review" + if flags: + return "review" + return "pass" + + +def replay_safety_verdict(value_protection: str, signature_parity: str) -> str: + if "fail" in {value_protection, signature_parity}: + return "fail" + if "review" in {value_protection, signature_parity}: + return "review" + return "pass" + + +def replay_performance_verdict(baseline_elapsed: float | None, candidate_elapsed: float | None) -> str: + elapsed_delta = delta(baseline_elapsed, candidate_elapsed) + if elapsed_delta is None: + return "unknown" + if elapsed_delta < 0: + return "improved" + if elapsed_delta > 0: + return "regressed" + return "unchanged" + + +def replay_candidate_verdict(safety: str, performance: str) -> str: + if safety == "fail": + return "reject" + if safety == "pass" and performance == "improved": + return "candidate_viable" + return "review" + + +def delta(baseline: float | None, candidate: float | None) -> float | None: + if baseline is None or candidate is None: + return None + return candidate - baseline + + +def delta_pct(baseline: float | None, candidate: float | None) -> float | None: + value = delta(baseline, candidate) + if value is None or baseline in {None, 0}: + return None + return value / baseline * 100 + + +def write_replay_comparison( + result: ReplacementReplayResult, + output: Path, + *, + workload_id: str | None = None, +) -> None: + row = replay_comparison_row(result, workload_id=workload_id).model_dump(mode="json") + row["flags"] = json.dumps(row["flags"], ensure_ascii=True) + pd.DataFrame([row]).to_csv(output, index=False) + + +@app.default +def main( + source: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")] = "text", + labels: Annotated[str, cyclopts.Parameter("--labels")], + nrows: Annotated[int | None, cyclopts.Parameter("--nrows")] = None, + replacement_repetitions: Annotated[int, cyclopts.Parameter("--replacement-repetitions")] = 1, + model_configs: Annotated[Path | None, cyclopts.Parameter("--model-configs")] = None, + model_providers: Annotated[Path | None, cyclopts.Parameter("--model-providers")] = None, + artifact_path: Annotated[Path, cyclopts.Parameter("--artifact-path")] = Path( + "/tmp/anonymizer-replacement-replay-artifacts" + ), + dd_parser_compat: Annotated[DDParserCompatMode, cyclopts.Parameter("--dd-parser-compat")] = ( + DDParserCompatMode.none + ), + output: Annotated[Path | None, cyclopts.Parameter("--output")] = None, + comparison_output: Annotated[Path | None, cyclopts.Parameter("--comparison-output")] = None, + workload_id: Annotated[str | None, cyclopts.Parameter("--workload-id")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + if nrows is not None and nrows <= 0: + raise ValueError("--nrows must be greater than zero") + if replacement_repetitions <= 0: + raise ValueError("--replacement-repetitions must be greater than zero") + result = run_replacement_replay( + source=source, + text_column=text_column, + labels=parse_labels(labels), + nrows=nrows, + model_configs=model_configs, + model_providers=model_providers, + artifact_path=artifact_path, + dd_parser_compat=dd_parser_compat, + replacement_repetitions=replacement_repetitions, + ) + except (ValidationError, ValueError, FileNotFoundError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + rendered = render_result(result, json_output=json_output) + if output is not None: + output.write_text(rendered + "\n", encoding="utf-8") + else: + sys.stdout.write(rendered + "\n") + if comparison_output is not None: + write_replay_comparison(result, comparison_output, workload_id=workload_id) + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py new file mode 100755 index 00000000..e6a871de --- /dev/null +++ b/tools/measurement/run_benchmarks.py @@ -0,0 +1,1527 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run Anonymizer benchmark suites and export measurement tables. + +Usage: + uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark-runs/suite + uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json +""" + +import json +import logging +import os +import shutil +import sys +import time +from dataclasses import dataclass +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +import pyarrow.parquet as pq +import yaml +from analyze_detection_artifacts import ( + analyze_artifacts, + build_detection_artifact_row_from_entities, + iter_detection_parquet_files, +) +from data_designer.config.models import ModelProvider +from data_designer.config.utils.io_helpers import load_config_file +from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context +from detection_strategies import ( + ExperimentalDetectionStrategy, + NativeDetectionRuntime, + experimental_detection_strategy_context, +) +from export_measurements import ExportFormat, export_tables, read_measurements +from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator +from replacement_strategies import ( + ExperimentalReplacementStrategy, + experimental_replacement_strategy_context, +) + +from anonymizer.config.anonymizer_config import ( + AnonymizerConfig, + AnonymizerInput, + Detect, + Rewrite, + infer_input_source_suffix, + is_remote_input_source, +) +from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute +from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_FINAL_ENTITIES +from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS +from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references +from anonymizer.engine.replace.structured_substitute import SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS +from anonymizer.engine.schemas import EntitiesSchema +from anonymizer.interface.anonymizer import Anonymizer +from anonymizer.measurement import MeasurementConfig, configured_measurement_session + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.benchmark") + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class CaseStatus(StrEnum): + planned = "planned" + completed = "completed" + error = "error" + + +class DDTraceMode(StrEnum): + none = "none" + last_message = "last_message" + all_messages = "all_messages" + + +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" +_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" + + +class NativeRuntimeSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + runtime_id: str | None = None + endpoint: str | None = None + endpoint_env: str | None = _NATIVE_ENDPOINT_ENV + model: str | None = None + model_env: str | None = _NATIVE_MODEL_ENV + provider: str = "native" + alias: str = "native-direct" + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=180.0, gt=0) + gliner_endpoint: str | None = None + gliner_endpoint_env: str | None = _GLINER_ENDPOINT_ENV + gliner_model: str | None = None + gliner_model_env: str | None = _GLINER_MODEL_ENV + gliner_provider: str = "gliner" + gliner_alias: str = "gliner-direct" + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) + max_workers: int = Field(default=4, ge=1) + + +class ReplaceKind(StrEnum): + redact = "redact" + hash = "hash" + annotate = "annotate" + substitute = "substitute" + + +class WorkloadSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + source: str + text_column: str = "text" + id_column: str | None = None + data_summary: str | None = None + row_limit: int | None = Field(default=None, ge=1) + row_offset: int = Field(default=0, ge=0) + + +class ReplaceSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + strategy: ReplaceKind + format_template: str | None = None + normalize_label: bool | None = None + algorithm: str | None = None + digest_length: int | None = None + instructions: str | None = None + + +class RewriteSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + protect: str | None = None + preserve: str | None = None + instructions: str | None = None + risk_tolerance: RiskTolerance = RiskTolerance.low + max_repair_iterations: int = 3 + strict_entity_protection: bool = False + + +class ConfigSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + detect: dict[str, Any] = Field(default_factory=dict) + replace: str | ReplaceSpec | None = None + rewrite: RewriteSpec | None = None + emit_telemetry: bool = False + experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default + experimental_replacement_strategy: ExperimentalReplacementStrategy = ExperimentalReplacementStrategy.default + native_runtime: NativeRuntimeSpec | None = None + + @model_validator(mode="after") + def validate_mode(self) -> "ConfigSpec": + if self.replace is None and self.rewrite is None: + raise ValueError("config must define replace or rewrite") + if self.replace is not None and self.rewrite is not None: + raise ValueError("config cannot define both replace and rewrite") + return self + + +class MatrixEntry(BaseModel): + model_config = ConfigDict(extra="forbid") + + workload: str + config: str + repetitions: int = Field(default=1, ge=1) + + +def _duplicates(values: list[str]) -> list[str]: + seen: set[str] = set() + duplicates: set[str] = set() + for value in values: + if value in seen: + duplicates.add(value) + seen.add(value) + return sorted(duplicates) + + +class BenchmarkSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + suite_id: str + model_configs: str | None = None + model_providers: str | None = None + artifact_path: str | None = None + dd_parser_compat: DDParserCompatMode = DDParserCompatMode.none + native_runtime: NativeRuntimeSpec | None = None + case_retries: int = Field(default=0, ge=0) + case_retry_backoff_sec: float = Field(default=0.0, ge=0.0) + workloads: list[WorkloadSpec] = Field(min_length=1) + configs: list[ConfigSpec] = Field(min_length=1) + matrix: list[MatrixEntry] | None = Field(default=None, min_length=1) + + @model_validator(mode="after") + def validate_ids(self) -> "BenchmarkSpec": + workload_ids = [workload.id for workload in self.workloads] + config_ids = [config.id for config in self.configs] + if duplicate_workloads := _duplicates(workload_ids): + raise ValueError(f"duplicate workload id(s): {', '.join(duplicate_workloads)}") + if duplicate_configs := _duplicates(config_ids): + raise ValueError(f"duplicate config id(s): {', '.join(duplicate_configs)}") + self._validate_matrix_references(set(workload_ids), set(config_ids)) + return self + + def _validate_matrix_references(self, workload_ids: set[str], config_ids: set[str]) -> None: + if self.matrix is None: + return + missing_workloads = sorted({entry.workload for entry in self.matrix} - workload_ids) + missing_configs = sorted({entry.config for entry in self.matrix} - config_ids) + if missing_workloads: + raise ValueError(f"matrix references unknown workload id(s): {', '.join(missing_workloads)}") + if missing_configs: + raise ValueError(f"matrix references unknown config id(s): {', '.join(missing_configs)}") + + +class BenchmarkCase(BaseModel): + suite_id: str + workload_id: str + config_id: str + repetition: int + case_id: str + status: CaseStatus = CaseStatus.planned + elapsed_sec: float | None = None + measurement_path: str | None = None + detection_artifact_path: str | None = None + trace_path: str | None = None + error: str | None = None + attempt_count: int = 0 + attempt_errors: list[str] = Field(default_factory=list) + + +class BenchmarkResult(BaseModel): + suite_id: str + output_dir: str + measurement_path: str + summary_path: str + table_dir: str | None + detection_artifact_analysis_path: str | None = None + cases: list[BenchmarkCase] + + +@dataclass(frozen=True) +class _CaseRunPaths: + raw_path: Path + artifact_output_path: Path + trace_path: Path | None + artifact_snapshot: dict[str, int] | None + + +@dataclass(frozen=True) +class _CaseExecution: + input_data: AnonymizerInput + trace_dataframe: pd.DataFrame | None = None + + +_TRACE_FINAL_ARTIFACT_STRATEGIES = { + ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_native_augment, + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + ExperimentalDetectionStrategy.native_single_pass, + ExperimentalDetectionStrategy.native_single_pass_recall, + ExperimentalDetectionStrategy.native_single_pass_values, + ExperimentalDetectionStrategy.native_single_pass_values_recall, +} +_NATIVE_RUNTIME_STRATEGIES = { + ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_native_augment, + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + ExperimentalDetectionStrategy.native_single_pass, + ExperimentalDetectionStrategy.native_single_pass_recall, + ExperimentalDetectionStrategy.native_single_pass_values, + ExperimentalDetectionStrategy.native_single_pass_values_recall, +} +_GLINER_NATIVE_RUNTIME_STRATEGIES = { + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, +} + +_FINAL_ARTIFACT_KEYS = { + "final_entity_count", + "weak_api_key_shape_count", + "final_entity_signature_count", + "final_entity_signature_hashes", + "final_entity_signature_labels", + "final_entity_signature_details", + "weak_api_key_shape_label_counts", + "final_label_counts", + "final_source_counts", +} + +_FINAL_ARTIFACT_PREFIXES = tuple(f"{key}." for key in _FINAL_ARTIFACT_KEYS if key != "final_entity_signature_hashes") + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def load_spec(path: Path) -> BenchmarkSpec: + if not path.exists() or path.is_dir(): + raise ValueError(f"spec path is not a file: {path}") + raw = yaml.safe_load(path.read_text(encoding="utf-8")) + if not isinstance(raw, dict): + raise ValueError("benchmark spec must be a YAML mapping") + return BenchmarkSpec.model_validate(raw) + + +def build_cases(spec: BenchmarkSpec) -> list[BenchmarkCase]: + matrix = spec.matrix or _cross_product_matrix(spec) + return [ + BenchmarkCase( + suite_id=spec.suite_id, + workload_id=entry.workload, + config_id=entry.config, + repetition=repetition, + case_id=f"{entry.workload}__{entry.config}__r{repetition:03d}", + ) + for entry in matrix + for repetition in range(entry.repetitions) + ] + + +def _cross_product_matrix(spec: BenchmarkSpec) -> list[MatrixEntry]: + return [ + MatrixEntry(workload=workload.id, config=config.id, repetitions=1) + for workload in spec.workloads + for config in spec.configs + ] + + +def prepare_output_dir(output_dir: Path, *, overwrite: bool, dry_run: bool) -> None: + if dry_run: + return + if output_dir.exists() and not output_dir.is_dir(): + raise ValueError(f"output path exists and is not a directory: {output_dir}") + if output_dir.exists(): + if overwrite: + shutil.rmtree(output_dir) + elif any(output_dir.iterdir()): + raise ValueError(f"output directory is not empty: {output_dir}; pass --overwrite to replace it") + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "raw").mkdir(exist_ok=True) + + +def preflight_suite(spec: BenchmarkSpec, *, spec_path: Path) -> None: + """Validate cheap suite inputs before any benchmark case consumes model time.""" + base_dir = spec_path.parent + errors: list[str] = [] + parsed_models = _preflight_model_configs(spec, base_dir=base_dir, errors=errors) + + _preflight_model_providers_with_errors(spec, base_dir=base_dir, errors=errors) + errors.extend(_preflight_workload_errors(spec, base_dir=base_dir)) + errors.extend(_preflight_config_errors(spec, parsed_models=parsed_models)) + if errors: + raise ValueError("Benchmark preflight failed:\n- " + "\n- ".join(errors)) + + +def _preflight_model_configs(spec: BenchmarkSpec, *, base_dir: Path, errors: list[str]) -> Any | None: + try: + return parse_model_configs(_resolve_config_source(spec.model_configs, base_dir)) + except Exception as exc: + errors.append(f"model_configs invalid: {exc}") + return None + + +def _preflight_model_providers_with_errors( + spec: BenchmarkSpec, + *, + base_dir: Path, + errors: list[str], +) -> None: + try: + _preflight_model_providers(spec, base_dir=base_dir) + except Exception as exc: + errors.append(f"model_providers invalid: {exc}") + + +def _preflight_workload_errors(spec: BenchmarkSpec, *, base_dir: Path) -> list[str]: + errors: list[str] = [] + for workload in spec.workloads: + try: + _preflight_workload(workload, base_dir=base_dir) + except Exception as exc: + errors.append(str(exc)) + return errors + + +def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) -> list[str]: + errors: list[str] = [] + active_config_ids = _active_config_ids(spec) + for config in spec.configs: + if config.id not in active_config_ids: + continue + try: + anonymizer_config = build_anonymizer_config(config) + except Exception as exc: + errors.append(f"config '{config.id}' invalid: {exc}") + continue + try: + _preflight_native_runtime(config, spec=spec) + except Exception as exc: + errors.append(f"config '{config.id}' native_runtime invalid: {exc}") + try: + _preflight_experimental_replacement_strategy(config, anonymizer_config) + except Exception as exc: + errors.append(f"config '{config.id}' experimental_replacement_strategy invalid: {exc}") + if parsed_models is None: + continue + try: + validate_model_alias_references( + parsed_models.model_configs, + parsed_models.selected_models, + check_substitute=isinstance(anonymizer_config.replace, Substitute) + or anonymizer_config.rewrite is not None, + check_rewrite=anonymizer_config.rewrite is not None, + ) + except ValueError as exc: + errors.append(f"config '{config.id}' model aliases invalid: {exc}") + return errors + + +def _active_config_ids(spec: BenchmarkSpec) -> set[str]: + if spec.matrix is None: + return {config.id for config in spec.configs} + return {entry.config for entry in spec.matrix} + + +def _preflight_native_runtime(config: ConfigSpec, *, spec: BenchmarkSpec) -> None: + strategy = config.experimental_detection_strategy + if strategy not in _NATIVE_RUNTIME_STRATEGIES: + return + runtime = _resolve_native_runtime_spec(spec, config) + if not runtime.runtime_id: + raise ValueError("native strategies require native_runtime.runtime_id") + if not runtime.endpoint or not runtime.model: + raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") + if strategy in _GLINER_NATIVE_RUNTIME_STRATEGIES: + if not runtime.gliner_endpoint or not runtime.gliner_model: + raise ValueError("GLiNER-native strategies require native_runtime.gliner_endpoint and gliner_model") + if not os.environ.get(runtime.gliner_api_key_env): + raise ValueError(f"{runtime.gliner_api_key_env} is not set for GLiNER-native strategy") + + +def _resolve_native_runtime_spec(spec: BenchmarkSpec, config: ConfigSpec) -> NativeRuntimeSpec: + runtime = spec.native_runtime or NativeRuntimeSpec() + if config.native_runtime is not None: + runtime = runtime.model_copy(update=config.native_runtime.model_dump(exclude_unset=True)) + return runtime.model_copy( + update={ + "endpoint": _resolve_runtime_value(runtime.endpoint, runtime.endpoint_env), + "model": _resolve_runtime_value(runtime.model, runtime.model_env), + "gliner_endpoint": _resolve_runtime_value(runtime.gliner_endpoint, runtime.gliner_endpoint_env), + "gliner_model": _resolve_runtime_value(runtime.gliner_model, runtime.gliner_model_env), + } + ) + + +def _resolve_runtime_value(explicit: str | None, env_var: str | None) -> str | None: + if explicit: + return explicit + return os.environ.get(env_var) if env_var else None + + +def _native_detection_runtime(spec: BenchmarkSpec, config: ConfigSpec) -> NativeDetectionRuntime: + runtime = _resolve_native_runtime_spec(spec, config) + if not runtime.runtime_id: + raise ValueError("native strategies require native_runtime.runtime_id") + if not runtime.endpoint or not runtime.model: + raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") + return NativeDetectionRuntime( + endpoint=runtime.endpoint, + model=runtime.model, + provider=runtime.provider, + alias=runtime.alias, + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, + gliner_endpoint=runtime.gliner_endpoint, + gliner_model=runtime.gliner_model or "", + gliner_provider=runtime.gliner_provider, + gliner_alias=runtime.gliner_alias, + gliner_api_key_env=runtime.gliner_api_key_env, + gliner_threshold=runtime.gliner_threshold, + max_workers=runtime.max_workers, + ) + + +def _preflight_experimental_replacement_strategy( + config: ConfigSpec, + anonymizer_config: AnonymizerConfig, +) -> None: + if config.experimental_replacement_strategy == ExperimentalReplacementStrategy.default: + return + if config.experimental_replacement_strategy != ExperimentalReplacementStrategy.local_structured_substitute: + raise ValueError(f"unsupported strategy: {config.experimental_replacement_strategy}") + if not isinstance(anonymizer_config.replace, Substitute): + raise ValueError("local_structured_substitute requires replace: substitute") + entity_labels = anonymizer_config.detect.entity_labels + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + if entity_labels is None: + raise ValueError( + "`local_structured_substitute` requires explicit detect.entity_labels limited to " + f"structured substitute labels: {supported}" + ) + unsupported = sorted(set(entity_labels) - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) + if unsupported: + raise ValueError( + f"unsupported local structured substitute labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + +def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: + raw = _resolve_config_source(spec.model_providers, base_dir) + if raw is None: + return + config_source: str | Path = raw + if isinstance(raw, str) and "\n" not in raw: + candidate = Path(raw.strip()).expanduser() + if candidate.suffix in (".yaml", ".yml"): + if not candidate.is_file(): + raise FileNotFoundError(f"Providers config file not found: {candidate}") + config_source = candidate + config_dict = yaml.safe_load(raw) if isinstance(raw, str) and "\n" in raw else load_config_file(config_source) + raw_providers = config_dict.get("providers") if isinstance(config_dict, dict) else None + if not isinstance(raw_providers, list): + raise ValueError("model_providers YAML must contain a top-level 'providers' list.") + for provider in raw_providers: + ModelProvider.model_validate(provider) + + +def _preflight_workload(workload: WorkloadSpec, *, base_dir: Path) -> None: + resolved_source = _resolve_input_source(workload.source, base_dir) + if _workload_has_row_slice(workload) and not _is_local_input_source(str(resolved_source)): + raise ValueError(f"workload '{workload.id}' row slicing requires a local workload source") + input_data = AnonymizerInput( + source=str(resolved_source), + text_column=workload.text_column, + id_column=workload.id_column, + data_summary=workload.data_summary, + ) + columns = _input_columns(input_data.source) + if columns is None: + return + if workload.text_column not in columns: + raise ValueError( + f"workload '{workload.id}' text_column '{workload.text_column}' not found in {input_data.source}; " + f"available columns: {sorted(columns)}" + ) + if workload.id_column is not None and workload.id_column not in columns: + raise ValueError( + f"workload '{workload.id}' id_column '{workload.id_column}' not found in {input_data.source}; " + f"available columns: {sorted(columns)}" + ) + + +def _input_columns(source: str) -> set[str] | None: + suffix = infer_input_source_suffix(source) + if suffix not in SUPPORTED_IO_FORMATS: + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + if is_remote_input_source(source): + return None + if suffix == ".csv": + return set(pd.read_csv(source, nrows=0).columns) + return set(pq.ParquetFile(source).schema_arrow.names) + + +def run_suite( + spec: BenchmarkSpec, + *, + spec_path: Path, + output_dir: Path, + export: bool, + fail_fast: bool, + dd_trace: DDTraceMode, + trace_dir: Path | None, +) -> BenchmarkResult: + contexts = _build_contexts( + spec, + spec_path=spec_path, + output_dir=output_dir, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) + anonymizer = Anonymizer(**contexts["anonymizer_kwargs"]) + cases = _run_cases(spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast, export=export) + measurement_path = combine_measurements(cases, output_dir / "measurements.jsonl") + should_export = _should_export_measurements(export=export, measurement_path=measurement_path) + table_dir = _export_suite_tables(measurement_path, output_dir=output_dir, should_export=should_export) + artifact_analysis_path = _combine_suite_detection_artifacts( + cases, output_dir=output_dir, should_export=should_export + ) + result = _benchmark_result( + spec, + output_dir=output_dir, + measurement_path=measurement_path, + table_dir=table_dir, + artifact_analysis_path=artifact_analysis_path, + cases=cases, + ) + write_summary(result) + return result + + +def _run_cases( + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + fail_fast: bool, + export: bool, +) -> list[BenchmarkCase]: + return [ + _run_case( + case, + spec, + contexts=contexts, + anonymizer=anonymizer, + fail_fast=fail_fast, + export_detection_artifacts=export, + ) + for case in build_cases(spec) + ] + + +def _should_export_measurements(*, export: bool, measurement_path: Path) -> bool: + return export and measurement_path.stat().st_size > 0 + + +def _export_suite_tables(measurement_path: Path, *, output_dir: Path, should_export: bool) -> Path | None: + if not should_export: + return None + return export_measurement_tables(measurement_path, output_dir / "tables") + + +def _combine_suite_detection_artifacts( + cases: list[BenchmarkCase], + *, + output_dir: Path, + should_export: bool, +) -> Path | None: + if not should_export: + return None + return combine_detection_artifact_analysis(cases, output_dir / "detection-artifacts.jsonl") + + +def _benchmark_result( + spec: BenchmarkSpec, + *, + output_dir: Path, + measurement_path: Path, + table_dir: Path | None, + artifact_analysis_path: Path | None, + cases: list[BenchmarkCase], +) -> BenchmarkResult: + return BenchmarkResult( + suite_id=spec.suite_id, + output_dir=str(output_dir), + measurement_path=str(measurement_path), + summary_path=str(output_dir / "summary.json"), + table_dir=str(table_dir) if table_dir is not None else None, + detection_artifact_analysis_path=str(artifact_analysis_path) if artifact_analysis_path is not None else None, + cases=cases, + ) + + +def _build_contexts( + spec: BenchmarkSpec, + *, + spec_path: Path, + output_dir: Path, + dd_trace: DDTraceMode, + trace_dir: Path | None, +) -> dict[str, Any]: + base_dir = spec_path.parent + artifact_path = _resolve_optional_path(spec.artifact_path, base_dir) or output_dir / "artifacts" + return { + "base_dir": base_dir, + "workloads": {workload.id: workload for workload in spec.workloads}, + "configs": {config.id: config for config in spec.configs}, + "raw_dir": output_dir / "raw", + "dd_trace": dd_trace, + "trace_dir": trace_dir or output_dir / "traces", + "dd_parser_compat": spec.dd_parser_compat, + "artifact_path": artifact_path, + "anonymizer_kwargs": { + "model_configs": _resolve_config_source(spec.model_configs, base_dir), + "model_providers": _resolve_config_source(spec.model_providers, base_dir), + "artifact_path": artifact_path, + }, + } + + +def _run_case( + case: BenchmarkCase, + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + fail_fast: bool, + export_detection_artifacts: bool, +) -> BenchmarkCase: + started = time.perf_counter() + attempt_errors: list[str] = [] + max_attempts = 1 if fail_fast else spec.case_retries + 1 + for attempt_number in range(1, max_attempts + 1): + paths = _case_run_paths(case, contexts=contexts, export_detection_artifacts=export_detection_artifacts) + try: + return _run_case_success( + case, + spec, + contexts=contexts, + anonymizer=anonymizer, + paths=paths, + started=started, + attempt_count=attempt_number, + attempt_errors=attempt_errors, + ) + except Exception as exc: + if fail_fast: + raise + attempt_errors.append(str(exc)) + if attempt_number >= max_attempts: + return _run_case_error( + case, + contexts=contexts, + paths=paths, + started=started, + error=exc, + attempt_count=attempt_number, + attempt_errors=attempt_errors, + ) + _sleep_before_case_retry(spec, case=case, attempt_number=attempt_number, error=exc) + + raise RuntimeError("unreachable benchmark retry state") + + +def _sleep_before_case_retry( + spec: BenchmarkSpec, + *, + case: BenchmarkCase, + attempt_number: int, + error: Exception, +) -> None: + logger.warning( + "case %s attempt %d failed; retrying after %.2fs: %s", + case.case_id, + attempt_number, + spec.case_retry_backoff_sec, + error, + ) + if spec.case_retry_backoff_sec > 0: + time.sleep(spec.case_retry_backoff_sec) + + +def _case_run_paths( + case: BenchmarkCase, + *, + contexts: dict[str, Any], + export_detection_artifacts: bool, +) -> _CaseRunPaths: + return _CaseRunPaths( + raw_path=contexts["raw_dir"] / f"{case.case_id}.jsonl", + artifact_output_path=contexts["raw_dir"] / f"{case.case_id}.detection-artifacts.jsonl", + trace_path=_case_trace_path(case, contexts=contexts), + artifact_snapshot=snapshot_detection_artifacts(contexts["artifact_path"]) + if export_detection_artifacts + else None, + ) + + +def _run_case_success( + case: BenchmarkCase, + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + paths: _CaseRunPaths, + started: float, + attempt_count: int, + attempt_errors: list[str], +) -> BenchmarkCase: + workload = _get_item(contexts["workloads"], case.workload_id, "workload") + config = _get_item(contexts["configs"], case.config_id, "config") + execution = _execute_case( + anonymizer, + workload, + config, + raw_path=paths.raw_path, + trace_path=paths.trace_path, + case=case, + spec=spec, + base_dir=contexts["base_dir"], + dd_trace=contexts["dd_trace"], + dd_parser_compat=contexts["dd_parser_compat"], + ) + detection_artifact_path = _case_detection_artifact_path( + contexts, + paths, + case=case, + config=config, + execution=execution, + ) + return _case_with_result( + case, + status=CaseStatus.completed, + started=started, + raw_path=paths.raw_path, + detection_artifact_path=detection_artifact_path, + trace_path=paths.trace_path, + attempt_count=attempt_count, + attempt_errors=attempt_errors, + ) + + +def _run_case_error( + case: BenchmarkCase, + *, + contexts: dict[str, Any], + paths: _CaseRunPaths, + started: float, + error: Exception, + attempt_count: int, + attempt_errors: list[str], +) -> BenchmarkCase: + detection_artifact_path = _export_case_detection_artifacts_if_requested( + contexts, + paths.artifact_output_path, + case=case, + artifact_snapshot=paths.artifact_snapshot, + ) + return _case_with_result( + case, + status=CaseStatus.error, + started=started, + raw_path=paths.raw_path, + detection_artifact_path=detection_artifact_path, + trace_path=paths.trace_path, + error=str(error), + attempt_count=attempt_count, + attempt_errors=attempt_errors, + ) + + +def _case_detection_artifact_path( + contexts: dict[str, Any], + paths: _CaseRunPaths, + *, + case: BenchmarkCase, + config: ConfigSpec, + execution: _CaseExecution, +) -> Path | None: + detection_artifact_path = _export_case_detection_artifacts_if_requested( + contexts, + paths.artifact_output_path, + case=case, + artifact_snapshot=paths.artifact_snapshot, + ) + detection_artifact_path = _trace_final_artifact_path_if_requested( + config, + detection_artifact_path, + paths.artifact_output_path, + case=case, + trace_dataframe=execution.trace_dataframe, + ) + if detection_artifact_path is not None or paths.artifact_snapshot is None: + return detection_artifact_path + return None + + +def _trace_final_artifact_path_if_requested( + config: ConfigSpec, + detection_artifact_path: Path | None, + output_path: Path, + *, + case: BenchmarkCase, + trace_dataframe: pd.DataFrame | None, +) -> Path | None: + if config.experimental_detection_strategy not in _TRACE_FINAL_ARTIFACT_STRATEGIES: + return detection_artifact_path + if trace_dataframe is None: + return detection_artifact_path + return patch_case_detection_artifacts_from_trace_dataframe( + detection_artifact_path or output_path, + trace_dataframe, + case=case, + ) + + +def patch_case_detection_artifacts_from_trace_dataframe( + output_path: Path, + trace_dataframe: pd.DataFrame, + *, + case: BenchmarkCase | None = None, +) -> Path | None: + final_rows = _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe) + if not final_rows: + return None + rows = _read_detection_artifact_payloads(output_path) if output_path.exists() else [] + patched = _merge_final_entity_artifact_rows(rows, final_rows) + if case is not None: + patched = [_with_case_metadata(row, case=case) for row in patched] + write_detection_artifact_payloads(patched, output_path) + return output_path + + +def _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe: pd.DataFrame) -> list[dict[str, Any]]: + entities_column = _trace_final_entities_column(trace_dataframe) + if entities_column is None: + return [] + return [ + _final_entity_artifact_row(raw_entities, row_index=row_index) + for row_index, raw_entities in enumerate(trace_dataframe[entities_column]) + ] + + +def _trace_final_entities_column(trace_dataframe: pd.DataFrame) -> str | None: + if COL_FINAL_ENTITIES in trace_dataframe.columns: + return COL_FINAL_ENTITIES + if COL_DETECTED_ENTITIES in trace_dataframe.columns: + return COL_DETECTED_ENTITIES + return None + + +def _final_entity_artifact_row(raw_entities: object, *, row_index: int) -> dict[str, Any]: + entities = EntitiesSchema.from_raw(raw_entities).entities + return build_detection_artifact_row_from_entities( + workflow_name="entity-detection-final-trace", + batch_file="trace_dataframe", + row_index=row_index, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=entities, + ).model_dump() + + +def _read_detection_artifact_payloads(output_path: Path) -> list[dict[str, Any]]: + with output_path.open(encoding="utf-8") as source: + return [json.loads(line) for line in source if line.strip()] + + +def _merge_final_entity_artifact_rows( + rows: list[dict[str, Any]], + final_rows: list[dict[str, Any]], +) -> list[dict[str, Any]]: + patched = [_patch_final_entity_artifact_row(row, final_row) for row, final_row in zip(rows, final_rows)] + return patched + rows[len(final_rows) :] + final_rows[len(rows) :] + + +def _patch_final_entity_artifact_row(row: dict[str, Any], final_row: dict[str, Any]) -> dict[str, Any]: + clean_row = _without_final_entity_artifact_fields(row) + return {**clean_row, **_final_entity_artifact_fields(final_row)} + + +def _without_final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: + return { + key: value + for key, value in row.items() + if key not in _FINAL_ARTIFACT_KEYS and not key.startswith(_FINAL_ARTIFACT_PREFIXES) + } + + +def _final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: + return {key: row[key] for key in _FINAL_ARTIFACT_KEYS} + + +def _case_with_result( + case: BenchmarkCase, + *, + status: CaseStatus, + started: float, + raw_path: Path, + detection_artifact_path: Path | None, + trace_path: Path | None, + attempt_count: int, + attempt_errors: list[str], + error: str | None = None, +) -> BenchmarkCase: + return case.model_copy( + update={ + "status": status, + "elapsed_sec": time.perf_counter() - started, + "measurement_path": str(raw_path), + "detection_artifact_path": (str(detection_artifact_path) if detection_artifact_path is not None else None), + "trace_path": str(trace_path) if trace_path is not None else None, + "error": error, + "attempt_count": attempt_count, + "attempt_errors": list(attempt_errors), + } + ) + + +def _export_case_detection_artifacts_if_requested( + contexts: dict[str, Any], + output_path: Path, + *, + case: BenchmarkCase, + artifact_snapshot: dict[str, int] | None, +) -> Path | None: + if artifact_snapshot is None: + return None + return export_case_detection_artifact_analysis( + contexts["artifact_path"], + output_path, + case=case, + artifact_snapshot=artifact_snapshot, + ) + + +def _case_trace_path(case: BenchmarkCase, *, contexts: dict[str, Any]) -> Path | None: + if contexts["dd_trace"] == DDTraceMode.none: + return None + return contexts["trace_dir"] / f"{case.case_id}.jsonl" + + +def _execute_case( + anonymizer: Anonymizer, + workload: WorkloadSpec, + config: ConfigSpec, + *, + raw_path: Path, + trace_path: Path | None, + case: BenchmarkCase, + spec: BenchmarkSpec, + base_dir: Path, + dd_trace: DDTraceMode, + dd_parser_compat: DDParserCompatMode, +) -> _CaseExecution: + anonymizer_config = build_anonymizer_config(config) + input_data = build_input( + workload, + base_dir, + slice_dir=raw_path.parent / "inputs", + case_id=case.case_id, + ) + measurement = MeasurementConfig( + output_path=raw_path, + run_id=case.case_id, + run_tags=_run_tags(case, spec), + streaming=True, + keep_records=False, + dd_trace=dd_trace.value, + dd_trace_path=trace_path, + fail_on_write_error=True, + ) + with configured_measurement_session(measurement): + with dd_parser_compat_context(dd_parser_compat): + detection_context_kwargs: dict[str, Any] = {} + if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: + detection_context_kwargs["native_runtime"] = _native_detection_runtime(spec, config) + with experimental_detection_strategy_context( + config.experimental_detection_strategy, + **detection_context_kwargs, + ): + with experimental_replacement_strategy_context(config.experimental_replacement_strategy): + result = anonymizer.run( + config=anonymizer_config, + data=input_data, + ) + return _CaseExecution(input_data=input_data, trace_dataframe=getattr(result, "trace_dataframe", None)) + + +def build_input( + workload: WorkloadSpec, + base_dir: Path, + *, + slice_dir: Path | None = None, + case_id: str | None = None, +) -> AnonymizerInput: + resolved_source = _resolve_input_source(workload.source, base_dir) + source = ( + _materialize_sliced_source(workload, resolved_source, slice_dir=slice_dir, case_id=case_id) + if _workload_has_row_slice(workload) + else resolved_source + ) + return AnonymizerInput( + source=str(source), + text_column=workload.text_column, + id_column=workload.id_column, + data_summary=workload.data_summary, + ) + + +def _workload_has_row_slice(workload: WorkloadSpec) -> bool: + return workload.row_limit is not None or workload.row_offset > 0 + + +def _is_local_input_source(source: str) -> bool: + return "://" not in source + + +def _materialize_sliced_source( + workload: WorkloadSpec, + source: str | Path, + *, + slice_dir: Path | None, + case_id: str | None, +) -> Path: + if not _is_local_input_source(str(source)): + raise ValueError(f"workload '{workload.id}' row slicing requires a local workload source") + if slice_dir is None or case_id is None: + raise ValueError("row slicing requires slice_dir and case_id") + source_path = Path(source) + suffix = infer_input_source_suffix(str(source_path)) + dataframe = _read_local_input_dataframe(source_path, suffix=suffix) + sliced = dataframe.iloc[_slice_bounds(workload)] + slice_dir.mkdir(parents=True, exist_ok=True) + destination = slice_dir / f"{_safe_case_filename(case_id)}{suffix}" + _write_local_input_dataframe(sliced, destination, suffix=suffix) + return destination + + +def _slice_bounds(workload: WorkloadSpec) -> slice: + start = workload.row_offset + stop = start + workload.row_limit if workload.row_limit is not None else None + return slice(start, stop) + + +def _read_local_input_dataframe(source: Path, *, suffix: str) -> pd.DataFrame: + if suffix == ".csv": + return pd.read_csv(source) + if suffix == ".parquet": + return pd.read_parquet(source) + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + + +def _write_local_input_dataframe(dataframe: pd.DataFrame, destination: Path, *, suffix: str) -> None: + if suffix == ".csv": + dataframe.to_csv(destination, index=False) + return + if suffix == ".parquet": + dataframe.to_parquet(destination, index=False) + return + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + + +def _safe_case_filename(case_id: str) -> str: + return "".join(char if char.isalnum() or char in "._-" else "_" for char in case_id) + + +def build_anonymizer_config(config: ConfigSpec) -> AnonymizerConfig: + detect = Detect.model_validate(config.detect) + if config.replace is not None: + return AnonymizerConfig( + detect=detect, replace=build_replace(config.replace), emit_telemetry=config.emit_telemetry + ) + return AnonymizerConfig(detect=detect, rewrite=build_rewrite(config.rewrite), emit_telemetry=config.emit_telemetry) + + +def build_replace(raw: str | ReplaceSpec) -> Redact | Hash | Annotate | Substitute: + spec = ReplaceSpec(strategy=ReplaceKind(raw)) if isinstance(raw, str) else raw + if spec.strategy == ReplaceKind.redact: + return Redact(**_present({"format_template": spec.format_template, "normalize_label": spec.normalize_label})) + if spec.strategy == ReplaceKind.hash: + return Hash( + **_present( + { + "format_template": spec.format_template, + "algorithm": spec.algorithm, + "digest_length": spec.digest_length, + } + ) + ) + if spec.strategy == ReplaceKind.annotate: + return Annotate(**_present({"format_template": spec.format_template})) + return Substitute(**_present({"instructions": spec.instructions})) + + +def build_rewrite(spec: RewriteSpec | None) -> Rewrite: + if spec is None: + raise ValueError("rewrite config is missing") + privacy_goal = _privacy_goal(spec) + return Rewrite( + privacy_goal=privacy_goal, + instructions=spec.instructions, + risk_tolerance=spec.risk_tolerance, + max_repair_iterations=spec.max_repair_iterations, + strict_entity_protection=spec.strict_entity_protection, + ) + + +def _privacy_goal(spec: RewriteSpec) -> PrivacyGoal | None: + if spec.protect is None and spec.preserve is None: + return None + return PrivacyGoal( + protect=spec.protect or DEFAULT_PROTECT_TEXT, + preserve=spec.preserve or DEFAULT_PRESERVE_TEXT, + ) + + +def combine_measurements(cases: list[BenchmarkCase], destination: Path) -> Path: + with destination.open("w", encoding="utf-8") as output: + for case in cases: + if case.measurement_path is None: + continue + source = Path(case.measurement_path) + if source.exists(): + output.write(source.read_text(encoding="utf-8")) + return destination + + +def combine_detection_artifact_analysis(cases: list[BenchmarkCase], destination: Path) -> Path | None: + chunks: list[str] = [] + for case in cases: + if case.detection_artifact_path is None: + continue + source = Path(case.detection_artifact_path) + if source.exists(): + chunks.append(_jsonl_chunk(source.read_text(encoding="utf-8"))) + if not chunks: + return None + destination.write_text("".join(chunks), encoding="utf-8") + return destination + + +def _jsonl_chunk(text: str) -> str: + if not text or text.endswith("\n"): + return text + return text + "\n" + + +def export_measurement_tables(measurement_path: Path, table_dir: Path) -> Path: + dataframe = read_measurements(measurement_path) + export_tables( + dataframe, input_path=measurement_path, output_dir=table_dir, export_format=ExportFormat.parquet, overwrite=True + ) + return table_dir + + +def snapshot_detection_artifacts(artifact_path: Path) -> dict[str, int]: + if not artifact_path.exists(): + return {} + return { + str(parquet_file.relative_to(artifact_path)): parquet_file.stat().st_mtime_ns + for parquet_file in iter_detection_parquet_files(artifact_path) + } + + +def changed_detection_artifact_files(artifact_path: Path, snapshot: dict[str, int]) -> list[Path]: + if not artifact_path.exists(): + return [] + changed: list[Path] = [] + for parquet_file in iter_detection_parquet_files(artifact_path): + key = str(parquet_file.relative_to(artifact_path)) + if snapshot.get(key) != parquet_file.stat().st_mtime_ns: + changed.append(parquet_file) + return changed + + +def export_detection_artifact_analysis( + artifact_path: Path, + output_path: Path, + *, + artifact_snapshot: dict[str, int] | None = None, +) -> Path | None: + if not artifact_path.exists(): + return None + parquet_files = ( + changed_detection_artifact_files(artifact_path, artifact_snapshot) if artifact_snapshot is not None else None + ) + analysis = analyze_artifacts(artifact_path, parquet_files=parquet_files) + if not analysis.rows: + return None + write_detection_artifact_payloads([row.model_dump() for row in analysis.rows], output_path) + return output_path + + +def export_case_detection_artifact_analysis( + artifact_path: Path, + output_path: Path, + *, + case: BenchmarkCase, + artifact_snapshot: dict[str, int], +) -> Path | None: + if not artifact_path.exists(): + return None + parquet_files = changed_detection_artifact_files(artifact_path, artifact_snapshot) + analysis = analyze_artifacts(artifact_path, parquet_files=parquet_files) + if not analysis.rows: + return None + write_detection_artifact_payloads( + [_with_case_metadata(row.model_dump(), case=case) for row in analysis.rows], + output_path, + ) + return output_path + + +def _with_case_metadata(row: dict[str, Any], *, case: BenchmarkCase) -> dict[str, Any]: + return { + "suite_id": case.suite_id, + "workload_id": case.workload_id, + "config_id": case.config_id, + "repetition": case.repetition, + "case_id": case.case_id, + "run_id": case.case_id, + **row, + } + + +def write_detection_artifact_payloads(rows: list[dict[str, Any]], output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + pd.json_normalize(rows, sep=".").to_json(output_path, orient="records", lines=True) + + +def write_summary(result: BenchmarkResult) -> None: + Path(result.summary_path).write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def render_result(result: BenchmarkResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = sum(case.status == CaseStatus.completed for case in result.cases) + errored = sum(case.status == CaseStatus.error for case in result.cases) + planned = sum(case.status == CaseStatus.planned for case in result.cases) + if planned and completed == 0 and errored == 0: + return f"Planned {planned} case(s); output={result.output_dir}" + return f"Ran {completed}/{len(result.cases)} case(s); errors={errored}; output={result.output_dir}" + + +def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: + config = next(item for item in spec.configs if item.id == case.config_id) + tags = { + "suite_id": spec.suite_id, + "workload_id": case.workload_id, + "config_id": case.config_id, + "repetition": case.repetition, + "case_id": case.case_id, + "experimental_detection_strategy": config.experimental_detection_strategy.value, + "experimental_replacement_strategy": config.experimental_replacement_strategy.value, + "dd_parser_compat": spec.dd_parser_compat.value, + } + if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: + tags.update(_native_runtime_tags(_resolve_native_runtime_spec(spec, config))) + return tags + + +def _native_runtime_tags(runtime: NativeRuntimeSpec) -> dict[str, Any]: + return _present( + { + "native_runtime_id": runtime.runtime_id, + "native_endpoint_env": runtime.endpoint_env, + "native_model": runtime.model, + "native_model_env": runtime.model_env, + "native_provider": runtime.provider, + "native_alias": runtime.alias, + "native_max_tokens": runtime.max_tokens, + "native_timeout_sec": runtime.timeout_sec, + "native_max_workers": runtime.max_workers, + "gliner_endpoint_env": runtime.gliner_endpoint_env, + "gliner_model": runtime.gliner_model, + "gliner_model_env": runtime.gliner_model_env, + "gliner_provider": runtime.gliner_provider, + "gliner_alias": runtime.gliner_alias, + "gliner_api_key_env": runtime.gliner_api_key_env, + "gliner_threshold": runtime.gliner_threshold, + } + ) + + +def _present(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} + + +def _get_item(items: dict[str, Any], item_id: str, item_type: str) -> Any: + if item_id not in items: + raise ValueError(f"unknown {item_type}: {item_id}") + return items[item_id] + + +def _resolve_input_source(source: str, base_dir: Path) -> str | Path: + if "://" in source: + return source + return _resolve_path(source, base_dir) + + +def _resolve_optional_path(raw: str | None, base_dir: Path) -> Path | None: + if raw is None: + return None + return _resolve_path(raw, base_dir) + + +def _resolve_config_source(raw: str | None, base_dir: Path) -> str | None: + if raw is None or "\n" in raw: + return raw + candidate = Path(raw).expanduser() + if candidate.suffix in {".yaml", ".yml"}: + return str(_resolve_path(raw, base_dir)) + return raw + + +def _resolve_path(raw: str, base_dir: Path) -> Path: + path = Path(raw).expanduser() + return path if path.is_absolute() else base_dir / path + + +def dry_run_result( + spec: BenchmarkSpec, + *, + output_dir: Path, + export: bool, + dd_trace: DDTraceMode, + trace_dir: Path | None, +) -> BenchmarkResult: + cases = build_cases(spec) + if dd_trace != DDTraceMode.none: + resolved_trace_dir = trace_dir or output_dir / "traces" + cases = [ + case.model_copy(update={"trace_path": str(resolved_trace_dir / f"{case.case_id}.jsonl")}) for case in cases + ] + return BenchmarkResult( + suite_id=spec.suite_id, + output_dir=str(output_dir), + measurement_path=str(output_dir / "measurements.jsonl"), + summary_path=str(output_dir / "summary.json"), + table_dir=str(output_dir / "tables") if export else None, + detection_artifact_analysis_path=str(output_dir / "detection-artifacts.jsonl") if export else None, + cases=cases, + ) + + +@app.default +def main( + spec: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + dry_run: Annotated[bool, cyclopts.Parameter("--dry-run")] = False, + export: Annotated[bool, cyclopts.Parameter("--export")] = True, + fail_fast: Annotated[bool, cyclopts.Parameter("--fail-fast")] = False, + dd_trace: Annotated[DDTraceMode, cyclopts.Parameter("--dd-trace")] = DDTraceMode.none, + trace_dir: Annotated[Path | None, cyclopts.Parameter("--trace-dir")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = run_or_plan( + spec, + output=output, + overwrite=overwrite, + dry_run=dry_run, + export=export, + fail_fast=fail_fast, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) + except (ValueError, ValidationError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if any(case.status == CaseStatus.error for case in result.cases): + raise SystemExit(1) + + +def run_or_plan( + spec_path: Path, + *, + output: Path | None, + overwrite: bool, + dry_run: bool, + export: bool, + fail_fast: bool, + dd_trace: DDTraceMode = DDTraceMode.none, + trace_dir: Path | None = None, +) -> BenchmarkResult: + benchmark_spec = load_spec(spec_path) + output_dir = output or Path("benchmark-runs") / benchmark_spec.suite_id + if trace_dir is not None and dd_trace == DDTraceMode.none: + raise ValueError("--trace-dir requires --dd-trace") + preflight_suite(benchmark_spec, spec_path=spec_path) + if dry_run: + return dry_run_result( + benchmark_spec, + output_dir=output_dir, + export=export, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) + prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) + return run_suite( + benchmark_spec, + spec_path=spec_path, + output_dir=output_dir, + export=export, + fail_fast=fail_fast, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/screen_strategy_comparisons.py b/tools/measurement/screen_strategy_comparisons.py new file mode 100644 index 00000000..0cd9e141 --- /dev/null +++ b/tools/measurement/screen_strategy_comparisons.py @@ -0,0 +1,1134 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Screen strategy-comparison CSVs across benchmark analysis directories. + +Usage: + uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ + uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ --output strategy-screen.csv + uv run python tools/measurement/screen_strategy_comparisons.py run-a/analysis run-b/analysis --json +""" + +from __future__ import annotations + +import ast +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, model_validator + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.strategy_screen") + +COMPARISON_COLUMNS = { + "workload_id", + "baseline_config_id", + "candidate_config_id", + "safety_verdict", + "performance_verdict", + "candidate_verdict", +} + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class GroupBy(StrEnum): + strategy = "strategy" + strategy_workload_family = "strategy_workload_family" + strategy_workload = "strategy_workload" + + +class ScreenRow(BaseModel): + source_path: str + workload_id: str + workload_family: str | None = None + baseline_config_id: str + candidate_config_id: str + baseline_strategy: str | None = None + candidate_strategy: str | None = None + baseline_replacement_strategy: str | None = None + candidate_replacement_strategy: str | None = None + baseline_case_count: int | None = None + candidate_case_count: int | None = None + value_protection_verdict: str | None = None + signature_parity_verdict: str | None = None + safety_verdict: str + performance_verdict: str + candidate_verdict: str + evidence_level: str = "legacy" + pipeline_elapsed_sec_delta_pct: float | None = None + observed_total_requests_delta: float | None = None + observed_total_tokens_delta: float | None = None + final_entity_count_delta: float | None = None + augmented_entity_count_delta: float | None = None + augmented_new_final_value_count_delta: float | None = None + baseline_original_value_leak_count: float | None = None + candidate_original_value_leak_count: float | None = None + original_value_leak_count_delta: float | None = None + baseline_original_value_leak_record_count: float | None = None + candidate_original_value_leak_record_count: float | None = None + original_value_leak_record_count_delta: float | None = None + baseline_replacement_missing_final_entity_count: float | None = None + candidate_replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_count: float | None = None + candidate_replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_value_count: float | None = None + candidate_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_value_count_delta: float | None = None + baseline_duplicate_synthetic_replacement_count: float | None = None + candidate_duplicate_synthetic_replacement_count: float | None = None + duplicate_synthetic_replacement_count_delta: float | None = None + baseline_only_final_entity_signature_count: float | None = None + candidate_only_final_entity_signature_count: float | None = None + shared_final_entity_signature_count: float | None = None + baseline_stable_final_entity_signature_count: int | None = None + candidate_stable_final_entity_signature_count: int | None = None + baseline_stable_candidate_unstable_final_entity_signature_count: float | None = None + candidate_stable_baseline_unstable_final_entity_signature_count: float | None = None + shared_stable_final_entity_signature_count: int | None = None + flags: list[str] = Field(default_factory=list) + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) + stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + + @model_validator(mode="after") + def fill_workload_family(self) -> "ScreenRow": + if self.workload_family is None: + self.workload_family = workload_family(self.workload_id) + return self + + +class ScreenSummary(BaseModel): + viable_count: int = 0 + review_count: int = 0 + reject_count: int = 0 + candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + safety_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + evidence_level_counts: dict[str, int] = Field(default_factory=dict) + + +class ScreenGroup(BaseModel): + group_key: str + candidate_strategy: str | None = None + candidate_replacement_strategy: str | None = None + candidate_config_ids: list[str] = Field(default_factory=list) + workload_ids: list[str] = Field(default_factory=list) + workload_families: list[str] = Field(default_factory=list) + row_count: int = 0 + min_baseline_case_count: int | None = None + min_candidate_case_count: int | None = None + viable_count: int = 0 + review_count: int = 0 + reject_count: int = 0 + has_conflicting_verdicts: bool = False + recommendation: str = "unknown" + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + evidence_level_counts: dict[str, int] = Field(default_factory=dict) + split_verdict_candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + best_pipeline_elapsed_sec_delta_pct: float | None = None + best_observed_total_tokens_delta: float | None = None + best_observed_total_requests_delta: float | None = None + worst_pipeline_elapsed_sec_delta_pct: float | None = None + worst_observed_total_tokens_delta: float | None = None + worst_observed_total_requests_delta: float | None = None + min_shared_stable_final_entity_signature_count: int | None = None + sum_baseline_replacement_missing_final_entity_count: float | None = None + sum_candidate_replacement_missing_final_entity_count: float | None = None + sum_baseline_original_value_leak_count: float | None = None + sum_candidate_original_value_leak_count: float | None = None + sum_baseline_original_value_leak_record_count: float | None = None + sum_candidate_original_value_leak_record_count: float | None = None + sum_baseline_replacement_synthetic_original_collision_count: float | None = None + sum_candidate_replacement_synthetic_original_collision_count: float | None = None + sum_baseline_replacement_synthetic_original_collision_value_count: float | None = None + sum_candidate_replacement_synthetic_original_collision_value_count: float | None = None + sum_baseline_duplicate_synthetic_replacement_count: float | None = None + sum_candidate_duplicate_synthetic_replacement_count: float | None = None + baseline_defect_improvement_count: int = 0 + flag_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) + stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + + @model_validator(mode="after") + def fill_recommendation(self) -> "ScreenGroup": + if self.recommendation == "unknown": + self.recommendation = group_recommendation(self) + return self + + +class ScreenResult(BaseModel): + input_paths: list[str] + scanned_file_count: int + comparison_file_count: int + row_count: int + duplicate_row_count: int = 0 + summary: ScreenSummary = Field(default_factory=ScreenSummary) + rows: list[ScreenRow] = Field(default_factory=list) + groups: list[ScreenGroup] = Field(default_factory=list) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def screen_comparison_paths( + paths: list[Path], + *, + group_by: GroupBy = GroupBy.strategy, + config_aliases: dict[str, str] | None = None, + source_includes: list[str] | None = None, + source_excludes: list[str] | None = None, +) -> ScreenResult: + files = filter_source_paths( + list(iter_csv_files(paths)), + includes=source_includes or [], + excludes=source_excludes or [], + ) + rows: list[ScreenRow] = [] + comparison_file_count = 0 + for csv_file in files: + table = read_csv_or_empty(csv_file) + if table is None: + continue + if not is_comparison_table(table): + continue + comparison_file_count += 1 + rows.extend(build_rows_from_table(table, source_path=csv_file)) + deduped_rows = sorted(dedupe_rows(rows), key=screen_sort_key) + groups = sorted( + summarize_groups(deduped_rows, group_by=group_by, config_aliases=config_aliases or {}), + key=group_sort_key, + ) + return ScreenResult( + input_paths=[str(path) for path in paths], + scanned_file_count=len(files), + comparison_file_count=comparison_file_count, + row_count=len(deduped_rows), + duplicate_row_count=len(rows) - len(deduped_rows), + summary=summarize_rows(deduped_rows), + rows=deduped_rows, + groups=groups, + ) + + +def iter_csv_files(paths: list[Path]) -> list[Path]: + files: list[Path] = [] + for path in paths: + if not path.exists(): + raise ValueError(f"comparison path does not exist: {path}") + if path.is_file(): + if path.suffix == ".csv": + files.append(path) + continue + files.extend(sorted(path.rglob("*.csv"))) + return sorted(set(files)) + + +def filter_source_paths(paths: list[Path], *, includes: list[str], excludes: list[str]) -> list[Path]: + return [path for path in paths if source_path_matches(path, includes=includes, excludes=excludes)] + + +def source_path_matches(path: Path, *, includes: list[str], excludes: list[str]) -> bool: + text = str(path) + if includes and not any(fragment in text for fragment in includes): + return False + return not any(fragment in text for fragment in excludes) + + +def is_comparison_table(table: pd.DataFrame) -> bool: + if "source_path" in table.columns: + return False + return COMPARISON_COLUMNS.issubset(set(table.columns)) + + +def read_csv_or_empty(csv_file: Path) -> pd.DataFrame | None: + try: + return pd.read_csv(csv_file) + except pd.errors.EmptyDataError: + return None + + +def read_config_aliases(path: Path | None) -> dict[str, str]: + if path is None: + return {} + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"config aliases file is not valid JSON: {path}") from exc + if not isinstance(payload, dict): + raise ValueError("config aliases file must contain a JSON object mapping config IDs to aliases") + return {str(key): str(value) for key, value in payload.items()} + + +def build_rows_from_table(table: pd.DataFrame, *, source_path: Path) -> list[ScreenRow]: + return [build_row(row, source_path=source_path) for _, row in table.iterrows()] + + +def build_row(row: pd.Series, *, source_path: Path) -> ScreenRow: + return ScreenRow( + source_path=str(source_path), + workload_id=required_string(row, "workload_id"), + workload_family=workload_family(required_string(row, "workload_id")), + baseline_config_id=required_string(row, "baseline_config_id"), + candidate_config_id=required_string(row, "candidate_config_id"), + baseline_strategy=optional_string(row.get("baseline_strategy")), + candidate_strategy=optional_string(row.get("candidate_strategy")), + baseline_replacement_strategy=optional_string(row.get("baseline_replacement_strategy")), + candidate_replacement_strategy=optional_string(row.get("candidate_replacement_strategy")), + baseline_case_count=optional_int(row.get("baseline_case_count")), + candidate_case_count=optional_int(row.get("candidate_case_count")), + value_protection_verdict=optional_string(row.get("value_protection_verdict")), + signature_parity_verdict=optional_string(row.get("signature_parity_verdict")), + safety_verdict=required_string(row, "safety_verdict"), + performance_verdict=required_string(row, "performance_verdict"), + candidate_verdict=required_string(row, "candidate_verdict"), + evidence_level=comparison_evidence_level(row), + pipeline_elapsed_sec_delta_pct=optional_float(row.get("pipeline_elapsed_sec_delta_pct")), + observed_total_requests_delta=optional_float(row.get("observed_total_requests_delta")), + observed_total_tokens_delta=optional_float(row.get("observed_total_tokens_delta")), + final_entity_count_delta=optional_float(row.get("final_entity_count_delta")), + augmented_entity_count_delta=optional_float(row.get("augmented_entity_count_delta")), + augmented_new_final_value_count_delta=optional_float(row.get("augmented_new_final_value_count_delta")), + baseline_original_value_leak_count=optional_float(row.get("baseline_original_value_leak_count")), + candidate_original_value_leak_count=optional_float(row.get("candidate_original_value_leak_count")), + original_value_leak_count_delta=optional_float(row.get("original_value_leak_count_delta")), + baseline_original_value_leak_record_count=optional_float(row.get("baseline_original_value_leak_record_count")), + candidate_original_value_leak_record_count=optional_float( + row.get("candidate_original_value_leak_record_count") + ), + original_value_leak_record_count_delta=optional_float(row.get("original_value_leak_record_count_delta")), + baseline_replacement_missing_final_entity_count=optional_float( + row.get("baseline_replacement_missing_final_entity_count") + ), + candidate_replacement_missing_final_entity_count=optional_float( + row.get("candidate_replacement_missing_final_entity_count") + ), + replacement_missing_final_entity_count_delta=optional_float( + row.get("replacement_missing_final_entity_count_delta") + ), + baseline_replacement_synthetic_original_collision_count=optional_float( + row.get("baseline_replacement_synthetic_original_collision_count") + ), + candidate_replacement_synthetic_original_collision_count=optional_float( + row.get("candidate_replacement_synthetic_original_collision_count") + ), + replacement_synthetic_original_collision_count_delta=optional_float( + row.get("replacement_synthetic_original_collision_count_delta") + ), + baseline_replacement_synthetic_original_collision_value_count=optional_float( + row.get("baseline_replacement_synthetic_original_collision_value_count") + ), + candidate_replacement_synthetic_original_collision_value_count=optional_float( + row.get("candidate_replacement_synthetic_original_collision_value_count") + ), + replacement_synthetic_original_collision_value_count_delta=optional_float( + row.get("replacement_synthetic_original_collision_value_count_delta") + ), + baseline_duplicate_synthetic_replacement_count=optional_float( + row.get("baseline_duplicate_synthetic_replacement_count") + ), + candidate_duplicate_synthetic_replacement_count=optional_float( + row.get("candidate_duplicate_synthetic_replacement_count") + ), + duplicate_synthetic_replacement_count_delta=optional_float( + row.get("duplicate_synthetic_replacement_count_delta") + ), + baseline_only_final_entity_signature_count=optional_float( + row.get("baseline_only_final_entity_signature_count") + ), + candidate_only_final_entity_signature_count=optional_float( + row.get("candidate_only_final_entity_signature_count") + ), + shared_final_entity_signature_count=optional_float(row.get("shared_final_entity_signature_count")), + baseline_stable_final_entity_signature_count=optional_int( + row.get("baseline_stable_final_entity_signature_count") + ), + candidate_stable_final_entity_signature_count=optional_int( + row.get("candidate_stable_final_entity_signature_count") + ), + baseline_stable_candidate_unstable_final_entity_signature_count=optional_float( + row.get("baseline_stable_candidate_unstable_final_entity_signature_count") + ), + candidate_stable_baseline_unstable_final_entity_signature_count=optional_float( + row.get("candidate_stable_baseline_unstable_final_entity_signature_count") + ), + shared_stable_final_entity_signature_count=optional_int(row.get("shared_stable_final_entity_signature_count")), + flags=parse_flags(row.get("flags")), + baseline_only_label_counts=preferred_count_columns( + row, + preferred_prefix="baseline_only_candidate_uncovered_signature_label_counts", + fallback_prefix="baseline_only_final_entity_signature_label_counts", + preferred_count_column="baseline_only_candidate_uncovered_signature_count", + ), + label_mismatch_label_counts=count_columns( + row, + "baseline_only_candidate_label_mismatch_signature_label_counts", + ), + baseline_replacement_missing_final_entity_label_counts=count_columns( + row, + "baseline_replacement_missing_final_entity_label_counts", + ), + candidate_replacement_missing_final_entity_label_counts=count_columns( + row, + "candidate_replacement_missing_final_entity_label_counts", + ), + baseline_original_value_leak_label_counts=count_columns(row, "baseline_original_value_leak_label_counts"), + candidate_original_value_leak_label_counts=count_columns(row, "candidate_original_value_leak_label_counts"), + baseline_replacement_synthetic_original_collision_label_counts=count_columns( + row, + "baseline_replacement_synthetic_original_collision_label_counts", + ), + candidate_replacement_synthetic_original_collision_label_counts=count_columns( + row, + "candidate_replacement_synthetic_original_collision_label_counts", + ), + stable_lost_label_counts=preferred_count_columns( + row, + preferred_prefix="baseline_stable_candidate_uncovered_signature_label_counts", + fallback_prefix="baseline_stable_candidate_unstable_final_entity_signature_label_counts", + preferred_count_column="baseline_stable_candidate_uncovered_signature_count", + ), + ) + + +def comparison_evidence_level(row: pd.Series) -> str: + if optional_string(row.get("value_protection_verdict")) and optional_string(row.get("signature_parity_verdict")): + return "split_verdicts" + if ( + not is_missing(row.get("baseline_stable_final_entity_signature_count")) + or not is_missing(row.get("candidate_stable_final_entity_signature_count")) + or not is_missing(row.get("shared_stable_final_entity_signature_count")) + ): + return "stable_signatures" + if ( + not is_missing(row.get("baseline_only_final_entity_signature_count")) + or not is_missing(row.get("candidate_only_final_entity_signature_count")) + or not is_missing(row.get("shared_final_entity_signature_count")) + ): + return "signature_counts" + return "legacy" + + +def summarize_rows(rows: list[ScreenRow]) -> ScreenSummary: + candidate_counts = count_values(row.candidate_verdict for row in rows) + return ScreenSummary( + viable_count=candidate_counts.get("candidate_viable", 0), + review_count=candidate_counts.get("review", 0), + reject_count=candidate_counts.get("reject", 0), + candidate_verdict_counts=candidate_counts, + value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), + signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), + safety_verdict_counts=count_values(row.safety_verdict for row in rows), + performance_verdict_counts=count_values(row.performance_verdict for row in rows), + evidence_level_counts=count_values(row.evidence_level for row in rows), + ) + + +def summarize_groups( + rows: list[ScreenRow], + *, + group_by: GroupBy = GroupBy.strategy, + config_aliases: dict[str, str] | None = None, +) -> list[ScreenGroup]: + return [ + build_group(group_key, group_rows) + for group_key, group_rows in grouped_rows( + rows, + group_by=group_by, + config_aliases=config_aliases or {}, + ).items() + ] + + +def grouped_rows( + rows: list[ScreenRow], + *, + group_by: GroupBy, + config_aliases: dict[str, str], +) -> dict[str, list[ScreenRow]]: + groups: dict[str, list[ScreenRow]] = {} + for row in rows: + groups.setdefault(group_key_for_row(row, group_by=group_by, config_aliases=config_aliases), []).append(row) + return dict(sorted(groups.items())) + + +def group_key_for_row(row: ScreenRow, *, group_by: GroupBy, config_aliases: dict[str, str]) -> str: + base = group_base_for_row(row, config_aliases=config_aliases) + if group_by == GroupBy.strategy_workload_family: + return f"{base}|family:{row.workload_family}" + if group_by == GroupBy.strategy_workload: + return f"{base}|workload:{row.workload_id}" + return base + + +def group_base_for_row(row: ScreenRow, *, config_aliases: dict[str, str]) -> str: + replacement = _non_default_replacement_strategy(row) + if row.candidate_strategy and row.candidate_strategy != "default": + base = f"strategy:{row.candidate_strategy}" + return f"{base}|replacement:{replacement}" if replacement else base + if replacement: + return f"replacement:{replacement}" + if alias := config_aliases.get(row.candidate_config_id): + return f"alias:{alias}" + return f"config:{row.candidate_config_id}" + + +def _non_default_replacement_strategy(row: ScreenRow) -> str | None: + if row.candidate_replacement_strategy and row.candidate_replacement_strategy != "default": + return row.candidate_replacement_strategy + return None + + +def workload_family(workload_id: str) -> str: + parts = [part for part in workload_id.split("-") if part] + while parts and _is_workload_slice_suffix(parts[-1]): + parts.pop() + return "-".join(parts) if parts else workload_id + + +def _is_workload_slice_suffix(value: str) -> bool: + return ( + value.isdigit() + or value == "slice" + or (value.startswith("r") and value[1:].isdigit()) + or (value.startswith("offset") and value[len("offset") :].isdigit()) + ) + + +def build_group(group_key: str, rows: list[ScreenRow]) -> ScreenGroup: + verdict_counts = count_values(row.candidate_verdict for row in rows) + group = ScreenGroup( + group_key=group_key, + candidate_strategy=single_optional_value(row.candidate_strategy for row in rows), + candidate_replacement_strategy=single_optional_value(row.candidate_replacement_strategy for row in rows), + candidate_config_ids=sorted({row.candidate_config_id for row in rows}), + workload_ids=sorted({row.workload_id for row in rows}), + workload_families=sorted({row.workload_family or workload_family(row.workload_id) for row in rows}), + row_count=len(rows), + min_baseline_case_count=min_present_int(row.baseline_case_count for row in rows), + min_candidate_case_count=min_present_int(row.candidate_case_count for row in rows), + viable_count=verdict_counts.get("candidate_viable", 0), + review_count=verdict_counts.get("review", 0), + reject_count=verdict_counts.get("reject", 0), + has_conflicting_verdicts=len(verdict_counts) > 1, + value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), + signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), + performance_verdict_counts=count_values(row.performance_verdict for row in rows), + evidence_level_counts=count_values(row.evidence_level for row in rows), + split_verdict_candidate_verdict_counts=count_values( + row.candidate_verdict for row in rows if row.evidence_level == "split_verdicts" + ), + best_pipeline_elapsed_sec_delta_pct=min_present(row.pipeline_elapsed_sec_delta_pct for row in rows), + best_observed_total_tokens_delta=min_present(row.observed_total_tokens_delta for row in rows), + best_observed_total_requests_delta=min_present(row.observed_total_requests_delta for row in rows), + worst_pipeline_elapsed_sec_delta_pct=max_present(row.pipeline_elapsed_sec_delta_pct for row in rows), + worst_observed_total_tokens_delta=max_present(row.observed_total_tokens_delta for row in rows), + worst_observed_total_requests_delta=max_present(row.observed_total_requests_delta for row in rows), + min_shared_stable_final_entity_signature_count=min_present_int( + row.shared_stable_final_entity_signature_count for row in rows + ), + sum_baseline_replacement_missing_final_entity_count=sum_present( + row.baseline_replacement_missing_final_entity_count for row in rows + ), + sum_candidate_replacement_missing_final_entity_count=sum_present( + row.candidate_replacement_missing_final_entity_count for row in rows + ), + sum_baseline_original_value_leak_count=sum_present(row.baseline_original_value_leak_count for row in rows), + sum_candidate_original_value_leak_count=sum_present(row.candidate_original_value_leak_count for row in rows), + sum_baseline_original_value_leak_record_count=sum_present( + row.baseline_original_value_leak_record_count for row in rows + ), + sum_candidate_original_value_leak_record_count=sum_present( + row.candidate_original_value_leak_record_count for row in rows + ), + sum_baseline_replacement_synthetic_original_collision_count=sum_present( + row.baseline_replacement_synthetic_original_collision_count for row in rows + ), + sum_candidate_replacement_synthetic_original_collision_count=sum_present( + row.candidate_replacement_synthetic_original_collision_count for row in rows + ), + sum_baseline_replacement_synthetic_original_collision_value_count=sum_present( + row.baseline_replacement_synthetic_original_collision_value_count for row in rows + ), + sum_candidate_replacement_synthetic_original_collision_value_count=sum_present( + row.candidate_replacement_synthetic_original_collision_value_count for row in rows + ), + sum_baseline_duplicate_synthetic_replacement_count=sum_present( + row.baseline_duplicate_synthetic_replacement_count for row in rows + ), + sum_candidate_duplicate_synthetic_replacement_count=sum_present( + row.candidate_duplicate_synthetic_replacement_count for row in rows + ), + baseline_defect_improvement_count=sum(1 for row in rows if row_has_baseline_defect_improvement(row)), + flag_counts=sum_string_counts(row.flags for row in rows), + baseline_only_label_counts=sum_dict_counts(row.baseline_only_label_counts for row in rows), + label_mismatch_label_counts=sum_dict_counts(row.label_mismatch_label_counts for row in rows), + stable_lost_label_counts=sum_dict_counts(row.stable_lost_label_counts for row in rows), + baseline_replacement_missing_final_entity_label_counts=sum_dict_counts( + row.baseline_replacement_missing_final_entity_label_counts for row in rows + ), + candidate_replacement_missing_final_entity_label_counts=sum_dict_counts( + row.candidate_replacement_missing_final_entity_label_counts for row in rows + ), + baseline_original_value_leak_label_counts=sum_dict_counts( + row.baseline_original_value_leak_label_counts for row in rows + ), + candidate_original_value_leak_label_counts=sum_dict_counts( + row.candidate_original_value_leak_label_counts for row in rows + ), + baseline_replacement_synthetic_original_collision_label_counts=sum_dict_counts( + row.baseline_replacement_synthetic_original_collision_label_counts for row in rows + ), + candidate_replacement_synthetic_original_collision_label_counts=sum_dict_counts( + row.candidate_replacement_synthetic_original_collision_label_counts for row in rows + ), + ) + group.recommendation = group_recommendation(group) + return group + + +def group_recommendation(group: ScreenGroup) -> str: + if group.viable_count and group.reject_count: + return "conflicting_evidence" + if group_needs_split_verdict_rerun(group): + return "needs_split_verdict_rerun" + if group_needs_viable_split_verdict(group): + return "needs_viable_split_verdict" + if group.viable_count and group.review_count and group_has_baseline_defect_improvement_reviews(group): + return "promising_with_baseline_defect_improvements" + if group.viable_count and group.review_count: + return "promising_needs_review" + if group.viable_count: + if group.row_count == 1: + return "single_slice_viable" + return "candidate_family_viable" + if group.reject_count: + return "reject" + if group.review_count: + if is_baseline_defect_improvement_group(group): + return "candidate_covers_baseline_defects" + if is_replacement_replay_review_group(group): + return "replacement_replay_review" + if is_reliability_review_group(group): + return "reliability_review" + if is_label_policy_review_group(group): + return "label_policy_review" + if group.performance_verdict_counts.get("improved", 0) == group.review_count: + return "review_only" + if group.performance_verdict_counts.get("improved", 0) or group.performance_verdict_counts.get("mixed", 0): + return "review_mixed_performance" + return "no_performance_win" + return "unknown" + + +def group_needs_split_verdict_rerun(group: ScreenGroup) -> bool: + if not group.evidence_level_counts: + return False + if group.reject_count: + return False + split_verdict_count = group.evidence_level_counts.get("split_verdicts", 0) + if split_verdict_count == group.row_count: + return False + if group.viable_count and group.review_count: + return split_verdict_count == 0 + if group.review_count == group.row_count and split_verdict_count: + return True + return False + + +def group_needs_viable_split_verdict(group: ScreenGroup) -> bool: + if not (group.viable_count and group.review_count): + return False + if not group.evidence_level_counts.get("split_verdicts", 0): + return False + return not bool(group.split_verdict_candidate_verdict_counts.get("candidate_viable", 0)) + + +def is_replacement_replay_review_group(group: ScreenGroup) -> bool: + if not group.candidate_replacement_strategy or group.candidate_replacement_strategy == "default": + return False + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + return bool(group.flag_counts.get("replacement_only_detection_instability", 0)) + + +_BASELINE_DEFECT_IMPROVEMENT_FLAGS = { + "candidate_covers_baseline_original_value_leak", + "candidate_covers_baseline_replacement_missing_final_entity", + "candidate_covers_baseline_replacement_synthetic_original_collision", +} + + +def row_has_baseline_defect_improvement(row: ScreenRow) -> bool: + return bool(set(row.flags) & _BASELINE_DEFECT_IMPROVEMENT_FLAGS) + + +def group_has_baseline_defect_improvement_reviews(group: ScreenGroup) -> bool: + if not group.review_count: + return False + return group.baseline_defect_improvement_count == group.review_count + + +def is_baseline_defect_improvement_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: + return False + if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: + return False + return group.baseline_defect_improvement_count == group.review_count + + +_RELIABILITY_REVIEW_FLAGS = {"failed_request_increase", "bridge_fallback_increase"} + + +def is_reliability_review_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + return any(group.flag_counts.get(flag, 0) > 0 for flag in _RELIABILITY_REVIEW_FLAGS) + + +def is_label_policy_review_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: + return False + if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: + return False + return bool(group.label_mismatch_label_counts or group.flag_counts.get("covered_label_mismatch")) + + +def group_performance_summary(group: ScreenGroup) -> str: + if group.performance_verdict_counts: + return label_summary(group.performance_verdict_counts) + return "unknown" + + +def single_optional_value(values: object) -> str | None: + unique = sorted({str(value) for value in values if value is not None}) + return unique[0] if len(unique) == 1 else None + + +def min_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return min(present) if present else None + + +def min_present_int(values: object) -> int | None: + value = min_present(values) + return int(value) if value is not None else None + + +def max_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return max(present) if present else None + + +def sum_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return sum(present) if present else None + + +def sum_string_counts(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for items in values: + for item in items: + counts[str(item)] = counts.get(str(item), 0) + 1 + return dict(sorted(counts.items())) + + +def sum_dict_counts(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for mapping in values: + for key, value in mapping.items(): + counts[str(key)] = counts.get(str(key), 0) + int(value) + return dict(sorted(counts.items())) + + +def group_sort_key(group: ScreenGroup) -> tuple[int, int, float, float, str]: + conflict_rank = 1 if group.has_conflicting_verdicts else 0 + verdict_rank = 0 if group.viable_count else 1 if group.review_count else 2 + elapsed = group.best_pipeline_elapsed_sec_delta_pct + tokens = group.best_observed_total_tokens_delta + return ( + conflict_rank, + verdict_rank, + elapsed if elapsed is not None else float("inf"), + tokens if tokens is not None else float("inf"), + group.group_key, + ) + + +def dedupe_rows(rows: list[ScreenRow]) -> list[ScreenRow]: + deduped: dict[str, ScreenRow] = {} + for row in rows: + payload = row.model_dump() + payload.pop("source_path", None) + key = json.dumps(payload, ensure_ascii=True, sort_keys=True) + deduped.setdefault(key, row) + return list(deduped.values()) + + +def count_values(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for value in values: + key = str(value) + counts[key] = counts.get(key, 0) + 1 + return dict(sorted(counts.items())) + + +def count_present_values(values: object) -> dict[str, int]: + return count_values(value for value in values if value is not None) + + +def screen_sort_key(row: ScreenRow) -> tuple[int, float, float, str, str]: + verdict_rank = {"candidate_viable": 0, "review": 1, "reject": 2}.get(row.candidate_verdict, 3) + elapsed_delta = row.pipeline_elapsed_sec_delta_pct + token_delta = row.observed_total_tokens_delta + return ( + verdict_rank, + elapsed_delta if elapsed_delta is not None else float("inf"), + token_delta if token_delta is not None else float("inf"), + row.workload_id, + row.candidate_config_id, + ) + + +def required_string(row: pd.Series, column: str) -> str: + value = optional_string(row.get(column)) + if value is None: + raise ValueError(f"comparison row missing required value: {column}") + return value + + +def optional_string(value: object) -> str | None: + if is_missing(value): + return None + return str(value) + + +def optional_float(value: object) -> float | None: + if is_missing(value): + return None + return float(value) + + +def optional_int(value: object) -> int | None: + value = optional_float(value) + return int(value) if value is not None else None + + +def is_missing(value: object) -> bool: + return value is None or (isinstance(value, float) and pd.isna(value)) + + +def parse_flags(value: object) -> list[str]: + if is_missing(value): + return [] + if isinstance(value, list): + return [str(item) for item in value] + if not isinstance(value, str): + return [str(value)] + parsed = parse_nested(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + return [value] if value else [] + + +def parse_nested(value: str) -> object: + try: + return json.loads(value) + except json.JSONDecodeError: + try: + return ast.literal_eval(value) + except (SyntaxError, ValueError): + return value + + +def count_columns(row: pd.Series, prefix: str) -> dict[str, int]: + label_counts: dict[str, int] = {} + column_prefix = f"{prefix}." + for column, raw_value in row.items(): + if not str(column).startswith(column_prefix): + continue + value = optional_float(raw_value) + if value is not None and value > 0: + label_counts[str(column)[len(column_prefix) :]] = int(value) + return dict(sorted(label_counts.items())) + + +def preferred_count_columns( + row: pd.Series, + *, + preferred_prefix: str, + fallback_prefix: str, + preferred_count_column: str, +) -> dict[str, int]: + if has_comparison_field(row, preferred_prefix) or not is_missing(row.get(preferred_count_column)): + return count_columns(row, preferred_prefix) + return count_columns(row, fallback_prefix) + + +def has_comparison_field(row: pd.Series, prefix: str) -> bool: + column_prefix = f"{prefix}." + return any(str(column).startswith(column_prefix) for column in row.index) + + +def write_rows(rows: list[ScreenRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + table = normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def write_groups(groups: list[ScreenGroup], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.DataFrame([group.model_dump() for group in groups]) + table = normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: + normalized = table.copy() + for column in normalized.columns: + if normalized[column].map(is_nested_cell).any(): + normalized[column] = normalized[column].map(json_cell) + return normalized + + +def is_nested_cell(value: object) -> bool: + return isinstance(value, dict | list) + + +def json_cell(value: object) -> object: + if not is_nested_cell(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def render_result(result: ScreenResult, *, json_output: bool, limit: int) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Screened {result.row_count} comparison row(s) from " + f"{result.comparison_file_count}/{result.scanned_file_count} CSV file(s): " + f"viable={result.summary.viable_count}, review={result.summary.review_count}, " + f"reject={result.summary.reject_count}, duplicates_skipped={result.duplicate_row_count}", + ] + for row in result.rows[:limit]: + lines.append(render_row(row)) + if len(result.rows) > limit: + lines.append(f"... {len(result.rows) - limit} more row(s)") + lines.append("Candidate groups:") + for group in result.groups[:limit]: + lines.append(render_group(group)) + if len(result.groups) > limit: + lines.append(f"... {len(result.groups) - limit} more group(s)") + return "\n".join(lines) + + +def render_row(row: ScreenRow) -> str: + return ( + f"- {row.workload_id}: {row.baseline_config_id}->{row.candidate_config_id} " + f"verdict={row.candidate_verdict} safety={row.safety_verdict} " + f"value_protection={row.value_protection_verdict or 'unknown'} " + f"signature_parity={row.signature_parity_verdict or 'unknown'} " + f"evidence={row.evidence_level} " + f"perf={row.performance_verdict} elapsed_delta={format_number(row.pipeline_elapsed_sec_delta_pct, '%')} " + f"replacement={row.baseline_replacement_strategy or 'unknown'}" + f"->{row.candidate_replacement_strategy or 'unknown'} " + f"tokens_delta={format_number(row.observed_total_tokens_delta)} " + f"cases={format_count(row.baseline_case_count)}/{format_count(row.candidate_case_count)} " + f"shared_stable={format_count(row.shared_stable_final_entity_signature_count)} " + f"aug_new_final_delta={format_number(row.augmented_new_final_value_count_delta)} " + f"baseline_missing_replacements={format_number(row.baseline_replacement_missing_final_entity_count)} " + f"candidate_missing_replacements={format_number(row.candidate_replacement_missing_final_entity_count)} " + f"baseline_original_value_leaks={format_number(row.baseline_original_value_leak_count)} " + f"candidate_original_value_leaks={format_number(row.candidate_original_value_leak_count)} " + f"baseline_replacement_collisions=" + f"{format_number(row.baseline_replacement_synthetic_original_collision_count)} " + f"candidate_replacement_collisions=" + f"{format_number(row.candidate_replacement_synthetic_original_collision_count)} " + f"baseline_duplicate_synthetics={format_number(row.baseline_duplicate_synthetic_replacement_count)} " + f"candidate_duplicate_synthetics={format_number(row.candidate_duplicate_synthetic_replacement_count)} " + f"lost={label_summary(row.baseline_only_label_counts)} " + f"label_mismatch={label_summary(row.label_mismatch_label_counts)} " + f"stable_lost={label_summary(row.stable_lost_label_counts)} " + f"baseline_missing_labels={label_summary(row.baseline_replacement_missing_final_entity_label_counts)} " + f"candidate_missing_labels={label_summary(row.candidate_replacement_missing_final_entity_label_counts)} " + f"baseline_leak_labels={label_summary(row.baseline_original_value_leak_label_counts)} " + f"leak_labels={label_summary(row.candidate_original_value_leak_label_counts)} " + f"baseline_collision_labels=" + f"{label_summary(row.baseline_replacement_synthetic_original_collision_label_counts)} " + f"collision_labels=" + f"{label_summary(row.candidate_replacement_synthetic_original_collision_label_counts)} " + f"flags={','.join(row.flags) if row.flags else 'none'}" + ) + + +def render_group(group: ScreenGroup) -> str: + return ( + f"- {group.group_key}: rows={group.row_count} viable={group.viable_count} " + f"review={group.review_count} reject={group.reject_count} " + f"conflict={str(group.has_conflicting_verdicts).lower()} recommendation={group.recommendation} " + f"replacement={group.candidate_replacement_strategy or 'unknown'} " + f"value_protection_counts={label_summary(group.value_protection_verdict_counts)} " + f"signature_parity_counts={label_summary(group.signature_parity_verdict_counts)} " + f"evidence_counts={label_summary(group.evidence_level_counts)} " + f"split_verdict_candidate_counts={label_summary(group.split_verdict_candidate_verdict_counts)} " + f"perf_counts={group_performance_summary(group)} " + f"best_elapsed_delta={format_number(group.best_pipeline_elapsed_sec_delta_pct, '%')} " + f"worst_elapsed_delta={format_number(group.worst_pipeline_elapsed_sec_delta_pct, '%')} " + f"best_tokens_delta={format_number(group.best_observed_total_tokens_delta)} " + f"worst_tokens_delta={format_number(group.worst_observed_total_tokens_delta)} " + f"best_requests_delta={format_number(group.best_observed_total_requests_delta)} " + f"worst_requests_delta={format_number(group.worst_observed_total_requests_delta)} " + f"min_cases={format_count(group.min_baseline_case_count)}/{format_count(group.min_candidate_case_count)} " + f"min_shared_stable={format_count(group.min_shared_stable_final_entity_signature_count)} " + f"baseline_defect_improvements={group.baseline_defect_improvement_count} " + f"baseline_missing_replacements=" + f"{format_number(group.sum_baseline_replacement_missing_final_entity_count)} " + f"candidate_missing_replacements=" + f"{format_number(group.sum_candidate_replacement_missing_final_entity_count)} " + f"baseline_original_value_leaks={format_number(group.sum_baseline_original_value_leak_count)} " + f"candidate_original_value_leaks={format_number(group.sum_candidate_original_value_leak_count)} " + f"baseline_replacement_collisions=" + f"{format_number(group.sum_baseline_replacement_synthetic_original_collision_count)} " + f"candidate_replacement_collisions=" + f"{format_number(group.sum_candidate_replacement_synthetic_original_collision_count)} " + f"baseline_duplicate_synthetics={format_number(group.sum_baseline_duplicate_synthetic_replacement_count)} " + f"candidate_duplicate_synthetics={format_number(group.sum_candidate_duplicate_synthetic_replacement_count)} " + f"lost={label_summary(group.baseline_only_label_counts)} " + f"label_mismatch={label_summary(group.label_mismatch_label_counts)} " + f"stable_lost={label_summary(group.stable_lost_label_counts)} " + f"baseline_missing_labels={label_summary(group.baseline_replacement_missing_final_entity_label_counts)} " + f"candidate_missing_labels={label_summary(group.candidate_replacement_missing_final_entity_label_counts)} " + f"baseline_leak_labels={label_summary(group.baseline_original_value_leak_label_counts)} " + f"leak_labels={label_summary(group.candidate_original_value_leak_label_counts)} " + f"baseline_collision_labels=" + f"{label_summary(group.baseline_replacement_synthetic_original_collision_label_counts)} " + f"collision_labels=" + f"{label_summary(group.candidate_replacement_synthetic_original_collision_label_counts)} " + f"flags={label_summary(group.flag_counts)}" + ) + + +def format_number(value: float | None, suffix: str = "") -> str: + if value is None: + return "unknown" + return f"{value:.1f}{suffix}" + + +def format_count(value: int | None) -> str: + return str(value) if value is not None else "unknown" + + +def label_summary(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ",".join(f"{label}:{count}" for label, count in counts.items()) + + +@app.default +def main( + comparison_paths: list[Path], + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + group_output: Annotated[Path | None, cyclopts.Parameter("--group-output")] = None, + group_by: Annotated[GroupBy, cyclopts.Parameter("--group-by")] = GroupBy.strategy, + config_aliases: Annotated[Path | None, cyclopts.Parameter("--config-aliases")] = None, + source_include: Annotated[list[str] | None, cyclopts.Parameter("--source-include")] = None, + source_exclude: Annotated[list[str] | None, cyclopts.Parameter("--source-exclude")] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + limit: Annotated[int, cyclopts.Parameter("--limit")] = 20, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = screen_comparison_paths( + comparison_paths, + group_by=group_by, + config_aliases=read_config_aliases(config_aliases), + source_includes=source_include or [], + source_excludes=source_exclude or [], + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + if output is not None: + write_rows(result.rows, output, format) + if group_output is not None: + write_groups(result.groups, group_output, format) + sys.stdout.write(render_result(result, json_output=json_output, limit=limit) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py new file mode 100644 index 00000000..ddb1a5d0 --- /dev/null +++ b/tools/measurement/staged_detection_probe.py @@ -0,0 +1,1390 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run benchmark-only DD-free staged entity detection probes. + +Usage: + uv run python tools/measurement/staged_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography --labels age,city,first_name,last_name,occupation \ + --output /tmp/staged-probe --overwrite +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +import sys +import time +from collections import Counter +from dataclasses import dataclass, replace +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any, Protocol + +import cyclopts +import httpx +import pandas as pd +from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities +from dd_parser_compat import _load_embedded_json +from direct_detection_probe import ( + CaseStatus, + DirectCompletion, + DirectDetectionClient, + DirectDetectionRequest, + DirectGenerationRequest, + HttpxDirectDetectionClient, + LogFormat, + PromptMode, + SignatureComparison, + build_direct_prompt, + compare_signature_sets, + parse_labels, +) +from pydantic import BaseModel, Field, ValidationError, model_validator + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, + COL_VALIDATED_ENTITIES, + COL_VALIDATION_CANDIDATES, + COL_VALIDATION_DECISIONS, +) +from anonymizer.engine.detection.chunked_validation import ( + build_chunk_excerpt, + chunk_candidates, + merge_chunk_decisions, + order_candidates_by_position, +) +from anonymizer.engine.detection.custom_columns import ( + apply_validation_and_finalize, + apply_validation_to_seed_entities, + enrich_validation_decisions, + merge_and_build_candidates, + prepare_validation_inputs, +) +from anonymizer.engine.detection.postprocess import ( + VALIDATION_CONTEXT_WINDOW, + EntitySpan, + TagNotation, + apply_augmented_entities, + build_tagged_text, + get_tag_notation, + parse_raw_entities, +) +from anonymizer.engine.schemas import ( + EntitiesSchema, + EntitySchema, + RawValidationDecisionsSchema, + ValidatedDecisionsSchema, + ValidationCandidatesSchema, +) + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.staged_detection_probe") + +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" +_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" +_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" +_UNCONFIGURED_MODEL = "configured-native-model" +_UNCONFIGURED_GLINER_ENDPOINT = "configured-gliner-endpoint" +_UNCONFIGURED_GLINER_MODEL = "configured-gliner-model" +_log_format = LogFormat.plain +_DATE_OF_BIRTH_CONTEXT_RE = re.compile(r"\b(born|birth|date of birth|dob)\b", re.IGNORECASE) + + +class SeedSource(StrEnum): + direct_llm = "direct_llm" + gliner = "gliner" + + +class ValidationPromptMode(StrEnum): + full_text = "full_text" + chunked_excerpt = "chunked_excerpt" + + +class GlinerDetectionRequest(BaseModel): + endpoint: str + model: str + text: str + labels: list[str] = Field(min_length=1) + threshold: float = Field(default=0.3, ge=0.0, le=1.0) + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=120.0, gt=0) + api_key_env: str = "NVIDIA_API_KEY" + + +class GlinerSeedClient(Protocol): + def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: ... + + +class StagedExecutionConfig(BaseModel): + endpoint: str = _UNCONFIGURED_ENDPOINT + model: str = _UNCONFIGURED_MODEL + seed_source: SeedSource = SeedSource.direct_llm + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT + gliner_model: str = _UNCONFIGURED_GLINER_MODEL + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=180.0, gt=0) + skip_augmentation: bool = False + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text + validation_max_entities_per_call: int = Field(default=10, gt=0) + validation_excerpt_window_chars: int = Field(default=160, gt=0) + + +class HttpxGlinerSeedClient: + def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: + api_key = os.environ.get(request.api_key_env) + if not api_key: + raise ValueError(f"{request.api_key_env} is not set") + response = httpx.post( + f"{request.endpoint.rstrip('/')}/chat/completions", + headers={"Authorization": f"Bearer {api_key}"}, + json=_gliner_payload(request), + timeout=request.timeout_sec, + ) + response.raise_for_status() + data = response.json() + return DirectCompletion( + content=_completion_content(data), + elapsed_sec=float(response.elapsed.total_seconds()), + usage=data.get("usage") or {}, + ) + + +def _gliner_payload(request: GlinerDetectionRequest) -> dict[str, Any]: + return { + "model": request.model, + "messages": [{"role": "user", "content": request.text}], + "temperature": 0, + "max_tokens": request.max_tokens, + "labels": request.labels, + "threshold": request.threshold, + "chunk_length": 384, + "overlap": 128, + "flat_ner": False, + } + + +def _completion_content(data: dict[str, Any]) -> str: + choice = (data.get("choices") or [{}])[0] + message = choice.get("message") if isinstance(choice, dict) else {} + if isinstance(message, dict): + return str(message.get("content") or "") + return str(choice.get("text") or "") if isinstance(choice, dict) else "" + + +class StagedDetectionRequest(BaseModel): + case_id: str + text: str + labels: list[str] = Field(min_length=1) + row_index: int = 0 + data_summary: str | None = None + + @model_validator(mode="after") + def normalize_labels(self) -> StagedDetectionRequest: + self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) + if not self.labels: + raise ValueError("labels must contain at least one non-empty label") + return self + + +class PhaseUsage(BaseModel): + seed: dict[str, Any] = Field(default_factory=dict) + validation: dict[str, Any] = Field(default_factory=dict) + augmentation: dict[str, Any] = Field(default_factory=dict) + + +class PhaseModelWork(BaseModel): + seed: bool = False + validation: bool = False + augmentation: bool = False + + +class PhaseSkipReasons(BaseModel): + seed: str | None = None + validation: str | None = None + augmentation: str | None = None + + +class PhaseModelRequests(BaseModel): + seed: int = 0 + validation: int = 0 + augmentation: int = 0 + + +class StagedDetectionCase(BaseModel): + case_id: str + row_index: int + seed_source: SeedSource = SeedSource.direct_llm + status: CaseStatus + elapsed_sec: float | None = None + model_elapsed_sec: float | None = None + phase_usage: PhaseUsage = Field(default_factory=PhaseUsage) + phase_model_work: PhaseModelWork = Field(default_factory=PhaseModelWork) + phase_skip_reasons: PhaseSkipReasons = Field(default_factory=PhaseSkipReasons) + phase_model_requests: PhaseModelRequests = Field(default_factory=PhaseModelRequests) + total_usage: dict[str, int] = Field(default_factory=dict) + model_phase_count: int = 0 + model_request_count: int = 0 + seed_suggestion_count: int = 0 + seed_entity_count: int = 0 + validation_candidate_count: int = 0 + validation_decision_count: int = 0 + augmented_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_label_counts: dict[str, int] = Field(default_factory=dict) + comparison: SignatureComparison | None = None + artifact: DetectionArtifactRow | None = None + error: str | None = None + + +class StagedDetectionRun(BaseModel): + input_path: str + text_column: str + endpoint: str + model: str + labels: list[str] + rows: list[StagedDetectionCase] = Field(default_factory=list) + + @property + def error_count(self) -> int: + return sum(1 for row in self.rows if row.status == CaseStatus.error) + + +@dataclass(frozen=True) +class StagedDetectionExecution: + case: StagedDetectionCase + row: dict[str, Any] + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + return + logger.error("bad_input error=%s", error) + + +def run_staged_detection_case( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None = None, + seed_source: SeedSource = SeedSource.direct_llm, + endpoint: str = _UNCONFIGURED_ENDPOINT, + model: str = _UNCONFIGURED_MODEL, + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, + gliner_model: str = _UNCONFIGURED_GLINER_MODEL, + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + max_tokens: int = 4096, + timeout_sec: float = 180.0, + skip_augmentation: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, +) -> StagedDetectionCase: + return execute_staged_detection_case( + request, + client=client, + seed_client=seed_client, + seed_source=seed_source, + endpoint=endpoint, + model=model, + gliner_endpoint=gliner_endpoint, + gliner_model=gliner_model, + gliner_api_key_env=gliner_api_key_env, + gliner_threshold=gliner_threshold, + max_tokens=max_tokens, + timeout_sec=timeout_sec, + skip_augmentation=skip_augmentation, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ).case + + +def execute_staged_detection_case( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None = None, + seed_source: SeedSource = SeedSource.direct_llm, + endpoint: str = _UNCONFIGURED_ENDPOINT, + model: str = _UNCONFIGURED_MODEL, + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, + gliner_model: str = _UNCONFIGURED_GLINER_MODEL, + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + max_tokens: int = 4096, + timeout_sec: float = 180.0, + skip_augmentation: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, +) -> StagedDetectionExecution: + config = _execution_config_from_params(locals()) + try: + return _run_staged_detection_execution( + request, + client=client, + seed_client=seed_client, + config=config, + ) + except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures + return StagedDetectionExecution(case=_errored_case(request, config.seed_source, exc), row={}) + + +def _execution_config_from_params(params: dict[str, Any]) -> StagedExecutionConfig: + return StagedExecutionConfig( + **{key: value for key, value in params.items() if key in StagedExecutionConfig.model_fields} + ) + + +def _errored_case(request: StagedDetectionRequest, seed_source: SeedSource, exc: Exception) -> StagedDetectionCase: + return StagedDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + seed_source=seed_source, + status=CaseStatus.error, + error=f"{type(exc).__name__}: {exc}", + ) + + +def _run_staged_detection_execution( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None, + config: StagedExecutionConfig, +) -> StagedDetectionExecution: + started = time.perf_counter() + row, seed_suggestion_count, seed_completion = _run_seed_phase( + request, + client=client, + seed_client=seed_client, + config=config, + ) + validation_completion = _run_validation_phase(row, request, client, config) + augmentation_completion = _run_augmentation_phase(row, request, client, config) + artifact = _finalize_row(row, request) + return StagedDetectionExecution( + case=_completed_case( + request, + seed_completion, + validation_completion, + augmentation_completion, + artifact, + row, + config=config, + seed_suggestion_count=seed_suggestion_count, + elapsed_sec=time.perf_counter() - started, + ), + row=row, + ) + + +def _run_seed_phase( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + if config.seed_source == SeedSource.gliner: + return _run_gliner_seed_phase(request, seed_client or HttpxGlinerSeedClient(), config) + return _run_direct_llm_seed_phase(request, client, config) + + +def _run_gliner_seed_phase( + request: StagedDetectionRequest, + detector: GlinerSeedClient, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + completion = detector.detect( + GlinerDetectionRequest( + endpoint=config.gliner_endpoint, + model=config.gliner_model, + text=request.text, + labels=request.labels, + threshold=config.gliner_threshold, + max_tokens=config.max_tokens, + timeout_sec=config.timeout_sec, + api_key_env=config.gliner_api_key_env, + ) + ) + seed_spans = parse_raw_entities(raw_response=completion.content, text=request.text) + return _seed_row_from_spans(request, seed_spans), _raw_detector_entity_count(completion.content), completion + + +def _run_direct_llm_seed_phase( + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + completion = _complete( + client, + prompt=_seed_prompt(request), + config=config, + ) + row, seed_suggestions = _seed_row_from_llm(request, completion.content) + return row, len(seed_suggestions), completion + + +def _complete( + client: DirectDetectionClient, + *, + prompt: str, + config: StagedExecutionConfig, +) -> DirectCompletion: + return client.complete( + DirectGenerationRequest( + endpoint=config.endpoint, + model=config.model, + prompt=prompt, + max_tokens=config.max_tokens, + timeout_sec=config.timeout_sec, + ) + ) + + +def _seed_prompt(request: StagedDetectionRequest) -> str: + return build_direct_prompt(_direct_seed_request(request)) + + +def _direct_seed_request(request: StagedDetectionRequest) -> DirectDetectionRequest: + return DirectDetectionRequest( + case_id=request.case_id, + text=request.text, + labels=request.labels, + row_index=request.row_index, + prompt_mode=PromptMode.compact, + data_summary=request.data_summary, + ) + + +def _seed_row_from_llm(request: StagedDetectionRequest, content: str) -> tuple[dict[str, Any], list[dict[str, Any]]]: + seed_spans, suggestions = _direct_seed_spans(request, content) + return _seed_row_from_spans(request, seed_spans), suggestions + + +def _direct_seed_spans(request: StagedDetectionRequest, content: str) -> tuple[list[EntitySpan], list[dict[str, Any]]]: + suggestions = _extract_entity_suggestions(content) + seed_spans = _spans_from_suggestions(request.text, suggestions, labels=request.labels, source="direct_seed") + return seed_spans, suggestions + + +def _seed_row_from_spans(request: StagedDetectionRequest, seed_spans: list[EntitySpan]) -> dict[str, Any]: + allowed = set(request.labels) + seed_spans = [_normalize_seed_span(request.text, span, allowed_labels=allowed) for span in seed_spans] + seed_spans = [span for span in seed_spans if span.label in allowed] + row = {COL_TEXT: request.text, COL_TAG_NOTATION: get_tag_notation(text=request.text)} + row[COL_RAW_DETECTED] = "" + row[COL_SEED_ENTITIES] = EntitiesSchema(entities=[_span_to_entity_schema(span) for span in seed_spans]).model_dump( + mode="json" + ) + prepare_validation_inputs(row) + return row + + +def _normalize_seed_span(text: str, span: EntitySpan, *, allowed_labels: set[str]) -> EntitySpan: + if _should_promote_birth_context_date_seed(text, span, allowed_labels=allowed_labels): + return replace(span, entity_id=_seed_entity_id("date_of_birth", span), label="date_of_birth") + return span + + +def _should_promote_birth_context_date_seed( + text: str, + span: EntitySpan, + *, + allowed_labels: set[str], +) -> bool: + return span.label == "date" and "date_of_birth" in allowed_labels and _span_has_date_of_birth_context(text, span) + + +def _span_has_date_of_birth_context(text: str, span: EntitySpan) -> bool: + before_start = max(0, span.start_position - VALIDATION_CONTEXT_WINDOW) + after_end = min(len(text), span.end_position + VALIDATION_CONTEXT_WINDOW) + return _contains_date_of_birth_context( + text[before_start : span.start_position], span.value, text[span.end_position : after_end] + ) + + +def _seed_entity_id(label: str, span: EntitySpan) -> str: + return f"{label}_{span.start_position}_{span.end_position}" + + +def _run_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) + if not candidates.candidates: + row[COL_VALIDATION_DECISIONS] = {"decisions": []} + row[COL_VALIDATED_ENTITIES] = {"decisions": []} + apply_validation_to_seed_entities(row) + return DirectCompletion(content='{"decisions": []}', elapsed_sec=0.0, usage={}) + if config.validation_prompt_mode == ValidationPromptMode.chunked_excerpt: + completion = _run_chunked_validation_phase(row, request, candidates, client, config) + else: + completion = _run_full_text_validation_phase(row, request, candidates, client, config) + row[COL_VALIDATION_DECISIONS] = _extract_validation_decisions( + completion.content, + candidates=candidates, + labels=request.labels, + ) + enrich_validation_decisions(row) + apply_validation_to_seed_entities(row) + return completion + + +def _run_full_text_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + completion = _complete( + client, + prompt=_validation_prompt(request, candidates), + config=config, + ) + return completion + + +def _run_chunked_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + completions: list[DirectCompletion] = [] + chunk_results: list[RawValidationDecisionsSchema] = [] + for chunk in _validation_chunks(row, candidates, config): + completion = _complete(client, prompt=_chunk_validation_prompt(request, chunk), config=config) + completions.append(completion) + chunk_results.append( + RawValidationDecisionsSchema.from_raw( + _extract_validation_decisions(completion.content, candidates=chunk[1], labels=request.labels) + ) + ) + decisions = merge_chunk_decisions(chunk_results, candidates) + return _combine_completions(completions, content=json.dumps(decisions, ensure_ascii=True, sort_keys=True)) + + +def _validation_chunks( + row: dict[str, Any], + candidates: ValidationCandidatesSchema, + config: StagedExecutionConfig, +) -> list[tuple[str, ValidationCandidatesSchema]]: + seed_spans = _seed_entity_spans(row) + ordered = order_candidates_by_position(candidates, seed_spans) + return [ + (_validation_excerpt(row, chunk, seed_spans, config), _chunk_candidate_schema(chunk)) + for chunk in chunk_candidates(ordered, config.validation_max_entities_per_call) + ] + + +def _validation_excerpt( + row: dict[str, Any], + chunk: list[tuple[Any, EntitySpan]], + seed_spans: list[EntitySpan], + config: StagedExecutionConfig, +) -> str: + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + return build_chunk_excerpt( + text=str(row.get(COL_TEXT, "")), + chunk_spans=[span for _candidate, span in chunk], + all_spans=seed_spans, + window_chars=config.validation_excerpt_window_chars, + notation=notation, + ) + + +def _chunk_candidate_schema(chunk: list[tuple[Any, EntitySpan]]) -> ValidationCandidatesSchema: + return ValidationCandidatesSchema(candidates=[candidate for candidate, _span in chunk]) + + +def _chunk_validation_prompt( + request: StagedDetectionRequest, + chunk: tuple[str, ValidationCandidatesSchema], +) -> str: + excerpt, candidates = chunk + return _validation_prompt(request, candidates, input_text=excerpt) + + +def _seed_entity_spans(row: dict[str, Any]) -> list[EntitySpan]: + return [ + EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + for entity in EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities + ] + + +def _combine_completions(completions: list[DirectCompletion], *, content: str) -> DirectCompletion: + return DirectCompletion( + content=content, + elapsed_sec=sum(completion.elapsed_sec for completion in completions), + usage=_sum_completion_usage(completions), + ) + + +def _sum_completion_usage(completions: list[DirectCompletion]) -> dict[str, int]: + totals: Counter[str] = Counter() + for completion in completions: + for key, value in completion.usage.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def _validation_prompt( + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + *, + input_text: str | None = None, +) -> str: + text_for_prompt = input_text if input_text is not None else request.text + label_guidance = _validation_label_guidance(request.labels) + return f"""Validate candidate privacy-sensitive entities. +Use only these decisions: keep, reclass, drop. +Use only these labels when reclassifying: {", ".join(request.labels)}. +Prefer keep when the candidate already has the right specific label. For example, reclass date_of_birth to date only when the surrounding context is not birth-related. +{label_guidance} +Return JSON only with this shape: +{{"decisions": [{{"id": "candidate id", "decision": "keep|reclass|drop", "proposed_label": "", "reason": "short reason"}}]}} + +Context text: +--- +{text_for_prompt} +--- + +Candidates: +{candidates.model_dump_json()} +""" + + +def _validation_label_guidance(labels: list[str]) -> str: + allowed = set(labels) + lines: list[str] = [] + if "degree" in allowed: + lines.append( + "Prefer degree for credential names and degree abbreviations such as Bachelor of Science, BSc, BA, MA, MSc, PhD, or JD; use education_level for broad levels such as undergraduate, graduate, or high school." + ) + if not lines: + return "" + return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) + + +def _run_augmentation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + if _should_skip_augmentation(request, config): + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + return DirectCompletion(content="", elapsed_sec=0.0, usage={}) + completion = _complete( + client, + prompt=_augmentation_prompt(request, row), + config=config, + ) + row[COL_AUGMENTED_ENTITIES] = { + "entities": _allowed_entity_suggestions(_extract_entity_suggestions(completion.content), labels=request.labels) + } + return completion + + +def _should_skip_augmentation(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: + return config.skip_augmentation + + +def _augmentation_prompt(request: StagedDetectionRequest, row: dict[str, Any]) -> str: + seed_entities = row.get(COL_SEED_ENTITIES_JSON, "[]") + label_guidance = _augmentation_label_guidance(request.labels) + return f"""Find additional sensitive entities not already covered by the validated seed entities. +Use only these labels: {", ".join(request.labels)}. +Return exact substrings only. Do not invent values. +{label_guidance} +Return JSON only with this shape: +{{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} + +Input text: +--- +{request.text} +--- + +Validated seed entities: +{seed_entities} +""" + + +def _augmentation_label_guidance(labels: list[str]) -> str: + allowed = set(labels) + lines: list[str] = [] + if "first_name" in allowed: + lines.append( + "For family/member lists, split personal names connected by 'and' into separate first_name or last_name entities. Do not label a list of people as organization_name." + ) + if "last_name" in allowed and ({"place_name", "company_name", "organization_name"} & allowed): + lines.append( + "If a surname appears inside a street, place, organization, company, or possessive business phrase, also return the surname substring as last_name when it is not already covered." + ) + if not lines: + return "" + return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) + + +def _finalize_row(row: dict[str, Any], request: StagedDetectionRequest) -> DetectionArtifactRow: + merge_and_build_candidates(row) + apply_validation_and_finalize(row) + normalize_birth_context_final_entities(row, allowed_labels=request.labels) + seed_entities = EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities + final_entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities + augmented = _augmented_entity_schemas(row, request) + return build_detection_artifact_row_from_entities( + workflow_name="staged-direct-detection", + batch_file="staged-direct-detection", + row_index=request.row_index, + seed_entities=list(seed_entities), + seed_validation_candidate_count=_candidate_count(row.get(COL_SEED_VALIDATION_CANDIDATES)), + merged_validation_candidate_count=_candidate_count(row.get(COL_VALIDATION_CANDIDATES)), + augmented_entities=augmented, + final_entities=list(final_entities), + ) + + +def _completed_case( + request: StagedDetectionRequest, + seed: DirectCompletion, + validation: DirectCompletion, + augmentation: DirectCompletion, + artifact: DetectionArtifactRow, + row: dict[str, Any], + config: StagedExecutionConfig, + seed_suggestion_count: int, + elapsed_sec: float, +) -> StagedDetectionCase: + phase_usage = PhaseUsage(seed=seed.usage, validation=validation.usage, augmentation=augmentation.usage) + phase_model_work = _phase_model_work(request, artifact, config) + phase_model_requests = _phase_model_requests(artifact, phase_model_work, config) + model_elapsed_sec = seed.elapsed_sec + validation.elapsed_sec + augmentation.elapsed_sec + return StagedDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + seed_source=config.seed_source, + status=CaseStatus.completed, + elapsed_sec=elapsed_sec, + model_elapsed_sec=model_elapsed_sec, + phase_usage=phase_usage, + phase_model_work=phase_model_work, + phase_skip_reasons=_phase_skip_reasons(request, artifact, config), + phase_model_requests=phase_model_requests, + total_usage=_sum_usage(phase_usage), + model_phase_count=_model_phase_count(phase_model_work), + model_request_count=_model_request_count(phase_model_requests), + seed_suggestion_count=seed_suggestion_count, + seed_entity_count=artifact.seed_entity_count, + validation_candidate_count=artifact.seed_validation_candidate_count, + validation_decision_count=_validated_decision_count(row.get(COL_VALIDATED_ENTITIES)), + augmented_suggestion_count=artifact.augmented_entity_count, + final_entity_count=artifact.final_entity_count, + final_entity_signature_count=artifact.final_entity_signature_count, + final_entity_signature_hashes=artifact.final_entity_signature_hashes, + final_label_counts=artifact.final_label_counts, + artifact=artifact, + ) + + +def _phase_model_work( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> PhaseModelWork: + return PhaseModelWork( + seed=_uses_seed_model(request, config), + validation=_uses_validation_model(request, artifact, config), + augmentation=not _should_skip_augmentation(request, config), + ) + + +def _uses_seed_model(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: + return config.seed_source in {SeedSource.direct_llm, SeedSource.gliner} + + +def _uses_validation_model( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> bool: + return artifact.seed_validation_candidate_count > 0 + + +def _phase_skip_reasons( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> PhaseSkipReasons: + return PhaseSkipReasons( + seed=_seed_skip_reason(request, config), + validation=_validation_skip_reason(request, artifact, config), + augmentation=_augmentation_skip_reason(request, config), + ) + + +def _seed_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: + return None + + +def _validation_skip_reason( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> str | None: + if artifact.seed_validation_candidate_count == 0: + return "no_seed_candidates" + return None + + +def _augmentation_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: + if config.skip_augmentation: + return "disabled" + return None + + +def _model_phase_count(phase_model_work: PhaseModelWork) -> int: + return sum((phase_model_work.seed, phase_model_work.validation, phase_model_work.augmentation)) + + +def _phase_model_requests( + artifact: DetectionArtifactRow, + phase_model_work: PhaseModelWork, + config: StagedExecutionConfig, +) -> PhaseModelRequests: + return PhaseModelRequests( + seed=1 if phase_model_work.seed else 0, + validation=_validation_model_request_count(artifact, phase_model_work, config), + augmentation=1 if phase_model_work.augmentation else 0, + ) + + +def _validation_model_request_count( + artifact: DetectionArtifactRow, + phase_model_work: PhaseModelWork, + config: StagedExecutionConfig, +) -> int: + if not phase_model_work.validation: + return 0 + if config.validation_prompt_mode == ValidationPromptMode.full_text: + return 1 + return _ceil_div(artifact.seed_validation_candidate_count, config.validation_max_entities_per_call) + + +def _ceil_div(numerator: int, denominator: int) -> int: + return (numerator + denominator - 1) // denominator + + +def _model_request_count(phase_model_requests: PhaseModelRequests) -> int: + return phase_model_requests.seed + phase_model_requests.validation + phase_model_requests.augmentation + + +def _extract_entity_suggestions(content: str) -> list[dict[str, str]]: + payload = _load_embedded_json(content) + raw_entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in raw_entities + if isinstance(item, dict) and item.get("value") and item.get("label") + ] + + +def _allowed_entity_suggestions(suggestions: list[dict[str, str]], *, labels: list[str]) -> list[dict[str, str]]: + allowed = set(labels) + return [suggestion for suggestion in suggestions if suggestion["label"] in allowed] + + +def _extract_validation_decisions( + content: str, + *, + candidates: ValidationCandidatesSchema | None = None, + labels: list[str] | None = None, +) -> dict[str, Any]: + payload = _load_embedded_json(content) + decisions = payload.get("decisions", []) if isinstance(payload, dict) else [] + if not isinstance(decisions, list): + decisions = [] + if candidates is not None: + decisions = _preserve_specific_seed_labels(decisions, candidates) + if labels is not None: + decisions = _keep_invalid_reclass_labels(decisions, allowed_labels=set(labels)) + return {"decisions": decisions} + + +def _preserve_specific_seed_labels( + decisions: list[object], + candidates: ValidationCandidatesSchema, +) -> list[object]: + candidate_by_id = {candidate.id: candidate for candidate in candidates.candidates} + normalized: list[object] = [] + for decision in decisions: + if not isinstance(decision, dict): + normalized.append(decision) + continue + candidate = candidate_by_id.get(str(decision.get("id") or "")) + if candidate is None or not _should_preserve_specific_seed_label(decision, candidate): + normalized.append(decision) + continue + normalized.append( + { + **decision, + "decision": "keep", + "proposed_label": "", + "reason": _preserved_label_reason(decision.get("reason")), + } + ) + return normalized + + +def _should_preserve_specific_seed_label(decision: dict[str, object], candidate: Any) -> bool: + if str(decision.get("decision") or "") != "reclass": + return False + proposed_label = str(decision.get("proposed_label") or "") + if candidate.label == "date_of_birth" and proposed_label == "date": + return _has_date_of_birth_context(candidate) + return False + + +def _has_date_of_birth_context(candidate: Any) -> bool: + return _contains_date_of_birth_context(candidate.context_before, candidate.value, candidate.context_after) + + +def _keep_invalid_reclass_labels(decisions: list[object], *, allowed_labels: set[str]) -> list[object]: + normalized: list[object] = [] + for decision in decisions: + if not isinstance(decision, dict): + normalized.append(decision) + continue + if not _has_invalid_reclass_label(decision, allowed_labels=allowed_labels): + normalized.append(decision) + continue + normalized.append( + { + **decision, + "decision": "keep", + "proposed_label": "", + "reason": _invalid_reclass_label_reason(decision.get("reason"), decision.get("proposed_label")), + } + ) + return normalized + + +def _has_invalid_reclass_label(decision: dict[str, object], *, allowed_labels: set[str]) -> bool: + if str(decision.get("decision") or "") != "reclass": + return False + proposed_label = str(decision.get("proposed_label") or "").strip() + return proposed_label not in allowed_labels + + +def _invalid_reclass_label_reason(reason: object, proposed_label: object) -> str: + text = str(reason or "").strip() + suffix = f"ignored invalid reclass label {str(proposed_label or '').strip()!r}" + return f"{text}; {suffix}" if text else suffix + + +def normalize_birth_context_final_entities( + row: dict[str, Any], *, allowed_labels: list[str] | set[str] +) -> dict[str, Any]: + """Demote native date_of_birth spans without birth context back to date.""" + allowed = set(allowed_labels) + if not {"date", "date_of_birth"}.issubset(allowed): + return row + text = str(row.get(COL_TEXT, "")) + entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities + normalized: list[EntitySchema] = [] + changed = False + for entity in entities: + if entity.label == "date_of_birth" and not _entity_has_date_of_birth_context(text, entity): + normalized.append(entity.model_copy(update={"id": _entity_id("date", entity), "label": "date"})) + changed = True + else: + normalized.append(entity) + if changed: + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=normalized).model_dump(mode="json") + row[COL_TAGGED_TEXT] = build_tagged_text( + text=text, + entities=[_entity_schema_to_span(entity) for entity in normalized], + ) + return row + + +def _entity_has_date_of_birth_context(text: str, entity: EntitySchema) -> bool: + before_start = max(0, entity.start_position - VALIDATION_CONTEXT_WINDOW) + after_end = min(len(text), entity.end_position + VALIDATION_CONTEXT_WINDOW) + return _contains_date_of_birth_context( + text[before_start : entity.start_position], + entity.value, + text[entity.end_position : after_end], + ) + + +def _entity_id(label: str, entity: EntitySchema) -> str: + return f"{label}_{entity.start_position}_{entity.end_position}" + + +def _entity_schema_to_span(entity: EntitySchema) -> EntitySpan: + return EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + + +def _contains_date_of_birth_context(context_before: object, value: object, context_after: object) -> bool: + context = f"{context_before} {value} {context_after}" + return bool(_DATE_OF_BIRTH_CONTEXT_RE.search(context)) + + +def _preserved_label_reason(reason: object) -> str: + text = str(reason or "").strip() + suffix = "preserved specific seed label from birth-related context" + return f"{text}; {suffix}" if text else suffix + + +def _raw_detector_entity_count(content: str) -> int: + payload = _load_embedded_json(content) + entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return len(entities) if isinstance(entities, list) else 0 + + +def _spans_from_suggestions( + text: str, suggestions: list[dict[str, str]], *, labels: list[str], source: str +) -> list[EntitySpan]: + allowed = set(labels) + cleaned = [item for item in suggestions if item["value"] and item["label"] in allowed] + spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) + return [ + EntitySpan( + entity_id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=source, + ) + for span in spans + ] + + +def _span_to_entity_schema(span: EntitySpan) -> EntitySchema: + return EntitySchema( + id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=span.source, + ) + + +def _augmented_entity_schemas(row: dict[str, Any], request: StagedDetectionRequest) -> list[EntitySchema]: + spans = _spans_from_suggestions( + request.text, + _extract_entities_from_payload(row.get(COL_AUGMENTED_ENTITIES)), + labels=request.labels, + source="augmenter", + ) + return [_span_to_entity_schema(span) for span in spans] + + +def _extract_entities_from_payload(payload: object) -> list[dict[str, str]]: + entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in entities + if isinstance(item, dict) + ] + + +def _candidate_count(raw: object) -> int: + return len(ValidationCandidatesSchema.from_raw(raw).candidates) + + +def _validated_decision_count(raw: object) -> int: + return len(ValidatedDecisionsSchema.from_raw(raw).decisions) + + +def _sum_usage(usage: PhaseUsage) -> dict[str, int]: + totals: Counter[str] = Counter() + for phase in (usage.seed, usage.validation, usage.augmentation): + for key, value in phase.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def apply_baseline_comparisons( + cases: list[StagedDetectionCase], + baseline_artifacts: Path, +) -> list[StagedDetectionCase]: + baseline = _read_baseline_artifacts(baseline_artifacts) + return [_case_with_comparison(case, baseline.get(case.row_index)) for case in cases] + + +def _case_with_comparison(case: StagedDetectionCase, baseline_row: dict[str, Any] | None) -> StagedDetectionCase: + if baseline_row is None or case.status != CaseStatus.completed or case.artifact is None: + return case + baseline_hashes = _baseline_signature_hashes(baseline_row) + if baseline_hashes is None: + return case + comparison = compare_signature_sets( + baseline_hashes=baseline_hashes, + baseline_labels=_signature_labels(baseline_row), + direct_hashes=set(case.artifact.final_entity_signature_hashes), + direct_labels=case.artifact.final_entity_signature_labels, + ) + return case.model_copy(update={"comparison": comparison}) + + +def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: + hashes = row.get("final_entity_signature_hashes") + if not isinstance(hashes, list): + return None + return {str(item) for item in hashes} + + +def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: + baseline: dict[int, dict[str, Any]] = {} + with path.open(encoding="utf-8") as source: + for line in source: + if not line.strip(): + continue + row = json.loads(line) + row_index = int(row.get("row_index", 0)) + if row_index in baseline: + raise ValueError(f"baseline artifacts has multiple rows for row_index={row_index}") + baseline[row_index] = row + return baseline + + +def _signature_labels(row: dict[str, Any]) -> dict[str, str]: + return { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") and value is not None + } + + +def run_probe( + input_path: Path, + *, + text_column: str, + labels: list[str], + output: Path | None = None, + overwrite: bool = False, + endpoint: str | None = None, + model: str | None = None, + seed_source: SeedSource = SeedSource.direct_llm, + gliner_endpoint: str | None = None, + gliner_model: str | None = None, + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + skip_augmentation: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, + limit: int = 1, + offset: int = 0, + baseline_artifacts: Path | None = None, +) -> StagedDetectionRun: + endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) + model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) + if seed_source == SeedSource.gliner: + gliner_endpoint = _required_runtime_value( + "gliner_endpoint", explicit=gliner_endpoint, env_var=_GLINER_ENDPOINT_ENV + ) + gliner_model = _required_runtime_value("gliner_model", explicit=gliner_model, env_var=_GLINER_MODEL_ENV) + if not os.environ.get(gliner_api_key_env): + raise ValueError(f"{gliner_api_key_env} is not set") + else: + gliner_endpoint = gliner_endpoint or _UNCONFIGURED_GLINER_ENDPOINT + gliner_model = gliner_model or _UNCONFIGURED_GLINER_MODEL + requests = _load_requests(input_path, text_column=text_column, labels=labels, limit=limit, offset=offset) + config = _execution_config_from_params(locals()) + cases = _run_probe_cases(requests, config) + if baseline_artifacts is not None: + cases = apply_baseline_comparisons(cases, baseline_artifacts) + result = StagedDetectionRun( + input_path=str(input_path), text_column=text_column, endpoint=endpoint, model=model, labels=labels, rows=cases + ) + if output is not None: + write_outputs(result, output, overwrite=overwrite) + return result + + +def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: + value = explicit or os.environ.get(env_var) + if not value: + raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") + return value + + +def _run_probe_cases( + requests: list[StagedDetectionRequest], + config: StagedExecutionConfig, +) -> list[StagedDetectionCase]: + client = HttpxDirectDetectionClient() + seed_client = HttpxGlinerSeedClient() if config.seed_source == SeedSource.gliner else None + return [ + run_staged_detection_case( + request, + client=client, + seed_client=seed_client, + seed_source=config.seed_source, + endpoint=config.endpoint, + model=config.model, + gliner_endpoint=config.gliner_endpoint, + gliner_model=config.gliner_model, + gliner_api_key_env=config.gliner_api_key_env, + gliner_threshold=config.gliner_threshold, + skip_augmentation=config.skip_augmentation, + validation_prompt_mode=config.validation_prompt_mode, + validation_max_entities_per_call=config.validation_max_entities_per_call, + validation_excerpt_window_chars=config.validation_excerpt_window_chars, + ) + for request in requests + ] + + +def _load_requests( + input_path: Path, *, text_column: str, labels: list[str], limit: int, offset: int +) -> list[StagedDetectionRequest]: + dataframe = pd.read_csv(input_path) + if text_column not in dataframe.columns: + raise ValueError(f"text column {text_column!r} not found in {input_path}") + selected = dataframe.iloc[offset : offset + limit] + return [ + StagedDetectionRequest( + case_id=f"{input_path.stem}-row-{int(index)}", + text=str(row[text_column]), + labels=labels, + row_index=int(index), + ) + for index, row in selected.iterrows() + ] + + +def write_outputs(result: StagedDetectionRun, output_dir: Path, *, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise ValueError(f"output directory already exists: {output_dir}") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True) + _write_jsonl(output_dir / "staged-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) + _write_jsonl(output_dir / "staged-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) + (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def _case_payload(case: StagedDetectionCase) -> dict[str, Any]: + payload = case.model_dump(exclude={"artifact"}) + payload["record_type"] = "staged_detection_case" + return payload + + +def _artifact_payload(case: StagedDetectionCase) -> dict[str, Any]: + payload = case.artifact.model_dump() if case.artifact is not None else {} + payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "staged_detection_artifact"}) + return payload + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as target: + for row in rows: + target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") + + +def render_result(result: StagedDetectionRun, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = len(result.rows) - result.error_count + return f"Ran {completed}/{len(result.rows)} staged detection case(s); errors={result.error_count}" + + +@app.default +def main( + input_path: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")], + labels: Annotated[str, cyclopts.Parameter("--labels")], + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, + model: Annotated[str | None, cyclopts.Parameter("--model")] = None, + seed_source: Annotated[SeedSource, cyclopts.Parameter("--seed-source")] = SeedSource.direct_llm, + gliner_endpoint: Annotated[str | None, cyclopts.Parameter("--gliner-endpoint")] = None, + gliner_model: Annotated[str | None, cyclopts.Parameter("--gliner-model")] = None, + gliner_api_key_env: Annotated[str, cyclopts.Parameter("--gliner-api-key-env")] = "NVIDIA_API_KEY", + gliner_threshold: Annotated[float, cyclopts.Parameter("--gliner-threshold")] = 0.3, + skip_augmentation: Annotated[bool, cyclopts.Parameter("--skip-augmentation")] = False, + validation_prompt_mode: Annotated[ + ValidationPromptMode, cyclopts.Parameter("--validation-prompt-mode") + ] = ValidationPromptMode.full_text, + validation_max_entities_per_call: Annotated[int, cyclopts.Parameter("--validation-max-entities-per-call")] = 10, + validation_excerpt_window_chars: Annotated[int, cyclopts.Parameter("--validation-excerpt-window-chars")] = 160, + limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, + offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, + baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + params = locals() + json_output = bool(params.pop("json_output")) + result = _run_main_probe(params) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if result.error_count: + raise SystemExit(1) + + +def _run_main_probe(params: dict[str, Any]) -> StagedDetectionRun: + configure_logging(params.pop("log_format")) + params["labels"] = parse_labels(params["labels"]) + try: + return run_probe(**params) + except (ValueError, ValidationError, httpx.HTTPError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + + +if __name__ == "__main__": + app() diff --git a/uv.lock b/uv.lock index ae31c1ea..45102849 100644 --- a/uv.lock +++ b/uv.lock @@ -2425,6 +2425,7 @@ dependencies = [ { name = "data-designer" }, { name = "httpx" }, { name = "pydantic" }, + { name = "pydantic-settings" }, { name = "pygments" }, { name = "tiktoken" }, ] @@ -2462,6 +2463,7 @@ requires-dist = [ { name = "data-designer", specifier = "==0.6.0" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "pydantic", specifier = ">=2.9,<3" }, + { name = "pydantic-settings", specifier = ">=2.12,<3" }, { name = "pygments", specifier = ">=2.20.0" }, { name = "tiktoken", specifier = ">=0.9.0" }, ]