From b75ab36bdd5b4f848ee6c3c24a7dae9edc2bde32 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 2 Jun 2026 21:38:50 +0000 Subject: [PATCH 01/26] feat: add anonymizer measurement instrumentation Signed-off-by: Aaron Gonzales --- pyproject.toml | 1 + .../engine/detection/detection_workflow.py | 91 +- src/anonymizer/engine/ndd/adapter.py | 161 ++- .../engine/replace/replace_runner.py | 52 +- .../engine/rewrite/rewrite_workflow.py | 148 +-- src/anonymizer/interface/anonymizer.py | 51 + src/anonymizer/measurement.py | 950 ++++++++++++++++++ tests/test_measurement.py | 667 ++++++++++++ uv.lock | 2 + 9 files changed, 1974 insertions(+), 149 deletions(-) create mode 100644 src/anonymizer/measurement.py create mode 100644 tests/test_measurement.py diff --git a/pyproject.toml b/pyproject.toml index 29865798..972fbb12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ license = "Apache-2.0" dependencies = [ "data-designer==0.6.0", "pydantic>=2.9,<3", + "pydantic-settings>=2.12,<3", "cyclopts>=3", "pygments>=2.20.0", "cryptography>=46.0.6", diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index 87eb644b..e482e912 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -59,6 +59,7 @@ EntitiesSchema, LatentEntitiesSchema, ) +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.detection") @@ -266,54 +267,64 @@ def run( ``identify_latent_entities`` if ``tag_latent_entities`` is True (rewrite mode). Merges failures from both stages. """ - if tag_latent_entities and privacy_goal is None: - raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)") - - compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities - detected_result = self.detect_and_validate_entities( - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - - if tag_latent_entities: - latent_result = self.identify_latent_entities( - detected_result.dataframe, + with stage_timer( + "EntityDetectionWorkflow.run", + input_row_count=len(dataframe), + tag_latent_entities=tag_latent_entities, + ) as measurement: + if tag_latent_entities and privacy_goal is None: + raise ValueError("privacy_goal is required when tag_latent_entities=True (rewrite mode)") + + compute_grouped = True if compute_grouped_entities is None else compute_grouped_entities + detected_result = self.detect_and_validate_entities( + dataframe, model_configs=model_configs, selected_models=selected_models, gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, entity_labels=entity_labels, - privacy_goal=privacy_goal, data_summary=data_summary, preview_num_records=preview_num_records, ) - final_df = latent_result.dataframe.copy() - final_failures = [*detected_result.failed_records, *latent_result.failed_records] - else: - final_df = detected_result.dataframe.copy() - final_failures = detected_result.failed_records - - # When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS), - # the augmenter is strict and out-of-scope labels are filtered. - # entity_labels=None is the only way to get permissive augmentation. - # TODO(docs): document this None-vs-explicit contract in user-facing docs. - if COL_DETECTED_ENTITIES in final_df.columns: - allowed = set(entity_labels) if entity_labels is not None else None - final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply( - lambda raw: _materialize_final_entities(raw, allowed_labels=allowed) + + if tag_latent_entities: + latent_result = self.identify_latent_entities( + detected_result.dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + entity_labels=entity_labels, + privacy_goal=privacy_goal, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + final_df = latent_result.dataframe.copy() + final_failures = [*detected_result.failed_records, *latent_result.failed_records] + else: + final_df = detected_result.dataframe.copy() + final_failures = detected_result.failed_records + + # When entity_labels is explicitly provided (even if it matches DEFAULT_ENTITY_LABELS), + # the augmenter is strict and out-of-scope labels are filtered. + # entity_labels=None is the only way to get permissive augmentation. + # TODO(docs): document this None-vs-explicit contract in user-facing docs. + if COL_DETECTED_ENTITIES in final_df.columns: + allowed = set(entity_labels) if entity_labels is not None else None + final_df[COL_FINAL_ENTITIES] = final_df[COL_DETECTED_ENTITIES].apply( + lambda raw: _materialize_final_entities(raw, allowed_labels=allowed) + ) + if compute_grouped: + final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value) + result = EntityDetectionResult( + dataframe=final_df, + failed_records=final_failures, ) - if compute_grouped: - final_df[COL_ENTITIES_BY_VALUE] = final_df[COL_FINAL_ENTITIES].apply(_build_entities_by_value) - return EntityDetectionResult( - dataframe=final_df, - failed_records=final_failures, - ) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result def _inject_detector_params( self, diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 8aa9b920..13e390c9 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -6,10 +6,12 @@ import json import logging import tempfile +import time import uuid +from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder @@ -18,6 +20,7 @@ from data_designer.config.seed_source import LocalFileSeedSource from anonymizer.interface.errors import AnonymizerWorkflowError +from anonymizer.measurement import current_collector, record_ndd_workflow if TYPE_CHECKING: import pandas as pd @@ -86,7 +89,15 @@ def run_workflow( logger.debug("NDD workflow '%s' starting with %d records", workflow_name, len(workflow_input_df)) col_names = [c.name for c in columns] logger.debug("NDD workflow '%s': %d columns %s", workflow_name, len(col_names), col_names) - model_aliases = [m.alias for m in model_configs] + available_model_aliases = [m.alias for m in model_configs] + model_aliases = _extract_workflow_model_aliases(columns) or available_model_aliases + record_count = ( + min(preview_num_records, len(workflow_input_df)) + if preview_num_records is not None + else len(workflow_input_df) + ) + started = time.perf_counter() + usage_probe = _DataDesignerUsageProbe(self._data_designer, enabled=current_collector() is not None) with tempfile.TemporaryDirectory(prefix=f"anonymizer_{workflow_name}_") as tmp_dir: seed_path = str(Path(tmp_dir) / "seed.parquet") @@ -97,33 +108,29 @@ def run_workflow( for column in columns: config_builder.add_column(column) - record_count = ( - min(preview_num_records, len(workflow_input_df)) - if preview_num_records is not None - else len(workflow_input_df) - ) try: - if preview_num_records is None: - run_results = self._data_designer.create( - config_builder, - num_records=len(workflow_input_df), - dataset_name=workflow_name, - ) - output_df = run_results.load_dataset() - else: - preview_results = self._data_designer.preview( - config_builder, - num_records=record_count, - ) - if preview_results.dataset is None: - output_df = workflow_input_df.iloc[0:0].copy() + with usage_probe: + if preview_num_records is None: + run_results = self._data_designer.create( + config_builder, + num_records=len(workflow_input_df), + dataset_name=workflow_name, + ) + output_df = run_results.load_dataset() else: - output_df = preview_results.dataset + preview_results = self._data_designer.preview( + config_builder, + num_records=record_count, + ) + if preview_results.dataset is None: + output_df = workflow_input_df.iloc[0:0].copy() + else: + output_df = preview_results.dataset except Exception as exc: logger.warning( "Workflow failed for %d input record(s) on model(s) %s: %s", record_count, - model_aliases, + available_model_aliases, exc, ) logger.debug( @@ -131,6 +138,20 @@ def run_workflow( workflow_name, col_names, ) + record_ndd_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=record_count, + seed_row_count=len(workflow_input_df), + output_row_count=None, + failed_record_count=None, + elapsed_sec=time.perf_counter() - started, + status="error", + preview_num_records=preview_num_records, + column_count=len(col_names), + column_names=col_names, + model_usage=usage_probe.model_usage(), + ) raise AnonymizerWorkflowError(f"Workflow failed: {exc}") from exc logger.debug("NDD workflow '%s' returned %d records", workflow_name, len(output_df)) @@ -143,6 +164,19 @@ def run_workflow( ), output_df=output_df, ) + record_ndd_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=record_count, + seed_row_count=len(workflow_input_df), + output_row_count=len(output_df), + failed_record_count=len(failed_records), + elapsed_sec=time.perf_counter() - started, + preview_num_records=preview_num_records, + column_count=len(col_names), + column_names=col_names, + model_usage=usage_probe.model_usage(), + ) return WorkflowRunResult(dataframe=output_df, failed_records=failed_records) def _attach_record_ids(self, df: pd.DataFrame) -> pd.DataFrame: @@ -225,3 +259,84 @@ def _detect_missing_records( ) for record_id in missing_ids ] + + +def _extract_workflow_model_aliases(columns: list[ColumnConfigT]) -> list[str]: + aliases: list[str] = [] + for column in columns: + aliases.extend(_as_alias_list(getattr(column, "model_alias", None))) + generator = getattr(column, "generator_function", None) + metadata = getattr(generator, "custom_column_metadata", None) + if isinstance(metadata, dict): + aliases.extend(_as_alias_list(metadata.get("model_aliases"))) + return list(dict.fromkeys(alias for alias in aliases if alias)) + + +def _as_alias_list(raw: Any) -> list[str]: + if raw is None: + return [] + if isinstance(raw, str): + return [raw] + if isinstance(raw, (list, tuple, set)): + return [str(item) for item in raw if str(item)] + return [str(raw)] + + +class _DataDesignerUsageProbe: + """Capture DataDesigner model usage from the per-run private ResourceProvider.""" + + def __init__(self, data_designer: DataDesigner, *, enabled: bool) -> None: + self._data_designer = data_designer + self._enabled = enabled + self._original_create_resource_provider: Any | None = None + self._resource_providers: list[Any] = [] + + def __enter__(self) -> _DataDesignerUsageProbe: + if not self._enabled: + return self + + original = getattr(self._data_designer, "_create_resource_provider", None) + if not callable(original): + return self + + self._original_create_resource_provider = original + + def wrapper(*args: Any, **kwargs: Any) -> Any: + resource_provider = original(*args, **kwargs) + self._resource_providers.append(resource_provider) + return resource_provider + + setattr(self._data_designer, "_create_resource_provider", wrapper) + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + if self._original_create_resource_provider is not None: + setattr(self._data_designer, "_create_resource_provider", self._original_create_resource_provider) + + def model_usage(self) -> dict[str, Any] | None: + usage: dict[str, Any] = {} + for resource_provider in self._resource_providers: + model_registry = getattr(resource_provider, "model_registry", None) + snapshot = _get_model_usage_snapshot(model_registry) + if not snapshot: + continue + for model_name, stats in snapshot.items(): + usage[str(model_name)] = _model_usage_as_json(stats) + return usage or None + + +def _get_model_usage_snapshot(model_registry: object) -> Mapping[str, object] | None: + get_snapshot = getattr(model_registry, "get_model_usage_snapshot", None) + if not callable(get_snapshot): + return None + snapshot = get_snapshot() + if isinstance(snapshot, Mapping): + return snapshot + return None + + +def _model_usage_as_json(stats: object) -> Any: + model_dump = getattr(stats, "model_dump", None) + if callable(model_dump): + return model_dump(mode="json") + return stats diff --git a/src/anonymizer/engine/replace/replace_runner.py b/src/anonymizer/engine/replace/replace_runner.py index d6501834..f3a95adc 100644 --- a/src/anonymizer/engine/replace/replace_runner.py +++ b/src/anonymizer/engine/replace/replace_runner.py @@ -28,6 +28,7 @@ from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, FailedRecord, NddAdapter from anonymizer.engine.replace.llm_replace_workflow import LlmReplaceWorkflow from anonymizer.engine.replace.strategies import apply_local_replace_strategy, apply_replacement_map +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.replace") @@ -73,27 +74,38 @@ def run( Evaluation is a separate concern — call ``evaluate()`` on the resulting dataframe when you want the LLM alignment scores. """ - logger.debug("replacement strategy: %s on %d records", type(replace_method).__name__, len(dataframe)) - - if isinstance(replace_method, (Annotate, Redact, Hash)): - local_df = apply_local_replace_strategy(dataframe, strategy=replace_method) - failed_records: list[FailedRecord] = [] - elif isinstance(replace_method, Substitute): - if self._llm_workflow is None: - raise ValueError("Substitute requires an llm_workflow, but none was provided.") - map_result = self._llm_workflow.generate_map_only( - dataframe, - model_configs=model_configs, - selected_models=selected_models, - instructions=replace_method.instructions, - preview_num_records=preview_num_records, - ) - local_df = apply_replacement_map(map_result.dataframe) - failed_records = list(map_result.failed_records) - else: - raise ValueError(f"Unsupported replace method: {type(replace_method).__name__}") + strategy = type(replace_method).__name__ + with stage_timer( + "ReplacementWorkflow.run", + strategy=strategy, + input_row_count=len(dataframe), + ) as measurement: + logger.debug("replacement strategy: %s on %d records", strategy, len(dataframe)) + + if isinstance(replace_method, (Annotate, Redact, Hash)): + local_df = apply_local_replace_strategy(dataframe, strategy=replace_method) + failed_records: list[FailedRecord] = [] + elif isinstance(replace_method, Substitute): + if self._llm_workflow is None: + raise ValueError("Substitute requires an llm_workflow, but none was provided.") + map_result = self._llm_workflow.generate_map_only( + dataframe, + model_configs=model_configs, + selected_models=selected_models, + instructions=replace_method.instructions, + preview_num_records=preview_num_records, + ) + local_df = apply_replacement_map(map_result.dataframe) + failed_records = list(map_result.failed_records) + else: + raise ValueError(f"Unsupported replace method: {type(replace_method).__name__}") - return ReplacementResult(dataframe=local_df, failed_records=failed_records) + result = ReplacementResult(dataframe=local_df, failed_records=failed_records) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result def evaluate( self, diff --git a/src/anonymizer/engine/rewrite/rewrite_workflow.py b/src/anonymizer/engine/rewrite/rewrite_workflow.py index 88c2b9c3..88fe07f0 100644 --- a/src/anonymizer/engine/rewrite/rewrite_workflow.py +++ b/src/anonymizer/engine/rewrite/rewrite_workflow.py @@ -37,6 +37,7 @@ from anonymizer.engine.rewrite.sensitivity_disposition import SensitivityDispositionWorkflow from anonymizer.engine.rewrite.workflow_utils import derive_seed_columns, select_seed_cols from anonymizer.engine.row_partitioning import merge_and_reorder, split_rows +from anonymizer.measurement import stage_timer logger = logging.getLogger("anonymizer.rewrite.workflow") @@ -196,80 +197,95 @@ def run( preview_num_records: int | None = None, strict_entity_protection: bool = False, ) -> RewriteResult: - all_failed: list[FailedRecord] = [] - - entity_rows, passthrough_rows = split_rows(dataframe, column=COL_ENTITIES_BY_VALUE, predicate=_has_entities) + with stage_timer("RewriteWorkflow.run", input_row_count=len(dataframe)) as measurement: + all_failed: list[FailedRecord] = [] - # Fast path: no entities anywhere - if entity_rows.empty: - _apply_passthrough_defaults(passthrough_rows) - result_df = merge_and_reorder(passthrough_rows) - return RewriteResult(dataframe=result_df, failed_records=all_failed) + entity_rows, passthrough_rows = split_rows(dataframe, column=COL_ENTITIES_BY_VALUE, predicate=_has_entities) + measurement.update( + entity_row_count=len(entity_rows), + passthrough_row_count=len(passthrough_rows), + ) - # --- Step 1: replacement map (needs only detection output) --- - replace_workflow = LlmReplaceWorkflow(adapter=self._adapter) - replace_result = replace_workflow.generate_map_only( - entity_rows, - model_configs=model_configs, - selected_models=replace_model_selection, - ) - entity_rows = _join_new_columns(entity_rows, replace_result.dataframe) - all_failed.extend(replace_result.failed_records) + # Fast path: no entities anywhere + if entity_rows.empty: + _apply_passthrough_defaults(passthrough_rows) + result_df = merge_and_reorder(passthrough_rows) + result = RewriteResult(dataframe=result_df, failed_records=all_failed) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result + + # --- Step 1: replacement map (needs only detection output) --- + replace_workflow = LlmReplaceWorkflow(adapter=self._adapter) + replace_result = replace_workflow.generate_map_only( + entity_rows, + model_configs=model_configs, + selected_models=replace_model_selection, + ) + entity_rows = _join_new_columns(entity_rows, replace_result.dataframe) + all_failed.extend(replace_result.failed_records) + + # --- Step 2: domain, disposition, QA, rewrite (single adapter call) --- + pipeline_columns = [ + *self._domain_wf.columns(selected_models=selected_models, data_summary=data_summary), + *self._disposition_wf.columns( + selected_models=selected_models, + privacy_goal=privacy_goal, + data_summary=data_summary, + strict_entity_protection=strict_entity_protection, + ), + *self._qa_wf.columns(selected_models=selected_models), + *self._rewrite_gen_wf.columns( + selected_models=selected_models, + privacy_goal=privacy_goal, + data_summary=data_summary, + ), + ] + + pipeline_seed = select_seed_cols(entity_rows, derive_seed_columns(pipeline_columns, entity_rows)) + pipeline_result = self._adapter.run_workflow( + pipeline_seed, + model_configs=model_configs, + columns=pipeline_columns, + workflow_name="rewrite-pipeline", + preview_num_records=preview_num_records, + ) + entity_rows = _join_new_columns(entity_rows, pipeline_result.dataframe) + all_failed.extend(pipeline_result.failed_records) - # --- Step 2: domain, disposition, QA, rewrite (single adapter call) --- - pipeline_columns = [ - *self._domain_wf.columns(selected_models=selected_models, data_summary=data_summary), - *self._disposition_wf.columns( + # --- Step 5: evaluate-repair loop --- + entity_rows, eval_repair_failed = self._run_evaluate_repair_loop( + entity_rows, + model_configs=model_configs, selected_models=selected_models, privacy_goal=privacy_goal, - data_summary=data_summary, - strict_entity_protection=strict_entity_protection, - ), - *self._qa_wf.columns(selected_models=selected_models), - *self._rewrite_gen_wf.columns( + evaluation=evaluation, + preview_num_records=preview_num_records, + ) + all_failed.extend(eval_repair_failed) + + # --- Step 6: final judge (non-critical) --- + entity_rows, judge_failed = self._run_final_judge( + entity_rows, + model_configs=model_configs, selected_models=selected_models, privacy_goal=privacy_goal, - data_summary=data_summary, - ), - ] - - pipeline_seed = select_seed_cols(entity_rows, derive_seed_columns(pipeline_columns, entity_rows)) - pipeline_result = self._adapter.run_workflow( - pipeline_seed, - model_configs=model_configs, - columns=pipeline_columns, - workflow_name="rewrite-pipeline", - preview_num_records=preview_num_records, - ) - entity_rows = _join_new_columns(entity_rows, pipeline_result.dataframe) - all_failed.extend(pipeline_result.failed_records) - - # --- Step 5: evaluate-repair loop --- - entity_rows, eval_repair_failed = self._run_evaluate_repair_loop( - entity_rows, - model_configs=model_configs, - selected_models=selected_models, - privacy_goal=privacy_goal, - evaluation=evaluation, - preview_num_records=preview_num_records, - ) - all_failed.extend(eval_repair_failed) - - # --- Step 6: final judge (non-critical) --- - entity_rows, judge_failed = self._run_final_judge( - entity_rows, - model_configs=model_configs, - selected_models=selected_models, - privacy_goal=privacy_goal, - evaluation=evaluation, - preview_num_records=preview_num_records, - ) - all_failed.extend(judge_failed) + evaluation=evaluation, + preview_num_records=preview_num_records, + ) + all_failed.extend(judge_failed) - # --- Merge and return --- - _apply_passthrough_defaults(passthrough_rows) - combined = merge_and_reorder(entity_rows, passthrough_rows) - return RewriteResult(dataframe=combined, failed_records=all_failed) + # --- Merge and return --- + _apply_passthrough_defaults(passthrough_rows) + combined = merge_and_reorder(entity_rows, passthrough_rows) + result = RewriteResult(dataframe=combined, failed_records=all_failed) + measurement.update( + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + ) + return result # --------------------------------------------------------------------------- # Evaluate-repair loop diff --git a/src/anonymizer/interface/anonymizer.py b/src/anonymizer/interface/anonymizer.py index ec08164a..b19762d3 100644 --- a/src/anonymizer/interface/anonymizer.py +++ b/src/anonymizer/interface/anonymizer.py @@ -59,6 +59,11 @@ from anonymizer.interface.errors import InvalidConfigError from anonymizer.interface.results import AnonymizerResult, PreviewResult from anonymizer.logging import LOG_INDENT, configure_logging, reapply_log_levels +from anonymizer.measurement import ( + record_record_metrics, + record_run_metadata, + stage_timer, +) from anonymizer.telemetry import ( NOT_APPLICABLE, AnonymizerEvent, @@ -331,6 +336,45 @@ def _run_internal( data: AnonymizerInput, context: ResolvedInput, preview_num_records: int | None, + ) -> AnonymizerResult: + input_df = context.dataframe + mode = "replace" if config.replace is not None else "rewrite" + strategy = type(config.replace).__name__ if config.replace is not None else "Rewrite" + with stage_timer( + "Anonymizer._run_internal", + mode=mode, + strategy=strategy, + input_row_count=len(input_df), + preview_num_records=preview_num_records, + ) as measurement: + record_run_metadata( + config=config, + data=data, + mode=mode, + strategy=strategy, + input_row_count=len(input_df), + preview_num_records=preview_num_records, + model_configs=self._model_configs, + ) + result = self._run_internal_impl( + config=config, + data=data, + context=context, + preview_num_records=preview_num_records, + ) + measurement.update( + output_row_count=len(result.trace_dataframe), + failed_record_count=len(result.failed_records), + ) + return result + + def _run_internal_impl( + self, + *, + config: AnonymizerConfig, + data: AnonymizerInput, + context: ResolvedInput, + preview_num_records: int | None, ) -> AnonymizerResult: input_df = context.dataframe num_records = len(input_df) @@ -455,6 +499,13 @@ def _run_internal( text_col = context.resolved_text_column renamed_trace = _rename_output_columns(final_df, resolved_text_column=text_col) logger.info("🎉 Pipeline complete — %d records processed, %d total failures", num_records, len(all_failures)) + record_record_metrics( + final_df, + mode="replace" if config.replace is not None else "rewrite", + strategy=type(config.replace).__name__ if config.replace is not None else "Rewrite", + text_column=COL_TEXT, + validation_max_entities_per_call=config.detect.validation_max_entities_per_call, + ) return AnonymizerResult( dataframe=_build_user_dataframe(renamed_trace, resolved_text_column=text_col), trace_dataframe=renamed_trace, diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py new file mode 100644 index 00000000..a56d8d3e --- /dev/null +++ b/src/anonymizer/measurement.py @@ -0,0 +1,950 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import hmac +import json +import logging +import math +import platform +import secrets +import time +import uuid +from collections import Counter +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +from numbers import Integral +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast +from urllib.parse import urlparse + +from pydantic import Field, ValidationError +from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError + +if TYPE_CHECKING: + import pandas as pd + + +_ACTIVE_COLLECTOR: ContextVar[MeasurementCollector | None] = ContextVar( + "anonymizer_measurement_collector", + default=None, +) +_GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") +MEASUREMENT_SCHEMA_VERSION = 1 +DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" + +logger = logging.getLogger("anonymizer.measurement") + + +class _MeasurementWriter(Protocol): + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: ... + + +class _JsonlMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + +class _JsonMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(records, f, ensure_ascii=True, indent=2, sort_keys=True) + + +def _writer_for_format(output_format: Literal["jsonl", "json"]) -> _MeasurementWriter: + if output_format == "json": + return _JsonMeasurementWriter() + return _JsonlMeasurementWriter() + + +class _MeasurementEnvSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix=DEFAULT_MEASUREMENT_ENV_PREFIX, + env_ignore_empty=True, + extra="ignore", + ) + + output_path: str | None = None + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + fail_on_write_error: bool = False + run_id: str | None = None + run_tags: dict[str, Any] = Field(default_factory=dict) + + +class MeasurementCollector: + """In-memory collector for local benchmark and throughput records. + + Records contain counts, labels, lengths, aliases, timings, and run-scoped + HMACs. They must not contain raw text, entity values, prompts, generated + outputs, replacement maps, provider secrets, or API keys. + """ + + def __init__( + self, + *, + run_id: str | None = None, + record_hash_key: bytes | str | None = None, + record_level: bool = True, + run_tags: Mapping[str, Any] | None = None, + ) -> None: + self.run_id = run_id or uuid.uuid4().hex + self.record_level = record_level + self.run_tags = cast(dict[str, Any], _json_safe(dict(run_tags or {}))) + if record_hash_key is None: + self._record_hash_key = secrets.token_bytes(32) + elif isinstance(record_hash_key, str): + self._record_hash_key = record_hash_key.encode("utf-8") + else: + self._record_hash_key = bytes(record_hash_key) + self._records: list[dict[str, Any]] = [] + + @property + def records(self) -> list[dict[str, Any]]: + """Return a shallow copy of collected measurement records.""" + return list(self._records) + + def record(self, record_type: str, **fields: Any) -> None: + """Append one machine-readable measurement record.""" + record = { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": record_type, + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + self._records.append(_json_safe(record)) + + def record_hash(self, *, row_index: object, text: str) -> str: + """Return a run-scoped HMAC for joining records without storing text.""" + serialized = json.dumps( + {"row_index": str(row_index), "text": text}, + default=str, + sort_keys=True, + separators=(",", ":"), + ) + return hmac.new(self._record_hash_key, serialized.encode("utf-8"), hashlib.sha256).hexdigest() + + def write_jsonl(self, path: str | Path) -> None: + """Write records as newline-delimited JSON.""" + _JsonlMeasurementWriter().write(self._records, path) + + def write_json(self, path: str | Path) -> None: + """Write records as a JSON array.""" + _JsonMeasurementWriter().write(self._records, path) + + def to_dataframe(self) -> pd.DataFrame: + """Return records as a pandas DataFrame for benchmark tooling.""" + import pandas as pd + + return pd.DataFrame(self._records) + + +@dataclass(frozen=True) +class MeasurementConfig: + """Configuration for writing structured measurement records around a run.""" + + output_path: str | Path + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + run_id: str | None = None + record_hash_key: bytes | str | None = None + run_tags: Mapping[str, Any] | None = None + fail_on_write_error: bool = False + + def __post_init__(self) -> None: + if self.output_format not in {"jsonl", "json"}: + raise ValueError("output_format must be 'jsonl' or 'json'") + + @classmethod + def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> MeasurementConfig | None: + """Build measurement config from environment variables, or None if output is unset. + + This is intentionally opt-in. Anonymizer API and CLI calls do not read + measurement environment variables unless benchmark/tooling code calls this + helper explicitly. + """ + try: + settings = _load_measurement_env_settings(prefix=prefix) + except (SettingsError, ValidationError) as exc: + raise ValueError(_measurement_env_error_message(exc, prefix=prefix)) from None + + if settings.output_path is None: + return None + return cls( + output_path=settings.output_path, + output_format=settings.output_format, + record_level=settings.record_level, + run_id=settings.run_id, + run_tags=settings.run_tags, + fail_on_write_error=settings.fail_on_write_error, + ) + + @classmethod + def from_sources( + cls, + explicit: MeasurementConfig | None = None, + *, + env: bool = False, + prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX, + ) -> MeasurementConfig | None: + """Resolve measurement config from explicit config first, then optional env.""" + if explicit is not None: + return explicit + if env: + return cls.from_env(prefix=prefix) + return None + + def write_collector(self, collector: MeasurementCollector) -> None: + """Write a collector using this config's output format.""" + _writer_for_format(self.output_format).write(collector.records, self.output_path) + + +@contextmanager +def measurement_session(collector: MeasurementCollector | None = None) -> Iterator[MeasurementCollector]: + """Activate a collector for code running in this context.""" + active = collector or MeasurementCollector() + token = _ACTIVE_COLLECTOR.set(active) + try: + yield active + finally: + _ACTIVE_COLLECTOR.reset(token) + + +@contextmanager +def configured_measurement_session(config: MeasurementConfig | None) -> Iterator[MeasurementCollector | None]: + """Activate and persist a collector when a measurement config is provided.""" + if config is None: + yield None + return + + collector = MeasurementCollector( + run_id=config.run_id, + record_hash_key=config.record_hash_key, + record_level=config.record_level, + run_tags=config.run_tags, + ) + with measurement_session(collector): + body_error: BaseException | None = None + try: + yield collector + except BaseException as exc: + body_error = exc + raise + finally: + _write_collector_safely(config=config, collector=collector, body_error=body_error) + + +def current_collector() -> MeasurementCollector | None: + """Return the active collector, if measurement is enabled.""" + return _ACTIVE_COLLECTOR.get() + + +@contextmanager +def stage_timer(stage: str, **fields: Any) -> Iterator[dict[str, Any]]: + """Record wall time for a stage when collection is active.""" + collector = current_collector() + if collector is None: + yield fields + return + + started = time.perf_counter() + status = "completed" + try: + yield fields + except BaseException: + status = "error" + raise + finally: + elapsed_sec = time.perf_counter() - started + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_stage(stage: str, *, elapsed_sec: float, status: str = "completed", **fields: Any) -> None: + """Record a pre-timed stage measurement if collection is active.""" + collector = current_collector() + if collector is None: + return + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_ndd_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, +) -> None: + """Record one DataDesigner workflow execution through the adapter boundary.""" + collector = current_collector() + if collector is None: + return + observed_usage = _summarize_model_usage(model_usage) + collector.record( + "ndd_workflow", + workflow_name=workflow_name, + status=status, + model_aliases=sorted(set(model_aliases)), + input_row_count=input_row_count, + seed_row_count=seed_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names or [], + model_usage=dict(model_usage or {}), + **observed_usage, + **_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=input_row_count, + output_row_count=output_row_count, + total_tokens=observed_usage["observed_total_tokens"], + total_requests=observed_usage["observed_total_requests"], + successful_requests=observed_usage["observed_successful_requests"], + ), + ) + + +def record_run_metadata( + *, + config: Any, + data: Any, + mode: str, + strategy: str, + input_row_count: int, + preview_num_records: int | None, + model_configs: list[Any], +) -> None: + """Record sanitized run/config metadata once per anonymizer run.""" + collector = current_collector() + if collector is None: + return + + detect = getattr(config, "detect", None) + entity_labels = getattr(detect, "entity_labels", None) + if entity_labels is None: + from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS + + entity_label_count = len(DEFAULT_ENTITY_LABELS) + else: + entity_label_count = len(entity_labels) + source = str(getattr(data, "source", "")) + source_metadata = _source_metadata(source) + collector.record( + "run", + mode=mode, + strategy=strategy, + input_row_count=input_row_count, + preview_num_records=preview_num_records, + source_hash=collector.record_hash(row_index="source", text=source), + input_source=source_metadata, + input_text_column=str(getattr(data, "text_column", "")), + input_has_id_column=bool(getattr(data, "id_column", None)), + input_has_data_summary=bool(getattr(data, "data_summary", None)), + detect={ + "gliner_threshold": getattr(detect, "gliner_threshold", None), + "entity_label_source": "custom" if entity_labels is not None else "default", + "entity_label_count": entity_label_count, + "entity_labels": list(entity_labels) if entity_labels is not None else None, + "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), + "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), + }, + replace=_replace_config_metadata(getattr(config, "replace", None)), + rewrite=_rewrite_config_metadata(getattr(config, "rewrite", None)), + models=[_model_config_metadata(model_config) for model_config in model_configs], + runtime=_runtime_metadata(), + ) + + +def record_record_metrics( + dataframe: pd.DataFrame, + *, + mode: str, + strategy: str, + text_column: str, + validation_max_entities_per_call: int, +) -> None: + """Record per-row count, length, and nominal-call metrics from a trace DataFrame.""" + collector = current_collector() + if collector is None: + return + if not collector.record_level: + return + + from anonymizer.engine.constants import ( + COL_ANY_HIGH_LEAKED, + COL_ENTITIES_BY_VALUE, + COL_FINAL_ENTITIES, + COL_LEAKAGE_MASS, + COL_NEEDS_HUMAN_REVIEW, + COL_NEEDS_REPAIR, + COL_REPAIR_ITERATIONS, + COL_REPLACEMENT_MAP, + COL_SEED_VALIDATION_CANDIDATES, + COL_UTILITY_SCORE, + COL_WEIGHTED_LEAKAGE_RATE, + ) + + ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) + for row_index, row in dataframe.iterrows(): + text = str(row.get(text_column, "")) + text_length_tokens = _count_text_tokens(text) + final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) + final_entity_count = len(final_entities) + ground_truth_metrics = _entity_ground_truth_metrics( + final_entities, + _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None, + ) + replacement_metrics = ( + _replacement_map_metrics(row.get(COL_REPLACEMENT_MAP)) if COL_REPLACEMENT_MAP in dataframe.columns else {} + ) + detected_candidate_count = ( + _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) + if COL_SEED_VALIDATION_CANDIDATES in dataframe.columns + else None + ) + validation_chunk_count = _validation_chunk_count( + detected_candidate_count, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + grouped_entity_count = ( + _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) + if COL_ENTITIES_BY_VALUE in dataframe.columns + else final_entity_count + ) + repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) + llm_calls_by_stage = estimate_llm_calls_by_stage( + mode=mode, + strategy=strategy, + has_grouped_entities=grouped_entity_count > 0, + validation_chunk_count=validation_chunk_count, + repair_iterations=repair_iterations, + ) + total_estimated = ( + sum(llm_calls_by_stage.values()) + if all(value is not None for value in llm_calls_by_stage.values()) + else None + ) + collector.record( + "record", + mode=mode, + strategy=strategy, + row_index=_safe_row_index(row_index), + record_hash=collector.record_hash(row_index=row_index, text=text), + text_length_chars=len(text), + text_length_chars_bucket=_size_bucket(len(text)), + text_length_tokens=text_length_tokens, + text_length_tokens_bucket=_size_bucket(text_length_tokens), + final_entity_count=final_entity_count, + final_entity_label_counts=dict( + sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) + ), + **ground_truth_metrics, + **replacement_metrics, + utility_score=_coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in dataframe.columns else None, + leakage_mass=_coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in dataframe.columns else None, + weighted_leakage_rate=( + _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) + if COL_WEIGHTED_LEAKAGE_RATE in dataframe.columns + else None + ), + any_high_leaked=( + _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in dataframe.columns else None + ), + needs_human_review=( + _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in dataframe.columns else None + ), + needs_repair=_coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in dataframe.columns else None, + detected_candidate_count=detected_candidate_count, + validation_chunk_count=validation_chunk_count, + repair_iterations=repair_iterations if mode == "rewrite" else 0, + llm_calls_estimated_by_stage=llm_calls_by_stage, + llm_calls_estimated_total=total_estimated, + ) + + +def _source_metadata(source: str) -> dict[str, Any]: + parsed = urlparse(source) + if parsed.scheme in {"http", "https"}: + return { + "kind": "remote_file", + "scheme": parsed.scheme, + "suffix": Path(parsed.path).suffix.lower() or None, + } + if parsed.scheme == "file": + return { + "kind": "local_file", + "scheme": "file", + "suffix": Path(parsed.path).suffix.lower() or None, + } + return { + "kind": "local_file" if source else "unknown", + "scheme": None, + "suffix": Path(source).suffix.lower() or None, + } + + +def _replace_config_metadata(replace_config: Any | None) -> dict[str, Any] | None: + if replace_config is None: + return None + + metadata: dict[str, Any] = { + "strategy": type(replace_config).__name__, + "has_instructions": bool(getattr(replace_config, "instructions", None)), + } + for attr in ("normalize_label", "algorithm", "digest_length"): + if hasattr(replace_config, attr): + metadata[attr] = getattr(replace_config, attr) + if hasattr(replace_config, "format_template"): + metadata["has_format_template"] = True + return metadata + + +def _rewrite_config_metadata(rewrite_config: Any | None) -> dict[str, Any] | None: + if rewrite_config is None: + return None + return { + "risk_tolerance": _enum_value(getattr(rewrite_config, "risk_tolerance", None)), + "max_repair_iterations": getattr(rewrite_config, "max_repair_iterations", None), + "strict_entity_protection": getattr(rewrite_config, "strict_entity_protection", None), + "has_privacy_goal": bool(getattr(rewrite_config, "privacy_goal", None)), + "has_instructions": bool(getattr(rewrite_config, "instructions", None)), + } + + +def _model_config_metadata(model_config: Any) -> dict[str, Any]: + inference_parameters = getattr(model_config, "inference_parameters", None) + return { + "alias": getattr(model_config, "alias", None), + "model": getattr(model_config, "model", None), + "provider": _enum_value(getattr(model_config, "provider", None)), + "base_url": bool(getattr(model_config, "base_url", None)), + "max_parallel_requests": getattr(inference_parameters, "max_parallel_requests", None), + } + + +def _runtime_metadata() -> dict[str, Any]: + try: + anonymizer_version = version("nemo-anonymizer") + except PackageNotFoundError: + anonymizer_version = None + return { + "anonymizer_version": anonymizer_version, + "measurement_schema_version": MEASUREMENT_SCHEMA_VERSION, + "platform_machine": platform.machine(), + "platform_system": platform.system(), + "python_version": platform.python_version(), + } + + +def _enum_value(value: Any) -> Any: + return getattr(value, "value", value) + + +def _throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, + total_tokens: int | None, + total_requests: int | None, + successful_requests: int | None, +) -> dict[str, float | None]: + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + "observed_tokens_per_sec": _safe_rate(total_tokens, elapsed_sec), + "observed_requests_per_sec": _safe_rate(total_requests, elapsed_sec), + "observed_tokens_per_successful_request": _safe_ratio(total_tokens, successful_requests), + } + + +def _row_throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, +) -> dict[str, float | None]: + if input_row_count is not None and input_row_count < 0: + input_row_count = None + if output_row_count is not None and output_row_count < 0: + output_row_count = None + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + } + + +def estimate_llm_calls_by_stage( + *, + mode: str, + strategy: str, + has_grouped_entities: bool, + validation_chunk_count: int | None, + repair_iterations: int = 0, +) -> dict[str, int | None]: + """Estimate nominal model calls for one record, split by workflow stage.""" + detection_calls = None if validation_chunk_count is None else 2 + validation_chunk_count + replace_map_generation = 0 + if has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): + replace_map_generation = 1 + + if mode != "rewrite": + return { + "entity_detection": detection_calls, + "replace_map_generation": replace_map_generation, + } + + rewrite_body_calls = has_grouped_entities + return { + "entity_detection": detection_calls, + "latent_entity_detection": 1 if rewrite_body_calls else 0, + "replace_map_generation": replace_map_generation, + "rewrite_pipeline": 5 if rewrite_body_calls else 0, + "rewrite_evaluate": 3 * (1 + repair_iterations) if rewrite_body_calls else 0, + "rewrite_repair": repair_iterations if rewrite_body_calls else 0, + "rewrite_final_judge": 1 if rewrite_body_calls else 0, + } + + +def _write_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + config.write_collector(collector) + except Exception as exc: + logger.warning("Failed to write Anonymizer measurement records (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise + + +def _measurement_env_error_message(exc: SettingsError | ValidationError, *, prefix: str) -> str: + fields: set[str] = set() + if isinstance(exc, ValidationError): + for error in exc.errors(include_input=False): + loc = error.get("loc", ()) + if loc: + fields.add(str(loc[0]).upper()) + else: + error_text = str(exc).lower() + for field_name in _MeasurementEnvSettings.model_fields: + if field_name in error_text: + fields.add(field_name.upper()) + + if fields: + env_fields = ", ".join(f"{prefix}{field}" for field in sorted(fields)) + return f"Invalid Anonymizer measurement environment configuration for: {env_fields}" + return "Invalid Anonymizer measurement environment configuration" + + +def _load_measurement_env_settings(*, prefix: str) -> _MeasurementEnvSettings: + settings_factory = cast(Any, _MeasurementEnvSettings) + return cast(_MeasurementEnvSettings, settings_factory(_env_prefix=prefix)) + + +def _validation_chunk_count( + detected_candidate_count: int | None, + *, + validation_max_entities_per_call: int, +) -> int | None: + if detected_candidate_count is None: + return None + if detected_candidate_count <= 0: + return 0 + return int(math.ceil(detected_candidate_count / validation_max_entities_per_call)) + + +def _safe_row_index(row_index: object) -> int | None: + if isinstance(row_index, bool): + return None + if isinstance(row_index, Integral): + return int(row_index) + return None + + +def _entities_from_raw(raw: object) -> list[dict[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + items = cast(Mapping[str, Any], payload).get("entities", []) + elif isinstance(payload, list): + items = payload + else: + items = [] + return [dict(cast(Mapping[str, Any], item)) for item in items if isinstance(item, Mapping)] + + +def _entity_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]] | None, +) -> dict[str, Any]: + if ground_truth_entities is None: + return { + "ground_truth_entity_count": None, + "ground_truth_entity_label_counts": None, + "entity_true_positive_count": None, + "entity_false_positive_count": None, + "entity_false_negative_count": None, + "entity_precision": None, + "entity_recall": None, + "entity_f1": None, + } + + predicted = _entity_identity_set(final_entities) + expected = _entity_identity_set(ground_truth_entities) + true_positive = len(predicted & expected) + false_positive = len(predicted - expected) + false_negative = len(expected - predicted) + precision = _safe_ratio(true_positive, true_positive + false_positive) + recall = _safe_ratio(true_positive, true_positive + false_negative) + f1 = ( + None + if precision is None or recall is None or precision + recall == 0 + else 2 * precision * recall / (precision + recall) + ) + return { + "ground_truth_entity_count": len(ground_truth_entities), + "ground_truth_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in ground_truth_entities if e.get("label")).items()) + ), + "entity_true_positive_count": true_positive, + "entity_false_positive_count": false_positive, + "entity_false_negative_count": false_negative, + "entity_precision": precision, + "entity_recall": recall, + "entity_f1": f1, + } + + +def _entity_identity_set(entities: list[dict[str, Any]]) -> set[tuple[str, str]]: + identities: set[tuple[str, str]] = set() + for entity in entities: + label = entity.get("label") + value = entity.get("value") + if label is None or value is None: + continue + identities.add((str(value), str(label))) + return identities + + +def _replacement_map_metrics(raw: object) -> dict[str, Any]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + replacements_raw = cast(Mapping[str, Any], payload).get("replacements") + replacements = replacements_raw if isinstance(replacements_raw, list) else [] + elif isinstance(payload, list): + replacements = payload + else: + replacements = [] + + replacement_maps = [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] + synthetic_values = [] + for item in replacement_maps: + synthetic = item.get("replacement", item.get("synthetic")) + if synthetic is not None: + synthetic_values.append(str(synthetic)) + return { + "replacement_count": len(replacement_maps), + "replacement_label_counts": dict( + sorted(Counter(item.get("label", "") for item in replacement_maps if item.get("label")).items()) + ), + "replacement_duplicate_value_count": max(0, len(synthetic_values) - len(set(synthetic_values))), + } + + +def _count_items(raw: object, *, primary_key: str, fallback_keys: tuple[str, ...] = ()) -> int: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + payload_map = cast(Mapping[str, Any], payload) + for key in (primary_key, *fallback_keys): + items = payload_map.get(key) + if isinstance(items, list): + return len(items) + return 0 + if isinstance(payload, list): + return len(payload) + return 0 + + +def _coerce_payload(raw: object) -> object: + model_dump = getattr(raw, "model_dump", None) + if callable(model_dump): + return model_dump(mode="python") + if isinstance(raw, str): + try: + return json.loads(raw) + except json.JSONDecodeError: + return {} + if raw is None: + return {} + return raw + + +def _coerce_int(raw: object, *, default: int) -> int: + try: + return int(cast(Any, raw)) + except (TypeError, ValueError): + return default + + +def _coerce_float(raw: object) -> float | None: + try: + value = float(cast(Any, raw)) + except (TypeError, ValueError): + return None + return None if math.isnan(value) else value + + +def _coerce_bool(raw: object) -> bool | None: + if raw is None: + return None + if isinstance(raw, float) and math.isnan(raw): + return None + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered in {"true", "1", "yes"}: + return True + if lowered in {"false", "0", "no"}: + return False + return None + try: + return bool(cast(Any, raw)) + except (TypeError, ValueError): + return None + + +def _safe_rate(numerator: int | float | None, elapsed_sec: float) -> float | None: + if numerator is None or elapsed_sec <= 0: + return None + return float(numerator) / elapsed_sec + + +def _safe_ratio(numerator: int | float | None, denominator: int | float | None) -> float | None: + if numerator is None or denominator is None or denominator == 0: + return None + return float(numerator) / float(denominator) + + +def _size_bucket(value: int) -> str: + if value == 0: + return "0" + for upper in (128, 512, 2048, 8192): + if value < upper: + return f"1-{upper - 1}" if upper == 128 else f"{upper // 4}-{upper - 1}" + return "8192+" + + +def _count_text_tokens(text: str) -> int: + try: + import tiktoken + + tokenizer = tiktoken.get_encoding("cl100k_base") + return len(tokenizer.encode(text, disallowed_special=())) + except Exception: + return len(text.split()) + + +def _summarize_model_usage(model_usage: Mapping[str, Any] | None) -> dict[str, int | None]: + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + reasoning_tokens = 0 + has_reasoning_tokens = False + successful_requests = 0 + failed_requests = 0 + total_requests = 0 + + for usage in (model_usage or {}).values(): + if not isinstance(usage, Mapping): + continue + + token_usage = usage.get("token_usage") + if isinstance(token_usage, Mapping): + input_tokens += _coerce_int(token_usage.get("input_tokens"), default=0) + output_tokens += _coerce_int(token_usage.get("output_tokens"), default=0) + total_tokens += _coerce_int(token_usage.get("total_tokens"), default=0) + if token_usage.get("reasoning_tokens") is not None: + has_reasoning_tokens = True + reasoning_tokens += _coerce_int(token_usage.get("reasoning_tokens"), default=0) + + request_usage = usage.get("request_usage") + if isinstance(request_usage, Mapping): + successful_requests += _coerce_int(request_usage.get("successful_requests"), default=0) + failed_requests += _coerce_int(request_usage.get("failed_requests"), default=0) + total_requests += _coerce_int(request_usage.get("total_requests"), default=0) + + if total_tokens == 0: + total_tokens = input_tokens + output_tokens + if total_requests == 0: + total_requests = successful_requests + failed_requests + + return { + "observed_input_tokens": input_tokens, + "observed_output_tokens": output_tokens, + "observed_total_tokens": total_tokens, + "observed_reasoning_tokens": reasoning_tokens if has_reasoning_tokens else None, + "observed_successful_requests": successful_requests, + "observed_failed_requests": failed_requests, + "observed_total_requests": total_requests, + } + + +def _json_safe(value: object) -> Any: + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, list): + return [_json_safe(v) for v in value] + if isinstance(value, tuple): + return [_json_safe(v) for v in value] + if isinstance(value, set): + return sorted(_json_safe(v) for v in value) + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return str(value) diff --git a/tests/test_measurement.py b/tests/test_measurement.py new file mode 100644 index 00000000..540ef4da --- /dev/null +++ b/tests/test_measurement.py @@ -0,0 +1,667 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from types import SimpleNamespace +from typing import cast +from unittest.mock import Mock + +import numpy as np +import pandas as pd +import pytest +from data_designer.config.column_configs import LLMTextColumnConfig +from data_designer.config.models import ModelConfig +from data_designer.interface.data_designer import DataDesigner + +from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect +from anonymizer.config.models import DetectionModelSelection +from anonymizer.config.replace_strategies import Redact +from anonymizer.engine.constants import ( + COL_ANY_HIGH_LEAKED, + COL_FINAL_ENTITIES, + COL_LEAKAGE_MASS, + COL_NEEDS_HUMAN_REVIEW, + COL_NEEDS_REPAIR, + COL_REPAIR_ITERATIONS, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_SEED_VALIDATION_CANDIDATES, + COL_TEXT, + COL_UTILITY_SCORE, + COL_WEIGHTED_LEAKAGE_RATE, +) +from anonymizer.engine.detection.detection_workflow import EntityDetectionResult, EntityDetectionWorkflow +from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter +from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow +from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow +from anonymizer.interface.anonymizer import Anonymizer +from anonymizer.measurement import ( + DEFAULT_MEASUREMENT_ENV_PREFIX, + MEASUREMENT_SCHEMA_VERSION, + MeasurementCollector, + MeasurementConfig, + configured_measurement_session, + estimate_llm_calls_by_stage, + measurement_session, + record_record_metrics, + stage_timer, +) + + +def test_ndd_adapter_records_workflow_measurement_without_raw_text() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme", "Bob works at Beta"], + RECORD_ID_COLUMN: ["record-a", "record-b"], + } + ) + mock_dd = Mock(spec=DataDesigner) + mock_dd.preview.return_value = SimpleNamespace(dataset=input_df.iloc[[0]].copy()) + adapter = NddAdapter(data_designer=mock_dd) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + result = adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="detector", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="detector", + ) + ], + workflow_name="entity-detection", + preview_num_records=2, + ) + + assert len(result.failed_records) == 1 + records = [record for record in collector.records if record["record_type"] == "ndd_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection" + assert record["model_aliases"] == ["detector"] + assert record["input_row_count"] == 2 + assert record["seed_row_count"] == 2 + assert record["output_row_count"] == 1 + assert record["failed_record_count"] == 1 + assert record["elapsed_sec"] >= 0 + + serialized = json.dumps(record) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert "Bob" not in serialized + + +def test_ndd_adapter_records_datadesigner_model_usage() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme"], + RECORD_ID_COLUMN: ["record-a"], + } + ) + + class UsageStats: + def model_dump(self, *, mode: str) -> dict[str, object]: + assert mode == "json" + return { + "token_usage": { + "input_tokens": 12, + "output_tokens": 4, + "total_tokens": 16, + }, + "request_usage": { + "successful_requests": 2, + "failed_requests": 1, + "total_requests": 3, + }, + } + + class ModelRegistry: + def get_model_usage_snapshot(self) -> dict[str, UsageStats]: + return {"dummy-model": UsageStats()} + + class UsageDataDesigner: + def _create_resource_provider(self, *_args: object, **_kwargs: object) -> SimpleNamespace: + return SimpleNamespace(model_registry=ModelRegistry()) + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + self._create_resource_provider("preview-dataset", _config_builder) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, UsageDataDesigner())) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="detector", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="detector", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + record = next(record for record in collector.records if record["record_type"] == "ndd_workflow") + assert record["model_usage"]["dummy-model"]["token_usage"]["input_tokens"] == 12 + assert record["observed_input_tokens"] == 12 + assert record["observed_output_tokens"] == 4 + assert record["observed_total_tokens"] == 16 + assert record["observed_successful_requests"] == 2 + assert record["observed_failed_requests"] == 1 + assert record["observed_total_requests"] == 3 + assert record["input_rows_per_sec"] >= 0 + assert record["output_rows_per_sec"] >= 0 + assert record["observed_tokens_per_sec"] >= 0 + assert record["observed_requests_per_sec"] >= 0 + assert record["observed_tokens_per_successful_request"] == 8 + + +def test_anonymizer_records_per_record_measurement_without_raw_pii(tmp_path: Path) -> None: + input_csv = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_csv, index=False) + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Acme", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + validation_candidates = { + "candidates": [ + {"value": "Alice", "label": "first_name"}, + {"value": "Acme", "label": "company_name"}, + ] + } + detection_workflow = Mock(spec=EntityDetectionWorkflow) + detection_workflow.run.return_value = EntityDetectionResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [validation_candidates], + } + ), + failed_records=[], + ) + replace_runner = Mock(spec=ReplacementWorkflow) + replace_runner.run.return_value = ReplacementResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_REPLACED_TEXT: ["[REDACTED] works at [REDACTED]"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [validation_candidates], + } + ), + failed_records=[], + ) + rewrite_runner = Mock(spec=RewriteWorkflow) + rewrite_runner.run.return_value = RewriteResult(dataframe=pd.DataFrame(), failed_records=[]) + anonymizer = Anonymizer( + detection_workflow=detection_workflow, + replace_runner=replace_runner, + rewrite_runner=rewrite_runner, + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + anonymizer.run( + config=AnonymizerConfig(replace=Redact(), detect=Detect(validation_max_entities_per_call=2)), + data=AnonymizerInput(source=str(input_csv)), + ) + + record_metrics = [record for record in collector.records if record["record_type"] == "record"] + assert len(record_metrics) == 1 + record = record_metrics[0] + assert record["mode"] == "replace" + assert record["strategy"] == "Redact" + assert record["text_length_chars"] == len("Alice works at Acme") + assert record["text_length_chars_bucket"] == "1-127" + assert record["text_length_tokens"] > 0 + assert record["text_length_tokens_bucket"] == "1-127" + assert record["final_entity_count"] == 2 + assert record["final_entity_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["detected_candidate_count"] == 2 + assert record["validation_chunk_count"] == 1 + assert record["llm_calls_estimated_by_stage"] == { + "entity_detection": 3, + "replace_map_generation": 0, + } + assert record["llm_calls_estimated_total"] == 3 + assert len(record["record_hash"]) == 64 + + stage_records = [record for record in collector.records if record["record_type"] == "stage"] + assert any(record["stage"] == "Anonymizer._run_internal" for record in stage_records) + assert any(record.get("input_rows_per_sec") is not None for record in stage_records) + + run_records = [record for record in collector.records if record["record_type"] == "run"] + assert len(run_records) == 1 + run_record = run_records[0] + assert run_record["mode"] == "replace" + assert run_record["strategy"] == "Redact" + assert run_record["input_row_count"] == 1 + assert run_record["input_source"] == {"kind": "local_file", "scheme": None, "suffix": ".csv"} + assert run_record["input_text_column"] == "text" + assert run_record["input_has_id_column"] is False + assert run_record["input_has_data_summary"] is False + assert run_record["detect"]["entity_label_source"] == "default" + assert run_record["detect"]["entity_label_count"] > 0 + assert run_record["replace"]["strategy"] == "Redact" + assert run_record["replace"]["normalize_label"] is True + assert len(run_record["source_hash"]) == 64 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert str(input_csv) not in serialized + + +def test_anonymizer_measurement_config_writes_jsonl(tmp_path: Path) -> None: + input_csv = tmp_path / "input.csv" + output_jsonl = tmp_path / "measurements.jsonl" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_csv, index=False) + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + ] + } + detection_workflow = Mock(spec=EntityDetectionWorkflow) + detection_workflow.run.return_value = EntityDetectionResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + } + ), + failed_records=[], + ) + replace_runner = Mock(spec=ReplacementWorkflow) + replace_runner.run.return_value = ReplacementResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_REPLACED_TEXT: ["[REDACTED] works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + } + ), + failed_records=[], + ) + anonymizer = Anonymizer( + detection_workflow=detection_workflow, + replace_runner=replace_runner, + rewrite_runner=Mock(spec=RewriteWorkflow), + ) + + with configured_measurement_session( + MeasurementConfig( + output_path=output_jsonl, + run_id="measurement-run", + record_hash_key="test-key", + run_tags={"config_id": "redact-default", "workload_id": "unit-small"}, + ) + ): + anonymizer.run( + config=AnonymizerConfig(replace=Redact()), + data=AnonymizerInput(source=str(input_csv)), + ) + + records = [json.loads(line) for line in output_jsonl.read_text(encoding="utf-8").splitlines()] + assert {record["record_type"] for record in records} >= {"record", "run", "stage"} + assert {record["run_id"] for record in records} == {"measurement-run"} + assert {record["schema_version"] for record in records} == {MEASUREMENT_SCHEMA_VERSION} + assert {record["run_tags"]["workload_id"] for record in records} == {"unit-small"} + assert all(isinstance(record["timestamp_unix_sec"], float) for record in records) + + serialized = json.dumps(records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert str(input_csv) not in serialized + + +def test_measurement_config_record_level_false_skips_record_rows(tmp_path: Path) -> None: + output_json = tmp_path / "measurements.json" + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ) + + with configured_measurement_session( + MeasurementConfig( + output_path=output_json, + output_format="json", + record_level=False, + run_id="stage-only", + record_hash_key="test-key", + ) + ): + with stage_timer("example", input_row_count=1): + pass + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + records = json.loads(output_json.read_text(encoding="utf-8")) + assert [record["record_type"] for record in records] == ["stage"] + assert records[0]["run_id"] == "stage-only" + assert records[0]["input_rows_per_sec"] >= 0 + + +def test_measurement_config_from_env_returns_none_without_output_path( + monkeypatch: pytest.MonkeyPatch, +) -> None: + prefix = "ANON_TEST_EMPTY_MEASUREMENT_" + monkeypatch.setenv(f"{prefix}RUN_ID", "env-run") + + assert MeasurementConfig.from_env(prefix=prefix) is None + + +def test_measurement_config_from_env_parses_supported_values( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + output_path = tmp_path / "measurements.json" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(output_path)) + monkeypatch.setenv(f"{prefix}OUTPUT_FORMAT", "json") + monkeypatch.setenv(f"{prefix}RECORD_LEVEL", "false") + monkeypatch.setenv(f"{prefix}FAIL_ON_WRITE_ERROR", "true") + monkeypatch.setenv(f"{prefix}RUN_ID", "env-run") + monkeypatch.setenv(f"{prefix}RUN_TAGS", '{"config_id": "redact-default", "attempt": 2}') + + config = MeasurementConfig.from_env(prefix=prefix) + + assert config is not None + assert config.output_path == str(output_path) + assert config.output_format == "json" + assert config.record_level is False + assert config.fail_on_write_error is True + assert config.run_id == "env-run" + assert config.run_tags == {"config_id": "redact-default", "attempt": 2} + + +def test_measurement_config_from_sources_keeps_env_opt_in( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(tmp_path / "env.jsonl")) + explicit = MeasurementConfig(output_path=tmp_path / "explicit.jsonl") + + assert MeasurementConfig.from_sources(env=False, prefix=prefix) is None + assert MeasurementConfig.from_sources(explicit=explicit, env=True, prefix=prefix) is explicit + + +def test_measurement_config_from_env_reports_sanitized_invalid_values( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + prefix = "ANON_TEST_MEASUREMENT_" + secret_payload = "sk-secret-token-value" + monkeypatch.setenv(f"{prefix}OUTPUT_PATH", str(tmp_path / "measurements.jsonl")) + monkeypatch.setenv(f"{prefix}RUN_TAGS", secret_payload) + + with pytest.raises(ValueError) as exc_info: + MeasurementConfig.from_env(prefix=prefix) + + message = str(exc_info.value) + assert f"{prefix}RUN_TAGS" in message + assert secret_payload not in message + assert str(tmp_path) not in message + + +def test_default_measurement_env_prefix_is_anonymizer_scoped() -> None: + assert DEFAULT_MEASUREMENT_ENV_PREFIX == "ANONYMIZER_MEASUREMENT_" + + +def test_measurement_config_write_errors_are_best_effort( + caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError(f"cannot write {_self.output_path}") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + caplog.set_level(logging.WARNING, logger="anonymizer.measurement") + output_path = tmp_path / "secret-output-sk-live-value.jsonl" + + with configured_measurement_session(MeasurementConfig(output_path=output_path)) as collector: + assert collector is not None + collector.record("example") + + assert "Failed to write Anonymizer measurement records" in caplog.text + assert str(output_path) not in caplog.text + assert "sk-live-value" not in caplog.text + + +def test_measurement_config_strict_write_errors_can_fail_clean_body( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError("cannot write") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + + with pytest.raises(OSError, match="cannot write"): + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", fail_on_write_error=True) + ) as collector: + assert collector is not None + collector.record("example") + + +def test_measurement_config_write_errors_do_not_mask_body_errors( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError("cannot write") + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + + with pytest.raises(RuntimeError, match="body failed"): + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", fail_on_write_error=True) + ) as collector: + assert collector is not None + collector.record("example") + raise RuntimeError("body failed") + + +def test_record_metrics_capture_generic_counts_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Acme", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + ground_truth_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Beta", "label": "company_name", "start_position": 15, "end_position": 19}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + {"original": "Acme", "label": "company_name", "synthetic": "Maya"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [final_entities], + "ground_truth_entities": [ground_truth_entities], + COL_REPLACEMENT_MAP: [replacement_map], + COL_SEED_VALIDATION_CANDIDATES: [{"candidates": final_entities["entities"]}], + COL_REPAIR_ITERATIONS: [2], + COL_UTILITY_SCORE: [0.82], + COL_LEAKAGE_MASS: [0.2], + COL_WEIGHTED_LEAKAGE_RATE: [0.1], + COL_ANY_HIGH_LEAKED: [False], + COL_NEEDS_HUMAN_REVIEW: [True], + COL_NEEDS_REPAIR: [False], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="rewrite", + strategy="Rewrite", + text_column=COL_TEXT, + validation_max_entities_per_call=2, + ) + + record = collector.records[0] + assert record["ground_truth_entity_count"] == 2 + assert record["ground_truth_entity_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["entity_true_positive_count"] == 1 + assert record["entity_false_positive_count"] == 1 + assert record["entity_false_negative_count"] == 1 + assert record["entity_precision"] == 0.5 + assert record["entity_recall"] == 0.5 + assert record["entity_f1"] == 0.5 + assert record["replacement_count"] == 2 + assert record["replacement_label_counts"] == {"company_name": 1, "first_name": 1} + assert record["replacement_duplicate_value_count"] == 1 + assert record["repair_iterations"] == 2 + assert record["utility_score"] == 0.82 + assert record["leakage_mass"] == 0.2 + assert record["weighted_leakage_rate"] == 0.1 + assert record["any_high_leaked"] is False + assert record["needs_human_review"] is True + assert record["needs_repair"] is False + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "Acme" not in serialized + assert "Beta" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_normalizes_integral_row_index_types() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice works at Acme"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + }, + index=pd.Index([np.int64(7)]), + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + assert collector.records[0]["row_index"] == 7 + + +def test_record_hash_uses_run_scoped_secret_by_default() -> None: + first = MeasurementCollector() + second = MeasurementCollector() + + assert first.record_hash(row_index=0, text="Alice works at Acme") != second.record_hash( + row_index=0, + text="Alice works at Acme", + ) + + +def test_stage_timer_records_errors() -> None: + workflow = EntityDetectionWorkflow(adapter=Mock(spec=NddAdapter)) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector), pytest.raises(ValueError, match="privacy_goal is required"): + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=[], + selected_models=DetectionModelSelection( + entity_detector="detector", + entity_validator=["validator"], + entity_augmenter="augmenter", + latent_detector="latent", + ), + gliner_detection_threshold=0.3, + tag_latent_entities=True, + privacy_goal=None, + ) + + stage_records = [record for record in collector.records if record["record_type"] == "stage"] + assert len(stage_records) == 1 + record = stage_records[0] + assert record["schema_version"] == MEASUREMENT_SCHEMA_VERSION + assert record["record_type"] == "stage" + assert record["run_id"] == collector.run_id + assert record["run_tags"] == {} + assert isinstance(record["timestamp_unix_sec"], float) + assert record["stage"] == "EntityDetectionWorkflow.run" + assert record["status"] == "error" + assert record["elapsed_sec"] >= 0 + assert record["input_row_count"] == 1 + assert record["input_rows_per_sec"] >= 0 + assert record["output_rows_per_sec"] is None + assert record["tag_latent_entities"] is True + + +def test_rewrite_llm_call_estimate_splits_by_stage() -> None: + calls = estimate_llm_calls_by_stage( + mode="rewrite", + strategy="Rewrite", + has_grouped_entities=True, + validation_chunk_count=2, + repair_iterations=2, + ) + + assert calls == { + "entity_detection": 4, + "latent_entity_detection": 1, + "replace_map_generation": 1, + "rewrite_pipeline": 5, + "rewrite_evaluate": 9, + "rewrite_repair": 2, + "rewrite_final_judge": 1, + } + + +def test_rewrite_llm_call_estimate_skips_rewrite_body_without_entities() -> None: + calls = estimate_llm_calls_by_stage( + mode="rewrite", + strategy="Rewrite", + has_grouped_entities=False, + validation_chunk_count=0, + repair_iterations=2, + ) + + assert calls == { + "entity_detection": 2, + "latent_entity_detection": 0, + "replace_map_generation": 0, + "rewrite_pipeline": 0, + "rewrite_evaluate": 0, + "rewrite_repair": 0, + "rewrite_final_judge": 0, + } diff --git a/uv.lock b/uv.lock index ae31c1ea..45102849 100644 --- a/uv.lock +++ b/uv.lock @@ -2425,6 +2425,7 @@ dependencies = [ { name = "data-designer" }, { name = "httpx" }, { name = "pydantic" }, + { name = "pydantic-settings" }, { name = "pygments" }, { name = "tiktoken" }, ] @@ -2462,6 +2463,7 @@ requires-dist = [ { name = "data-designer", specifier = "==0.6.0" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "pydantic", specifier = ">=2.9,<3" }, + { name = "pydantic-settings", specifier = ">=2.12,<3" }, { name = "pygments", specifier = ">=2.20.0" }, { name = "tiktoken", specifier = ">=0.9.0" }, ] From 019d852c7fcef4e794755f398baf4b312c160cdb Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Wed, 3 Jun 2026 16:21:13 +0000 Subject: [PATCH 02/26] feat: add measurement benchmark tooling Signed-off-by: Aaron Gonzales --- src/anonymizer/measurement.py | 391 ++++++++----- tests/test_measurement.py | 12 + tests/tools/test_measurement_tools.py | 136 +++++ tools/measurement/README.md | 328 +++++++++++ .../measurement/examples/repo-data-smoke.yaml | 25 + tools/measurement/export_measurements.py | 205 +++++++ tools/measurement/run_benchmarks.py | 529 ++++++++++++++++++ 7 files changed, 1471 insertions(+), 155 deletions(-) create mode 100644 tests/tools/test_measurement_tools.py create mode 100644 tools/measurement/README.md create mode 100644 tools/measurement/examples/repo-data-smoke.yaml create mode 100755 tools/measurement/export_measurements.py create mode 100755 tools/measurement/run_benchmarks.py diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index a56d8d3e..252b6ed8 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -26,6 +26,8 @@ from pydantic import Field, ValidationError from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError +from anonymizer.engine.constants import COL_FINAL_ENTITIES + if TYPE_CHECKING: import pandas as pd @@ -322,30 +324,36 @@ def record_ndd_workflow( if collector is None: return observed_usage = _summarize_model_usage(model_usage) - collector.record( - "ndd_workflow", - workflow_name=workflow_name, - status=status, - model_aliases=sorted(set(model_aliases)), - input_row_count=input_row_count, - seed_row_count=seed_row_count, - output_row_count=output_row_count, - failed_record_count=failed_record_count, - elapsed_sec=elapsed_sec, - preview_num_records=preview_num_records, - column_count=column_count, - column_names=column_names or [], - model_usage=dict(model_usage or {}), + workflow_fields = { + "workflow_name": workflow_name, + "status": status, + "model_aliases": sorted(set(model_aliases)), + "input_row_count": input_row_count, + "seed_row_count": seed_row_count, + "output_row_count": output_row_count, + "failed_record_count": failed_record_count, + "elapsed_sec": elapsed_sec, + "preview_num_records": preview_num_records, + "column_count": column_count, + "column_names": column_names or [], + "model_usage": dict(model_usage or {}), + } + collector.record("ndd_workflow", **_ndd_workflow_fields(workflow_fields, observed_usage)) + + +def _ndd_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: + return { + **fields, **observed_usage, **_throughput_fields( - elapsed_sec=elapsed_sec, - input_row_count=input_row_count, - output_row_count=output_row_count, + elapsed_sec=cast(float, fields["elapsed_sec"]), + input_row_count=cast(int, fields["input_row_count"]), + output_row_count=cast(int | None, fields["output_row_count"]), total_tokens=observed_usage["observed_total_tokens"], total_requests=observed_usage["observed_total_requests"], successful_requests=observed_usage["observed_successful_requests"], ), - ) + } def record_run_metadata( @@ -364,15 +372,7 @@ def record_run_metadata( return detect = getattr(config, "detect", None) - entity_labels = getattr(detect, "entity_labels", None) - if entity_labels is None: - from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS - - entity_label_count = len(DEFAULT_ENTITY_LABELS) - else: - entity_label_count = len(entity_labels) source = str(getattr(data, "source", "")) - source_metadata = _source_metadata(source) collector.record( "run", mode=mode, @@ -380,18 +380,11 @@ def record_run_metadata( input_row_count=input_row_count, preview_num_records=preview_num_records, source_hash=collector.record_hash(row_index="source", text=source), - input_source=source_metadata, + input_source=_source_metadata(source), input_text_column=str(getattr(data, "text_column", "")), input_has_id_column=bool(getattr(data, "id_column", None)), input_has_data_summary=bool(getattr(data, "data_summary", None)), - detect={ - "gliner_threshold": getattr(detect, "gliner_threshold", None), - "entity_label_source": "custom" if entity_labels is not None else "default", - "entity_label_count": entity_label_count, - "entity_labels": list(entity_labels) if entity_labels is not None else None, - "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), - "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), - }, + detect=_detect_config_metadata(detect), replace=_replace_config_metadata(getattr(config, "replace", None)), rewrite=_rewrite_config_metadata(getattr(config, "rewrite", None)), models=[_model_config_metadata(model_config) for model_config in model_configs], @@ -409,101 +402,179 @@ def record_record_metrics( ) -> None: """Record per-row count, length, and nominal-call metrics from a trace DataFrame.""" collector = current_collector() - if collector is None: - return - if not collector.record_level: + if collector is None or not collector.record_level: return + ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) + columns = set(dataframe.columns) + for row_index, row in dataframe.iterrows(): + final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) + collector.record( + "record", + **_base_record_fields( + collector=collector, + row_index=row_index, + row=row, + text_column=text_column, + mode=mode, + strategy=strategy, + ), + **_entity_record_fields(row, final_entities=final_entities, ground_truth_column=ground_truth_column), + **_replacement_record_fields(row, columns=columns), + **_rewrite_record_fields(row, columns=columns), + **_llm_record_fields( + row, + columns=columns, + mode=mode, + strategy=strategy, + final_entity_count=len(final_entities), + validation_max_entities_per_call=validation_max_entities_per_call, + ), + ) + + +def _detect_config_metadata(detect: Any | None) -> dict[str, Any]: + entity_labels = getattr(detect, "entity_labels", None) + if entity_labels is None: + from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS + + entity_label_count = len(DEFAULT_ENTITY_LABELS) + else: + entity_label_count = len(entity_labels) + return { + "gliner_threshold": getattr(detect, "gliner_threshold", None), + "entity_label_source": "custom" if entity_labels is not None else "default", + "entity_label_count": entity_label_count, + "entity_labels": list(entity_labels) if entity_labels is not None else None, + "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), + "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), + } + + +def _base_record_fields( + *, + collector: MeasurementCollector, + row_index: object, + row: Any, + text_column: str, + mode: str, + strategy: str, +) -> dict[str, Any]: + text = str(row.get(text_column, "")) + text_length_tokens = _count_text_tokens(text) + return { + "mode": mode, + "strategy": strategy, + "row_index": _safe_row_index(row_index), + "record_hash": collector.record_hash(row_index=row_index, text=text), + "text_length_chars": len(text), + "text_length_chars_bucket": _size_bucket(len(text)), + "text_length_tokens": text_length_tokens, + "text_length_tokens_bucket": _size_bucket(text_length_tokens), + } + + +def _entity_record_fields( + row: Any, + *, + final_entities: list[dict[str, Any]], + ground_truth_column: str | None, +) -> dict[str, Any]: + ground_truth_entities = ( + _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None + ) + return { + "final_entity_count": len(final_entities), + "final_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) + ), + **_entity_ground_truth_metrics(final_entities, ground_truth_entities), + } + + +def _replacement_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPLACEMENT_MAP + + if COL_REPLACEMENT_MAP not in columns: + return {} + return _replacement_map_metrics(row.get(COL_REPLACEMENT_MAP)) + + +def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: from anonymizer.engine.constants import ( COL_ANY_HIGH_LEAKED, - COL_ENTITIES_BY_VALUE, - COL_FINAL_ENTITIES, COL_LEAKAGE_MASS, COL_NEEDS_HUMAN_REVIEW, COL_NEEDS_REPAIR, - COL_REPAIR_ITERATIONS, - COL_REPLACEMENT_MAP, - COL_SEED_VALIDATION_CANDIDATES, COL_UTILITY_SCORE, COL_WEIGHTED_LEAKAGE_RATE, ) - ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) - for row_index, row in dataframe.iterrows(): - text = str(row.get(text_column, "")) - text_length_tokens = _count_text_tokens(text) - final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) - final_entity_count = len(final_entities) - ground_truth_metrics = _entity_ground_truth_metrics( - final_entities, - _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None, - ) - replacement_metrics = ( - _replacement_map_metrics(row.get(COL_REPLACEMENT_MAP)) if COL_REPLACEMENT_MAP in dataframe.columns else {} - ) - detected_candidate_count = ( - _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) - if COL_SEED_VALIDATION_CANDIDATES in dataframe.columns - else None - ) - validation_chunk_count = _validation_chunk_count( - detected_candidate_count, - validation_max_entities_per_call=validation_max_entities_per_call, - ) - grouped_entity_count = ( - _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) - if COL_ENTITIES_BY_VALUE in dataframe.columns - else final_entity_count - ) - repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) - llm_calls_by_stage = estimate_llm_calls_by_stage( - mode=mode, - strategy=strategy, - has_grouped_entities=grouped_entity_count > 0, - validation_chunk_count=validation_chunk_count, - repair_iterations=repair_iterations, - ) - total_estimated = ( - sum(llm_calls_by_stage.values()) - if all(value is not None for value in llm_calls_by_stage.values()) - else None - ) - collector.record( - "record", - mode=mode, - strategy=strategy, - row_index=_safe_row_index(row_index), - record_hash=collector.record_hash(row_index=row_index, text=text), - text_length_chars=len(text), - text_length_chars_bucket=_size_bucket(len(text)), - text_length_tokens=text_length_tokens, - text_length_tokens_bucket=_size_bucket(text_length_tokens), - final_entity_count=final_entity_count, - final_entity_label_counts=dict( - sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) - ), - **ground_truth_metrics, - **replacement_metrics, - utility_score=_coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in dataframe.columns else None, - leakage_mass=_coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in dataframe.columns else None, - weighted_leakage_rate=( - _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) - if COL_WEIGHTED_LEAKAGE_RATE in dataframe.columns - else None - ), - any_high_leaked=( - _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in dataframe.columns else None - ), - needs_human_review=( - _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in dataframe.columns else None - ), - needs_repair=_coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in dataframe.columns else None, - detected_candidate_count=detected_candidate_count, - validation_chunk_count=validation_chunk_count, - repair_iterations=repair_iterations if mode == "rewrite" else 0, - llm_calls_estimated_by_stage=llm_calls_by_stage, - llm_calls_estimated_total=total_estimated, - ) + return { + "utility_score": _coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in columns else None, + "leakage_mass": _coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in columns else None, + "weighted_leakage_rate": ( + _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) if COL_WEIGHTED_LEAKAGE_RATE in columns else None + ), + "any_high_leaked": _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in columns else None, + "needs_human_review": ( + _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in columns else None + ), + "needs_repair": _coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in columns else None, + } + + +def _llm_record_fields( + row: Any, + *, + columns: set[str], + mode: str, + strategy: str, + final_entity_count: int, + validation_max_entities_per_call: int, +) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPAIR_ITERATIONS + + detected_candidate_count = _detected_candidate_count(row, columns=columns) + validation_chunk_count = _validation_chunk_count( + detected_candidate_count, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + grouped_entity_count = _grouped_entity_count(row, columns=columns, final_entity_count=final_entity_count) + repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) + calls_by_stage = estimate_llm_calls_by_stage( + mode=mode, + strategy=strategy, + has_grouped_entities=grouped_entity_count > 0, + validation_chunk_count=validation_chunk_count, + repair_iterations=repair_iterations, + ) + total_estimated = ( + sum(calls_by_stage.values()) if all(value is not None for value in calls_by_stage.values()) else None + ) + return { + "detected_candidate_count": detected_candidate_count, + "validation_chunk_count": validation_chunk_count, + "repair_iterations": repair_iterations if mode == "rewrite" else 0, + "llm_calls_estimated_by_stage": calls_by_stage, + "llm_calls_estimated_total": total_estimated, + } + + +def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: + from anonymizer.engine.constants import COL_SEED_VALIDATION_CANDIDATES + + if COL_SEED_VALIDATION_CANDIDATES not in columns: + return None + return _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) + + +def _grouped_entity_count(row: Any, *, columns: set[str], final_entity_count: int) -> int: + from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE + + if COL_ENTITIES_BY_VALUE not in columns: + return final_entity_count + return _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) def _source_metadata(source: str) -> dict[str, Any]: @@ -892,50 +963,58 @@ def _count_text_tokens(text: str) -> int: def _summarize_model_usage(model_usage: Mapping[str, Any] | None) -> dict[str, int | None]: - input_tokens = 0 - output_tokens = 0 - total_tokens = 0 - reasoning_tokens = 0 - has_reasoning_tokens = False - successful_requests = 0 - failed_requests = 0 - total_requests = 0 - + totals = _empty_model_usage_totals() for usage in (model_usage or {}).values(): if not isinstance(usage, Mapping): continue + _add_model_usage_totals(totals, usage) + + if totals["total_tokens"] == 0: + totals["total_tokens"] = totals["input_tokens"] + totals["output_tokens"] + if totals["total_requests"] == 0: + totals["total_requests"] = totals["successful_requests"] + totals["failed_requests"] - token_usage = usage.get("token_usage") - if isinstance(token_usage, Mapping): - input_tokens += _coerce_int(token_usage.get("input_tokens"), default=0) - output_tokens += _coerce_int(token_usage.get("output_tokens"), default=0) - total_tokens += _coerce_int(token_usage.get("total_tokens"), default=0) - if token_usage.get("reasoning_tokens") is not None: - has_reasoning_tokens = True - reasoning_tokens += _coerce_int(token_usage.get("reasoning_tokens"), default=0) - - request_usage = usage.get("request_usage") - if isinstance(request_usage, Mapping): - successful_requests += _coerce_int(request_usage.get("successful_requests"), default=0) - failed_requests += _coerce_int(request_usage.get("failed_requests"), default=0) - total_requests += _coerce_int(request_usage.get("total_requests"), default=0) - - if total_tokens == 0: - total_tokens = input_tokens + output_tokens - if total_requests == 0: - total_requests = successful_requests + failed_requests + return { + "observed_input_tokens": totals["input_tokens"], + "observed_output_tokens": totals["output_tokens"], + "observed_total_tokens": totals["total_tokens"], + "observed_reasoning_tokens": totals["reasoning_tokens"] if totals["has_reasoning_tokens"] else None, + "observed_successful_requests": totals["successful_requests"], + "observed_failed_requests": totals["failed_requests"], + "observed_total_requests": totals["total_requests"], + } + +def _empty_model_usage_totals() -> dict[str, int | bool]: return { - "observed_input_tokens": input_tokens, - "observed_output_tokens": output_tokens, - "observed_total_tokens": total_tokens, - "observed_reasoning_tokens": reasoning_tokens if has_reasoning_tokens else None, - "observed_successful_requests": successful_requests, - "observed_failed_requests": failed_requests, - "observed_total_requests": total_requests, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + "has_reasoning_tokens": False, + "successful_requests": 0, + "failed_requests": 0, + "total_requests": 0, } +def _add_model_usage_totals(totals: dict[str, int | bool], usage: Mapping[str, Any]) -> None: + token_usage = usage.get("token_usage") + if isinstance(token_usage, Mapping): + totals["input_tokens"] += _coerce_int(token_usage.get("input_tokens"), default=0) + totals["output_tokens"] += _coerce_int(token_usage.get("output_tokens"), default=0) + totals["total_tokens"] += _coerce_int(token_usage.get("total_tokens"), default=0) + if token_usage.get("reasoning_tokens") is not None: + totals["has_reasoning_tokens"] = True + totals["reasoning_tokens"] += _coerce_int(token_usage.get("reasoning_tokens"), default=0) + + request_usage = usage.get("request_usage") + if isinstance(request_usage, Mapping): + totals["successful_requests"] += _coerce_int(request_usage.get("successful_requests"), default=0) + totals["failed_requests"] += _coerce_int(request_usage.get("failed_requests"), default=0) + totals["total_requests"] += _coerce_int(request_usage.get("total_requests"), default=0) + + def _json_safe(value: object) -> Any: if isinstance(value, dict): return {str(k): _json_safe(v) for k, v in value.items()} @@ -944,7 +1023,9 @@ def _json_safe(value: object) -> Any: if isinstance(value, tuple): return [_json_safe(v) for v in value] if isinstance(value, set): - return sorted(_json_safe(v) for v in value) + return sorted((_json_safe(v) for v in value), key=str) + if isinstance(value, float) and not math.isfinite(value): + return None if isinstance(value, (str, int, float, bool)) or value is None: return value return str(value) diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 540ef4da..098a3096 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -329,6 +329,18 @@ def test_anonymizer_measurement_config_writes_jsonl(tmp_path: Path) -> None: assert str(input_csv) not in serialized +def test_measurement_records_write_strict_json_safe_values(tmp_path: Path) -> None: + output_jsonl = tmp_path / "measurements.jsonl" + collector = MeasurementCollector(record_hash_key="test-key") + collector.record("run", non_finite=float("nan"), mixed_set={1, "two"}) + + collector.write_jsonl(output_jsonl) + + payload = json.loads(output_jsonl.read_text(encoding="utf-8")) + assert payload["non_finite"] is None + assert payload["mixed_set"] == [1, "two"] + + def test_measurement_config_record_level_false_skips_record_rows(tmp_path: Path) -> None: output_json = tmp_path / "measurements.json" dataframe = pd.DataFrame( diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py new file mode 100644 index 00000000..3448a71f --- /dev/null +++ b/tests/tools/test_measurement_tools.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest +from pydantic import ValidationError + +from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: + tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") + dataframe = pd.DataFrame( + [ + {"record_type": "run", "run_id": "case-a", "run_tags": {"suite_id": "suite-a"}}, + {"record_type": "stage", "run_id": "case-a", "stage": "detect", "metrics": {"rows": 2}}, + ] + ) + + result = tool.export_tables( + dataframe, + input_path=tmp_path / "measurements.jsonl", + output_dir=tmp_path / "tables", + export_format=tool.ExportFormat.csv, + overwrite=False, + ) + + assert result.total_rows == 2 + assert {table.record_type for table in result.tables} == {"run", "stage"} + assert (tmp_path / "tables/run.csv").exists() + assert (tmp_path / "tables/stage.csv").exists() + assert (tmp_path / "tables/manifest.json").exists() + + +def test_benchmark_spec_validates_matrix_references(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-suite +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +matrix: + - workload: missing + config: redact +""", + encoding="utf-8", + ) + + with pytest.raises(ValidationError, match="unknown workload"): + tool.load_spec(spec_path) + + +def test_benchmark_partial_rewrite_goal_uses_public_defaults() -> None: + tool = load_tool("measurement_benchmark_tool_defaults", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + rewrite = tool.build_rewrite(tool.RewriteSpec(protect="Direct payroll identifiers")) + + assert rewrite.privacy_goal.protect == "Direct payroll identifiers" + assert rewrite.privacy_goal.preserve == DEFAULT_PRESERVE_TEXT + + +def test_benchmark_output_dir_requires_overwrite_for_existing_files(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_output", REPO_ROOT / "tools/measurement/run_benchmarks.py") + output_dir = tmp_path / "benchmark-output" + output_dir.mkdir() + existing = output_dir / "summary.json" + existing.write_text("{}", encoding="utf-8") + + with pytest.raises(ValueError, match="not empty"): + tool.prepare_output_dir(output_dir, overwrite=False, dry_run=False) + + tool.prepare_output_dir(output_dir, overwrite=True, dry_run=False) + + assert (output_dir / "raw").is_dir() + assert not existing.exists() + + +def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_dry_run", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: smoke-suite +workloads: + - id: biography + source: biographies.csv +configs: + - id: redact + replace: redact +matrix: + - workload: biography + config: redact + repetitions: 2 +""", + encoding="utf-8", + ) + output_dir = tmp_path / "dry-run-output" + + result = tool.run_or_plan( + spec_path, + output=output_dir, + overwrite=False, + dry_run=True, + export=False, + fail_fast=False, + ) + + assert len(result.cases) == 2 + assert result.table_dir is None + assert {case.status for case in result.cases} == {tool.CaseStatus.planned} + assert not output_dir.exists() diff --git a/tools/measurement/README.md b/tools/measurement/README.md new file mode 100644 index 00000000..3501b9cd --- /dev/null +++ b/tools/measurement/README.md @@ -0,0 +1,328 @@ + + + +# Measurement Tools + +`export_measurements.py` converts Anonymizer measurement JSONL into one table +per `record_type`. + +Run these tools inside the project environment, either with an activated venv +or through `uv run`. + +```bash +uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables +``` + +By default it writes Parquet files plus `manifest.json`: + +- `run.parquet` +- `stage.parquet` +- `record.parquet` +- `ndd_workflow.parquet` when adapter records are present + +Use `--format csv` or `--format jsonl` for non-Parquet output, and +`--overwrite` to replace existing output files. + +`run_benchmarks.py` runs repeatable Anonymizer workloads and writes the same +measurement JSONL format, one raw file per benchmark case plus a combined +`measurements.jsonl`. + +```bash +uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark-runs/suite +uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json +``` + +Benchmark suites are YAML files with three parts: + +- `workloads`: input datasets and text-column metadata. +- `configs`: Anonymizer replace or rewrite configurations. +- `matrix`: optional workload/config pairs and repetition counts. When omitted, + every workload is crossed with every config once. + +Example: + +```yaml +suite_id: shell-and-biography-smoke +model_configs: ./model-configs.yaml +model_providers: ./providers.yaml +workloads: + - id: biographies + source: ./data/biographies.csv + text_column: text + - id: support + source: ./data/support.csv + text_column: body + id_column: ticket_id +configs: + - id: redact-default + replace: redact + - id: hash-agent-labels + detect: + entity_labels: [person, email, api_key, password] + replace: + strategy: hash + digest_length: 12 + - id: rewrite-low-risk + rewrite: + risk_tolerance: low + max_repair_iterations: 1 +matrix: + - workload: biographies + config: redact-default + repetitions: 3 + - workload: support + config: hash-agent-labels +``` + +The runner refuses to write into a non-empty output directory unless +`--overwrite` is set. By default it also exports Parquet tables into +`tables/`; pass `--no-export` when you only want the raw measurement JSONL. + +## Output Layout + +A benchmark run writes one raw measurement file per case, then combines them: + +```text +benchmark-runs/suite-id/ + raw/ + biographies__redact-default__r000.jsonl + support__hash-agent-labels__r000.jsonl + measurements.jsonl + summary.json + tables/ + manifest.json + run.parquet + stage.parquet + record.parquet + ndd_workflow.parquet +``` + +Use `summary.json` to inspect case status and errors. Use `measurements.jsonl` +when you need the original structured records. Use `tables/` for analysis. + +The exporter groups records by `record_type`: + +- `run`: one row per Anonymizer run, with sanitized config, workload, model, and + runtime metadata. +- `stage`: one row per measured pipeline stage, with elapsed time, row counts, + and throughput fields. +- `record`: one row per input row when record-level measurement is enabled, + with text-size buckets, entity counts, replacement counts, rewrite scores, + and estimated nominal LLM call counts. +- `ndd_workflow`: one row per DataDesigner adapter call, with model aliases, + elapsed time, row counts, failed-record counts, and observed token/request + usage when DataDesigner exposes it. + +The tables never store raw text, prompts, generated outputs, entity values, or +replacement maps. `record_hash` is a run-scoped HMAC, so it can join rows within +one run but should not be treated as a durable dataset identifier. + +## Analysis Patterns + +Start with these questions: + +- Which workload/config pair is fastest at the same quality target? +- Which stage dominates wall time: detection, replacement, rewrite, or a + DataDesigner sub-workflow? +- Does latency scale with text length, entity count, or rewrite repair work? +- Do token counts, request counts, and failed records explain latency outliers? +- Are quality metrics worse on one data shape, such as legal text, biographies, + support tickets, shell history, or mixed natural-language/code records? + +Most analyses join `stage`, `record`, and `ndd_workflow` back to `run` through +`run_id`, then group by run tags: + +- `run_tags.suite_id` +- `run_tags.workload_id` +- `run_tags.config_id` +- `run_tags.repetition` +- `run_tags.case_id` + +Prefer medians and percentiles over averages when comparing latency. LLM calls +usually have long tails, and one retry or provider stall can distort a mean. + +## Pandas Examples + +Load exported tables: + +```python +from pathlib import Path + +import pandas as pd + +tables = Path("benchmark-runs/shell-and-biography-smoke/tables") +run = pd.read_parquet(tables / "run.parquet") +stage = pd.read_parquet(tables / "stage.parquet") +record = pd.read_parquet(tables / "record.parquet") +ndd = pd.read_parquet(tables / "ndd_workflow.parquet") +``` + +Compare end-to-end stage latency by workload and config: + +```python +stage_group_cols = ["run_tags.workload_id", "run_tags.config_id", "stage"] + +stage_summary = ( + stage + .groupby(stage_group_cols) + .agg( + runs=("run_id", "nunique"), + median_sec=("elapsed_sec", "median"), + p95_sec=("elapsed_sec", lambda s: s.quantile(0.95)), + rows_per_sec=("rows_per_sec", "median"), + ) + .reset_index() + .sort_values(["run_tags.workload_id", "stage", "median_sec"]) +) + +print(stage_summary) +``` + +Find slow records and relate them to text size and entity count: + +```python +record_view = record[ + [ + "run_tags.workload_id", + "run_tags.config_id", + "record_hash", + "text_length_tokens", + "text_length_tokens_bucket", + "final_entity_count", + "nominal_llm_call_count", + "utility_score", + "leakage_mass", + ] +].copy() + +shape_group_cols = [ + "run_tags.workload_id", + "run_tags.config_id", + "text_length_tokens_bucket", +] + +by_shape = ( + record_view + .groupby(shape_group_cols) + .agg( + records=("record_hash", "count"), + median_entities=("final_entity_count", "median"), + median_nominal_calls=("nominal_llm_call_count", "median"), + median_utility=("utility_score", "median"), + median_leakage=("leakage_mass", "median"), + ) + .reset_index() +) + +print(by_shape) +``` + +Summarize DataDesigner token and request usage: + +```python +workflow_group_cols = ["run_tags.workload_id", "run_tags.config_id", "workflow_name"] + +token_summary = ( + ndd + .groupby(workflow_group_cols) + .agg( + calls=("workflow_name", "count"), + median_sec=("elapsed_sec", "median"), + total_input_tokens=("observed_input_tokens", "sum"), + total_output_tokens=("observed_output_tokens", "sum"), + total_requests=("observed_total_requests", "sum"), + failed_records=("failed_record_count", "sum"), + ) + .reset_index() + .sort_values(["run_tags.workload_id", "run_tags.config_id", "median_sec"]) +) + +print(token_summary) +``` + +Join run metadata to stage timing: + +```python +run_meta = run[ + [ + "run_id", + "mode", + "strategy", + "detect.entity_label_count", + "detect.validation_max_entities_per_call", + ] +] + +stage_with_config = stage.merge(run_meta, on="run_id", how="left") + +config_group_cols = ["mode", "strategy", "detect.entity_label_count", "stage"] + +print(stage_with_config.groupby(config_group_cols)["elapsed_sec"].median()) +``` + +For quick interactive work, CSV can be easier than Parquet: + +```bash +uv run python tools/measurement/export_measurements.py \ + benchmark-runs/suite-id/measurements.jsonl \ + --output /tmp/suite-csv \ + --format csv \ + --overwrite +``` + +## Metric Interpretation + +Use metrics as signals, not as a single score. + +Latency and throughput: + +- `elapsed_sec`: wall time for a measured stage or DataDesigner workflow. +- `rows_per_sec`: completed output rows per second for the measured block. +- `tokens_per_sec`: observed total tokens per second when token usage exists. +- `text_length_tokens_bucket`: a coarse text-size bucket for comparing similar + inputs without storing text. + +LLM usage: + +- `observed_input_tokens`, `observed_output_tokens`, and + `observed_total_tokens`: provider-reported usage when available. Missing or + zero values mean the provider path did not expose usage, not necessarily that + no tokens were consumed. +- `observed_total_requests`, `observed_successful_requests`, and + `observed_failed_requests`: request counts when DataDesigner exposes them. +- `nominal_llm_call_count`: an internal estimate based on the Anonymizer + pipeline shape. Treat it as expected work, not observed provider traffic. + +Entity and quality metrics: + +- `final_entity_count`: entities that survive detection and validation. +- `final_entity_label_counts`: per-label entity counts serialized as JSON in + exported tabular files. +- `ground_truth_*`: precision, recall, F1, false positives, and false negatives + when the input includes one of the supported ground-truth entity columns. +- `utility_score`, `leakage_mass`, `weighted_leakage_rate`, + `needs_repair`, and `needs_human_review`: rewrite-mode evaluation fields. + These are null for replace-mode runs. + +Error and reliability metrics: + +- `failed_record_count`: records dropped by a DataDesigner workflow. +- `status`: completion state for a stage or workflow. +- `summary.json` case errors: runner-level failures, such as invalid inputs or + model endpoint failures. + +## Reading Results Safely + +Compare like with like. A shell-history workload, a support-ticket workload, +and a legal-document workload stress different parts of Anonymizer. Group by +`workload_id` before drawing conclusions about model routing, speculative +decoding, validation chunk size, or rewrite repair settings. + +Record-level rows describe input shape and output quality, not per-record wall +time. Stage and workflow rows carry timing. To explain a slow run, first find +the slow stage, then inspect the records in that run for text length, entity +count, nominal call count, and rewrite repair signals. + +When token or request fields are missing, check `ndd_workflow.model_usage`. +The measurement layer records deeper provider usage only when DataDesigner +returns it. diff --git a/tools/measurement/examples/repo-data-smoke.yaml b/tools/measurement/examples/repo-data-smoke.yaml new file mode 100644 index 00000000..5d0e5623 --- /dev/null +++ b/tools/measurement/examples/repo-data-smoke.yaml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +suite_id: repo-data-smoke +workloads: + - id: biographies + source: ../../../docs/data/NVIDIA_synthetic_biographies.csv + text_column: biography + - id: legal + source: ../../../docs/data/TAB_legal_sample25.csv + text_column: text +configs: + - id: biographies-redact-default + replace: redact + - id: legal-hash-agent-labels + detect: + entity_labels: [person, email, api_key, password] + replace: + strategy: hash + digest_length: 12 +matrix: + - workload: biographies + config: biographies-redact-default + - workload: legal + config: legal-hash-agent-labels diff --git a/tools/measurement/export_measurements.py b/tools/measurement/export_measurements.py new file mode 100755 index 00000000..796fcc40 --- /dev/null +++ b/tools/measurement/export_measurements.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Export Anonymizer measurement JSONL into per-record-type tables. + +Usage: + uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables + uv run python tools/measurement/export_measurements.py measurements.jsonl -o tables --format csv + uv run python tools/measurement/export_measurements.py measurements.jsonl -o tables --json +""" + +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.export") + +MANIFEST_FILENAME = "manifest.json" + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class TableSummary(BaseModel): + record_type: str + rows: int + columns: int + path: str + + +class ExportResult(BaseModel): + input_path: str + output_dir: str + format: ExportFormat + total_rows: int + tables: list[TableSummary] = Field(default_factory=list) + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def read_measurements(path: Path) -> pd.DataFrame: + if not path.exists(): + raise ValueError(f"input path does not exist: {path}") + if path.is_dir(): + raise ValueError(f"input path is a directory: {path}") + if path.suffix == ".json": + dataframe = pd.read_json(path) + else: + dataframe = pd.read_json(path, lines=True) + if "record_type" not in dataframe.columns: + raise ValueError("measurement input must contain a record_type field") + return dataframe + + +def normalize_table(rows: pd.DataFrame) -> pd.DataFrame: + relevant_rows = rows.dropna(axis="columns", how="all") + normalized = pd.json_normalize(relevant_rows.to_dict("records"), sep=".") + for column in normalized.columns: + if normalized[column].map(_is_nested_value).any(): + normalized[column] = normalized[column].map(_json_cell) + return normalized + + +def _is_nested_value(value: object) -> bool: + return isinstance(value, dict | list) + + +def _json_cell(value: object) -> object: + if not _is_nested_value(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def write_table(table: pd.DataFrame, path: Path, export_format: ExportFormat) -> None: + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + + +def export_tables( + dataframe: pd.DataFrame, + *, + input_path: Path, + output_dir: Path, + export_format: ExportFormat, + overwrite: bool, +) -> ExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _export_one_table(record_type, rows, output_dir=output_dir, export_format=export_format, overwrite=overwrite) + for record_type, rows in dataframe.groupby("record_type", sort=False) + ] + result = ExportResult( + input_path=str(input_path), + output_dir=str(output_dir), + format=export_format, + total_rows=len(dataframe), + tables=tables, + manifest_path=str(output_dir / MANIFEST_FILENAME), + ) + write_manifest(result, output_dir / MANIFEST_FILENAME, overwrite=overwrite) + return result + + +def _export_one_table( + record_type: str, + rows: pd.DataFrame, + *, + output_dir: Path, + export_format: ExportFormat, + overwrite: bool, +) -> TableSummary: + table = normalize_table(rows) + path = output_dir / f"{record_type}.{export_format.value}" + ensure_can_write(path, overwrite=overwrite) + write_table(table, path, export_format) + return TableSummary(record_type=record_type, rows=len(table), columns=len(table.columns), path=str(path)) + + +def write_manifest(result: ExportResult, path: Path, *, overwrite: bool) -> None: + ensure_can_write(path, overwrite=overwrite) + path.write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def ensure_can_write(path: Path, *, overwrite: bool) -> None: + if path.exists() and not overwrite: + raise ValueError(f"output already exists: {path}; pass --overwrite to replace it") + + +def render_result(result: ExportResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Wrote {len(result.tables)} table(s) from {result.total_rows} measurement record(s)"] + lines.append(f"Output: {result.output_dir}") + for table in result.tables: + lines.append(f"- {table.record_type}: {table.rows} rows, {table.columns} columns -> {table.path}") + lines.append(f"Manifest: {result.manifest_path}") + return "\n".join(lines) + + +@app.default +def main( + input_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + output_dir = output or input_path.with_suffix("").with_name(f"{input_path.stem}-tables") + try: + dataframe = read_measurements(input_path) + result = export_tables( + dataframe, + input_path=input_path, + output_dir=output_dir, + export_format=format, + overwrite=overwrite, + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py new file mode 100755 index 00000000..bc0ade51 --- /dev/null +++ b/tools/measurement/run_benchmarks.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run Anonymizer benchmark suites and export measurement tables. + +Usage: + uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark-runs/suite + uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json +""" + +import json +import logging +import shutil +import sys +import time +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import yaml +from export_measurements import ExportFormat, export_tables, read_measurements +from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator + +from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect, Rewrite +from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute +from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance +from anonymizer.interface.anonymizer import Anonymizer +from anonymizer.measurement import MeasurementConfig, configured_measurement_session + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.benchmark") + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class CaseStatus(StrEnum): + planned = "planned" + completed = "completed" + error = "error" + + +class ReplaceKind(StrEnum): + redact = "redact" + hash = "hash" + annotate = "annotate" + substitute = "substitute" + + +class WorkloadSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + source: str + text_column: str = "text" + id_column: str | None = None + data_summary: str | None = None + + +class ReplaceSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + strategy: ReplaceKind + format_template: str | None = None + normalize_label: bool | None = None + algorithm: str | None = None + digest_length: int | None = None + instructions: str | None = None + + +class RewriteSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + protect: str | None = None + preserve: str | None = None + instructions: str | None = None + risk_tolerance: RiskTolerance = RiskTolerance.low + max_repair_iterations: int = 3 + strict_entity_protection: bool = False + + +class ConfigSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + detect: dict[str, Any] = Field(default_factory=dict) + replace: str | ReplaceSpec | None = None + rewrite: RewriteSpec | None = None + emit_telemetry: bool = False + + @model_validator(mode="after") + def validate_mode(self) -> "ConfigSpec": + if self.replace is None and self.rewrite is None: + raise ValueError("config must define replace or rewrite") + if self.replace is not None and self.rewrite is not None: + raise ValueError("config cannot define both replace and rewrite") + return self + + +class MatrixEntry(BaseModel): + model_config = ConfigDict(extra="forbid") + + workload: str + config: str + repetitions: int = Field(default=1, ge=1) + + +def _duplicates(values: list[str]) -> list[str]: + seen: set[str] = set() + duplicates: set[str] = set() + for value in values: + if value in seen: + duplicates.add(value) + seen.add(value) + return sorted(duplicates) + + +class BenchmarkSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + suite_id: str + model_configs: str | None = None + model_providers: str | None = None + artifact_path: str | None = None + workloads: list[WorkloadSpec] = Field(min_length=1) + configs: list[ConfigSpec] = Field(min_length=1) + matrix: list[MatrixEntry] | None = Field(default=None, min_length=1) + + @model_validator(mode="after") + def validate_ids(self) -> "BenchmarkSpec": + workload_ids = [workload.id for workload in self.workloads] + config_ids = [config.id for config in self.configs] + if duplicate_workloads := _duplicates(workload_ids): + raise ValueError(f"duplicate workload id(s): {', '.join(duplicate_workloads)}") + if duplicate_configs := _duplicates(config_ids): + raise ValueError(f"duplicate config id(s): {', '.join(duplicate_configs)}") + self._validate_matrix_references(set(workload_ids), set(config_ids)) + return self + + def _validate_matrix_references(self, workload_ids: set[str], config_ids: set[str]) -> None: + if self.matrix is None: + return + missing_workloads = sorted({entry.workload for entry in self.matrix} - workload_ids) + missing_configs = sorted({entry.config for entry in self.matrix} - config_ids) + if missing_workloads: + raise ValueError(f"matrix references unknown workload id(s): {', '.join(missing_workloads)}") + if missing_configs: + raise ValueError(f"matrix references unknown config id(s): {', '.join(missing_configs)}") + + +class BenchmarkCase(BaseModel): + suite_id: str + workload_id: str + config_id: str + repetition: int + case_id: str + status: CaseStatus = CaseStatus.planned + elapsed_sec: float | None = None + measurement_path: str | None = None + error: str | None = None + + +class BenchmarkResult(BaseModel): + suite_id: str + output_dir: str + measurement_path: str + summary_path: str + table_dir: str | None + cases: list[BenchmarkCase] + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def load_spec(path: Path) -> BenchmarkSpec: + if not path.exists() or path.is_dir(): + raise ValueError(f"spec path is not a file: {path}") + raw = yaml.safe_load(path.read_text(encoding="utf-8")) + if not isinstance(raw, dict): + raise ValueError("benchmark spec must be a YAML mapping") + return BenchmarkSpec.model_validate(raw) + + +def build_cases(spec: BenchmarkSpec) -> list[BenchmarkCase]: + matrix = spec.matrix or _cross_product_matrix(spec) + return [ + BenchmarkCase( + suite_id=spec.suite_id, + workload_id=entry.workload, + config_id=entry.config, + repetition=repetition, + case_id=f"{entry.workload}__{entry.config}__r{repetition:03d}", + ) + for entry in matrix + for repetition in range(entry.repetitions) + ] + + +def _cross_product_matrix(spec: BenchmarkSpec) -> list[MatrixEntry]: + return [ + MatrixEntry(workload=workload.id, config=config.id, repetitions=1) + for workload in spec.workloads + for config in spec.configs + ] + + +def prepare_output_dir(output_dir: Path, *, overwrite: bool, dry_run: bool) -> None: + if dry_run: + return + if output_dir.exists() and not output_dir.is_dir(): + raise ValueError(f"output path exists and is not a directory: {output_dir}") + if output_dir.exists(): + if overwrite: + shutil.rmtree(output_dir) + elif any(output_dir.iterdir()): + raise ValueError(f"output directory is not empty: {output_dir}; pass --overwrite to replace it") + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "raw").mkdir(exist_ok=True) + + +def run_suite( + spec: BenchmarkSpec, + *, + spec_path: Path, + output_dir: Path, + export: bool, + fail_fast: bool, +) -> BenchmarkResult: + contexts = _build_contexts(spec, spec_path=spec_path, output_dir=output_dir) + anonymizer = Anonymizer(**contexts["anonymizer_kwargs"]) + cases = [ + _run_case(case, spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast) + for case in build_cases(spec) + ] + measurement_path = combine_measurements(cases, output_dir / "measurements.jsonl") + should_export = export and measurement_path.stat().st_size > 0 + table_dir = export_measurement_tables(measurement_path, output_dir / "tables") if should_export else None + result = BenchmarkResult( + suite_id=spec.suite_id, + output_dir=str(output_dir), + measurement_path=str(measurement_path), + summary_path=str(output_dir / "summary.json"), + table_dir=str(table_dir) if table_dir is not None else None, + cases=cases, + ) + write_summary(result) + return result + + +def _build_contexts(spec: BenchmarkSpec, *, spec_path: Path, output_dir: Path) -> dict[str, Any]: + base_dir = spec_path.parent + artifact_path = _resolve_optional_path(spec.artifact_path, base_dir) or output_dir / "artifacts" + return { + "base_dir": base_dir, + "workloads": {workload.id: workload for workload in spec.workloads}, + "configs": {config.id: config for config in spec.configs}, + "raw_dir": output_dir / "raw", + "anonymizer_kwargs": { + "model_configs": _resolve_config_source(spec.model_configs, base_dir), + "model_providers": _resolve_config_source(spec.model_providers, base_dir), + "artifact_path": artifact_path, + }, + } + + +def _run_case( + case: BenchmarkCase, + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + fail_fast: bool, +) -> BenchmarkCase: + raw_path = contexts["raw_dir"] / f"{case.case_id}.jsonl" + started = time.perf_counter() + try: + workload = _get_item(contexts["workloads"], case.workload_id, "workload") + config = _get_item(contexts["configs"], case.config_id, "config") + _execute_case( + anonymizer, workload, config, raw_path=raw_path, case=case, spec=spec, base_dir=contexts["base_dir"] + ) + return case.model_copy( + update={ + "status": CaseStatus.completed, + "elapsed_sec": time.perf_counter() - started, + "measurement_path": str(raw_path), + } + ) + except Exception as exc: + if fail_fast: + raise + return case.model_copy( + update={ + "status": CaseStatus.error, + "elapsed_sec": time.perf_counter() - started, + "measurement_path": str(raw_path), + "error": str(exc), + } + ) + + +def _execute_case( + anonymizer: Anonymizer, + workload: WorkloadSpec, + config: ConfigSpec, + *, + raw_path: Path, + case: BenchmarkCase, + spec: BenchmarkSpec, + base_dir: Path, +) -> None: + measurement = MeasurementConfig(output_path=raw_path, run_id=case.case_id, run_tags=_run_tags(case, spec)) + with configured_measurement_session(measurement): + anonymizer.run(config=build_anonymizer_config(config), data=build_input(workload, base_dir)) + + +def build_input(workload: WorkloadSpec, base_dir: Path) -> AnonymizerInput: + return AnonymizerInput( + source=str(_resolve_input_source(workload.source, base_dir)), + text_column=workload.text_column, + id_column=workload.id_column, + data_summary=workload.data_summary, + ) + + +def build_anonymizer_config(config: ConfigSpec) -> AnonymizerConfig: + detect = Detect.model_validate(config.detect) + if config.replace is not None: + return AnonymizerConfig( + detect=detect, replace=build_replace(config.replace), emit_telemetry=config.emit_telemetry + ) + return AnonymizerConfig(detect=detect, rewrite=build_rewrite(config.rewrite), emit_telemetry=config.emit_telemetry) + + +def build_replace(raw: str | ReplaceSpec) -> Redact | Hash | Annotate | Substitute: + spec = ReplaceSpec(strategy=ReplaceKind(raw)) if isinstance(raw, str) else raw + if spec.strategy == ReplaceKind.redact: + return Redact(**_present({"format_template": spec.format_template, "normalize_label": spec.normalize_label})) + if spec.strategy == ReplaceKind.hash: + return Hash( + **_present( + { + "format_template": spec.format_template, + "algorithm": spec.algorithm, + "digest_length": spec.digest_length, + } + ) + ) + if spec.strategy == ReplaceKind.annotate: + return Annotate(**_present({"format_template": spec.format_template})) + return Substitute(**_present({"instructions": spec.instructions})) + + +def build_rewrite(spec: RewriteSpec | None) -> Rewrite: + if spec is None: + raise ValueError("rewrite config is missing") + privacy_goal = _privacy_goal(spec) + return Rewrite( + privacy_goal=privacy_goal, + instructions=spec.instructions, + risk_tolerance=spec.risk_tolerance, + max_repair_iterations=spec.max_repair_iterations, + strict_entity_protection=spec.strict_entity_protection, + ) + + +def _privacy_goal(spec: RewriteSpec) -> PrivacyGoal | None: + if spec.protect is None and spec.preserve is None: + return None + return PrivacyGoal( + protect=spec.protect or DEFAULT_PROTECT_TEXT, + preserve=spec.preserve or DEFAULT_PRESERVE_TEXT, + ) + + +def combine_measurements(cases: list[BenchmarkCase], destination: Path) -> Path: + with destination.open("w", encoding="utf-8") as output: + for case in cases: + if case.measurement_path is None: + continue + source = Path(case.measurement_path) + if source.exists(): + output.write(source.read_text(encoding="utf-8")) + return destination + + +def export_measurement_tables(measurement_path: Path, table_dir: Path) -> Path: + dataframe = read_measurements(measurement_path) + export_tables( + dataframe, input_path=measurement_path, output_dir=table_dir, export_format=ExportFormat.parquet, overwrite=True + ) + return table_dir + + +def write_summary(result: BenchmarkResult) -> None: + Path(result.summary_path).write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def render_result(result: BenchmarkResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = sum(case.status == CaseStatus.completed for case in result.cases) + errored = sum(case.status == CaseStatus.error for case in result.cases) + planned = sum(case.status == CaseStatus.planned for case in result.cases) + if planned and completed == 0 and errored == 0: + return f"Planned {planned} case(s); output={result.output_dir}" + return f"Ran {completed}/{len(result.cases)} case(s); errors={errored}; output={result.output_dir}" + + +def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: + return { + "suite_id": spec.suite_id, + "workload_id": case.workload_id, + "config_id": case.config_id, + "repetition": case.repetition, + "case_id": case.case_id, + } + + +def _present(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} + + +def _get_item(items: dict[str, Any], item_id: str, item_type: str) -> Any: + if item_id not in items: + raise ValueError(f"unknown {item_type}: {item_id}") + return items[item_id] + + +def _resolve_input_source(source: str, base_dir: Path) -> str | Path: + if "://" in source: + return source + return _resolve_path(source, base_dir) + + +def _resolve_optional_path(raw: str | None, base_dir: Path) -> Path | None: + if raw is None: + return None + return _resolve_path(raw, base_dir) + + +def _resolve_config_source(raw: str | None, base_dir: Path) -> str | None: + if raw is None or "\n" in raw: + return raw + candidate = Path(raw).expanduser() + if candidate.suffix in {".yaml", ".yml"}: + return str(_resolve_path(raw, base_dir)) + return raw + + +def _resolve_path(raw: str, base_dir: Path) -> Path: + path = Path(raw).expanduser() + return path if path.is_absolute() else base_dir / path + + +def dry_run_result(spec: BenchmarkSpec, *, output_dir: Path, export: bool) -> BenchmarkResult: + cases = build_cases(spec) + return BenchmarkResult( + suite_id=spec.suite_id, + output_dir=str(output_dir), + measurement_path=str(output_dir / "measurements.jsonl"), + summary_path=str(output_dir / "summary.json"), + table_dir=str(output_dir / "tables") if export else None, + cases=cases, + ) + + +@app.default +def main( + spec: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + dry_run: Annotated[bool, cyclopts.Parameter("--dry-run")] = False, + export: Annotated[bool, cyclopts.Parameter("--export")] = True, + fail_fast: Annotated[bool, cyclopts.Parameter("--fail-fast")] = False, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = run_or_plan( + spec, output=output, overwrite=overwrite, dry_run=dry_run, export=export, fail_fast=fail_fast + ) + except (ValueError, ValidationError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if any(case.status == CaseStatus.error for case in result.cases): + raise SystemExit(1) + + +def run_or_plan( + spec_path: Path, + *, + output: Path | None, + overwrite: bool, + dry_run: bool, + export: bool, + fail_fast: bool, +) -> BenchmarkResult: + benchmark_spec = load_spec(spec_path) + output_dir = output or Path("benchmark-runs") / benchmark_spec.suite_id + prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) + if dry_run: + return dry_run_result(benchmark_spec, output_dir=output_dir, export=export) + return run_suite(benchmark_spec, spec_path=spec_path, output_dir=output_dir, export=export, fail_fast=fail_fast) + + +if __name__ == "__main__": + app() From 4745a4bb1d8239a0cd080219c04b606b5dce1964 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Wed, 3 Jun 2026 18:07:41 +0000 Subject: [PATCH 03/26] Add opt-in DD tracing for benchmarks Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 209 +++++++++++++++- src/anonymizer/measurement.py | 138 ++++++++++- tests/test_measurement.py | 99 ++++++++ tests/tools/test_measurement_tools.py | 182 ++++++++++++++ tools/measurement/README.md | 43 ++++ .../run-repo-data-smoke-with-dd-traces.sh | 15 ++ tools/measurement/run_benchmarks.py | 224 +++++++++++++++++- 7 files changed, 895 insertions(+), 15 deletions(-) create mode 100644 tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 13e390c9..3621a178 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -8,10 +8,12 @@ import tempfile import time import uuid -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from threading import RLock +from typing import TYPE_CHECKING, Any, cast from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder @@ -29,6 +31,7 @@ logger = logging.getLogger("anonymizer.ndd") RECORD_ID_COLUMN = "_anonymizer_record_id" +_DD_MESSAGE_TRACE_PATCH_LOCK = RLock() @dataclass(frozen=True) @@ -109,7 +112,7 @@ def run_workflow( config_builder.add_column(column) try: - with usage_probe: + with usage_probe, _dd_message_trace(workflow_name=workflow_name): if preview_num_records is None: run_results = self._data_designer.create( config_builder, @@ -340,3 +343,203 @@ def _model_usage_as_json(stats: object) -> Any: if callable(model_dump): return model_dump(mode="json") return stats + + +@contextmanager +def _dd_message_trace(*, workflow_name: str) -> Iterator[None]: + collector = current_collector() + if collector is None or not collector.dd_trace_enabled: + yield + return + + from data_designer.engine.models.facade import ModelFacade + + with _DD_MESSAGE_TRACE_PATCH_LOCK: + original_completion = ModelFacade.completion + original_acompletion = ModelFacade.acompletion + ModelFacade.completion = _traced_completion( + original_completion, collector=collector, workflow_name=workflow_name + ) + ModelFacade.acompletion = _traced_acompletion( + original_acompletion, + collector=collector, + workflow_name=workflow_name, + ) + try: + yield + finally: + ModelFacade.completion = original_completion + ModelFacade.acompletion = original_acompletion + + +def _traced_completion(original_completion: Any, *, collector: Any, workflow_name: str) -> Any: + def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: + return _run_traced_completion( + original_completion, + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + args=args, + kwargs=kwargs, + ) + + return traced + + +def _run_traced_completion( + completion: Any, + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + started = time.perf_counter() + response: Any | None = None + status = "completed" + error_type: str | None = None + try: + response = completion(model_facade, messages, *args, **kwargs) + return response + except BaseException as exc: + status = "error" + error_type = type(exc).__name__ + raise + finally: + _record_dd_message_trace( + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=False, + ) + + +def _traced_acompletion(original_acompletion: Any, *, collector: Any, workflow_name: str) -> Any: + async def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: + return await _run_traced_acompletion( + original_acompletion, + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + args=args, + kwargs=kwargs, + ) + + return traced + + +async def _run_traced_acompletion( + acompletion: Any, + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + started = time.perf_counter() + response: Any | None = None + status = "completed" + error_type: str | None = None + try: + response = await acompletion(model_facade, messages, *args, **kwargs) + return response + except BaseException as exc: + status = "error" + error_type = type(exc).__name__ + raise + finally: + _record_dd_message_trace( + collector=collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=True, + ) + + +def _record_dd_message_trace( + *, + collector: Any, + workflow_name: str, + model_facade: Any, + messages: list[Any], + response: Any | None, + elapsed_sec: float, + status: str, + error_type: str | None, + is_async: bool, +) -> None: + if error_type == "SyncClientUnavailableError": + return + collector.record_dd_message_trace( + workflow_name=workflow_name, + model_alias=getattr(model_facade, "model_alias", None), + model_name=getattr(model_facade, "model_name", None), + model_provider_name=getattr(model_facade, "model_provider_name", None), + modality="chat", + is_async=is_async, + status=status, + error_type=error_type, + elapsed_sec=elapsed_sec, + messages=_trace_messages(messages, mode=collector.dd_trace_mode), + response=_trace_response(response), + usage=_trace_usage(getattr(response, "usage", None) if response is not None else None), + ) + + +def _trace_messages(messages: list[Any], *, mode: str) -> list[dict[str, Any]]: + selected = messages if mode == "all_messages" else messages[-1:] + return [_trace_message(message) for message in selected] + + +def _trace_message(message: Any) -> dict[str, Any]: + to_dict = getattr(message, "to_dict", None) + if callable(to_dict): + return cast(dict[str, Any], to_dict()) + if isinstance(message, Mapping): + return dict(message) + return {"role": getattr(message, "role", None), "content": getattr(message, "content", None)} + + +def _trace_response(response: Any | None) -> dict[str, Any] | None: + if response is None: + return None + message = getattr(response, "message", None) + if message is None: + return None + return { + "content": getattr(message, "content", None), + "reasoning_content": getattr(message, "reasoning_content", None), + "tool_calls": _trace_tool_calls(getattr(message, "tool_calls", [])), + } + + +def _trace_tool_calls(tool_calls: Any) -> list[Any]: + if isinstance(tool_calls, list): + return [getattr(tool_call, "__dict__", tool_call) for tool_call in tool_calls] + return [] + + +def _trace_usage(usage: Any | None) -> dict[str, Any] | None: + if usage is None: + return None + return { + "input_tokens": getattr(usage, "input_tokens", None), + "output_tokens": getattr(usage, "output_tokens", None), + "total_tokens": getattr(usage, "total_tokens", None), + } diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 252b6ed8..309264c5 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -39,6 +39,8 @@ _GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") MEASUREMENT_SCHEMA_VERSION = 1 DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" +DD_TRACE_MODES = {"none", "last_message", "all_messages"} +DDTraceMode = Literal["none", "last_message", "all_messages"] logger = logging.getLogger("anonymizer.measurement") @@ -47,6 +49,12 @@ class _MeasurementWriter(Protocol): def write(self, records: list[dict[str, Any]], path: str | Path) -> None: ... +class _MeasurementSink(Protocol): + def write_record(self, record: dict[str, Any]) -> None: ... + + def close(self) -> None: ... + + class _JsonlMeasurementWriter: def write(self, records: list[dict[str, Any]], path: str | Path) -> None: output_path = Path(path) @@ -70,6 +78,19 @@ def _writer_for_format(output_format: Literal["jsonl", "json"]) -> _MeasurementW return _JsonlMeasurementWriter() +class _JsonlMeasurementSink: + def __init__(self, path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + self._file = output_path.open("w", encoding="utf-8", buffering=1) + + def write_record(self, record: dict[str, Any]) -> None: + self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + def close(self) -> None: + self._file.close() + + class _MeasurementEnvSettings(BaseSettings): model_config = SettingsConfigDict( env_prefix=DEFAULT_MEASUREMENT_ENV_PREFIX, @@ -80,6 +101,10 @@ class _MeasurementEnvSettings(BaseSettings): output_path: str | None = None output_format: Literal["jsonl", "json"] = "jsonl" record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | None = None fail_on_write_error: bool = False run_id: str | None = None run_tags: dict[str, Any] = Field(default_factory=dict) @@ -100,10 +125,22 @@ def __init__( record_hash_key: bytes | str | None = None, record_level: bool = True, run_tags: Mapping[str, Any] | None = None, + record_sink: _MeasurementSink | None = None, + keep_records: bool = True, + dd_trace_mode: DDTraceMode = "none", + dd_trace_sink: _MeasurementSink | None = None, + fail_on_write_error: bool = False, ) -> None: self.run_id = run_id or uuid.uuid4().hex self.record_level = record_level self.run_tags = cast(dict[str, Any], _json_safe(dict(run_tags or {}))) + self._record_sink = record_sink + self._keep_records = keep_records + self._dd_trace_mode = dd_trace_mode + self._dd_trace_sink = dd_trace_sink + self._fail_on_write_error = fail_on_write_error + self._sink_failed = False + self._dd_trace_failed = False if record_hash_key is None: self._record_hash_key = secrets.token_bytes(32) elif isinstance(record_hash_key, str): @@ -127,7 +164,65 @@ def record(self, record_type: str, **fields: Any) -> None: "run_tags": self.run_tags, "timestamp_unix_sec": time.time(), } - self._records.append(_json_safe(record)) + safe_record = _json_safe(record) + if self._keep_records: + self._records.append(safe_record) + if self._record_sink is not None: + self._write_record_to_sink(safe_record) + + def close(self) -> None: + """Close any streaming measurement sink attached to this collector.""" + if self._record_sink is not None: + self._record_sink.close() + if self._dd_trace_sink is not None: + self._dd_trace_sink.close() + + @property + def dd_trace_mode(self) -> DDTraceMode: + return self._dd_trace_mode + + @property + def dd_trace_enabled(self) -> bool: + return self._dd_trace_mode != "none" and self._dd_trace_sink is not None + + def record_dd_message_trace(self, **fields: Any) -> None: + """Write an explicitly opt-in DataDesigner message trace record. + + These records may contain raw prompts, input text, model outputs, and + PII. They are intentionally written to a separate trace sink and are + never appended to the safe measurement record list. + """ + if not self.dd_trace_enabled or self._dd_trace_failed: + return + + record = _json_safe( + { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": "dd_message_trace", + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + ) + try: + cast(_MeasurementSink, self._dd_trace_sink).write_record(record) + except Exception: + self._dd_trace_failed = True + logger.warning("Failed to write DataDesigner message trace records") + if self._fail_on_write_error: + raise + + def _write_record_to_sink(self, record: dict[str, Any]) -> None: + if self._sink_failed: + return + try: + cast(_MeasurementSink, self._record_sink).write_record(record) + except Exception: + self._sink_failed = True + logger.warning("Failed to stream Anonymizer measurement records") + if self._fail_on_write_error: + raise def record_hash(self, *, row_index: object, text: str) -> str: """Return a run-scoped HMAC for joining records without storing text.""" @@ -161,6 +256,10 @@ class MeasurementConfig: output_path: str | Path output_format: Literal["jsonl", "json"] = "jsonl" record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | Path | None = None run_id: str | None = None record_hash_key: bytes | str | None = None run_tags: Mapping[str, Any] | None = None @@ -169,6 +268,12 @@ class MeasurementConfig: def __post_init__(self) -> None: if self.output_format not in {"jsonl", "json"}: raise ValueError("output_format must be 'jsonl' or 'json'") + if self.streaming and self.output_format != "jsonl": + raise ValueError("streaming measurement output only supports jsonl") + if self.dd_trace not in DD_TRACE_MODES: + raise ValueError("dd_trace must be 'none', 'last_message', or 'all_messages'") + if self.dd_trace != "none" and self.dd_trace_path is None: + raise ValueError("dd_trace_path is required when dd_trace is enabled") @classmethod def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> MeasurementConfig | None: @@ -189,6 +294,10 @@ def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> Measuremen output_path=settings.output_path, output_format=settings.output_format, record_level=settings.record_level, + streaming=settings.streaming, + keep_records=settings.keep_records, + dd_trace=settings.dd_trace, + dd_trace_path=settings.dd_trace_path, run_id=settings.run_id, run_tags=settings.run_tags, fail_on_write_error=settings.fail_on_write_error, @@ -232,11 +341,18 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator yield None return + sink = _JsonlMeasurementSink(config.output_path) if config.streaming else None + dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) if config.dd_trace != "none" else None collector = MeasurementCollector( run_id=config.run_id, record_hash_key=config.record_hash_key, record_level=config.record_level, run_tags=config.run_tags, + record_sink=sink, + keep_records=config.keep_records, + dd_trace_mode=config.dd_trace, + dd_trace_sink=dd_trace_sink, + fail_on_write_error=config.fail_on_write_error, ) with measurement_session(collector): body_error: BaseException | None = None @@ -246,7 +362,11 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator body_error = exc raise finally: - _write_collector_safely(config=config, collector=collector, body_error=body_error) + if config.streaming: + _close_collector_safely(config=config, collector=collector, body_error=body_error) + else: + _write_collector_safely(config=config, collector=collector, body_error=body_error) + _close_collector_safely(config=config, collector=collector, body_error=body_error) def current_collector() -> MeasurementCollector | None: @@ -735,6 +855,20 @@ def _write_collector_safely( raise +def _close_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + collector.close() + except Exception as exc: + logger.warning("Failed to close Anonymizer measurement stream (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise + + def _measurement_env_error_message(exc: SettingsError | ValidationError, *, prefix: str) -> str: fields: set[str] = set() if isinstance(exc, ValidationError): diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 098a3096..7b4c5676 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -15,6 +15,9 @@ import pytest from data_designer.config.column_configs import LLMTextColumnConfig from data_designer.config.models import ModelConfig +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, Usage +from data_designer.engine.models.facade import ModelFacade +from data_designer.engine.models.utils import ChatMessage from data_designer.interface.data_designer import DataDesigner from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect @@ -498,6 +501,102 @@ def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector raise RuntimeError("body failed") +def test_streaming_measurement_session_writes_jsonl_without_retaining_records(tmp_path: Path) -> None: + output_path = tmp_path / "measurements.jsonl" + + with configured_measurement_session( + MeasurementConfig(output_path=output_path, streaming=True, keep_records=False) + ) as collector: + assert collector is not None + collector.record("example", value=1) + + assert collector.records == [] + assert output_path.read_text(encoding="utf-8").count("\n") == 1 + + collector.record("example", value=2) + + lines = output_path.read_text(encoding="utf-8").splitlines() + assert len(lines) == 2 + assert [json.loads(line)["value"] for line in lines] == [1, 2] + + +def test_streaming_measurement_requires_jsonl_output(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="streaming measurement output only supports jsonl"): + MeasurementConfig(output_path=tmp_path / "measurements.json", output_format="json", streaming=True) + + +def test_dd_message_trace_requires_trace_path(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="dd_trace_path is required"): + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_trace="last_message") + + +def test_ndd_adapter_writes_opt_in_dd_message_trace( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + def fake_completion( + _self: object, + _messages: list[ChatMessage], + skip_usage_tracking: bool = False, + **_kwargs: object, + ) -> ChatCompletionResponse: + assert skip_usage_tracking is False + return ChatCompletionResponse( + message=AssistantMessage(content="secret response"), + usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), + ) + + monkeypatch.setattr(ModelFacade, "completion", fake_completion) + + class TraceDataDesigner: + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + ModelFacade.completion( + SimpleNamespace(model_alias="alias", model_name="dummy-model", model_provider_name="provider"), + [ + ChatMessage.as_system("system secret"), + ChatMessage.as_user("prompt secret"), + ], + ) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + trace_path = tmp_path / "trace.jsonl" + + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", dd_trace="last_message", dd_trace_path=trace_path + ) + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="alias", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="alias", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + trace = json.loads(trace_path.read_text(encoding="utf-8").strip()) + assert trace["record_type"] == "dd_message_trace" + assert trace["workflow_name"] == "entity-detection" + assert trace["model_alias"] == "alias" + assert trace["status"] == "completed" + assert trace["messages"] == [{"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}] + assert trace["response"]["content"] == "secret response" + assert trace["usage"] == {"input_tokens": 3, "output_tokens": 2, "total_tokens": 5} + + serialized_measurements = (tmp_path / "measurements.jsonl").read_text(encoding="utf-8") + assert "prompt secret" not in serialized_measurements + assert "secret response" not in serialized_measurements + + def test_record_metrics_capture_generic_counts_without_raw_values() -> None: final_entities = { "entities": [ diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 3448a71f..2d295a64 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -5,8 +5,11 @@ import importlib.util import sys +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from types import ModuleType +from typing import Any import pandas as pd import pytest @@ -134,3 +137,182 @@ def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None assert result.table_dir is None assert {case.status for case in result.cases} == {tool.CaseStatus.planned} assert not output_dir.exists() + + +def test_benchmark_preflight_rejects_missing_text_column(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_preflight_input", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"body": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-input-suite +workloads: + - id: biography + source: input.csv + text_column: text +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="workload 'biography' text_column 'text' not found"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_preflight_models", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-model-suite +model_configs: | + selected_models: + detection: + entity_detector: detector + entity_validator: [validator] + entity_augmenter: augmenter + replace: + replacement_generator: missing-replacer + model_configs: + - alias: detector + model: test/detector + - alias: validator + model: test/validator + - alias: augmenter + model: test/augmenter +workloads: + - id: biography + source: input.csv +configs: + - id: substitute + replace: substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="missing-replacer"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + provider_path = tmp_path / "providers.yaml" + provider_path.write_text("not_providers: []\n", encoding="utf-8") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-provider-suite +model_providers: providers.yaml +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="providers"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_provider_path", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + provider_path = tmp_path / "providers.yaml" + provider_path.write_text( + """ +providers: + - name: test-provider + endpoint: https://example.com/v1 + provider_type: openai + api_key: TEST_API_KEY +""", + encoding="utf-8", + ) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: provider-path-suite +model_providers: providers.yaml +workloads: + - id: biography + source: input.csv +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_case_passes_dd_trace_config_to_measurement_session( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_dd_trace", REPO_ROOT / "tools/measurement/run_benchmarks.py") + captured: list[Any] = [] + + @contextmanager + def fake_measurement_session(config: Any) -> Iterator[None]: + captured.append(config) + yield None + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> None: + assert config.replace is not None + assert data.text_column == "text" + + monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) + + spec = tool.BenchmarkSpec( + suite_id="trace-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "input.csv", index=False) + case = tool.BenchmarkCase( + suite_id="trace-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + trace_path = tmp_path / "traces" / "input__redact__r000.jsonl" + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=tmp_path / "raw" / "input__redact__r000.jsonl", + trace_path=trace_path, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.all_messages, + ) + + assert len(captured) == 1 + assert captured[0].dd_trace == "all_messages" + assert captured[0].dd_trace_path == trace_path + assert captured[0].streaming is True + assert captured[0].keep_records is False diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 3501b9cd..82f7a955 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -30,6 +30,25 @@ measurement JSONL format, one raw file per benchmark case plus a combined ```bash uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark-runs/suite uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json +uv run python tools/measurement/run_benchmarks.py suite.yaml \ + --output benchmark-runs/suite \ + --dd-trace last-message +``` + +To rerun the repo-data smoke suite with DataDesigner traces enabled: + +```bash +bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh +``` + +The script writes to `/tmp/anonymizer-repo-data-smoke-dd-traces` by default. +Pass a different output directory as the first argument, or set +`DD_TRACE_MODE=all-messages` when you need full chat history: + +```bash +DD_TRACE_MODE=all-messages \ + bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh \ + /tmp/anonymizer-repo-data-smoke-dd-traces-full ``` Benchmark suites are YAML files with three parts: @@ -77,6 +96,21 @@ matrix: The runner refuses to write into a non-empty output directory unless `--overwrite` is set. By default it also exports Parquet tables into `tables/`; pass `--no-export` when you only want the raw measurement JSONL. +Before starting a real run, the benchmark runner performs cheap preflight +checks: suite/config parsing, local dataset existence, CSV/Parquet text-column +metadata, provider YAML shape, and active model-alias references. `--dry-run` +only expands the planned matrix and skips these file/config checks. + +For debugging DataDesigner calls, pass `--dd-trace last-message` or +`--dd-trace all-messages`. Trace records are written separately from sanitized +measurements, under `traces/{case_id}.jsonl` by default. Use `--trace-dir` to +choose another directory. `last-message` stores only the final prompt message +for each DataDesigner model call; `all-messages` stores the full message list. + +DataDesigner traces may contain raw input text, prompts, model outputs, entity +values, replacement values, secrets, and PII. Treat them as debug artifacts: +keep them out of shared benchmark bundles unless they have been reviewed or +redacted. ## Output Layout @@ -87,6 +121,8 @@ benchmark-runs/suite-id/ raw/ biographies__redact-default__r000.jsonl support__hash-agent-labels__r000.jsonl + traces/ + biographies__redact-default__r000.jsonl measurements.jsonl summary.json tables/ @@ -97,8 +133,15 @@ benchmark-runs/suite-id/ ndd_workflow.parquet ``` +Raw per-case JSONL files are streamed as measurement events are recorded, so a +long run leaves inspectable partial output before the case exits. The combined +`measurements.jsonl` is written after the completed and errored case files are +collected. + Use `summary.json` to inspect case status and errors. Use `measurements.jsonl` when you need the original structured records. Use `tables/` for analysis. +Use `traces/` only when `--dd-trace` was enabled and you need raw +DataDesigner message-level debugging. The exporter groups records by `record_type`: diff --git a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh new file mode 100644 index 00000000..cc521b92 --- /dev/null +++ b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +set -euo pipefail + +output_dir="${1:-/tmp/anonymizer-repo-data-smoke-dd-traces}" +trace_mode="${DD_TRACE_MODE:-last-message}" + +uv run python tools/measurement/run_benchmarks.py \ + /stable-cache/anonymizer/repo-data-smoke.yaml \ + --output "${output_dir}" \ + --overwrite \ + --dd-trace "${trace_mode}" \ + --trace-dir "${output_dir}/traces" diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index bc0ade51..13843940 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -18,13 +18,26 @@ from typing import Annotated, Any import cyclopts +import pandas as pd +import pyarrow.parquet as pq import yaml +from data_designer.config.models import ModelProvider +from data_designer.config.utils.io_helpers import load_config_file from export_measurements import ExportFormat, export_tables, read_measurements from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator -from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect, Rewrite +from anonymizer.config.anonymizer_config import ( + AnonymizerConfig, + AnonymizerInput, + Detect, + Rewrite, + infer_input_source_suffix, + is_remote_input_source, +) from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance +from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS +from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import MeasurementConfig, configured_measurement_session @@ -46,6 +59,12 @@ class CaseStatus(StrEnum): error = "error" +class DDTraceMode(StrEnum): + none = "none" + last_message = "last_message" + all_messages = "all_messages" + + class ReplaceKind(StrEnum): redact = "redact" hash = "hash" @@ -163,6 +182,7 @@ class BenchmarkCase(BaseModel): status: CaseStatus = CaseStatus.planned elapsed_sec: float | None = None measurement_path: str | None = None + trace_path: str | None = None error: str | None = None @@ -236,6 +256,105 @@ def prepare_output_dir(output_dir: Path, *, overwrite: bool, dry_run: bool) -> N (output_dir / "raw").mkdir(exist_ok=True) +def preflight_suite(spec: BenchmarkSpec, *, spec_path: Path) -> None: + """Validate cheap suite inputs before any benchmark case consumes model time.""" + base_dir = spec_path.parent + errors: list[str] = [] + parsed_models = None + + try: + parsed_models = parse_model_configs(_resolve_config_source(spec.model_configs, base_dir)) + except Exception as exc: + errors.append(f"model_configs invalid: {exc}") + + try: + _preflight_model_providers(spec, base_dir=base_dir) + except Exception as exc: + errors.append(f"model_providers invalid: {exc}") + + for workload in spec.workloads: + try: + _preflight_workload(workload, base_dir=base_dir) + except Exception as exc: + errors.append(str(exc)) + + for config in spec.configs: + try: + anonymizer_config = build_anonymizer_config(config) + except Exception as exc: + errors.append(f"config '{config.id}' invalid: {exc}") + continue + if parsed_models is None: + continue + try: + validate_model_alias_references( + parsed_models.model_configs, + parsed_models.selected_models, + check_substitute=isinstance(anonymizer_config.replace, Substitute) + or anonymizer_config.rewrite is not None, + check_rewrite=anonymizer_config.rewrite is not None, + ) + except ValueError as exc: + errors.append(f"config '{config.id}' model aliases invalid: {exc}") + + if errors: + raise ValueError("Benchmark preflight failed:\n- " + "\n- ".join(errors)) + + +def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: + raw = _resolve_config_source(spec.model_providers, base_dir) + if raw is None: + return + config_source: str | Path = raw + if isinstance(raw, str) and "\n" not in raw: + candidate = Path(raw.strip()).expanduser() + if candidate.suffix in (".yaml", ".yml"): + if not candidate.is_file(): + raise FileNotFoundError(f"Providers config file not found: {candidate}") + config_source = candidate + config_dict = load_config_file(config_source) + raw_providers = config_dict.get("providers") if isinstance(config_dict, dict) else None + if not isinstance(raw_providers, list): + raise ValueError("model_providers YAML must contain a top-level 'providers' list.") + for provider in raw_providers: + ModelProvider.model_validate(provider) + + +def _preflight_workload(workload: WorkloadSpec, *, base_dir: Path) -> None: + resolved_source = _resolve_input_source(workload.source, base_dir) + input_data = AnonymizerInput( + source=str(resolved_source), + text_column=workload.text_column, + id_column=workload.id_column, + data_summary=workload.data_summary, + ) + columns = _input_columns(input_data.source) + if columns is None: + return + if workload.text_column not in columns: + raise ValueError( + f"workload '{workload.id}' text_column '{workload.text_column}' not found in {input_data.source}; " + f"available columns: {sorted(columns)}" + ) + if workload.id_column is not None and workload.id_column not in columns: + raise ValueError( + f"workload '{workload.id}' id_column '{workload.id_column}' not found in {input_data.source}; " + f"available columns: {sorted(columns)}" + ) + + +def _input_columns(source: str) -> set[str] | None: + suffix = infer_input_source_suffix(source) + if suffix not in SUPPORTED_IO_FORMATS: + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + if is_remote_input_source(source): + return None + if suffix == ".csv": + return set(pd.read_csv(source, nrows=0).columns) + return set(pq.ParquetFile(source).schema_arrow.names) + + def run_suite( spec: BenchmarkSpec, *, @@ -243,8 +362,16 @@ def run_suite( output_dir: Path, export: bool, fail_fast: bool, + dd_trace: DDTraceMode, + trace_dir: Path | None, ) -> BenchmarkResult: - contexts = _build_contexts(spec, spec_path=spec_path, output_dir=output_dir) + contexts = _build_contexts( + spec, + spec_path=spec_path, + output_dir=output_dir, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) anonymizer = Anonymizer(**contexts["anonymizer_kwargs"]) cases = [ _run_case(case, spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast) @@ -265,7 +392,14 @@ def run_suite( return result -def _build_contexts(spec: BenchmarkSpec, *, spec_path: Path, output_dir: Path) -> dict[str, Any]: +def _build_contexts( + spec: BenchmarkSpec, + *, + spec_path: Path, + output_dir: Path, + dd_trace: DDTraceMode, + trace_dir: Path | None, +) -> dict[str, Any]: base_dir = spec_path.parent artifact_path = _resolve_optional_path(spec.artifact_path, base_dir) or output_dir / "artifacts" return { @@ -273,6 +407,8 @@ def _build_contexts(spec: BenchmarkSpec, *, spec_path: Path, output_dir: Path) - "workloads": {workload.id: workload for workload in spec.workloads}, "configs": {config.id: config for config in spec.configs}, "raw_dir": output_dir / "raw", + "dd_trace": dd_trace, + "trace_dir": trace_dir or output_dir / "traces", "anonymizer_kwargs": { "model_configs": _resolve_config_source(spec.model_configs, base_dir), "model_providers": _resolve_config_source(spec.model_providers, base_dir), @@ -290,18 +426,28 @@ def _run_case( fail_fast: bool, ) -> BenchmarkCase: raw_path = contexts["raw_dir"] / f"{case.case_id}.jsonl" + trace_path = _case_trace_path(case, contexts=contexts) started = time.perf_counter() try: workload = _get_item(contexts["workloads"], case.workload_id, "workload") config = _get_item(contexts["configs"], case.config_id, "config") _execute_case( - anonymizer, workload, config, raw_path=raw_path, case=case, spec=spec, base_dir=contexts["base_dir"] + anonymizer, + workload, + config, + raw_path=raw_path, + trace_path=trace_path, + case=case, + spec=spec, + base_dir=contexts["base_dir"], + dd_trace=contexts["dd_trace"], ) return case.model_copy( update={ "status": CaseStatus.completed, "elapsed_sec": time.perf_counter() - started, "measurement_path": str(raw_path), + "trace_path": str(trace_path) if trace_path is not None else None, } ) except Exception as exc: @@ -312,22 +458,40 @@ def _run_case( "status": CaseStatus.error, "elapsed_sec": time.perf_counter() - started, "measurement_path": str(raw_path), + "trace_path": str(trace_path) if trace_path is not None else None, "error": str(exc), } ) +def _case_trace_path(case: BenchmarkCase, *, contexts: dict[str, Any]) -> Path | None: + if contexts["dd_trace"] == DDTraceMode.none: + return None + return contexts["trace_dir"] / f"{case.case_id}.jsonl" + + def _execute_case( anonymizer: Anonymizer, workload: WorkloadSpec, config: ConfigSpec, *, raw_path: Path, + trace_path: Path | None, case: BenchmarkCase, spec: BenchmarkSpec, base_dir: Path, + dd_trace: DDTraceMode, ) -> None: - measurement = MeasurementConfig(output_path=raw_path, run_id=case.case_id, run_tags=_run_tags(case, spec)) + measurement = MeasurementConfig( + output_path=raw_path, + run_id=case.case_id, + run_tags=_run_tags(case, spec), + streaming=True, + keep_records=False, + dd_trace=dd_trace.value, + dd_trace_path=trace_path, + fail_on_write_error=True, + ) with configured_measurement_session(measurement): anonymizer.run(config=build_anonymizer_config(config), data=build_input(workload, base_dir)) @@ -471,8 +635,20 @@ def _resolve_path(raw: str, base_dir: Path) -> Path: return path if path.is_absolute() else base_dir / path -def dry_run_result(spec: BenchmarkSpec, *, output_dir: Path, export: bool) -> BenchmarkResult: +def dry_run_result( + spec: BenchmarkSpec, + *, + output_dir: Path, + export: bool, + dd_trace: DDTraceMode, + trace_dir: Path | None, +) -> BenchmarkResult: cases = build_cases(spec) + if dd_trace != DDTraceMode.none: + resolved_trace_dir = trace_dir or output_dir / "traces" + cases = [ + case.model_copy(update={"trace_path": str(resolved_trace_dir / f"{case.case_id}.jsonl")}) for case in cases + ] return BenchmarkResult( suite_id=spec.suite_id, output_dir=str(output_dir), @@ -492,13 +668,22 @@ def main( dry_run: Annotated[bool, cyclopts.Parameter("--dry-run")] = False, export: Annotated[bool, cyclopts.Parameter("--export")] = True, fail_fast: Annotated[bool, cyclopts.Parameter("--fail-fast")] = False, + dd_trace: Annotated[DDTraceMode, cyclopts.Parameter("--dd-trace")] = DDTraceMode.none, + trace_dir: Annotated[Path | None, cyclopts.Parameter("--trace-dir")] = None, json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, ) -> None: configure_logging(log_format) try: result = run_or_plan( - spec, output=output, overwrite=overwrite, dry_run=dry_run, export=export, fail_fast=fail_fast + spec, + output=output, + overwrite=overwrite, + dry_run=dry_run, + export=export, + fail_fast=fail_fast, + dd_trace=dd_trace, + trace_dir=trace_dir, ) except (ValueError, ValidationError) as exc: log_bad_input(str(exc)) @@ -516,13 +701,32 @@ def run_or_plan( dry_run: bool, export: bool, fail_fast: bool, + dd_trace: DDTraceMode = DDTraceMode.none, + trace_dir: Path | None = None, ) -> BenchmarkResult: benchmark_spec = load_spec(spec_path) output_dir = output or Path("benchmark-runs") / benchmark_spec.suite_id - prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) + if trace_dir is not None and dd_trace == DDTraceMode.none: + raise ValueError("--trace-dir requires --dd-trace") if dry_run: - return dry_run_result(benchmark_spec, output_dir=output_dir, export=export) - return run_suite(benchmark_spec, spec_path=spec_path, output_dir=output_dir, export=export, fail_fast=fail_fast) + return dry_run_result( + benchmark_spec, + output_dir=output_dir, + export=export, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) + preflight_suite(benchmark_spec, spec_path=spec_path) + prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) + return run_suite( + benchmark_spec, + spec_path=spec_path, + output_dir=output_dir, + export=export, + fail_fast=fail_fast, + dd_trace=dd_trace, + trace_dir=trace_dir, + ) if __name__ == "__main__": From 6e7a30cba777068aeb4d3042169a22bf08e56a23 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 19:59:18 +0000 Subject: [PATCH 04/26] feat: add benchmark analysis strategy probes Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/constants.py | 1 + .../engine/detection/chunked_validation.py | 12 +- .../engine/detection/detection_workflow.py | 77 +- src/anonymizer/engine/detection/rules.py | 274 ++ src/anonymizer/engine/ndd/adapter.py | 28 +- .../engine/replace/llm_replace_workflow.py | 43 +- .../engine/replace/structured_substitute.py | 299 ++ src/anonymizer/measurement.py | 246 +- tests/engine/test_chunked_validation.py | 35 + tests/engine/test_detection_rules.py | 318 +++ tests/engine/test_llm_replace_workflow.py | 35 + tests/engine/test_structured_substitute.py | 160 ++ tests/test_measurement.py | 344 +++ tests/tools/test_benchmark_output_analysis.py | 907 ++++++ tests/tools/test_compare_strategy_pairs.py | 1474 ++++++++++ tests/tools/test_dd_parser_compat.py | 93 + tests/tools/test_dd_trace_analysis.py | 160 ++ .../tools/test_detection_artifact_analysis.py | 159 ++ tests/tools/test_detection_strategies.py | 2143 +++++++++++++++ tests/tools/test_direct_detection_probe.py | 147 + tests/tools/test_extract_signature_deltas.py | 190 ++ tests/tools/test_measurement_tools.py | 1298 ++++++++- tests/tools/test_replacement_strategies.py | 76 + .../test_replay_replacement_strategies.py | 430 +++ .../tools/test_screen_strategy_comparisons.py | 1046 +++++++ .../test_staged_detection_output_analysis.py | 242 ++ tests/tools/test_staged_detection_probe.py | 812 ++++++ tools/measurement/README.md | 1569 ++++++++++- tools/measurement/analyze_benchmark_output.py | 1180 ++++++++ tools/measurement/analyze_dd_traces.py | 431 +++ .../analyze_detection_artifacts.py | 376 +++ .../analyze_staged_detection_output.py | 528 ++++ tools/measurement/compare_strategy_pairs.py | 1459 ++++++++++ tools/measurement/dd_parser_compat.py | 108 + tools/measurement/detection_strategies.py | 2436 +++++++++++++++++ tools/measurement/direct_detection_probe.py | 559 ++++ tools/measurement/extract_signature_deltas.py | 556 ++++ tools/measurement/replacement_strategies.py | 58 + .../replay_replacement_strategies.py | 736 +++++ tools/measurement/run_benchmarks.py | 881 +++++- .../screen_strategy_comparisons.py | 1166 ++++++++ tools/measurement/staged_detection_probe.py | 1475 ++++++++++ 42 files changed, 24477 insertions(+), 90 deletions(-) create mode 100644 src/anonymizer/engine/detection/rules.py create mode 100644 src/anonymizer/engine/replace/structured_substitute.py create mode 100644 tests/engine/test_detection_rules.py create mode 100644 tests/engine/test_structured_substitute.py create mode 100644 tests/tools/test_benchmark_output_analysis.py create mode 100644 tests/tools/test_compare_strategy_pairs.py create mode 100644 tests/tools/test_dd_parser_compat.py create mode 100644 tests/tools/test_dd_trace_analysis.py create mode 100644 tests/tools/test_detection_artifact_analysis.py create mode 100644 tests/tools/test_detection_strategies.py create mode 100644 tests/tools/test_direct_detection_probe.py create mode 100644 tests/tools/test_extract_signature_deltas.py create mode 100644 tests/tools/test_replacement_strategies.py create mode 100644 tests/tools/test_replay_replacement_strategies.py create mode 100644 tests/tools/test_screen_strategy_comparisons.py create mode 100644 tests/tools/test_staged_detection_output_analysis.py create mode 100644 tests/tools/test_staged_detection_probe.py create mode 100644 tools/measurement/analyze_benchmark_output.py create mode 100644 tools/measurement/analyze_dd_traces.py create mode 100644 tools/measurement/analyze_detection_artifacts.py create mode 100644 tools/measurement/analyze_staged_detection_output.py create mode 100644 tools/measurement/compare_strategy_pairs.py create mode 100644 tools/measurement/dd_parser_compat.py create mode 100644 tools/measurement/detection_strategies.py create mode 100644 tools/measurement/direct_detection_probe.py create mode 100644 tools/measurement/extract_signature_deltas.py create mode 100644 tools/measurement/replacement_strategies.py create mode 100644 tools/measurement/replay_replacement_strategies.py create mode 100644 tools/measurement/screen_strategy_comparisons.py create mode 100644 tools/measurement/staged_detection_probe.py diff --git a/src/anonymizer/engine/constants.py b/src/anonymizer/engine/constants.py index fdfecadf..d9e1f079 100644 --- a/src/anonymizer/engine/constants.py +++ b/src/anonymizer/engine/constants.py @@ -45,6 +45,7 @@ COL_ENTITIES_BY_VALUE = "_entities_by_value" COL_REPLACED_TEXT = "__nemo_anonymizer_text_output__" COL_REPLACEMENT_MAP = "_replacement_map" +COL_REPLACEMENT_MAP_SOURCE = "_replacement_map_source" # LlmReplaceWorkflow internal prompt-construction columns. Created by # `LlmReplaceWorkflow.generate_map_only` for the replacement-generator prompt diff --git a/src/anonymizer/engine/detection/chunked_validation.py b/src/anonymizer/engine/detection/chunked_validation.py index 50601cad..b870fab9 100644 --- a/src/anonymizer/engine/detection/chunked_validation.py +++ b/src/anonymizer/engine/detection/chunked_validation.py @@ -102,6 +102,11 @@ class ChunkedValidationParams(BaseModel): max_entities_per_call: Upper bound on candidates per chunk. excerpt_window_chars: Chars of surrounding raw text included in each chunk's excerpt on either side of the chunk span. + single_chunk_full_text: If True, a row with one validation chunk sees + the full tagged document. If False, even a single chunk uses the + excerpt window. The default preserves production parity with the + pre-chunking validation path; benchmarks may disable it to probe + compact validation prompts. prompt_template: Jinja2 source for the validation prompt (with ``_seed_tagged_text``, ``_validation_skeleton``, ``_tag_notation`` placeholders). Typically produced by ``_get_validation_prompt``. @@ -119,6 +124,7 @@ class ChunkedValidationParams(BaseModel): pool: list[str] = Field(min_length=1) max_entities_per_call: int = Field(gt=0) excerpt_window_chars: int = Field(gt=0) + single_chunk_full_text: bool = True prompt_template: str = Field(repr=False) system_prompt: str | None = Field(default=None, repr=False) @@ -449,7 +455,11 @@ def chunked_validate_row( # only making one call there's no cost reason to clip, and clipping # would silently narrow the context the validator sees. Computed once # here because ``len(chunks) == 1`` is loop-invariant. - single_chunk_tagged_text = build_tagged_text(text, all_spans, notation=notation) if len(chunks) == 1 else None + single_chunk_tagged_text = ( + build_tagged_text(text, all_spans, notation=notation) + if len(chunks) == 1 and params.single_chunk_full_text + else None + ) dispatch_kwargs_per_chunk: list[dict[str, Any]] = [] for chunk_index, chunk in enumerate(chunks): diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index e482e912..153eb50d 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -49,7 +49,12 @@ parse_detected_entities, prepare_validation_inputs, ) -from anonymizer.engine.detection.postprocess import EntitySpan, group_entities_by_value +from anonymizer.engine.detection.postprocess import EntitySpan, build_tagged_text, group_entities_by_value +from anonymizer.engine.detection.rules import ( + STRUCTURED_RULE_FAST_LANE_LABELS, + SUPPORTED_RULE_LABELS, + detect_high_confidence_entities, +) from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases from anonymizer.engine.prompt_utils import substitute_placeholders @@ -86,6 +91,33 @@ class EntityDetectionWorkflow: def __init__(self, adapter: NddAdapter) -> None: self._adapter = adapter + def detect_with_high_confidence_rules( + self, + dataframe: pd.DataFrame, + *, + entity_labels: list[str] | None = None, + ) -> EntityDetectionResult: + """Detect only deterministic high-confidence rule spans without DataDesigner. + + This is an internal fast-lane primitive for benchmark probes and + future routing work. It is intentionally limited to labels with narrow + deterministic coverage and does not attempt contextual PII detection. + """ + labels = _resolve_detection_labels(entity_labels) + _ensure_high_confidence_rule_labels(labels) + output = dataframe.copy() + output[COL_DETECTED_ENTITIES] = output[COL_TEXT].apply( + lambda text: _high_confidence_rule_payload(text, labels=labels) + ) + output[COL_TAGGED_TEXT] = output.apply( + lambda row: _tagged_text_from_entities( + text=row.get(COL_TEXT, ""), + raw_entities=row.get(COL_DETECTED_ENTITIES, {}), + ), + axis=1, + ) + return EntityDetectionResult(dataframe=output, failed_records=[]) + def detect_and_validate_entities( self, dataframe: pd.DataFrame, @@ -95,6 +127,7 @@ def detect_and_validate_entities( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, data_summary: str | None = None, preview_num_records: int | None = None, @@ -144,6 +177,7 @@ def detect_and_validate_entities( pool=list(validator_aliases), max_entities_per_call=validation_max_entities_per_call, excerpt_window_chars=validation_excerpt_window_chars, + single_chunk_full_text=validation_single_chunk_full_text, prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels), ) @@ -355,6 +389,47 @@ def _resolve_detection_labels(entity_labels: list[str] | None) -> list[str]: return list(entity_labels) +def labels_are_supported_by_high_confidence_rules(labels: list[str]) -> bool: + """Return True when every label can be handled by deterministic rules.""" + return set(labels).issubset(SUPPORTED_RULE_LABELS) + + +def labels_are_supported_by_structured_rule_fast_lane(labels: list[str]) -> bool: + """Return True when every label is safe for the structured no-model fast lane.""" + return set(labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) + + +def _ensure_high_confidence_rule_labels(labels: list[str]) -> None: + unsupported = sorted(set(labels) - SUPPORTED_RULE_LABELS) + if unsupported: + supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) + raise ValueError( + f"unsupported high-confidence rule labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + +def _high_confidence_rule_payload(text: object, *, labels: list[str]) -> dict: + spans = detect_high_confidence_entities(str(text), labels=labels) + return EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump(mode="json") + + +def _tagged_text_from_entities(*, text: object, raw_entities: object) -> str: + parsed = EntitiesSchema.from_raw(raw_entities) + spans = [ + EntitySpan( + entity_id=e.id, + value=e.value, + label=e.label, + start_position=e.start_position, + end_position=e.end_position, + score=e.score, + source=e.source, + ) + for e in parsed.entities + ] + return build_tagged_text(text=str(text), entities=spans) + + def _materialize_final_entities(raw: object, *, allowed_labels: set[str] | None) -> dict: """Build COL_FINAL_ENTITIES, optionally filtering to *allowed_labels*.""" parsed = EntitiesSchema.from_raw(raw) diff --git a/src/anonymizer/engine/detection/rules.py b/src/anonymizer/engine/detection/rules.py new file mode 100644 index 00000000..1b96748b --- /dev/null +++ b/src/anonymizer/engine/detection/rules.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass + +from anonymizer.engine.detection.postprocess import EntitySpan, resolve_overlaps + +_RULE_SCORE = 1.0 +_RULE_SOURCE = "rule" +_RELIGIOUS_BELIEF_TERMS = ( + "agnostic", + "atheist", + "baptist", + "buddhist", + "catholic", + "christian", + "hindu", + "jewish", + "mormon", + "muslim", + "protestant", + "secular", +) +_RELIGIOUS_BELIEF_RE = "|".join(re.escape(term) for term in _RELIGIOUS_BELIEF_TERMS) +_COOKIE_PAIR_RE = r"[A-Za-z][A-Za-z0-9_-]*=[^;'\s\"\r\n]+" +_COOKIE_VALUE_RE = rf"({_COOKIE_PAIR_RE}(?:;\s*{_COOKIE_PAIR_RE})*)" +_STRUCTURED_ID_VALUE_RE = ( + r"(?:[A-Za-z][A-Za-z0-9]{1,20}[-_][A-Za-z0-9][A-Za-z0-9_-]{5,}|" + r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})" +) + + +@dataclass(frozen=True) +class _RulePattern: + label: str + pattern: re.Pattern[str] + group: int = 0 + + +_RULES: tuple[_RulePattern, ...] = ( + _RulePattern( + label="api_key", + pattern=re.compile(r"sk-(?:test|ant-api03|proj|prod)-[A-Za-z0-9_-]{16,}"), + ), + _RulePattern(label="api_key", pattern=re.compile(r"ghp_[A-Za-z0-9_]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"hf_[A-Za-z0-9]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"pat-[A-Za-z0-9_-]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"xoxb-[A-Za-z0-9-]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"AIza[A-Za-z0-9_-]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"ya29\.[A-Za-z0-9_-]{20,}")), + _RulePattern(label="api_key", pattern=re.compile(r"AKIA[A-Z0-9]{16,}")), + _RulePattern( + label="api_key", + pattern=re.compile( + r"\b(?:api[_-]?key|token|auth[_-]?token|session[_-]?id|aws_access_key_id|access_key_id)=" + r"([^\s;'\"\\]{8,})", + flags=re.IGNORECASE, + ), + group=1, + ), + _RulePattern( + label="api_key", + pattern=re.compile(r"Authorization:\s*Bearer\s+([A-Za-z0-9._-]{16,})", flags=re.IGNORECASE), + group=1, + ), + _RulePattern( + label="http_cookie", + pattern=re.compile(rf"\bCookie:\s*{_COOKIE_VALUE_RE}", flags=re.IGNORECASE), + group=1, + ), + _RulePattern( + label="http_cookie", + pattern=re.compile(rf"\bcookie\s*=\s*{_COOKIE_VALUE_RE}", flags=re.IGNORECASE), + group=1, + ), + _RulePattern( + label="pin", + pattern=re.compile(r"(?]+", + flags=re.IGNORECASE, + ), + ), + _RulePattern( + label="email", + pattern=re.compile(r"(?]+")), + _RulePattern( + label="date_of_birth", + pattern=re.compile( + r"\b(?:born|date\s+of\s+birth|dob)\s*(?:[:=-]|\bin\b|\bon\b)?\s*" + r"(\d{4}|\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}-\d{2}-\d{2})\b", + flags=re.IGNORECASE, + ), + group=1, + ), + _RulePattern( + label="religious_belief", + pattern=re.compile( + rf"\b(?:describes?\s+(?:himself|herself|themself|themselves)\s+as|" + rf"identif(?:y|ies)\s+as|raised\s+in\s+the|practicing)\s+" + rf"(?:a|an|the)?\s*({_RELIGIOUS_BELIEF_RE})\b", + flags=re.IGNORECASE, + ), + group=1, + ), + _RulePattern( + label="street_address", + pattern=re.compile( + r"\b(?:lives?\s+at|living\s+at|house\s+on|home\s+on)\s+" + r"([A-Z0-9][A-Za-z0-9.\s-]{1,60}?\b" + r"(?:Street|St\.?|Avenue|Ave\.?|Road|Rd\.?|Drive|Dr\.?|Trail|Boulevard|Blvd\.?|Lane|Ln\.?|Court|Ct\.?))", + ), + group=1, + ), + _RulePattern( + label="organization_name", + pattern=re.compile( + r"\b(?:at|from|with|joining|joined)\s+" + r"([A-Z][A-Za-z0-9&.'\u2019 -]{2,90}?\b" + r"(?:Center|Hospital|Clinic|University|College|Institute|Bank|Builders|Construction|Woodworks|Health))" + r"\b", + ), + group=1, + ), +) + +SUPPORTED_RULE_LABELS = frozenset(rule.label for rule in _RULES) +STRUCTURED_RULE_FAST_LANE_LABELS = frozenset( + { + "api_key", + "email", + "http_cookie", + "password", + "pin", + "unique_id", + "url", + "user_name", + } +) + + +def detect_high_confidence_entities(text: str, labels: Iterable[str] | None = None) -> list[EntitySpan]: + """Detect deterministic high-confidence PII and secret spans in raw text. + + These rules intentionally cover narrow, high-signal command/log and prose + patterns. They are suitable as a local seed detector or benchmark probe, + not as a complete replacement for model-backed contextual detection. + """ + allowed_labels = set(labels) if labels is not None else None + spans: list[EntitySpan] = [] + + for rule in _RULES: + if allowed_labels is not None and rule.label not in allowed_labels: + continue + for match in rule.pattern.finditer(text): + start, end = match.span(rule.group) + if start < 0 or end <= start: + continue + value = text[start:end] + value, end = _trim_rule_value(label=rule.label, value=value, end=end) + if not value: + continue + spans.append( + EntitySpan( + entity_id=_build_rule_entity_id(label=rule.label, start=start, end=end), + value=value, + label=rule.label, + start_position=start, + end_position=end, + score=_RULE_SCORE, + source=_RULE_SOURCE, + ) + ) + + return resolve_overlaps(_deduplicate(spans)) + + +def _trim_rule_value(*, label: str, value: str, end: int) -> tuple[str, int]: + if label != "http_cookie": + return value, end + trimmed = value.rstrip(".,") + return trimmed, end - (len(value) - len(trimmed)) + + +def _deduplicate(entities: list[EntitySpan]) -> list[EntitySpan]: + seen: set[tuple[str, int, int]] = set() + deduplicated: list[EntitySpan] = [] + for entity in entities: + key = (entity.label, entity.start_position, entity.end_position) + if key in seen: + continue + seen.add(key) + deduplicated.append(entity) + return deduplicated + + +def _build_rule_entity_id(*, label: str, start: int, end: int) -> str: + return f"{label}_{start}_{end}" diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 3621a178..70b05f9c 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -329,6 +329,10 @@ def model_usage(self) -> dict[str, Any] | None: def _get_model_usage_snapshot(model_registry: object) -> Mapping[str, object] | None: + alias_snapshot = _get_model_usage_snapshot_by_alias(model_registry) + if alias_snapshot: + return alias_snapshot + get_snapshot = getattr(model_registry, "get_model_usage_snapshot", None) if not callable(get_snapshot): return None @@ -338,6 +342,28 @@ def _get_model_usage_snapshot(model_registry: object) -> Mapping[str, object] | return None +def _get_model_usage_snapshot_by_alias(model_registry: object) -> Mapping[str, object] | None: + models = getattr(model_registry, "_models", None) + if not isinstance(models, Mapping): + return None + + snapshot: dict[str, object] = {} + for model_alias, model_facade in models.items(): + stats = getattr(model_facade, "usage_stats", None) + if stats is None or not getattr(stats, "has_usage", False): + continue + payload = _model_usage_as_json(stats) + if isinstance(payload, Mapping): + payload = { + **payload, + "model_alias": getattr(model_facade, "model_alias", str(model_alias)), + "model_name": getattr(model_facade, "model_name", None), + "model_provider_name": getattr(model_facade, "model_provider_name", None), + } + snapshot[str(model_alias)] = payload + return snapshot or None + + def _model_usage_as_json(stats: object) -> Any: model_dump = getattr(stats, "model_dump", None) if callable(model_dump): @@ -484,8 +510,6 @@ def _record_dd_message_trace( error_type: str | None, is_async: bool, ) -> None: - if error_type == "SyncClientUnavailableError": - return collector.record_dd_message_trace( workflow_name=workflow_name, model_alias=getattr(model_facade, "model_alias", None), diff --git a/src/anonymizer/engine/replace/llm_replace_workflow.py b/src/anonymizer/engine/replace/llm_replace_workflow.py index ccd5cb1d..531b6827 100644 --- a/src/anonymizer/engine/replace/llm_replace_workflow.py +++ b/src/anonymizer/engine/replace/llm_replace_workflow.py @@ -19,6 +19,7 @@ COL_ENTITIES_FOR_REPLACE_JSON, COL_ENTITY_EXAMPLES, COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, ENTITY_LABEL_EXAMPLES, ) from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter @@ -28,6 +29,7 @@ from anonymizer.engine.schemas import EntitiesByValueSchema, EntityReplacementMapSchema logger = logging.getLogger("anonymizer.replace.llm_workflow") +REPLACEMENT_MAP_SOURCE_LLM = "llm" # Workflow-internal scratch columns used only to build the replacement-generator # prompt. Created in `generate_map_only` and dropped before returning — nothing @@ -71,6 +73,7 @@ def generate_map_only( # Partition: rows with an empty entity list bypass replacement-map generation. entity_rows, passthrough_rows = split_rows(working_df, column=COL_ENTITIES_FOR_REPLACE, predicate=bool) passthrough_rows[COL_REPLACEMENT_MAP] = [{"replacements": []} for _ in range(len(passthrough_rows))] + passthrough_rows[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LLM if entity_rows.empty: passthrough_only = merge_and_reorder(passthrough_rows) @@ -110,6 +113,7 @@ def generate_map_only( ), axis=1, ) + output_df[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LLM combined = merge_and_reorder(output_df, passthrough_rows) return LlmReplaceResult( @@ -160,28 +164,56 @@ def _filter_replacement_map_to_input_entities( for label in entity.labels if entity.value and label } + protected_original_values = {value for value, _ in allowed_pairs} filtered: list[dict[str, str]] = [] seen: set[tuple[str, str]] = set() + synthetic_collision_labels: Counter[str] = Counter() for replacement in parsed_map.replacements: key = (replacement.original, replacement.label) if key not in allowed_pairs or key in seen: continue + if replacement.synthetic in protected_original_values: + synthetic_collision_labels[replacement.label] += 1 + seen.add(key) + filtered.append( + { + "original": replacement.original, + "label": replacement.label, + "synthetic": _collision_safe_synthetic( + replacement.label, + index=synthetic_collision_labels[replacement.label], + protected_original_values=protected_original_values, + ), + } + ) + continue seen.add(key) filtered.append(replacement.model_dump()) + if synthetic_collision_labels: + logger.warning( + "Replacement map repaired synthetic-original collision entries for record %s; repaired=%d " + "(repaired_by_label=%s)", + record_id or "", + sum(synthetic_collision_labels.values()), + dict(synthetic_collision_labels), + ) if logger.isEnabledFor(logging.DEBUG): raw_pairs = {(r.original, r.label) for r in parsed_map.replacements} filtered_pairs = {(f["original"], f["label"]) for f in filtered} unrequested_labels = Counter(label for _, label in (raw_pairs - allowed_pairs)) unfilled_labels = Counter(label for _, label in (allowed_pairs - filtered_pairs)) logger.debug( - "Replacement map record %s: requested=%d raw=%d filtered=%d%s%s", + "Replacement map record %s: requested=%d raw=%d filtered=%d%s%s%s", record_id or "", len(allowed_pairs), len(parsed_map.replacements), len(filtered), f" unrequested_by_label={dict(unrequested_labels)}" if unrequested_labels else "", f" unfilled_by_label={dict(unfilled_labels)}" if unfilled_labels else "", + f" synthetic_original_collision_by_label={dict(synthetic_collision_labels)}" + if synthetic_collision_labels + else "", ) if not filtered and allowed_pairs: requested_labels = Counter(label for _, label in allowed_pairs) @@ -195,6 +227,15 @@ def _filter_replacement_map_to_input_entities( return {"replacements": filtered} +def _collision_safe_synthetic(label: str, *, index: int, protected_original_values: set[str]) -> str: + label_token = "".join(char.upper() if char.isalnum() else "_" for char in label).strip("_") or "VALUE" + while True: + candidate = f"[SUBSTITUTE_{label_token}_{index}]" + if candidate not in protected_original_values: + return candidate + index += 1 + + def _get_replacement_mapping_prompt(*, entities_column: str, instructions: str | None = None) -> str: instruction_block = f"\nAdditional instructions: {instructions}\n" if instructions else "" prompt = """Generate synthetic replacements for sensitive entities. ONE value per entity, used consistently. diff --git a/src/anonymizer/engine/replace/structured_substitute.py b/src/anonymizer/engine/replace/structured_substitute.py new file mode 100644 index 00000000..6205c005 --- /dev/null +++ b/src/anonymizer/engine/replace/structured_substitute.py @@ -0,0 +1,299 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import re +from collections.abc import Callable, Iterable + +import pandas as pd + +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE +from anonymizer.engine.schemas import EntitiesByValueSchema + +REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED = "local_structured" +SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS = frozenset( + { + "api_key", + "date_of_birth", + "email", + "http_cookie", + "organization_name", + "password", + "pin", + "religious_belief", + "street_address", + "unique_id", + "url", + "user_name", + } +) + +_RELIGIOUS_BELIEF_SUBSTITUTES = ( + "agnostic", + "atheist", + "buddhist", + "catholic", + "christian", + "hindu", + "jewish", + "muslim", + "secular", +) + + +def apply_structured_substitution_maps( + dataframe: pd.DataFrame, + *, + entities_column: str = COL_ENTITIES_BY_VALUE, +) -> pd.DataFrame: + """Attach deterministic substitute maps for supported structured labels. + + This helper intentionally builds only replacement maps. Text rewriting still + uses the normal replacement-map application path, so span handling remains + identical to LLM-backed ``Substitute``. + """ + output_df = dataframe.copy() + output_df[COL_REPLACEMENT_MAP] = output_df[entities_column].apply(build_structured_substitution_map) + output_df[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + return output_df + + +def build_structured_substitution_map(raw_entities: object) -> dict[str, list[dict[str, str]]]: + """Build a substitute map without model calls for narrow structured labels.""" + parsed = EntitiesByValueSchema.from_raw(raw_entities) + unsupported = _unsupported_labels(parsed) + if unsupported: + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + raise ValueError( + f"local structured substitute supports only deterministic structured labels; " + f"unsupported labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + original_values = {entity.value for entity in parsed.entities_by_value if entity.value} + synthetic_values: set[str] = set() + replacements: list[dict[str, str]] = [] + seen: set[tuple[str, str]] = set() + for entity in parsed.entities_by_value: + if not entity.value: + continue + for label in entity.labels: + if not label: + continue + key = (entity.value, label) + if key in seen: + continue + seen.add(key) + synthetic = structured_substitute_value( + entity.value, + label, + forbidden_values=original_values | synthetic_values, + ) + synthetic_values.add(synthetic) + replacements.append( + { + "original": entity.value, + "label": label, + "synthetic": synthetic, + } + ) + return {"replacements": replacements} + + +def structured_substitute_value( + value: str, + label: str, + *, + forbidden_values: Iterable[str] | None = None, +) -> str: + """Return a deterministic synthetic value for one supported structured label.""" + generator = _GENERATORS.get(label) + if generator is None: + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + raise ValueError(f"unsupported local structured substitute label: {label}; supported labels: {supported}") + forbidden = {str(item) for item in forbidden_values or () if item is not None} + forbidden.add(value) + + for salt in ("", "alternate"): + synthetic = generator(value, _digest(value=value, label=label, salt=salt)) + if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): + return synthetic + + fallback_index = 0 + while True: + synthetic = f"synthetic-{label}-{_digest(value=value, label=label, salt=f'fallback-{fallback_index}')[:12]}" + if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): + return synthetic + fallback_index += 1 + + +def _synthetic_is_allowed(synthetic: str, *, original: str, forbidden_values: set[str]) -> bool: + """Reject self-preservation and exact collisions with other protected originals.""" + return synthetic != original and original not in synthetic and synthetic not in forbidden_values + + +def _unsupported_labels(parsed: EntitiesByValueSchema) -> list[str]: + labels = {label for entity in parsed.entities_by_value for label in entity.labels if label} + return sorted(labels - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) + + +def _digest(*, value: str, label: str, salt: str = "") -> str: + return hashlib.sha256(f"{label}\0{salt}\0{value}".encode("utf-8")).hexdigest() + + +def _api_key(value: str, digest: str) -> str: + if value.startswith("ghp_"): + return "ghp_" + digest[:36] + if value.startswith("hf_"): + return "hf_" + digest[:40] + if value.startswith("pat-"): + return "pat-" + digest[:40] + if value.startswith("xoxb-"): + return f"xoxb-{digest[:12]}-{digest[12:24]}-{digest[24:36]}" + if value.startswith("ya29."): + return "ya29." + digest[:44] + if value.startswith("AKIA"): + return "AKIA" + digest[:16].upper() + sk_match = re.match(r"^(sk-(?:test|ant-api03|proj|prod)-)", value) + if sk_match: + return sk_match.group(1) + digest[:48] + return "tok_" + digest[:32] + + +def _password(_value: str, digest: str) -> str: + return f"Synthetic!{digest[:10]}A7" + + +def _email(_value: str, digest: str) -> str: + return f"user-{digest[:12]}@example.invalid" + + +def _http_cookie(value: str, digest: str) -> str: + parts = [part.strip() for part in value.split(";") if part.strip()] + rendered: list[str] = [] + for index, part in enumerate(parts): + if "=" not in part: + continue + name = part.split("=", 1)[0].strip() or f"cookie_{index}" + rendered.append(f"{name}={_cookie_value(name=name, digest=digest, index=index)}") + if rendered: + return "; ".join(rendered) + return f"session_id={digest[:32]}; auth_token={digest[32:56]}" + + +def _cookie_value(*, name: str, digest: str, index: int) -> str: + offset = (index * 8) % max(len(digest) - 8, 1) + chunk = digest[offset : offset + 24] + normalized = name.lower() + if "session" in normalized: + return chunk[:32] + if normalized.endswith("id") or normalized == "user_id": + return str(10000 + (int(chunk[:8], 16) % 90000)) + if "token" in normalized or "auth" in normalized or "jwt" in normalized: + return f"tok_{chunk}" + return chunk + + +def _pin(value: str, digest: str) -> str: + length = max(4, min(len(value), 8)) + number = int(digest[:12], 16) % (10**length) + return f"{number:0{length}d}" + + +def _unique_id(value: str, digest: str) -> str: + if re.fullmatch(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}", value): + return f"{digest[:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" + prefix_match = re.match(r"^([A-Za-z]+[-_])", value) + if prefix_match: + prefix = prefix_match.group(1) + return f"{prefix}{digest[:20]}" + if value.isdigit(): + return str(100000 + (int(digest[:12], 16) % 900000)) + return f"id_{digest[:24]}" + + +def _user_name(value: str, digest: str) -> str: + separator = "." if "." in value else "_" if "_" in value else "" + if separator: + return f"user{separator}{digest[:10]}" + return f"user{digest[:12]}" + + +def _url(value: str, digest: str) -> str: + scheme_match = re.match(r"^([A-Za-z][A-Za-z0-9+.-]*://)", value) + scheme = scheme_match.group(1) if scheme_match else "https://" + scheme_name = scheme[:-3].lower() + if scheme_name in {"postgres", "postgresql", "mysql", "mariadb", "mongodb", "mongodb+srv", "redis", "rediss"}: + return f"{scheme}user_{digest[:8]}:Synthetic!{digest[8:16]}@db-{digest[16:24]}.example.invalid:5432/app" + return f"{scheme}synthetic-{digest[:16]}.example.invalid/resource/{digest[16:24]}" + + +def _date_of_birth(value: str, digest: str) -> str: + year = 1950 + (int(digest[:4], 16) % 50) + month = 1 + (int(digest[4:6], 16) % 12) + day = 1 + (int(digest[6:8], 16) % 28) + if re.fullmatch(r"\d{4}", value): + return str(year) + ymd = re.fullmatch(r"\d{4}([/-])\d{1,2}\1\d{1,2}", value) + if ymd: + sep = ymd.group(1) + return f"{year:04d}{sep}{month:02d}{sep}{day:02d}" + mdy = re.fullmatch(r"\d{1,2}([/-])\d{1,2}\1(\d{2}|\d{4})", value) + if mdy: + sep = mdy.group(1) + rendered_year = f"{year % 100:02d}" if len(value.rsplit(sep, 1)[-1]) == 2 else f"{year:04d}" + return f"{month:02d}{sep}{day:02d}{sep}{rendered_year}" + return f"{year:04d}-{month:02d}-{day:02d}" + + +def _street_address(value: str, digest: str) -> str: + suffix_match = re.search( + r"\b(Street|St\.?|Avenue|Ave\.?|Road|Rd\.?|Drive|Dr\.?|Trail|Boulevard|Blvd\.?|Lane|Ln\.?|Court|Ct\.?)$", + value, + ) + suffix = suffix_match.group(1) if suffix_match else "Street" + number = 100 + (int(digest[:4], 16) % 8900) + return f"{number} Cedar Ridge {suffix}" + + +def _organization_name(value: str, digest: str) -> str: + suffixes = ( + "Center", + "Hospital", + "Clinic", + "University", + "College", + "Institute", + "Bank", + "Builders", + "Construction", + "Woodworks", + "Health", + ) + suffix = next((candidate for candidate in suffixes if value.endswith(candidate)), "Group") + prefixes = ("Northbridge", "Helios", "Mariner", "Summit", "Cedar") + prefix = prefixes[int(digest[:2], 16) % len(prefixes)] + return f"{prefix} {suffix}" + + +def _religious_belief(value: str, digest: str) -> str: + normalized = value.lower() + candidates = [candidate for candidate in _RELIGIOUS_BELIEF_SUBSTITUTES if candidate != normalized] + return candidates[int(digest[:2], 16) % len(candidates)] + + +_GENERATORS: dict[str, Callable[[str, str], str]] = { + "api_key": _api_key, + "date_of_birth": _date_of_birth, + "email": _email, + "http_cookie": _http_cookie, + "organization_name": _organization_name, + "password": _password, + "pin": _pin, + "religious_belief": _religious_belief, + "street_address": _street_address, + "unique_id": _unique_id, + "url": _url, + "user_name": _user_name, +} diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 309264c5..510eb988 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -440,6 +440,81 @@ def record_ndd_workflow( model_usage: Mapping[str, Any] | None = None, ) -> None: """Record one DataDesigner workflow execution through the adapter boundary.""" + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="ndd_workflow", + extra_fields=None, + ) + + +def record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, + extra_fields: Mapping[str, Any] | None = None, +) -> None: + """Record one sanitized model-backed workflow execution. + + Use this for non-DataDesigner model calls that still need benchmark + accounting. Raw prompts, text, responses, and replacement values do not + belong in ``model_usage``. + """ + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="model_workflow", + extra_fields=extra_fields, + ) + + +def _record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str, + seed_row_count: int | None, + preview_num_records: int | None, + column_count: int | None, + column_names: list[str] | None, + model_usage: Mapping[str, Any] | None, + record_type: str, + extra_fields: Mapping[str, Any] | None, +) -> None: collector = current_collector() if collector is None: return @@ -457,14 +532,19 @@ def record_ndd_workflow( "column_count": column_count, "column_names": column_names or [], "model_usage": dict(model_usage or {}), + **dict(extra_fields or {}), } - collector.record("ndd_workflow", **_ndd_workflow_fields(workflow_fields, observed_usage)) + collector.record(record_type, **_model_workflow_fields(workflow_fields, observed_usage)) -def _ndd_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: +def _model_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: return { **fields, **observed_usage, + "observed_failed_request_rate": _safe_ratio( + observed_usage["observed_failed_requests"], + observed_usage["observed_total_requests"], + ), **_throughput_fields( elapsed_sec=cast(float, fields["elapsed_sec"]), input_row_count=cast(int, fields["input_row_count"]), @@ -540,8 +620,9 @@ def record_record_metrics( strategy=strategy, ), **_entity_record_fields(row, final_entities=final_entities, ground_truth_column=ground_truth_column), - **_replacement_record_fields(row, columns=columns), + **_replacement_record_fields(row, columns=columns, final_entities=final_entities), **_rewrite_record_fields(row, columns=columns), + **_original_value_leak_record_fields(row, columns=columns, final_entities=final_entities), **_llm_record_fields( row, columns=columns, @@ -612,12 +693,22 @@ def _entity_record_fields( } -def _replacement_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: +def _replacement_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: from anonymizer.engine.constants import COL_REPLACEMENT_MAP if COL_REPLACEMENT_MAP not in columns: return {} - return _replacement_map_metrics(row.get(COL_REPLACEMENT_MAP)) + raw_map = row.get(COL_REPLACEMENT_MAP) + return { + **_replacement_map_metrics(raw_map), + **_replacement_coverage_metrics(raw_map, final_entities), + **_replacement_collision_metrics(raw_map, final_entities), + } def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: @@ -644,6 +735,67 @@ def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: } +def _original_value_leak_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: + output_column = _output_text_column(columns) + if output_column is None: + return {"original_value_leak_count": None, "original_value_leak_label_counts": {}} + output_text = str(row.get(output_column, "")) + leaked = [ + entity + for entity in final_entities + if entity.get("value") and _output_contains_original_value(output_text, str(entity.get("value"))) + ] + return { + "original_value_leak_count": len(leaked), + "original_value_leak_label_counts": dict( + sorted(Counter(str(entity.get("label") or "") for entity in leaked if entity.get("label")).items()) + ), + } + + +def _output_contains_original_value(output_text: str, value: str) -> bool: + if _needs_boundary_sensitive_leak_match(value): + return _contains_with_alnum_boundaries(output_text, value) + return value in output_text + + +def _needs_boundary_sensitive_leak_match(value: str) -> bool: + return len(value) <= 4 or value.isdigit() + + +def _contains_with_alnum_boundaries(output_text: str, value: str) -> bool: + start = 0 + while True: + match_start = output_text.find(value, start) + if match_start < 0: + return False + match_end = match_start + len(value) + if _has_alnum_boundaries(output_text, match_start, match_end): + return True + start = match_start + 1 + + +def _has_alnum_boundaries(text: str, start: int, end: int) -> bool: + before_is_alnum = start > 0 and text[start - 1].isalnum() + after_is_alnum = end < len(text) and text[end].isalnum() + return not before_is_alnum and not after_is_alnum + + +def _output_text_column(columns: set[str]) -> str | None: + from anonymizer.engine.constants import COL_REPLACED_TEXT, COL_REWRITTEN_TEXT + + if COL_REPLACED_TEXT in columns: + return COL_REPLACED_TEXT + if COL_REWRITTEN_TEXT in columns: + return COL_REWRITTEN_TEXT + return None + + def _llm_record_fields( row: Any, *, @@ -662,12 +814,14 @@ def _llm_record_fields( ) grouped_entity_count = _grouped_entity_count(row, columns=columns, final_entity_count=final_entity_count) repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) + replace_map_generation_uses_llm = _replace_map_generation_uses_llm(row, columns=columns) calls_by_stage = estimate_llm_calls_by_stage( mode=mode, strategy=strategy, has_grouped_entities=grouped_entity_count > 0, validation_chunk_count=validation_chunk_count, repair_iterations=repair_iterations, + replace_map_generation_uses_llm=replace_map_generation_uses_llm, ) total_estimated = ( sum(calls_by_stage.values()) if all(value is not None for value in calls_by_stage.values()) else None @@ -681,6 +835,15 @@ def _llm_record_fields( } +def _replace_map_generation_uses_llm(row: Any, *, columns: set[str]) -> bool: + from anonymizer.engine.constants import COL_REPLACEMENT_MAP_SOURCE + from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + + if COL_REPLACEMENT_MAP_SOURCE not in columns: + return True + return row.get(COL_REPLACEMENT_MAP_SOURCE) != REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + + def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: from anonymizer.engine.constants import COL_SEED_VALIDATION_CANDIDATES @@ -816,11 +979,12 @@ def estimate_llm_calls_by_stage( has_grouped_entities: bool, validation_chunk_count: int | None, repair_iterations: int = 0, + replace_map_generation_uses_llm: bool = True, ) -> dict[str, int | None]: """Estimate nominal model calls for one record, split by workflow stage.""" detection_calls = None if validation_chunk_count is None else 2 + validation_chunk_count replace_map_generation = 0 - if has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): + if replace_map_generation_uses_llm and has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): replace_map_generation = 1 if mode != "rewrite": @@ -978,16 +1142,7 @@ def _entity_identity_set(entities: list[dict[str, Any]]) -> set[tuple[str, str]] def _replacement_map_metrics(raw: object) -> dict[str, Any]: - payload = _coerce_payload(raw) - if isinstance(payload, Mapping): - replacements_raw = cast(Mapping[str, Any], payload).get("replacements") - replacements = replacements_raw if isinstance(replacements_raw, list) else [] - elif isinstance(payload, list): - replacements = payload - else: - replacements = [] - - replacement_maps = [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] + replacement_maps = _replacement_maps_from_raw(raw) synthetic_values = [] for item in replacement_maps: synthetic = item.get("replacement", item.get("synthetic")) @@ -1002,6 +1157,65 @@ def _replacement_map_metrics(raw: object) -> dict[str, Any]: } +def _replacement_coverage_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + replacement_original_values = { + str(original) + for item in _replacement_maps_from_raw(raw) + if (original := item.get("original")) is not None and str(original) + } + missing_entities = [ + entity + for entity in final_entities + if entity.get("value") and str(entity.get("value")) not in replacement_original_values + ] + missing_values = {str(entity.get("value")) for entity in missing_entities if entity.get("value")} + return { + "replacement_missing_final_entity_count": len(missing_entities), + "replacement_missing_final_entity_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in missing_entities if entity.get("label")).items() + ) + ), + "replacement_missing_final_value_count": len(missing_values), + } + + +def _replacement_collision_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + synthetic_values = { + str(synthetic) + for item in _replacement_maps_from_raw(raw) + if (synthetic := item.get("replacement", item.get("synthetic"))) is not None and str(synthetic) + } + collided_entities = [ + entity for entity in final_entities if entity.get("value") and str(entity.get("value")) in synthetic_values + ] + collided_values = {str(entity.get("value")) for entity in collided_entities if entity.get("value")} + return { + "replacement_synthetic_original_collision_count": len(collided_entities), + "replacement_synthetic_original_collision_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in collided_entities if entity.get("label")).items() + ) + ), + "replacement_synthetic_original_collision_value_count": len(collided_values), + } + + +def _replacement_maps_from_raw(raw: object) -> list[Mapping[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + replacements_raw = cast(Mapping[str, Any], payload).get("replacements") + tolist = getattr(replacements_raw, "tolist", None) + if callable(tolist): + replacements_raw = tolist() + replacements = replacements_raw if isinstance(replacements_raw, list) else [] + elif isinstance(payload, list): + replacements = payload + else: + replacements = [] + return [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] + + def _count_items(raw: object, *, primary_key: str, fallback_keys: tuple[str, ...] = ()) -> int: payload = _coerce_payload(raw) if isinstance(payload, Mapping): diff --git a/tests/engine/test_chunked_validation.py b/tests/engine/test_chunked_validation.py index f9b402a3..6eb3ed9e 100644 --- a/tests/engine/test_chunked_validation.py +++ b/tests/engine/test_chunked_validation.py @@ -447,6 +447,41 @@ def test_single_chunk_sends_single_chunk_tagged_text_not_windowed_excerpt(self) "around Alice/Bob would clip the suffix entirely." ) + def test_single_chunk_can_use_compact_excerpt_when_configured(self) -> None: + prefix = "START_ONLY_MARKER " + ("prefix filler " * 80) + suffix = (" suffix filler" * 80) + " END_ONLY_MARKER" + middle = "Alice met Bob." + text = prefix + middle + suffix + alice_start = len(prefix) + bob_start = alice_start + 10 + + spans = [ + _entity_span("a", "Alice", "first_name", alice_start, alice_start + 5), + _entity_span("b", "Bob", "first_name", bob_start, bob_start + 3), + ] + candidates = _candidates_schema( + ("a", "Alice", "first_name"), + ("b", "Bob", "first_name"), + ) + row = _build_row(text=text, seed_entities=spans, candidates=candidates) + + facade = FakeFacade("v0", response={"decisions": [{"id": "a", "decision": "keep"}]}) + params = ChunkedValidationParams( + pool=["v0"], + max_entities_per_call=10, + excerpt_window_chars=20, + single_chunk_full_text=False, + prompt_template=_MINIMAL_TEMPLATE, + ) + + chunked_validate_row(row, params, {"v0": facade}) + + prompt = facade.calls[0]["prompt"] + assert "Alice" in prompt + assert "Bob" in prompt + assert "START_ONLY_MARKER" not in prompt + assert "END_ONLY_MARKER" not in prompt + def test_empty_candidates_short_circuits_without_calls(self) -> None: row = _build_row(text="hello", seed_entities=[], candidates=_candidates_schema()) facade = FakeFacade("v0", response={"decisions": []}) diff --git a/tests/engine/test_detection_rules.py b/tests/engine/test_detection_rules.py new file mode 100644 index 00000000..d7640ed5 --- /dev/null +++ b/tests/engine/test_detection_rules.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter +from unittest.mock import Mock + +import pandas as pd +import pytest + +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TAGGED_TEXT, COL_TEXT +from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow +from anonymizer.engine.detection.rules import ( + STRUCTURED_RULE_FAST_LANE_LABELS, + SUPPORTED_RULE_LABELS, + detect_high_confidence_entities, +) +from anonymizer.engine.schemas import EntitiesSchema + +SHELL_TEXT = """$ curl -H 'Authorization: Bearer sk-test-AAAAAAAAAAAAAAAAAAAAAAAA' https://internal.example.test/api +$ export AWS_ACCESS_KEY_ID=AKIATEST1234567890FAKE +$ export AWS_SECRET_ACCESS_KEY=fakeSecretValue1234567890! +$ docker run -e DATABASE_URL='postgres://app_user:fakeDbPass123!@db.example.test:5432/app' -e API_KEY=ghp_FAKEtoken1234567890abcdef myapp:latest +$ ssh jane.doe@example.test@host-01.example.test +Password: fakeSshPass123! +""" + + +def test_detect_high_confidence_entities_extracts_shell_secret_values() -> None: + entities = detect_high_confidence_entities( + SHELL_TEXT, + labels=["api_key", "password", "email", "url"], + ) + + assert Counter(entity.label for entity in entities) == { + "api_key": 3, + "password": 2, + "email": 1, + "url": 2, + } + values_by_label = {(entity.label, entity.value) for entity in entities} + assert ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA") in values_by_label + assert ("api_key", "AKIATEST1234567890FAKE") in values_by_label + assert ("api_key", "ghp_FAKEtoken1234567890abcdef") in values_by_label + assert ("password", "fakeSecretValue1234567890!") in values_by_label + assert ("password", "fakeSshPass123!") in values_by_label + assert ("email", "jane.doe@example.test") in values_by_label + assert ("url", "https://internal.example.test/api") in values_by_label + assert ("url", "postgres://app_user:fakeDbPass123!@db.example.test:5432/app") in values_by_label + + values = [entity.value for entity in entities] + assert all(not value.startswith(("Authorization", "Bearer", "API_KEY=", "Password:")) for value in values) + + +def test_detect_high_confidence_entities_extracts_email_before_sentence_punctuation() -> None: + entities = detect_high_confidence_entities( + "Email alice@example.com. Then contact bob@example.co.uk, if needed.", + labels=["email"], + ) + + assert [entity.value for entity in entities] == ["alice@example.com", "bob@example.co.uk"] + + +def test_detect_high_confidence_entities_excludes_config_url_separators() -> None: + text = ( + "DATABASE_URL=postgres://svc_user:DbSecretPass2026!@db.example.test:5432/app; " + "endpoint: https://internal.example.test/admin;" + ) + + entities = detect_high_confidence_entities(text, labels=["url"]) + + assert [entity.value for entity in entities] == [ + "postgres://svc_user:DbSecretPass2026!@db.example.test:5432/app", + "https://internal.example.test/admin", + ] + + +def test_supported_rule_labels_match_detected_label_families() -> None: + assert SUPPORTED_RULE_LABELS == { + "api_key", + "date_of_birth", + "email", + "http_cookie", + "organization_name", + "password", + "pin", + "religious_belief", + "street_address", + "unique_id", + "url", + "user_name", + } + + +def test_structured_rule_fast_lane_excludes_narrow_prose_labels() -> None: + assert STRUCTURED_RULE_FAST_LANE_LABELS == { + "api_key", + "email", + "http_cookie", + "password", + "pin", + "unique_id", + "url", + "user_name", + } + assert {"date_of_birth", "organization_name", "religious_belief", "street_address"}.isdisjoint( + STRUCTURED_RULE_FAST_LANE_LABELS + ) + + +def test_detect_high_confidence_entities_respects_label_filter() -> None: + entities = detect_high_confidence_entities(SHELL_TEXT, labels=["password"]) + + assert Counter(entity.label for entity in entities) == {"password": 3} + assert {entity.value for entity in entities} == { + "fakeSecretValue1234567890!", + "fakeDbPass123!", + "fakeSshPass123!", + } + + +def test_detect_high_confidence_entities_extracts_sudo_stdin_password() -> None: + text = '$ echo "P@ssw0rd-local-2026!" | sudo -S systemctl restart nginx' + + entities = detect_high_confidence_entities(text, labels=["password"]) + + assert [(entity.label, entity.value) for entity in entities] == [("password", "P@ssw0rd-local-2026!")] + + +def test_detect_high_confidence_entities_does_not_treat_generic_echo_as_password() -> None: + text = '$ echo "P@ssw0rd-local-2026!" | grep local' + + assert detect_high_confidence_entities(text, labels=["password"]) == [] + + +def test_detect_high_confidence_entities_does_not_emit_secret_false_positives_for_prose() -> None: + prose = ( + "Alice Johnson filed Case No. 2025-CV-12345 in Superior Court. " + "The opinion cites Section 10(b), Exhibit A-17, and docket trace order_390974. " + "A biography says Jordan Patel joined NVIDIA in 2021 and later moved to Seattle." + ) + + entities = detect_high_confidence_entities(prose, labels=["api_key", "password", "email", "url"]) + + assert entities == [] + + +def test_detect_high_confidence_entities_extracts_contextual_date_of_birth() -> None: + text = "The applicant was born in 1978 and later moved to Berlin. Another report cites 2024." + + entities = detect_high_confidence_entities(text, labels=["date_of_birth"]) + + assert [(entity.label, entity.value) for entity in entities] == [("date_of_birth", "1978")] + + +def test_detect_high_confidence_entities_ignores_standalone_year_for_date_of_birth() -> None: + text = "The report cites filings from 1978, 2021, and 2024." + + assert detect_high_confidence_entities(text, labels=["date_of_birth"]) == [] + + +def test_detect_high_confidence_entities_extracts_narrow_prose_patterns() -> None: + text = ( + "After graduation he spent three years at NASA's Goddard Space Flight Center before joining a lab. " + "Idilio describes himself as secular and leans progressive on most political issues. " + "Outside the lab, Idilio shares a modest house on West Roberts Drive with his wife." + ) + + entities = detect_high_confidence_entities( + text, + labels=["organization_name", "religious_belief", "street_address"], + ) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("organization_name", "NASA's Goddard Space Flight Center"), + ("religious_belief", "secular"), + ("street_address", "West Roberts Drive"), + ] + + +def test_detect_high_confidence_entities_avoids_generic_prose_belief_false_positive() -> None: + text = "Jordan describes himself as careful and later worked at a local lab near Roberts Drive." + + assert ( + detect_high_confidence_entities( + text, + labels=["organization_name", "religious_belief", "street_address"], + ) + == [] + ) + + +def test_detect_high_confidence_entities_returns_sorted_non_overlapping_spans() -> None: + entities = detect_high_confidence_entities( + "token=sk-test-BBBBBBBBBBBBBBBBBBBBBBBB and Auth: ignored\nPassword: fakePass123!", + labels=["api_key", "password"], + ) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "sk-test-BBBBBBBBBBBBBBBBBBBBBBBB"), + ("password", "fakePass123!"), + ] + assert entities[0].end_position < entities[1].start_position + + +def test_detect_high_confidence_entities_extracts_session_id_assignments() -> None: + text = "Cookie: session_id=abc123xyz; auth_token=xoxb-STRUCTURED-Slack-Token-000000" + + entities = detect_high_confidence_entities(text, labels=["api_key"]) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "abc123xyz"), + ("api_key", "xoxb-STRUCTURED-Slack-Token-000000"), + ] + + +def test_detect_high_confidence_entities_extracts_structured_identifier_labels() -> None: + text = ( + "POST /audit HTTP/1.1\n" + "Cookie: session_id=abc123xyz; user_id=26762; auth_token=token-abcdef\n" + "trace-id: req_KA5k78XNwT0yUNZkPpwq\n" + "pin=97294\n" + "user_name=sloanenguy217\n" + ) + + entities = detect_high_confidence_entities( + text, + labels=["http_cookie", "pin", "unique_id", "user_name"], + ) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("http_cookie", "session_id=abc123xyz; user_id=26762; auth_token=token-abcdef"), + ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), + ("pin", "97294"), + ("user_name", "sloanenguy217"), + ] + + +def test_detect_high_confidence_entities_extracts_quoted_structured_identifier_keys() -> None: + text = '{"user": "avery_khan", "pin": "4921", "callback": "https://internal.example.test/admin"}' + + entities = detect_high_confidence_entities(text, labels=["pin", "url", "user_name"]) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("user_name", "avery_khan"), + ("pin", "4921"), + ("url", "https://internal.example.test/admin"), + ] + + +def test_detect_high_confidence_entities_excludes_cookie_sentence_punctuation() -> None: + text = "Cookie: session_id=abc123xyz; auth_token=token-abcdef. Recovery flow starts." + + entities = detect_high_confidence_entities(text, labels=["http_cookie"]) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("http_cookie", "session_id=abc123xyz; auth_token=token-abcdef"), + ] + + +def test_detect_high_confidence_entities_extracts_service_principal_user_and_tenant_id() -> None: + text = "$ az login --service-principal -u skylerlee985 -p fakePass123! --tenant trace-1b7278d77a73ef4e" + + entities = detect_high_confidence_entities(text, labels=["user_name", "unique_id"]) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("user_name", "skylerlee985"), + ("unique_id", "trace-1b7278d77a73ef4e"), + ] + + +def test_detect_high_confidence_entities_extracts_audit_user_and_trace_id() -> None: + text = "Audit record: user skylerlee985 opened session with trace-id req_KA5k78XNwT0yUNZkPpwq." + + entities = detect_high_confidence_entities(text, labels=["user_name", "unique_id"]) + + assert [(entity.label, entity.value) for entity in entities] == [ + ("user_name", "skylerlee985"), + ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), + ] + + +def test_detect_high_confidence_entities_does_not_extract_structured_identifiers_from_generic_prose() -> None: + text = "The order_390974 filing mentions user research, cookie recipes, and a five digit docket page." + + assert detect_high_confidence_entities(text, labels=["http_cookie", "pin", "unique_id", "user_name"]) == [] + + +def test_workflow_can_detect_with_high_confidence_rules_without_adapter_calls() -> None: + adapter = Mock() + workflow = EntityDetectionWorkflow(adapter=adapter) + + result = workflow.detect_with_high_confidence_rules( + pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), + entity_labels=["api_key", "password"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("password", "fakePass123!"), + ] + tagged_text = result.dataframe[COL_TAGGED_TEXT].iloc[0] + assert "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" in tagged_text + assert "fakePass123!" in tagged_text + assert result.failed_records == [] + + +def test_workflow_rule_detection_rejects_unsupported_labels() -> None: + workflow = EntityDetectionWorkflow(adapter=Mock()) + + with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): + workflow.detect_with_high_confidence_rules( + pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), + entity_labels=["api_key", "person"], + ) diff --git a/tests/engine/test_llm_replace_workflow.py b/tests/engine/test_llm_replace_workflow.py index f7abbc44..f6ae372b 100644 --- a/tests/engine/test_llm_replace_workflow.py +++ b/tests/engine/test_llm_replace_workflow.py @@ -330,6 +330,41 @@ def test_filter_replacement_map_anomaly_summaries_do_not_leak_pii( _assert_no_pii_in_logs(caplog, extra_secrets=("Acme Corp", "NovaCorp")) +def test_filter_replacement_map_repairs_synthetic_original_collisions_without_pii( + caplog: pytest.LogCaptureFixture, +) -> None: + """Synthetic values must not reuse another protected original from the same row.""" + parsed_entities = EntitiesByValueSchema.model_validate( + { + "entities_by_value": [ + {"value": "1979-01-01", "labels": ["date"]}, + {"value": "1980-02-02", "labels": ["date"]}, + ] + } + ) + raw_map = { + "replacements": [ + {"original": "1979-01-01", "label": "date", "synthetic": "1980-02-02"}, + {"original": "1980-02-02", "label": "date", "synthetic": "1991-03-04"}, + ] + } + + with caplog.at_level(logging.WARNING, logger="anonymizer"): + result = _filter_replacement_map_to_input_entities( + raw_map=raw_map, parsed_entities=parsed_entities, record_id="row-collision" + ) + + assert result == { + "replacements": [ + {"original": "1979-01-01", "label": "date", "synthetic": "[SUBSTITUTE_DATE_1]"}, + {"original": "1980-02-02", "label": "date", "synthetic": "1991-03-04"}, + ] + } + assert "synthetic-original collision" in caplog.text + assert "date" in caplog.text + _assert_no_pii_in_logs(caplog, extra_secrets=("1979-01-01", "1980-02-02", "1991-03-04")) + + def test_filter_replacement_map_empty_warning_does_not_leak_pii( caplog: pytest.LogCaptureFixture, ) -> None: diff --git a/tests/engine/test_structured_substitute.py b/tests/engine/test_structured_substitute.py new file mode 100644 index 00000000..83014cba --- /dev/null +++ b/tests/engine/test_structured_substitute.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pandas as pd +import pytest + +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP +from anonymizer.engine.replace import structured_substitute as structured_substitute_module +from anonymizer.engine.replace.structured_substitute import ( + apply_structured_substitution_maps, + build_structured_substitution_map, + structured_substitute_value, +) + + +@pytest.mark.parametrize( + ("label", "value"), + [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("date_of_birth", "1978-02-03"), + ("email", "alice@example.com"), + ("http_cookie", "session_id=abc123xyz; user_id=12345; auth_token=secret-token"), + ("organization_name", "Acme Research Center"), + ("password", "CorrectHorse!123"), + ("pin", "97294"), + ("religious_belief", "secular"), + ("street_address", "123 Maple Street"), + ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), + ("url", "https://staging.example.com/admin"), + ("user_name", "sloanenguy217"), + ], +) +def test_structured_substitute_value_does_not_preserve_original(label: str, value: str) -> None: + synthetic = structured_substitute_value(value, label) + + assert synthetic != value + assert value not in synthetic + + +def test_build_structured_substitution_map_for_entities_by_value() -> None: + raw_entities = { + "entities_by_value": [ + {"value": "alice@example.com", "labels": ["email"]}, + {"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "labels": ["api_key"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + replacements = replacement_map["replacements"] + assert [(item["original"], item["label"]) for item in replacements] == [ + ("alice@example.com", "email"), + ("sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "api_key"), + ] + serialized = str(replacement_map) + assert "alice@example.com" in serialized + assert "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" in serialized + assert all(item["original"] != item["synthetic"] for item in replacements) + + +def test_build_structured_substitution_map_avoids_other_original_values(monkeypatch: pytest.MonkeyPatch) -> None: + first_original = "alice@example.com" + second_original = "bob@example.com" + first_default_digest = structured_substitute_module._digest(value=first_original, label="email") + + def synthetic_email(value: str, digest: str) -> str: + if value == first_original and digest == first_default_digest: + return second_original + return f"user-{digest[:12]}@example.invalid" + + monkeypatch.setitem(structured_substitute_module._GENERATORS, "email", synthetic_email) + raw_entities = { + "entities_by_value": [ + {"value": first_original, "labels": ["email"]}, + {"value": second_original, "labels": ["email"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + replacements = {(item["original"], item["label"]): item["synthetic"] for item in replacement_map["replacements"]} + assert replacements[(first_original, "email")] != second_original + assert set(replacements.values()).isdisjoint({first_original, second_original}) + + +def test_build_structured_substitution_map_avoids_duplicate_synthetic_values( + monkeypatch: pytest.MonkeyPatch, +) -> None: + first_original = "Acme Research Center" + second_original = "Globex Research Center" + default_digests = { + structured_substitute_module._digest(value=first_original, label="organization_name"), + structured_substitute_module._digest(value=second_original, label="organization_name"), + } + + def synthetic_organization(_value: str, digest: str) -> str: + if digest in default_digests: + return "Northbridge Center" + return f"Helios {digest[:8]} Center" + + monkeypatch.setitem(structured_substitute_module._GENERATORS, "organization_name", synthetic_organization) + raw_entities = { + "entities_by_value": [ + {"value": first_original, "labels": ["organization_name"]}, + {"value": second_original, "labels": ["organization_name"]}, + ] + } + + replacement_map = build_structured_substitution_map(raw_entities) + + synthetics = [item["synthetic"] for item in replacement_map["replacements"]] + assert len(synthetics) == 2 + assert len(set(synthetics)) == 2 + assert "Northbridge Center" in synthetics + + +def test_structured_cookie_substitute_preserves_cookie_shape() -> None: + synthetic = structured_substitute_value( + "session_id=abc123xyz; user_id=12345; auth_token=secret-token", + "http_cookie", + ) + + assert "session_id=" in synthetic + assert "user_id=" in synthetic + assert "auth_token=" in synthetic + session_value = synthetic.split("session_id=", 1)[1].split(";", 1)[0] + assert len(session_value) >= 16 + assert not session_value.isdigit() + assert "abc123xyz" not in synthetic + assert "secret-token" not in synthetic + + +def test_structured_unique_id_substitute_preserves_uuid_shape() -> None: + synthetic = structured_substitute_value("1ce21179-998b-447b-2dee-3e8adb6afa35", "unique_id") + + assert synthetic.count("-") == 4 + assert synthetic != "1ce21179-998b-447b-2dee-3e8adb6afa35" + + +def test_build_structured_substitution_map_rejects_unsupported_labels() -> None: + raw_entities = {"entities_by_value": [{"value": "Alice", "labels": ["person"]}]} + + with pytest.raises(ValueError, match="unsupported labels: person"): + build_structured_substitution_map(raw_entities) + + +def test_apply_structured_substitution_maps_adds_replacement_map_column() -> None: + dataframe = pd.DataFrame( + {COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": "alice@example.com", "labels": ["email"]}]}]} + ) + + output = apply_structured_substitution_maps(dataframe) + + assert COL_REPLACEMENT_MAP in output.columns + replacement = output[COL_REPLACEMENT_MAP].iloc[0]["replacements"][0] + assert replacement["original"] == "alice@example.com" + assert replacement["label"] == "email" + assert replacement["synthetic"].endswith("@example.invalid") diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 7b4c5676..64f154d4 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -20,6 +20,7 @@ from data_designer.engine.models.utils import ChatMessage from data_designer.interface.data_designer import DataDesigner +import anonymizer.measurement as measurement from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect from anonymizer.config.models import DetectionModelSelection from anonymizer.config.replace_strategies import Redact @@ -32,6 +33,7 @@ COL_REPAIR_ITERATIONS, COL_REPLACED_TEXT, COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, COL_SEED_VALIDATION_CANDIDATES, COL_TEXT, COL_UTILITY_SCORE, @@ -40,6 +42,7 @@ from anonymizer.engine.detection.detection_workflow import EntityDetectionResult, EntityDetectionWorkflow from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow +from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import ( @@ -169,6 +172,149 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa assert record["observed_tokens_per_successful_request"] == 8 +def test_ndd_adapter_records_datadesigner_model_usage_by_alias_for_shared_model_names() -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme"], + RECORD_ID_COLUMN: ["record-a"], + } + ) + + class UsageStats: + has_usage = True + + def __init__(self, *, input_tokens: int, output_tokens: int, successful: int, failed: int) -> None: + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.successful = successful + self.failed = failed + + def model_dump(self, *, mode: str) -> dict[str, object]: + assert mode == "json" + return { + "token_usage": { + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "total_tokens": self.input_tokens + self.output_tokens, + }, + "request_usage": { + "successful_requests": self.successful, + "failed_requests": self.failed, + "total_requests": self.successful + self.failed, + }, + } + + class ModelRegistry: + def __init__(self) -> None: + self._models = { + "validator": SimpleNamespace( + model_alias="validator", + model_name="shared-model", + model_provider_name="local-vllm", + usage_stats=UsageStats(input_tokens=12, output_tokens=4, successful=2, failed=1), + ), + "augmenter": SimpleNamespace( + model_alias="augmenter", + model_name="shared-model", + model_provider_name="local-vllm", + usage_stats=UsageStats(input_tokens=20, output_tokens=8, successful=1, failed=0), + ), + } + + def get_model_usage_snapshot(self) -> dict[str, UsageStats]: + return { + "shared-model": UsageStats(input_tokens=999, output_tokens=999, successful=99, failed=99), + } + + class UsageDataDesigner: + def _create_resource_provider(self, *_args: object, **_kwargs: object) -> SimpleNamespace: + return SimpleNamespace(model_registry=ModelRegistry()) + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + self._create_resource_provider("preview-dataset", _config_builder) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, UsageDataDesigner())) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="validator", model="shared-model")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="validator", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + record = next(record for record in collector.records if record["record_type"] == "ndd_workflow") + assert sorted(record["model_usage"]) == ["augmenter", "validator"] + assert record["model_usage"]["validator"]["model_alias"] == "validator" + assert record["model_usage"]["validator"]["model_name"] == "shared-model" + assert record["model_usage"]["validator"]["model_provider_name"] == "local-vllm" + assert record["model_usage"]["validator"]["token_usage"]["input_tokens"] == 12 + assert record["model_usage"]["augmenter"]["token_usage"]["input_tokens"] == 20 + assert record["observed_input_tokens"] == 32 + assert record["observed_output_tokens"] == 12 + assert record["observed_total_tokens"] == 44 + assert record["observed_successful_requests"] == 3 + assert record["observed_failed_requests"] == 1 + assert record["observed_total_requests"] == 4 + + +def test_records_generic_model_workflow_usage_without_raw_text() -> None: + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + assert hasattr(measurement, "record_model_workflow") + measurement.record_model_workflow( + workflow_name="entity-detection-native-rules-router", + model_aliases=["native-direct"], + input_row_count=1, + output_row_count=1, + failed_record_count=0, + elapsed_sec=0.25, + model_usage={ + "native-direct": { + "model_alias": "native-direct", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 3, + "failed_requests": 0, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + }, + }, + }, + ) + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-rules-router" + assert record["model_aliases"] == ["native-direct"] + assert record["observed_total_requests"] == 3 + assert record["observed_input_tokens"] == 30 + assert record["observed_output_tokens"] == 12 + assert record["observed_total_tokens"] == 42 + assert record["observed_failed_request_rate"] == 0 + assert record["observed_tokens_per_successful_request"] == 14 + + serialized = json.dumps(record) + assert "Alice" not in serialized + assert "sk-test" not in serialized + + def test_anonymizer_records_per_record_measurement_without_raw_pii(tmp_path: Path) -> None: input_csv = tmp_path / "input.csv" pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_csv, index=False) @@ -235,6 +381,8 @@ def test_anonymizer_records_per_record_measurement_without_raw_pii(tmp_path: Pat assert record["final_entity_label_counts"] == {"company_name": 1, "first_name": 1} assert record["detected_candidate_count"] == 2 assert record["validation_chunk_count"] == 1 + assert record["original_value_leak_count"] == 0 + assert record["original_value_leak_label_counts"] == {} assert record["llm_calls_estimated_by_stage"] == { "entity_detection": 3, "replace_map_generation": 0, @@ -655,6 +803,12 @@ def test_record_metrics_capture_generic_counts_without_raw_values() -> None: assert record["replacement_count"] == 2 assert record["replacement_label_counts"] == {"company_name": 1, "first_name": 1} assert record["replacement_duplicate_value_count"] == 1 + assert record["replacement_missing_final_entity_count"] == 0 + assert record["replacement_missing_final_entity_label_counts"] == {} + assert record["replacement_missing_final_value_count"] == 0 + assert record["replacement_synthetic_original_collision_count"] == 0 + assert record["replacement_synthetic_original_collision_label_counts"] == {} + assert record["replacement_synthetic_original_collision_value_count"] == 0 assert record["repair_iterations"] == 2 assert record["utility_score"] == 0.82 assert record["leakage_mass"] == 0.2 @@ -670,6 +824,196 @@ def test_record_metrics_capture_generic_counts_without_raw_values() -> None: assert "Maya" not in serialized +def test_record_metrics_counts_missing_replacement_map_entries_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "2030-01-01", "label": "date", "start_position": 13, "end_position": 23}, + {"value": "2030-01-01", "label": "date", "start_position": 27, "end_position": 37}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice filed 2030-01-01 and 2030-01-01"], + COL_FINAL_ENTITIES: [final_entities], + COL_REPLACEMENT_MAP: [replacement_map], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["replacement_missing_final_entity_count"] == 2 + assert record["replacement_missing_final_entity_label_counts"] == {"date": 2} + assert record["replacement_missing_final_value_count"] == 1 + assert record["replacement_synthetic_original_collision_count"] == 0 + assert record["replacement_synthetic_original_collision_label_counts"] == {} + assert record["replacement_synthetic_original_collision_value_count"] == 0 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "2030-01-01" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_counts_synthetic_original_collisions_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "2030-01-01", "label": "date", "start_position": 13, "end_position": 23}, + ] + } + replacement_map = { + "replacements": [ + {"original": "Alice", "label": "first_name", "synthetic": "Maya"}, + {"original": "2029-12-01", "label": "date", "synthetic": "2030-01-01"}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice filed 2030-01-01"], + COL_FINAL_ENTITIES: [final_entities], + COL_REPLACEMENT_MAP: [replacement_map], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["replacement_synthetic_original_collision_count"] == 1 + assert record["replacement_synthetic_original_collision_label_counts"] == {"date": 1} + assert record["replacement_synthetic_original_collision_value_count"] == 1 + + serialized = json.dumps(collector.records) + assert "Alice" not in serialized + assert "2030-01-01" not in serialized + assert "2029-12-01" not in serialized + assert "Maya" not in serialized + + +def test_record_metrics_counts_original_value_replacement_leaks_without_raw_values() -> None: + leaked_key = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + dataframe = pd.DataFrame( + { + COL_TEXT: [f"token={leaked_key}"], + COL_REPLACED_TEXT: [f"still token={leaked_key}"], + COL_FINAL_ENTITIES: [{"entities": [{"value": leaked_key, "label": "api_key"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 1 + assert record["original_value_leak_label_counts"] == {"api_key": 1} + assert leaked_key not in json.dumps(collector.records) + + +def test_record_metrics_ignores_short_value_inside_hash_replacement_token() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice is 34 years old."], + COL_REPLACED_TEXT: ["Alice is years old."], + COL_FINAL_ENTITIES: [{"entities": [{"value": "34", "label": "age"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 0 + assert record["original_value_leak_label_counts"] == {} + + +def test_record_metrics_counts_standalone_short_value_replacement_leaks() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice is 34 years old."], + COL_REPLACED_TEXT: ["Alice is 34 years old."], + COL_FINAL_ENTITIES: [{"entities": [{"value": "34", "label": "age"}]}], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Hash", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["original_value_leak_count"] == 1 + assert record["original_value_leak_label_counts"] == {"age": 1} + + +def test_record_metrics_do_not_estimate_llm_replace_map_call_for_local_structured_source() -> None: + dataframe = pd.DataFrame( + { + COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"], + COL_FINAL_ENTITIES: [{"entities": [{"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "label": "api_key"}]}], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACEMENT_MAP_SOURCE: [REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Substitute", + text_column=COL_TEXT, + validation_max_entities_per_call=100, + ) + + record = collector.records[0] + assert record["llm_calls_estimated_by_stage"] == { + "entity_detection": None, + "replace_map_generation": 0, + } + assert record["llm_calls_estimated_total"] is None + + def test_record_metrics_normalizes_integral_row_index_types() -> None: dataframe = pd.DataFrame( { diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py new file mode 100644 index 00000000..44ba8358 --- /dev/null +++ b/tests/tools/test_benchmark_output_analysis.py @@ -0,0 +1,907 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "ndd_workflow", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "elapsed_sec": 8.5, + "observed_total_requests": 4, + "observed_successful_requests": 3, + "observed_input_tokens": 5000, + "observed_output_tokens": 1000, + "observed_total_tokens": 6000, + "observed_failed_requests": 1, + "model_usage": { + "nvidia/gliner-pii": { + "request_usage": { + "successful_requests": 1, + "failed_requests": 0, + "total_requests": 1, + }, + "token_usage": { + "input_tokens": 1000, + "output_tokens": 100, + "total_tokens": 1100, + }, + }, + "local-nemotron-json": { + "model_alias": "local-nemotron-json", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 2, + "failed_requests": 1, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 4000, + "output_tokens": 900, + "total_tokens": 4900, + }, + }, + }, + "detect": {"validation_max_entities_per_call": 10}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r000", + "final_entity_count": 14, + "replacement_count": 12, + "replacement_missing_final_entity_count": 2, + "replacement_missing_final_entity_label_counts": {"date": 2}, + "replacement_missing_final_value_count": 1, + "replacement_synthetic_original_collision_count": 1, + "replacement_synthetic_original_collision_label_counts": {"date": 1}, + "replacement_synthetic_original_collision_value_count": 1, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "shell__rules-only__r000", + "final_entity_count": 8, + "replacement_count": 8, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_entity_label_counts": {}, + "replacement_missing_final_value_count": 0, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_label_counts": {}, + "replacement_synthetic_original_collision_value_count": 0, + "original_value_leak_count": 1, + "original_value_leak_label_counts": {"api_key": 1}, + "run_tags": { + "suite_id": "suite", + "workload_id": "shell", + "config_id": "rules-only", + "experimental_detection_strategy": "rules_only", + "experimental_replacement_strategy": "local_structured_substitute", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "shell__rules-only__r000", + }, + }, + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "repetition": 0, + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "seed_entity_count": 13, + "seed_validation_candidate_count": 13, + "augmented_entity_count": 1, + "augmented_new_final_value_count": 1, + "final_entity_count": 14, + "final_source_counts": {"detector": 11, "augmenter": 3}, + "final_entity_signature_hashes": ["bio-hash-a", "bio-hash-b"], + "final_entity_signature_labels": {"bio-hash-a": "person", "bio-hash-b": "city"}, + "final_entity_signature_details": { + "bio-hash-a": { + "label": "person", + "source": "detector", + "row_index": 0, + "start_position": 0, + "end_position": 5, + "value_hash": "hash-person", + "value_length": 5, + } + }, + "final_entity_signature_count": 2, + }, + { + "suite_id": "suite", + "workload_id": "shell", + "config_id": "rules-only", + "repetition": 0, + "case_id": "shell__rules-only__r000", + "run_id": "shell__rules-only__r000", + "workflow_name": "rules-only", + "seed_entity_count": 8, + "seed_validation_candidate_count": 0, + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "final_entity_count": 8, + "final_source_counts": {"rule": 8}, + "final_entity_signature_hashes": ["shell-hash-a"], + "final_entity_signature_labels": {"shell-hash-a": "api_key"}, + "final_entity_signature_details": { + "shell-hash-a": { + "label": "api_key", + "source": "rule", + "row_index": 0, + "start_position": 12, + "end_position": 32, + "value_hash": "hash-secret", + "value_length": 20, + } + }, + "final_entity_signature_count": 1, + }, + ], + ) + traces_dir = benchmark_dir / "traces" + traces_dir.mkdir() + _write_jsonl( + traces_dir / "bio__default__r000.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "model_alias": "local-nemotron-json", + "status": "error", + "error_type": "SyncClientUnavailableError", + "is_async": False, + "messages": [{"role": "user", "content": "Alice has sk-test"}], + "response": "Alice still has sk-test", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "dd_message_trace", + "run_id": "bio__default__r000", + "workflow_name": "entity-detection", + "model_alias": "local-nemotron-json", + "status": "success", + "is_async": True, + "messages": [{"role": "user", "content": "sk-test"}], + "response": "Alice", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.case_count == 2 + assert result.group_count == 2 + assert result.model_usage_count == 2 + assert result.model_usage_group_count == 2 + cases = {row.case_id: row for row in result.cases} + assert cases["bio__default__r000"].experimental_replacement_strategy == "default" + assert cases["bio__default__r000"].observed_total_requests == 4 + assert cases["bio__default__r000"].observed_successful_requests == 3 + assert cases["bio__default__r000"].observed_input_tokens == 5000 + assert cases["bio__default__r000"].observed_output_tokens == 1000 + assert cases["bio__default__r000"].observed_total_tokens == 6000 + assert cases["bio__default__r000"].observed_failed_request_rate == pytest.approx(1 / 4) + assert cases["bio__default__r000"].dd_trace_record_count == 2 + assert cases["bio__default__r000"].dd_trace_error_count == 1 + assert cases["bio__default__r000"].dd_trace_sync_client_unavailable_count == 1 + assert cases["bio__default__r000"].observed_bridge_fallback_requests == 1 + assert cases["bio__default__r000"].observed_non_bridge_total_requests == 3 + assert cases["bio__default__r000"].observed_non_bridge_failed_requests == 0 + assert cases["bio__default__r000"].observed_non_bridge_failed_request_rate == 0 + assert cases["bio__default__r000"].validation_max_entities_per_call == 10 + assert cases["bio__default__r000"].original_value_leak_count == 0 + assert cases["bio__default__r000"].original_value_leak_record_count == 0 + assert cases["bio__default__r000"].original_value_leak_label_counts == {} + assert cases["bio__default__r000"].replacement_missing_final_entity_count == 2 + assert cases["bio__default__r000"].replacement_missing_final_entity_label_counts == {"date": 2} + assert cases["bio__default__r000"].replacement_missing_final_value_count == 1 + assert cases["bio__default__r000"].replacement_synthetic_original_collision_count == 1 + assert cases["bio__default__r000"].replacement_synthetic_original_collision_label_counts == {"date": 1} + assert cases["bio__default__r000"].replacement_synthetic_original_collision_value_count == 1 + assert cases["bio__default__r000"].seed_validation_candidate_count == 13 + assert cases["bio__default__r000"].estimated_seed_validation_chunk_count == 2 + assert cases["bio__default__r000"].augmented_new_final_value_count == 1 + assert cases["bio__default__r000"].artifact_final_detector_entity_count == 11 + assert cases["bio__default__r000"].artifact_final_augmenter_entity_count == 3 + assert cases["bio__default__r000"].artifact_final_entity_signature_hashes == ["bio-hash-a", "bio-hash-b"] + assert cases["bio__default__r000"].artifact_final_entity_signature_labels == { + "bio-hash-a": "person", + "bio-hash-b": "city", + } + assert cases["bio__default__r000"].artifact_final_entity_signature_details == { + "bio-hash-a": { + "label": "person", + "source": "detector", + "row_index": 0, + "start_position": 0, + "end_position": 5, + "value_hash": "hash-person", + "value_length": 5, + } + } + assert cases["bio__default__r000"].artifact_final_entity_signature_count == 2 + assert cases["shell__rules-only__r000"].observed_total_requests == 0 + assert cases["shell__rules-only__r000"].experimental_replacement_strategy == "local_structured_substitute" + assert cases["shell__rules-only__r000"].observed_failed_request_rate is None + assert cases["shell__rules-only__r000"].observed_bridge_fallback_requests is None + assert cases["shell__rules-only__r000"].observed_non_bridge_failed_requests is None + assert cases["shell__rules-only__r000"].final_entity_count == 8 + assert cases["shell__rules-only__r000"].replacement_missing_final_entity_count == 0 + assert cases["shell__rules-only__r000"].replacement_missing_final_entity_label_counts == {} + assert cases["shell__rules-only__r000"].replacement_missing_final_value_count == 0 + assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_count == 0 + assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_label_counts == {} + assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_value_count == 0 + assert cases["shell__rules-only__r000"].original_value_leak_count == 1 + assert cases["shell__rules-only__r000"].original_value_leak_record_count == 1 + assert cases["shell__rules-only__r000"].original_value_leak_label_counts == {"api_key": 1} + assert cases["shell__rules-only__r000"].artifact_final_rule_entity_count == 8 + assert cases["shell__rules-only__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] + assert cases["shell__rules-only__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} + assert cases["shell__rules-only__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "rule" + model_rows = {row.model_name: row for row in result.model_usage} + assert model_rows["nvidia/gliner-pii"].observed_failed_requests == 0 + assert model_rows["nvidia/gliner-pii"].observed_failed_request_rate == 0 + assert model_rows["nvidia/gliner-pii"].observed_total_tokens == 1100 + assert model_rows["nvidia/nemotron-3-super"].model_alias == "local-nemotron-json" + assert model_rows["nvidia/nemotron-3-super"].experimental_replacement_strategy == "default" + assert model_rows["nvidia/nemotron-3-super"].model_provider_name == "local-vllm" + assert model_rows["nvidia/nemotron-3-super"].observed_failed_requests == 1 + assert model_rows["nvidia/nemotron-3-super"].observed_failed_request_rate == pytest.approx(1 / 3) + assert model_rows["nvidia/nemotron-3-super"].observed_total_tokens == 4900 + model_groups = {(row.model_alias, row.model_name): row for row in result.model_usage_groups} + nemotron_group = model_groups[("local-nemotron-json", "nvidia/nemotron-3-super")] + assert nemotron_group.model_provider_name == "local-vllm" + assert nemotron_group.sum_observed_failed_requests == 1 + assert nemotron_group.observed_failed_request_rate == pytest.approx(1 / 3) + assert nemotron_group.median_observed_total_requests == 3 + bio_group = next(group for group in result.groups if group.workload_id == "bio") + assert bio_group.experimental_replacement_strategy == "default" + assert bio_group.median_observed_bridge_fallback_requests == 1 + assert bio_group.median_observed_non_bridge_total_requests == 3 + assert bio_group.median_observed_non_bridge_failed_requests == 0 + assert bio_group.median_observed_non_bridge_failed_request_rate == 0 + assert bio_group.median_replacement_missing_final_entity_count == 2 + assert bio_group.median_replacement_missing_final_value_count == 1 + assert bio_group.replacement_missing_final_entity_label_counts == {"date": 2} + assert bio_group.median_replacement_synthetic_original_collision_count == 1 + assert bio_group.median_replacement_synthetic_original_collision_value_count == 1 + assert bio_group.replacement_synthetic_original_collision_label_counts == {"date": 1} + shell_group = next(group for group in result.groups if group.workload_id == "shell") + assert shell_group.experimental_replacement_strategy == "local_structured_substitute" + assert shell_group.sum_original_value_leak_count == 1 + assert shell_group.leaking_case_count == 1 + assert shell_group.median_original_value_leak_count == 1 + + serialized = result.model_dump_json() + assert "Alice" not in serialized + assert "sk-test" not in serialized + + +def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_model_workflow", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "model_workflow", + "run_id": "bio__native__r000", + "workflow_name": "entity-detection-native-rules-router", + "elapsed_sec": 0.25, + "observed_total_requests": 3, + "observed_successful_requests": 3, + "observed_failed_requests": 0, + "observed_input_tokens": 30, + "observed_output_tokens": 12, + "observed_total_tokens": 42, + "model_usage": { + "native-direct": { + "model_alias": "native-direct", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "request_usage": { + "successful_requests": 3, + "failed_requests": 0, + "total_requests": 3, + }, + "token_usage": { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + }, + } + }, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "native", + "experimental_detection_strategy": "native_rules_router", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__native__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__native__r000", + "final_entity_count": 2, + "replacement_count": 2, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "native", + "experimental_detection_strategy": "native_rules_router", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "repetition": 0, + "case_id": "bio__native__r000", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.case_count == 1 + case = result.cases[0] + assert case.observed_total_requests == 3 + assert case.observed_total_tokens == 42 + assert case.observed_failed_request_rate == 0 + assert result.model_usage_count == 1 + model_row = result.model_usage[0] + assert model_row.workflow_name == "entity-detection-native-rules-router" + assert model_row.model_alias == "native-direct" + assert model_row.model_name == "nvidia/nemotron-3-super" + assert model_row.observed_total_tokens == 42 + assert result.groups[0].median_observed_total_requests == 3 + assert result.model_usage_groups[0].sum_observed_total_tokens == 42 + + +def test_analyze_benchmark_output_accepts_detection_artifact_override(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_override", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r000", + }, + } + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "final_entity_signature_hashes": ["stale-hash"], + "final_entity_signature_count": 1, + } + ], + ) + refreshed_artifacts = tmp_path / "refreshed-detection-artifacts.jsonl" + _write_jsonl( + refreshed_artifacts, + [ + { + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "final_entity_count": 2, + "final_entity_signature_hashes": ["fresh-hash-a", "fresh-hash-b"], + "final_entity_signature_labels": {"fresh-hash-a": "person", "fresh-hash-b": "email"}, + "final_entity_signature_count": 2, + } + ], + ) + + default_result = tool.analyze_benchmark_output(benchmark_dir) + override_result = tool.analyze_benchmark_output(benchmark_dir, detection_artifacts=refreshed_artifacts) + + assert default_result.cases[0].artifact_final_entity_signature_hashes == ["stale-hash"] + assert override_result.detection_artifacts_path == str(refreshed_artifacts) + assert override_result.cases[0].artifact_final_entity_signature_hashes == ["fresh-hash-a", "fresh-hash-b"] + assert override_result.cases[0].artifact_final_entity_signature_labels == { + "fresh-hash-a": "person", + "fresh-hash-b": "email", + } + + +def test_analyze_benchmark_output_requires_detection_artifact_override_path(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_override_missing", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "run_tags": {"case_id": "bio__default__r000"}, + } + ], + ) + + with pytest.raises(ValueError, match="input path does not exist"): + tool.analyze_benchmark_output(benchmark_dir, detection_artifacts=tmp_path / "missing.jsonl") + + +def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_export", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + result = tool.BenchmarkOutputAnalysis( + benchmark_dir=str(tmp_path / "benchmark"), + cases=[ + tool.CaseAnalysisRow( + suite_id="suite", + workload_id="shell", + config_id="rules", + experimental_detection_strategy="rules_only", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + repetition=0, + case_id="shell__rules__r000", + run_id="shell__rules__r000", + final_entity_count=8, + ) + ], + groups=[ + tool.GroupAnalysisRow( + workload_id="shell", + config_id="rules", + experimental_detection_strategy="rules_only", + experimental_replacement_strategy="local_structured_substitute", + case_count=1, + median_final_entity_count=8, + median_observed_successful_requests=0, + median_observed_input_tokens=0, + median_observed_output_tokens=0, + median_observed_failed_request_rate=0, + median_artifact_final_entity_count=8, + median_artifact_final_rule_entity_count=8, + ) + ], + model_usage=[ + tool.ModelUsageAnalysisRow( + workload_id="shell", + config_id="rules", + experimental_detection_strategy="rules_only", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + case_id="shell__rules__r000", + run_id="shell__rules__r000", + workflow_name="entity-detection", + model_name="nvidia/gliner-pii", + observed_total_requests=1, + observed_successful_requests=1, + observed_total_tokens=1200, + ) + ], + model_usage_groups=[ + tool.ModelUsageGroupAnalysisRow( + workload_id="shell", + config_id="rules", + experimental_detection_strategy="rules_only", + experimental_replacement_strategy="local_structured_substitute", + dd_parser_compat="raw_json", + workflow_name="entity-detection", + model_name="nvidia/gliner-pii", + case_count=1, + workflow_count=1, + sum_observed_total_requests=1, + sum_observed_successful_requests=1, + sum_observed_total_tokens=1200, + median_observed_total_requests=1, + median_observed_total_tokens=1200, + ) + ], + ) + + output_dir = tmp_path / "tables" + tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) + + assert pd.read_csv(output_dir / "case_analysis.csv")["case_id"].tolist() == ["shell__rules__r000"] + assert pd.read_csv(output_dir / "case_analysis.csv")["experimental_replacement_strategy"].tolist() == [ + "local_structured_substitute" + ] + assert pd.read_csv(output_dir / "group_analysis.csv")["case_count"].tolist() == [1] + assert pd.read_csv(output_dir / "model_analysis.csv")["model_name"].tolist() == ["nvidia/gliner-pii"] + assert pd.read_csv(output_dir / "model_group_analysis.csv")["workflow_count"].tolist() == [1] + assert (output_dir / "manifest.json").exists() + + +def test_analyze_benchmark_output_preserves_zero_entity_cases(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_zero", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "empty__redact__r000", + "final_entity_count": 0, + "replacement_count": 0, + "run_tags": { + "workload_id": "empty", + "config_id": "redact", + "experimental_detection_strategy": "default", + "case_id": "empty__redact__r000", + }, + } + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.cases[0].final_entity_count == 0 + assert result.groups[0].median_final_entity_count == 0 + + +def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_replacement_strategy_groups", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "secrets__candidate__r000", + "final_entity_count": 4, + "run_tags": { + "workload_id": "secrets", + "config_id": "candidate", + "experimental_detection_strategy": "rules_covered_or_default", + "experimental_replacement_strategy": "default", + "case_id": "secrets__candidate__r000", + }, + }, + { + "record_type": "record", + "run_id": "secrets__candidate__r001", + "final_entity_count": 4, + "run_tags": { + "workload_id": "secrets", + "config_id": "candidate", + "experimental_detection_strategy": "rules_covered_or_default", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "secrets__candidate__r001", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + assert result.group_count == 2 + assert {group.experimental_replacement_strategy for group in result.groups} == { + "default", + "local_structured_substitute", + } + + +def test_analyze_benchmark_output_surfaces_route_counts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_route_counts", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "model_workflow", + "run_id": "mixed__router__r000", + "workflow_name": "entity-detection-rules-covered-router", + "status": "completed", + "input_row_count": 2, + "output_row_count": 2, + "failed_record_count": 0, + "elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_successful_requests": 0, + "observed_failed_requests": 0, + "observed_input_tokens": 0, + "observed_output_tokens": 0, + "observed_total_tokens": 0, + "route_total_row_count": 2, + "route_rule_row_count": 1, + "route_fallback_row_count": 1, + "run_tags": { + "workload_id": "mixed", + "config_id": "router", + "experimental_detection_strategy": "rules_covered_or_default", + "experimental_replacement_strategy": "default", + "case_id": "mixed__router__r000", + }, + } + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + case = result.cases[0] + assert case.route_total_row_count == 2 + assert case.route_rule_row_count == 1 + assert case.route_fallback_row_count == 1 + group = result.groups[0] + assert group.median_route_total_row_count == 2 + assert group.median_route_rule_row_count == 1 + assert group.median_route_fallback_row_count == 1 + + +def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_failures", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "stage", + "run_id": "shell__candidate__r000", + "stage": "Anonymizer._run_internal", + "status": "completed", + "elapsed_sec": 1.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "repetition": 0, + "case_id": "shell__candidate__r000", + }, + }, + { + "record_type": "ndd_workflow", + "run_id": "shell__candidate__r001", + "workflow_name": "entity-detection", + "status": "error", + "elapsed_sec": 0.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "repetition": 1, + "case_id": "shell__candidate__r001", + }, + }, + { + "record_type": "stage", + "run_id": "shell__candidate__r001", + "stage": "Anonymizer._run_internal", + "status": "error", + "elapsed_sec": 0.2, + "run_tags": { + "workload_id": "shell", + "config_id": "candidate", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "repetition": 1, + "case_id": "shell__candidate__r001", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + cases = {row.case_id: row for row in result.cases} + assert cases["shell__candidate__r000"].case_failed is False + assert cases["shell__candidate__r000"].error_stage_count == 0 + assert cases["shell__candidate__r000"].error_ndd_workflow_count == 0 + assert cases["shell__candidate__r001"].case_failed is True + assert cases["shell__candidate__r001"].error_stage_count == 1 + assert cases["shell__candidate__r001"].error_ndd_workflow_count == 1 + assert result.groups[0].failed_case_count == 1 + assert result.groups[0].failed_case_rate == pytest.approx(0.5) + assert result.groups[0].error_stage_count == 1 + assert result.groups[0].error_ndd_workflow_count == 1 + assert "failed_cases=1/2" in tool.render_result(result, json_output=False) + + +def test_analyze_benchmark_output_groups_artifact_contribution_metrics(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_artifact_group", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "record", + "run_id": "bio__default__r000", + "final_entity_count": 10, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r001", + "final_entity_count": 14, + "run_tags": { + "workload_id": "bio", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default__r001", + }, + }, + ], + ) + _write_jsonl( + benchmark_dir / "detection-artifacts.jsonl", + [ + { + "workload_id": "bio", + "config_id": "default", + "case_id": "bio__default__r000", + "run_id": "bio__default__r000", + "seed_entity_count": 9, + "seed_validation_candidate_count": 9, + "augmented_entity_count": 4, + "augmented_new_final_value_count": 1, + "final_entity_count": 11, + "final_source_counts": {"detector": 10, "augmenter": 1}, + "final_entity_signature_hashes": ["a", "b"], + "final_entity_signature_count": 2, + }, + { + "workload_id": "bio", + "config_id": "default", + "case_id": "bio__default__r001", + "run_id": "bio__default__r001", + "seed_entity_count": 13, + "seed_validation_candidate_count": 13, + "augmented_entity_count": 8, + "augmented_new_final_value_count": 3, + "final_entity_count": 15, + "final_source_counts": {"detector": 12, "augmenter": 3}, + "final_entity_signature_hashes": ["a", "b", "c", "d"], + "final_entity_signature_count": 4, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + group = result.groups[0] + assert group.median_final_entity_count == 12 + assert group.median_observed_successful_requests == 0 + assert group.median_observed_input_tokens == 0 + assert group.median_observed_output_tokens == 0 + assert group.median_observed_failed_request_rate is None + assert group.median_seed_entity_count == 11 + assert group.median_seed_validation_candidate_count == 11 + assert group.median_augmented_entity_count == 6 + assert group.median_augmented_new_final_value_count == 2 + assert group.median_artifact_final_entity_count == 13 + assert group.median_artifact_final_detector_entity_count == 11 + assert group.median_artifact_final_augmenter_entity_count == 2 + assert group.median_artifact_final_entity_signature_count == 3 diff --git a/tests/tools/test_compare_strategy_pairs.py b/tests/tools/test_compare_strategy_pairs.py new file mode 100644 index 00000000..be9f455a --- /dev/null +++ b/tests/tools/test_compare_strategy_pairs.py @@ -0,0 +1,1474 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> None: + tool = load_tool("measurement_compare_strategy_pairs", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py") + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "shell-no-augment", + "experimental_detection_strategy": "no_augment", + "experimental_replacement_strategy": "default", + "case_id": "shell__base", + "pipeline_elapsed_sec": 4.5, + "observed_total_requests": 3, + "observed_total_tokens": 2875, + "observed_failed_requests": 1, + "final_entity_count": 4, + "seed_validation_candidate_count": 5, + "augmented_entity_count": 2, + "augmented_new_final_value_count": 1, + "artifact_final_detector_entity_count": 4, + }, + { + "workload_id": "shell-1", + "config_id": "shell-filter", + "experimental_detection_strategy": "rules_filter_guardrail_no_augment", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 0.8, + "observed_total_requests": 1, + "observed_total_tokens": 101, + "observed_failed_requests": 0, + "final_entity_count": 8, + "seed_validation_candidate_count": 0, + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "artifact_final_rule_entity_count": 8, + }, + { + "workload_id": "legal-1", + "config_id": "legal-no-augment", + "experimental_detection_strategy": "no_augment", + "experimental_replacement_strategy": "default", + "case_id": "legal__base", + "pipeline_elapsed_sec": 21, + "observed_total_requests": 3, + "observed_total_tokens": 8847, + "observed_failed_requests": 1, + "final_entity_count": 26, + "seed_validation_candidate_count": 40, + "artifact_final_detector_entity_count": 26, + }, + { + "workload_id": "legal-1", + "config_id": "legal-filter", + "experimental_detection_strategy": "rules_filter_guardrail_no_augment", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "legal__candidate", + "pipeline_elapsed_sec": 20.9, + "observed_total_requests": 3, + "observed_total_tokens": 8847, + "observed_failed_requests": 1, + "final_entity_count": 26, + "seed_validation_candidate_count": 40, + "artifact_final_detector_entity_count": 26, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_strategy="no_augment", + candidate_strategy="rules_filter_guardrail_no_augment", + ) + + by_workload = {row.workload_id: row for row in rows} + shell = by_workload["shell-1"] + assert shell.baseline_replacement_strategy == "default" + assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.final_entity_count_delta == 4 + assert shell.observed_total_requests_delta == -2 + assert shell.observed_total_tokens_delta == -2774 + assert shell.seed_validation_candidate_count_delta == -5 + assert shell.baseline_augmented_entity_count == 2 + assert shell.candidate_augmented_entity_count == 0 + assert shell.augmented_entity_count_delta == -2 + assert shell.baseline_augmented_new_final_value_count == 1 + assert shell.candidate_augmented_new_final_value_count == 0 + assert shell.augmented_new_final_value_count_delta == -1 + assert shell.candidate_rule_entity_count == 8 + assert shell.safety_verdict == "review" + assert shell.performance_verdict == "improved" + assert shell.candidate_verdict == "review" + assert shell.flags == ["no_candidate_detector_entities", "candidate_uses_rule_entities"] + + legal = by_workload["legal-1"] + assert legal.baseline_replacement_strategy == "default" + assert legal.candidate_replacement_strategy == "local_structured_substitute" + assert legal.final_entity_count_delta == 0 + assert legal.observed_total_tokens_delta == 0 + assert legal.candidate_detector_entity_count == 26 + assert legal.safety_verdict == "pass" + assert legal.performance_verdict == "improved" + assert legal.candidate_verdict == "candidate_viable" + assert legal.flags == [] + + +def test_compare_case_analysis_rejects_ambiguous_strategy_selector() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_ambiguous", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "base-a", + "experimental_detection_strategy": "no_augment", + "case_id": "a", + }, + { + "workload_id": "shell-1", + "config_id": "base-b", + "experimental_detection_strategy": "no_augment", + "case_id": "b", + }, + { + "workload_id": "shell-1", + "config_id": "candidate", + "experimental_detection_strategy": "rules_only", + "case_id": "c", + }, + ] + ) + + with pytest.raises(ValueError, match="baseline selector matched multiple configs"): + tool.compare_case_analysis(table, baseline_strategy="no_augment", candidate_strategy="rules_only") + + +def test_compare_case_analysis_rejects_candidate_synthetic_original_collisions() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_replacement_collisions", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-1", + "config_id": "baseline", + "experimental_detection_strategy": "default", + "case_id": "base", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 4, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_value_count": 0, + }, + { + "workload_id": "legal-1", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "cand", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 2, + "observed_total_tokens": 1000, + "final_entity_count": 4, + "replacement_synthetic_original_collision_count": 1, + "replacement_synthetic_original_collision_value_count": 1, + "replacement_synthetic_original_collision_label_counts.date": 1, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_replacement_synthetic_original_collision_count == 1 + assert row.replacement_synthetic_original_collision_count_delta == 1 + assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + assert "candidate_replacement_synthetic_original_collision" in row.flags + assert row.value_protection_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_compare_case_analysis_rejects_candidate_missing_replacement_map_entries() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_missing_replacement_map_entries", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "baseline", + "experimental_detection_strategy": "default", + "case_id": "base", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 4, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_value_count": 0, + }, + { + "workload_id": "structured-identifiers", + "config_id": "candidate", + "experimental_detection_strategy": "default", + "case_id": "cand", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 2, + "observed_total_tokens": 1000, + "final_entity_count": 4, + "replacement_missing_final_entity_count": 1, + "replacement_missing_final_value_count": 1, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_replacement_missing_final_entity_count == 1 + assert row.replacement_missing_final_entity_count_delta == 1 + assert "candidate_replacement_missing_final_entity" in row.flags + assert row.value_protection_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_compare_case_tables_allows_candidate_from_separate_run() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_cross_run", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + baseline = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "legal-no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "legal__base", + "observed_total_tokens": 55790, + "final_entity_count": 193, + "artifact_final_detector_entity_count": 172, + } + ] + ) + candidate = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "legal-rules-guardrail", + "experimental_detection_strategy": "rules_guardrail_no_augment", + "case_id": "legal__candidate", + "observed_total_tokens": 55805, + "final_entity_count": 193, + "artifact_final_detector_entity_count": 172, + } + ] + ) + + rows = tool.compare_case_tables( + baseline, + candidate, + baseline_strategy="no_augment", + candidate_strategy="rules_guardrail_no_augment", + ) + + assert len(rows) == 1 + assert rows[0].workload_id == "legal-5" + assert rows[0].baseline_config_id == "legal-no-augment" + assert rows[0].candidate_config_id == "legal-rules-guardrail" + assert rows[0].observed_total_tokens_delta == 15 + assert rows[0].flags == ["token_increase"] + + +def test_compare_case_analysis_preserves_augmentation_contribution_deltas() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_augmentation", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-2", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "augmented_entity_count": 8, + "augmented_new_final_value_count": 3, + "artifact_final_augmenter_entity_count": 3, + "artifact_final_entity_signature_hashes": ["a", "b", "c"], + }, + { + "workload_id": "legal-2", + "config_id": "no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "no-augment-r0", + "augmented_entity_count": 0, + "augmented_new_final_value_count": 0, + "artifact_final_augmenter_entity_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="no-augment") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_augmented_entity_count == 8 + assert row.candidate_augmented_entity_count == 0 + assert row.augmented_entity_count_delta == -8 + assert row.baseline_augmented_new_final_value_count == 3 + assert row.candidate_augmented_new_final_value_count == 0 + assert row.augmented_new_final_value_count_delta == -3 + assert row.baseline_augmenter_entity_count == 3 + assert row.candidate_augmenter_entity_count == 0 + + +def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_detector_only", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-1", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "bio-1", + "config_id": "detector-only", + "experimental_detection_strategy": "detector_only", + "case_id": "bio__detector", + "pipeline_elapsed_sec": 5, + "observed_total_requests": 1, + "observed_total_tokens": 1000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="detector_only") + + assert len(rows) == 1 + row = rows[0] + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["candidate_skips_llm_validation"] + + +def test_compare_case_analysis_review_gates_rule_detector_only_candidates() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_rule_detector_only", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "shell__default", + "pipeline_elapsed_sec": 8, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-1", + "config_id": "rule-detector-only", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 1, + "observed_total_requests": 1, + "observed_total_tokens": 200, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 1, + "artifact_final_rule_entity_count": 1, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_strategy="default", + candidate_strategy="rules_guardrail_detector_only", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["candidate_uses_rule_entities", "candidate_skips_llm_validation"] + + +def test_compare_case_analysis_review_gates_rules_covered_or_default_when_signatures_match() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_rules_covered_or_default", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-1", + "config_id": "rule-labels-default", + "experimental_detection_strategy": "default", + "case_id": "shell__default", + "pipeline_elapsed_sec": 21.4, + "observed_total_requests": 8, + "observed_total_tokens": 9854, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + { + "workload_id": "shell-1", + "config_id": "rule-labels-covered-or-default", + "experimental_detection_strategy": "rules_covered_or_default", + "case_id": "shell__candidate", + "pipeline_elapsed_sec": 0.001, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 2, + "artifact_final_rule_entity_count": 2, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_config="rule-labels-default", + candidate_config="rule-labels-covered-or-default", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.shared_final_entity_signature_label_counts == {"api_key": 1, "password": 1} + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "pass" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["no_candidate_detector_entities", "candidate_uses_rule_entities"] + + +def test_compare_case_analysis_flags_signature_loss_even_when_counts_match() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_signature_loss", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "final_entity_count": 3, + "artifact_final_entity_signature_hashes": '["a","b","c"]', + "artifact_final_entity_signature_labels": '{"a":"first_name","b":"city","c":"first_name"}', + }, + { + "workload_id": "bio-5", + "config_id": "no-augment", + "experimental_detection_strategy": "no_augment", + "case_id": "bio__no_augment", + "final_entity_count": 3, + "artifact_final_entity_signature_hashes": ["a", "b", "d"], + "artifact_final_entity_signature_labels": {"a": "first_name", "b": "city", "d": "last_name"}, + }, + ] + ) + + rows = tool.compare_case_analysis( + table, + baseline_config="default", + candidate_config="no-augment", + ) + + assert len(rows) == 1 + row = rows[0] + assert row.final_entity_count_delta == 0 + assert row.baseline_only_final_entity_signature_count == 1 + assert row.candidate_only_final_entity_signature_count == 1 + assert row.shared_final_entity_signature_count == 2 + assert row.baseline_only_final_entity_signature_label_counts == {"first_name": 1} + assert row.candidate_only_final_entity_signature_label_counts == {"last_name": 1} + assert row.shared_final_entity_signature_label_counts == {"city": 1, "first_name": 1} + assert row.safety_verdict == "fail" + assert row.performance_verdict == "unknown" + assert row.candidate_verdict == "reject" + assert row.flags == ["entity_signature_loss"] + + +def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_coverage", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 2, + "artifact_final_detector_entity_count": 2, + "artifact_final_entity_signature_hashes": ["api-token", "pin"], + "artifact_final_entity_signature_labels": {"api-token": "api_key", "pin": "pin"}, + "artifact_final_entity_signature_details": { + "api-token": { + "label": "api_key", + "source": "augmenter", + "row_index": 0, + "start_position": 42, + "end_position": 66, + "value_hash": "token-hash", + "value_length": 24, + }, + "pin": { + "label": "pin", + "source": "detector", + "row_index": 0, + "start_position": 90, + "end_position": 95, + "value_hash": "pin-hash", + "value_length": 5, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "rules-local", + "experimental_detection_strategy": "rules_covered_or_default", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 2, + "artifact_final_rule_entity_count": 2, + "artifact_final_entity_signature_hashes": ["cookie", "pin"], + "artifact_final_entity_signature_labels": {"cookie": "http_cookie", "pin": "pin"}, + "artifact_final_entity_signature_details": { + "cookie": { + "label": "http_cookie", + "source": "rule", + "row_index": 0, + "start_position": 30, + "end_position": 80, + "value_hash": "cookie-hash", + "value_length": 50, + }, + "pin": { + "label": "pin", + "source": "rule", + "row_index": 0, + "start_position": 90, + "end_position": 95, + "value_hash": "pin-hash", + "value_length": 5, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 1 + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_covered_signature_label_counts == {"api_key": 1} + assert row.baseline_only_candidate_uncovered_signature_label_counts == {} + assert "entity_signature_loss" not in row.flags + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_review_gates_covered_label_mismatch() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_label_mismatch", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-row", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 1000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["dob"], + "artifact_final_entity_signature_labels": {"dob": "date_of_birth"}, + "artifact_final_entity_signature_details": { + "dob": { + "label": "date_of_birth", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 35, + "value_hash": "date-hash", + "value_length": 15, + }, + }, + }, + { + "workload_id": "legal-row", + "config_id": "native-validation", + "experimental_detection_strategy": "detector_native_validate_no_augment", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "observed_total_requests": 2, + "observed_total_tokens": 300, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["date"], + "artifact_final_entity_signature_labels": {"date": "date"}, + "artifact_final_entity_signature_details": { + "date": { + "label": "date", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 35, + "value_hash": "date-hash", + "value_length": 15, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-validation") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_label_mismatch_signature_count == 1 + assert row.baseline_only_candidate_label_mismatch_signature_label_counts == {"date_of_birth": 1} + assert "entity_signature_loss" not in row.flags + assert "covered_label_mismatch" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_treats_high_overlap_candidate_span_as_covered() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_overlap", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-assignment"], + "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, + "artifact_final_entity_signature_details": { + "token-assignment": { + "label": "http_cookie", + "source": "augmenter", + "row_index": 0, + "start_position": 20, + "end_position": 68, + "value_hash": "token-assignment-hash", + "value_length": 48, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "rules-local", + "experimental_detection_strategy": "rules_covered_or_default", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 1, + "artifact_final_rule_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-value"], + "artifact_final_entity_signature_labels": {"token-value": "api_key"}, + "artifact_final_entity_signature_details": { + "token-value": { + "label": "api_key", + "source": "rule", + "row_index": 0, + "start_position": 26, + "end_position": 69, + "value_hash": "token-value-hash", + "value_length": 43, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 1 + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_overlapping_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert row.baseline_only_candidate_overlapping_signature_label_counts == {"http_cookie": 1} + assert row.baseline_only_candidate_uncovered_signature_label_counts == {} + assert "entity_signature_loss" not in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_treats_small_assignment_prefix_gap_as_boundary_overlap() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_candidate_span_boundary", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["login-assignment"], + "artifact_final_entity_signature_labels": {"login-assignment": "unique_id"}, + "artifact_final_entity_signature_details": { + "login-assignment": { + "label": "unique_id", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 38, + "value_hash": "login-assignment-hash", + "value_length": 18, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "rules-local", + "experimental_detection_strategy": "rules_covered_or_default", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 0.01, + "observed_total_requests": 0, + "observed_total_tokens": 0, + "final_entity_count": 1, + "artifact_final_rule_entity_count": 1, + "artifact_final_entity_signature_hashes": ["login-value"], + "artifact_final_entity_signature_labels": {"login-value": "user_name"}, + "artifact_final_entity_signature_details": { + "login-value": { + "label": "user_name", + "source": "rule", + "row_index": 0, + "start_position": 26, + "end_position": 38, + "value_hash": "login-value-hash", + "value_length": 12, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_candidate_covered_signature_count == 1 + assert row.baseline_only_candidate_overlapping_signature_count == 1 + assert row.baseline_only_candidate_uncovered_signature_count == 0 + assert "entity_signature_loss" not in row.flags + assert "span_boundary_mismatch" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_flags_replacement_only_detection_instability() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_replacement_only_detection_instability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "config_id": "dd-substitute", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 10, + "observed_total_requests": 4, + "observed_total_tokens": 4000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-assignment"], + "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, + "artifact_final_entity_signature_details": { + "token-assignment": { + "label": "http_cookie", + "source": "detector", + "row_index": 0, + "start_position": 20, + "end_position": 70, + "value_hash": "token-assignment-hash", + "value_length": 50, + }, + }, + }, + { + "workload_id": "structured-identifiers", + "config_id": "local-substitute", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "local_structured_substitute", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 7, + "observed_total_requests": 3, + "observed_total_tokens": 3000, + "final_entity_count": 1, + "artifact_final_detector_entity_count": 1, + "artifact_final_entity_signature_hashes": ["token-value"], + "artifact_final_entity_signature_labels": {"token-value": "api_key"}, + "artifact_final_entity_signature_details": { + "token-value": { + "label": "api_key", + "source": "detector", + "row_index": 0, + "start_position": 26, + "end_position": 70, + "value_hash": "token-value-hash", + "value_length": 44, + }, + }, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="dd-substitute", candidate_config="local-substitute") + + assert len(rows) == 1 + row = rows[0] + assert "covered_label_mismatch" in row.flags + assert "replacement_only_detection_instability" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_rejects_candidate_original_value_leaks() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_original_value_leak", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "structured-secrets", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "structured__default", + "pipeline_elapsed_sec": 10, + "observed_total_tokens": 1000, + "final_entity_count": 2, + "original_value_leak_count": 0, + "original_value_leak_record_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + { + "workload_id": "structured-secrets", + "config_id": "candidate", + "experimental_detection_strategy": "rules_covered_or_default", + "case_id": "structured__candidate", + "pipeline_elapsed_sec": 1, + "observed_total_tokens": 0, + "final_entity_count": 2, + "original_value_leak_count": 2, + "original_value_leak_record_count": 1, + "original_value_leak_label_counts.api_key": 1, + "original_value_leak_label_counts.password": 1, + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.candidate_original_value_leak_count == 2 + assert row.candidate_original_value_leak_record_count == 1 + assert row.original_value_leak_count_delta == 2 + assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + assert row.flags == ["candidate_original_value_leak"] + + +def test_compare_case_analysis_marks_clean_but_expensive_candidate_for_review() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_verdicts", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "legal__default", + "pipeline_elapsed_sec": 20, + "observed_total_tokens": 1000, + "final_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "legal__candidate", + "pipeline_elapsed_sec": 25, + "observed_total_tokens": 1200, + "final_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + assert rows[0].safety_verdict == "pass" + assert rows[0].performance_verdict == "regressed" + assert rows[0].candidate_verdict == "review" + assert rows[0].flags == ["token_increase"] + + +def test_compare_case_analysis_separates_bridge_fallbacks_from_provider_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_trace_adjusted", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "bio__default", + "pipeline_elapsed_sec": 20, + "observed_total_requests": 4, + "observed_failed_requests": 1, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 3, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 12, + "artifact_final_detector_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "bio-5", + "config_id": "windowed", + "experimental_detection_strategy": "windowed", + "case_id": "bio__windowed", + "pipeline_elapsed_sec": 15, + "observed_total_requests": 6, + "observed_failed_requests": 2, + "observed_bridge_fallback_requests": 2, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 12, + "artifact_final_detector_entity_count": 12, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="windowed") + + assert len(rows) == 1 + row = rows[0] + assert row.observed_failed_requests_delta == 1 + assert row.observed_bridge_fallback_requests_delta == 1 + assert row.observed_non_bridge_failed_requests_delta == 0 + assert "bridge_fallback_increase" in row.flags + assert "failed_request_increase" not in row.flags + + +def test_compare_case_analysis_flags_non_bridge_provider_failure_increase() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_non_bridge_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "legal__default", + "observed_total_requests": 5, + "observed_failed_requests": 1, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 0, + "final_entity_count": 10, + "artifact_final_detector_entity_count": 10, + "artifact_final_entity_signature_hashes": ["a"], + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "experimental_detection_strategy": "candidate", + "case_id": "legal__candidate", + "observed_total_requests": 5, + "observed_failed_requests": 2, + "observed_bridge_fallback_requests": 1, + "observed_non_bridge_total_requests": 4, + "observed_non_bridge_failed_requests": 1, + "final_entity_count": 10, + "artifact_final_detector_entity_count": 10, + "artifact_final_entity_signature_hashes": ["a"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.observed_bridge_fallback_requests_delta == 0 + assert row.observed_non_bridge_failed_requests_delta == 1 + assert "bridge_fallback_increase" not in row.flags + assert "failed_request_increase" in row.flags + assert row.safety_verdict == "review" + assert row.performance_verdict == "unchanged" + assert row.candidate_verdict == "review" + + +def test_compare_case_analysis_rejects_candidate_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_case_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 8, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "default", + "experimental_detection_strategy": "default", + "case_id": "default-r1", + "pipeline_elapsed_sec": 8, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "experimental_detection_strategy": "rules_guardrail_detector_only", + "case_id": "candidate-r1", + "pipeline_elapsed_sec": 0.2, + "case_failed": True, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 0 + assert row.candidate_failed_case_count == 1 + assert row.failed_case_count_delta == 1 + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + assert "candidate_case_failures" in row.flags + + +def test_compare_case_analysis_counts_model_workflow_errors_as_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_model_workflow_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-5", + "config_id": "default", + "case_id": "default-r0", + "pipeline_elapsed_sec": 8, + "error_stage_count": 0, + "error_ndd_workflow_count": 0, + "error_model_workflow_count": 0, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + { + "workload_id": "shell-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 2, + "error_stage_count": 0, + "error_ndd_workflow_count": 0, + "error_model_workflow_count": 1, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 0 + assert row.candidate_failed_case_count == 1 + assert row.safety_verdict == "fail" + assert row.candidate_verdict == "reject" + assert "candidate_case_failures" in row.flags + + +def test_compare_case_analysis_review_gates_baseline_case_failures() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_baseline_case_failure", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "bio-5", + "config_id": "baseline", + "case_id": "baseline-r0", + "pipeline_elapsed_sec": 30, + "case_failed": True, + "artifact_final_entity_signature_hashes": [], + }, + { + "workload_id": "bio-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "pipeline_elapsed_sec": 20, + "case_failed": False, + "artifact_final_entity_signature_hashes": ["a", "b"], + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_failed_case_count == 1 + assert row.candidate_failed_case_count == 0 + assert row.failed_case_count_delta == -1 + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + assert row.flags == ["baseline_case_failures"] + + +def test_compare_case_analysis_flags_repeated_signature_instability() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_stability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r1", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r1", + "artifact_final_entity_signature_hashes": ["a"], + "artifact_final_entity_signature_labels": {"a": "person"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.baseline_stable_final_entity_signature_count == 2 + assert row.candidate_stable_final_entity_signature_count == 1 + assert row.stable_final_entity_signature_count_delta == -1 + assert row.baseline_stable_candidate_unstable_final_entity_signature_count == 1 + assert row.baseline_stable_candidate_unstable_final_entity_signature_label_counts == {"date": 1} + assert row.safety_verdict == "fail" + assert row.candidate_verdict == "reject" + assert row.flags == ["stable_entity_signature_loss"] + + +def test_compare_case_analysis_does_not_infer_stability_loss_from_single_candidate_run() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_single_candidate_stability", + REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "default", + "case_id": "default-r1", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + { + "workload_id": "legal-5", + "config_id": "candidate", + "case_id": "candidate-r0", + "artifact_final_entity_signature_hashes": ["a", "b"], + "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, + }, + ] + ) + + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") + + assert len(rows) == 1 + row = rows[0] + assert row.baseline_only_final_entity_signature_count == 0 + assert row.candidate_only_final_entity_signature_count == 0 + assert row.baseline_stable_final_entity_signature_count is None + assert row.candidate_stable_final_entity_signature_count is None + assert row.stable_final_entity_signature_count_delta is None + assert row.baseline_stable_candidate_unstable_final_entity_signature_count is None + assert row.safety_verdict == "pass" + assert "stable_entity_signature_loss" not in row.flags + + +def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_export", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" + ) + rows = [ + tool.ComparisonRow( + workload_id="shell-1", + baseline_config_id="base", + candidate_config_id="candidate", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="review", + signature_parity_verdict="pass", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + baseline_final_entity_count=4, + candidate_final_entity_count=8, + final_entity_count_delta=4, + flags=["candidate_uses_rule_entities"], + ) + ] + + output = tmp_path / "comparison.csv" + tool.write_comparisons(rows, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["workload_id"].tolist() == ["shell-1"] + assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["final_entity_count_delta"].tolist() == [4] + assert exported["flags"].tolist() == ['["candidate_uses_rule_entities"]'] + + +def test_compare_strategy_pairs_summarizes_candidate_verdicts() -> None: + tool = load_tool( + "measurement_compare_strategy_pairs_summary", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" + ) + rows = [ + tool.ComparisonRow( + workload_id="legal-1", + baseline_config_id="legal-default", + candidate_config_id="legal-no-augment", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="pass", + signature_parity_verdict="pass", + safety_verdict="pass", + performance_verdict="improved", + candidate_verdict="candidate_viable", + ), + tool.ComparisonRow( + workload_id="bio-1", + baseline_config_id="bio-default", + candidate_config_id="bio-no-augment", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="fail", + signature_parity_verdict="fail", + safety_verdict="fail", + performance_verdict="improved", + candidate_verdict="reject", + baseline_only_final_entity_signature_label_counts={"first_name": 2}, + ), + tool.ComparisonRow( + workload_id="shell-1", + baseline_config_id="shell-default", + candidate_config_id="shell-rules", + baseline_case_count=1, + candidate_case_count=1, + value_protection_verdict="review", + signature_parity_verdict="pass", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + ), + ] + + summary = tool.summarize_comparisons(rows) + + assert summary.comparison_count == 3 + assert summary.value_protection_verdict_counts == {"fail": 1, "pass": 1, "review": 1} + assert summary.signature_parity_verdict_counts == {"fail": 1, "pass": 2, "review": 0} + assert summary.safety_verdict_counts == {"fail": 1, "pass": 1, "review": 1} + assert summary.performance_verdict_counts["improved"] == 3 + assert summary.candidate_verdict_counts == {"candidate_viable": 1, "reject": 1, "review": 1} + assert summary.candidate_viable_workloads == ["legal-1"] + assert summary.review_workloads == ["shell-1"] + assert summary.rejected_workloads == ["bio-1"] + + rendered = tool.render_result( + tool.ComparisonResult( + input_path="case_analysis.csv", + baseline_selector="strategy:default", + candidate_selector="strategy:no_augment", + summary=summary, + comparisons=rows, + ), + json_output=False, + ) + assert "Compared 3 workload(s): viable=1, review=1, reject=1" in rendered + assert ( + "- bio-1: verdict=reject (safety=fail, value_protection=fail, signature_parity=fail, performance=improved)" + ) in rendered + assert "elapsed unknown->unknown" in rendered + assert "lost_labels=first_name:2" in rendered diff --git a/tests/tools/test_dd_parser_compat.py b/tests/tools/test_dd_parser_compat.py new file mode 100644 index 00000000..b9b0241c --- /dev/null +++ b/tests/tools/test_dd_parser_compat.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +from data_designer.engine.models.recipes import response_recipes as recipes +from pydantic import BaseModel + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +class TinyPayload(BaseModel): + value: str + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_raw_json_compat_accepts_pydantic_raw_json_and_restores_parser() -> None: + tool = load_tool("measurement_dd_parser_compat_pydantic", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + original = recipes.PydanticResponseRecipe._build_parser_fn + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.PydanticResponseRecipe(TinyPayload) + parsed = recipe._build_parser_fn()('{"value": "ok"}') + + assert parsed.value == "ok" + assert recipes.PydanticResponseRecipe._build_parser_fn is original + + +def test_raw_json_compat_accepts_pydantic_json_after_reasoning_prefix() -> None: + tool = load_tool( + "measurement_dd_parser_compat_pydantic_reasoning", REPO_ROOT / "tools/measurement/dd_parser_compat.py" + ) + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.PydanticResponseRecipe(TinyPayload) + parsed = recipe.parse('reasoning text\n\n\n\n{"value": "ok"}') + + assert parsed.value == "ok" + + +def test_raw_json_compat_accepts_structured_raw_json_and_restores_parser() -> None: + tool = load_tool("measurement_dd_parser_compat_structured", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + original = recipes.StructuredResponseRecipe._build_parser_fn + schema = { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + } + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.StructuredResponseRecipe(schema) + parsed = recipe._build_parser_fn()('{"value": "ok"}') + + assert parsed == {"value": "ok"} + assert recipes.StructuredResponseRecipe._build_parser_fn is original + + +def test_raw_json_compat_uses_outermost_embedded_json_object() -> None: + tool = load_tool("measurement_dd_parser_compat_outer_json", REPO_ROOT / "tools/measurement/dd_parser_compat.py") + schema = { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + } + }, + "required": ["items"], + } + + with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): + recipe = recipes.StructuredResponseRecipe(schema) + parsed = recipe.parse('text before {"items": [{"value": "first"}, {"value": "last"}]}') + + assert parsed == {"items": [{"value": "first"}, {"value": "last"}]} diff --git a/tests/tools/test_dd_trace_analysis.py b/tests/tools/test_dd_trace_analysis.py new file mode 100644 index 00000000..82bb5457 --- /dev/null +++ b/tests/tools/test_dd_trace_analysis.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_dd_traces_classifies_response_shape_without_raw_content(tmp_path: Path) -> None: + tool = load_tool("measurement_dd_trace_analysis", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") + trace_path = tmp_path / "trace.jsonl" + _write_jsonl( + trace_path, + [ + { + "record_type": "dd_message_trace", + "run_id": "case-1", + "run_tags": { + "suite_id": "suite", + "workload_id": "legal", + "config_id": "default", + "case_id": "case-1", + }, + "workflow_name": "entity-detection", + "model_alias": "validator", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "status": "completed", + "error_type": None, + "elapsed_sec": 3.0, + "messages": [{"role": "user", "content": [{"type": "text", "text": "secret prompt text"}]}], + "response": {"content": '{"decisions": []}', "reasoning_content": None, "tool_calls": []}, + "usage": {"input_tokens": 11, "output_tokens": 7, "total_tokens": 18}, + }, + { + "record_type": "dd_message_trace", + "run_id": "case-1", + "run_tags": { + "suite_id": "suite", + "workload_id": "legal", + "config_id": "default", + "case_id": "case-1", + }, + "workflow_name": "entity-detection", + "model_alias": "augmenter", + "model_name": "nvidia/nemotron-3-super", + "model_provider_name": "local-vllm", + "status": "completed", + "error_type": None, + "elapsed_sec": 4.0, + "messages": [{"role": "user", "content": [{"type": "text", "text": "another secret prompt"}]}], + "response": { + "content": 'reasoning\n\n```json\n{"entities": []}\n```', + "reasoning_content": None, + "tool_calls": [], + }, + "usage": {"input_tokens": 13, "output_tokens": 17, "total_tokens": 30}, + }, + { + "record_type": "other", + "response": {"content": "ignored raw text"}, + }, + ], + ) + + result = tool.analyze_trace_path(trace_path) + + assert result.trace_record_count == 2 + rows = {row.model_alias: row for row in result.rows} + assert rows["validator"].response_shape == "raw_json" + assert rows["validator"].response_has_embedded_json is True + assert rows["validator"].prompt_chars == len("secret prompt text") + assert rows["augmenter"].response_shape == "fenced_json" + assert rows["augmenter"].response_has_thinking is True + assert rows["augmenter"].response_chars > 0 + serialized = result.model_dump_json() + assert "secret prompt text" not in serialized + assert "another secret prompt" not in serialized + assert "decisions" not in serialized + assert "entities" not in serialized + + group = next(row for row in result.groups if row.response_shape == "fenced_json") + assert group.model_alias == "augmenter" + assert group.model_name == "nvidia/nemotron-3-super" + assert group.trace_record_count == 1 + assert group.sum_total_tokens == 30 + assert group.error_count == 0 + + +def test_analyze_dd_traces_reads_trace_directory_and_exports_tables(tmp_path: Path) -> None: + tool = load_tool("measurement_dd_trace_analysis_export", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") + trace_dir = tmp_path / "traces" + trace_dir.mkdir() + _write_jsonl( + trace_dir / "a.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "case-a", + "workflow_name": "entity-detection", + "model_name": "nvidia/nemotron-3-super", + "status": "error", + "error_type": "ParserException", + "elapsed_sec": 1.0, + "response": None, + "usage": None, + } + ], + ) + _write_jsonl( + trace_dir / "b.jsonl", + [ + { + "record_type": "dd_message_trace", + "run_id": "case-b", + "workflow_name": "entity-detection", + "model_name": "nvidia/gliner-pii", + "status": "completed", + "elapsed_sec": 0.2, + "response": {"content": "plain text", "reasoning_content": None, "tool_calls": []}, + "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + } + ], + ) + + result = tool.analyze_trace_path(trace_dir) + output_dir = tmp_path / "analysis" + tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) + + assert result.trace_record_count == 2 + rows = {row.run_id: row for row in result.rows} + assert rows["case-a"].status == "error" + assert rows["case-a"].response_shape == "none" + assert rows["case-b"].response_shape == "text" + assert pd.read_csv(output_dir / "trace_analysis.csv")["run_id"].tolist() == ["case-a", "case-b"] + assert pd.read_csv(output_dir / "trace_group_analysis.csv")["trace_record_count"].sum() == 2 + assert (output_dir / "manifest.json").exists() diff --git a/tests/tools/test_detection_artifact_analysis.py b/tests/tools/test_detection_artifact_analysis.py new file mode 100644 index 00000000..3bc25c51 --- /dev/null +++ b/tests/tools/test_detection_artifact_analysis.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _entity(value: str, label: str, start: int, end: int, *, source: str = "detector") -> dict[str, object]: + return { + "id": f"{label}_{start}_{end}", + "value": value, + "label": label, + "start_position": start, + "end_position": end, + "score": 1.0, + "source": source, + } + + +def _write_artifact(root: Path, workflow: str, rows: list[dict[str, object]]) -> None: + parquet_dir = root / workflow / "parquet-files" + parquet_dir.mkdir(parents=True) + pd.DataFrame(rows).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + + +def test_detection_artifact_analysis_reports_augmentation_contribution(tmp_path: Path) -> None: + tool = load_tool( + "measurement_detection_artifact_analysis", + REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py", + ) + artifact_root = tmp_path / "artifacts" + _write_artifact( + artifact_root, + "entity-detection", + [ + { + "_seed_entities_json": json.dumps([_entity("Alice", "first_name", 0, 5)]), + "_seed_validation_candidates": json.dumps( + {"candidates": [{"id": "first_name_0_5", "value": "Alice", "label": "first_name"}]} + ), + "_augmented_entities": json.dumps( + { + "entities": [ + {"value": "Alice", "label": "first_name", "reason": "duplicate"}, + {"value": "12 February 1980", "label": "api_key", "reason": "date mislabeled"}, + ] + } + ), + "_detected_entities": json.dumps( + { + "entities": [ + _entity("Alice", "first_name", 0, 5), + _entity("12 February 1980", "api_key", 20, 36, source="augmenter"), + ] + } + ), + "_validation_candidates": json.dumps( + { + "candidates": [ + {"id": "first_name_0_5", "value": "Alice", "label": "first_name"}, + {"id": "api_key_20_36", "value": "12 February 1980", "label": "api_key"}, + ] + } + ), + } + ], + ) + + result = tool.analyze_artifacts(artifact_root) + + assert len(result.rows) == 1 + row = result.rows[0] + assert row.seed_entity_count == 1 + assert row.seed_validation_candidate_count == 1 + assert row.merged_validation_candidate_count == 2 + assert row.augmented_entity_count == 2 + assert row.augmented_duplicate_seed_value_count == 1 + assert row.augmented_new_value_count == 1 + assert row.augmented_new_final_value_count == 1 + assert row.final_entity_count == 2 + assert row.weak_api_key_shape_count == 1 + assert row.weak_api_key_shape_label_counts == {"api_key": 1} + assert row.final_entity_signature_count == 2 + assert len(row.final_entity_signature_hashes) == 2 + assert set(row.final_entity_signature_labels) == set(row.final_entity_signature_hashes) + assert sorted(row.final_entity_signature_labels.values()) == ["api_key", "first_name"] + assert set(row.final_entity_signature_details) == set(row.final_entity_signature_hashes) + first_name_detail = next( + detail for detail in row.final_entity_signature_details.values() if detail["label"] == "first_name" + ) + assert first_name_detail["source"] == "detector" + assert first_name_detail["row_index"] == 0 + assert first_name_detail["start_position"] == 0 + assert first_name_detail["end_position"] == 5 + assert first_name_detail["value_length"] == 5 + assert len(first_name_detail["value_hash"]) == 16 + + serialized = row.model_dump_json() + assert "Alice" not in serialized + assert "12 February" not in serialized + + +def test_detection_artifact_analysis_handles_no_augment_rows(tmp_path: Path) -> None: + tool = load_tool( + "measurement_detection_artifact_analysis_no_augment", + REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py", + ) + artifact_root = tmp_path / "artifacts" + _write_artifact( + artifact_root, + "entity-detection-no-augment", + [ + { + "_seed_entities_json": json.dumps([_entity("Aydin", "city", 12, 17)]), + "_seed_validation_candidates": json.dumps( + {"candidates": [{"id": "city_12_17", "value": "Aydin", "label": "city"}]} + ), + "_augmented_entities": json.dumps({"entities": []}), + "_detected_entities": json.dumps({"entities": [_entity("Aydin", "city", 12, 17)]}), + } + ], + ) + + result = tool.analyze_artifacts(artifact_root) + + assert len(result.rows) == 1 + row = result.rows[0] + assert row.workflow_name == "entity-detection-no-augment" + assert row.seed_entity_count == 1 + assert row.seed_validation_candidate_count == 1 + assert row.merged_validation_candidate_count == 0 + assert row.augmented_entity_count == 0 + assert row.augmented_new_value_count == 0 + assert row.augmented_new_final_value_count == 0 + assert row.final_entity_count == 1 + assert row.final_source_counts == {"detector": 1} + assert row.final_entity_signature_count == 1 + assert row.final_entity_signature_hashes == sorted(row.final_entity_signature_hashes) + assert list(row.final_entity_signature_labels.values()) == ["city"] diff --git a/tests/tools/test_detection_strategies.py b/tests/tools/test_detection_strategies.py new file mode 100644 index 00000000..c5bda974 --- /dev/null +++ b/tests/tools/test_detection_strategies.py @@ -0,0 +1,2143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +import threading +import time +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import Mock + +import pandas as pd +import pytest + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_FINAL_ENTITIES, + COL_INITIAL_TAGGED_TEXT, + COL_MERGED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, + COL_VALIDATED_ENTITIES, + COL_VALIDATED_SEED_ENTITIES, + COL_VALIDATION_DECISIONS, +) +from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow +from anonymizer.engine.ndd.model_loader import load_default_model_selection +from anonymizer.engine.schemas import EntitiesSchema +from anonymizer.measurement import MeasurementCollector, measurement_session + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_rules_only_strategy_detects_rule_spans_and_restores_workflow_method() -> None: + tool = load_tool("measurement_detection_strategies", REPO_ROOT / "tools/measurement/detection_strategies.py") + original = EntityDetectionWorkflow.detect_and_validate_entities + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_only): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "password"], + ) + + assert EntityDetectionWorkflow.detect_and_validate_entities is original + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("password", "fakePass123!"), + ] + + +def test_rules_only_strategy_rejects_unsupported_labels_at_runtime() -> None: + tool = load_tool( + "measurement_detection_strategies_runtime_guard", REPO_ROOT / "tools/measurement/detection_strategies.py" + ) + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_only): + workflow = EntityDetectionWorkflow(adapter=Mock()) + with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "person"], + ) + + +def test_rules_covered_or_default_short_circuits_structured_fast_lane_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_covered_short_circuit", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "password"], + ) + + adapter.run_workflow.assert_not_called() + assert "_anonymizer_row_order" not in result.dataframe.columns + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("password", "fakePass123!"), + ] + + +def test_rules_covered_or_default_falls_back_for_uncovered_structured_assignments() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_covered_row_fallback", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + calls: list[pd.DataFrame] = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(dataframe.copy()) + output = dataframe.copy() + output[COL_DETECTED_ENTITIES] = [ + EntitiesSchema(entities=[]).model_dump(mode="json") for _ in range(len(output)) + ] + output[COL_TAGGED_TEXT] = output[COL_TEXT] + return SimpleNamespace(dataframe=output, failed_records=[]) + + rows = [ + "token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", + '{"password": "SecretNoRule123!"}', + ] + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: rows}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "password"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert len(calls) == 1 + assert calls[0][COL_TEXT].tolist() == ['{"password": "SecretNoRule123!"}'] + assert "_anonymizer_row_order" not in result.dataframe.columns + assert result.dataframe[COL_TEXT].tolist() == rows + first_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + second_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[1]).entities + assert [(entity.label, entity.value) for entity in first_entities] == [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA") + ] + assert second_entities == [] + + +def test_rules_covered_or_default_records_route_counts() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_covered_route_counts", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + output = dataframe.copy() + output[COL_DETECTED_ENTITIES] = [ + EntitiesSchema(entities=[]).model_dump(mode="json") for _ in range(len(output)) + ] + output[COL_TAGGED_TEXT] = output[COL_TEXT] + return SimpleNamespace(dataframe=output, failed_records=[]) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + collector = MeasurementCollector(record_hash_key="test-key") + try: + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.rules_covered_or_default + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame( + { + COL_TEXT: [ + "token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", + '{"password": "SecretNoRule123!"}', + ] + } + ), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "password"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-rules-covered-router" + assert record["route_total_row_count"] == 2 + assert record["route_rule_row_count"] == 1 + assert record["route_fallback_row_count"] == 1 + assert record["observed_total_requests"] == 0 + assert record["observed_total_tokens"] == 0 + + +def test_native_rules_router_strategy_runs_staged_detection_without_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_native_rules_router", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SequencedClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs = [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace(content=self.outputs.pop(0), elapsed_sec=0.1, usage={}) + + adapter = Mock() + client = SequencedClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_rules_router, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "direct_seed"), + ("organization_name", "NVIDIA", "augmenter"), + ] + assert len(client.prompts) == 3 + + +def test_native_rules_router_strategy_records_direct_model_usage() -> None: + tool = load_tool( + "measurement_detection_strategies_native_rules_router_usage", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SequencedClient: + def __init__(self) -> None: + self.outputs = [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ] + + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content=self.outputs.pop(0), + elapsed_sec=0.1, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_rules_router, + native_client=SequencedClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["first_name", "organization_name"], + ) + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-rules-router" + assert record["observed_total_requests"] == 3 + assert record["observed_successful_requests"] == 3 + assert record["observed_failed_requests"] == 0 + assert record["observed_input_tokens"] == 30 + assert record["observed_output_tokens"] == 12 + assert record["observed_total_tokens"] == 42 + assert record["model_usage"]["native-direct"]["model_name"] == "nvidia/nemotron-3-super" + + +def test_native_candidate_validate_no_augment_strategy_skips_data_designer_and_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_native_candidate_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SequencedClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs = [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + ] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=self.outputs.pop(0), + elapsed_sec=0.1, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + adapter = Mock() + client = SequencedClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "direct_seed"), + ] + assert len(client.prompts) == 2 + assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-candidate-validate-no-augment" + assert record["observed_total_requests"] == 2 + assert record["observed_total_tokens"] == 28 + + +def test_detector_native_validate_no_augment_strategy_reuses_detector_seed_and_direct_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_called_once() + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( + "entity-detection-detector-native-validate-no-augment-seed" + ) + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + assert len(client.prompts) == 1 + assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + + +def test_detector_native_validate_no_augment_ignores_invalid_reclass_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate_invalid_label", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class ValidationClient: + def complete(self, request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content=( + '{"decisions": [' + '{"id": "first_name_0_5", "decision": "reclass", ' + '"proposed_label": "drop", "reason": "invalid label"}' + "]}" + ), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=ValidationClient(), + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + + +def test_detector_native_validate_native_augment_uses_direct_validation_and_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_augment", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + seed_row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "first_name_0_5", + "value": "Alice", + "label": "first_name", + "start_position": 0, + "end_position": 5, + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(seed_row) + + class DirectClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + if "Find additional sensitive entities" in request.prompt: + return SimpleNamespace( + content='{"entities": [{"value": "NVIDIA", "label": "organization_name"}]}', + elapsed_sec=0.2, + usage={"prompt_tokens": 50, "completion_tokens": 6, "total_tokens": 56}, + ) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) + client = DirectClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_called_once() + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( + "entity-detection-detector-native-validate-native-augment-seed" + ) + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ("organization_name", "NVIDIA", "augmenter"), + ] + assert len(client.prompts) == 2 + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + assert records[0]["workflow_name"] == "entity-detection-detector-native-validate-native-augment" + assert records[0]["observed_total_requests"] == 2 + assert records[0]["observed_total_tokens"] == 104 + + +def test_detector_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_native_validate_parallel", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + def seed_row(text: str, value: str) -> dict[str, object]: + row = { + COL_TEXT: text, + COL_RAW_DETECTED: "", + COL_SEED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": f"first_name_0_{len(value)}", + "value": value, + "label": "first_name", + "start_position": 0, + "end_position": len(value), + "score": 0.9, + "source": "detector", + } + ] + ).model_dump(mode="json"), + COL_TAG_NOTATION: "xml", + } + tool.prepare_validation_inputs(row) + return row + + class ValidationClient: + def __init__(self) -> None: + self.lock = threading.Lock() + self.active_count = 0 + self.max_active_count = 0 + + def complete(self, request): # type: ignore[no-untyped-def] + with self.lock: + self.active_count += 1 + self.max_active_count = max(self.max_active_count, self.active_count) + time.sleep(0.05) + with self.lock: + self.active_count -= 1 + candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" + return SimpleNamespace( + content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + seed_df = pd.DataFrame( + [ + seed_row("Alice works at NVIDIA.", "Alice"), + seed_row("Bob works at NVIDIA.", "Bob"), + ], + index=[10, 4], + ) + adapter = Mock() + adapter.run_workflow.return_value = SimpleNamespace(dataframe=seed_df, failed_records=[]) + client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + assert client.max_active_count > 1 + assert list(result.dataframe.index) == [10, 4] + entities_by_row = [ + [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] + for raw in result.dataframe[COL_DETECTED_ENTITIES] + ] + assert entities_by_row == [ + [("first_name", "Alice", "detector")], + [("first_name", "Bob", "detector")], + ] + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert [record["observed_total_requests"] for record in records] == [1, 1] + assert [record["observed_total_tokens"] for record in records] == [48, 48] + record = records[0] + assert record["workflow_name"] == "entity-detection-detector-native-validate-no-augment" + assert record["observed_total_requests"] == 1 + assert record["observed_total_tokens"] == 48 + + +def test_gliner_native_validate_no_augment_strategy_bypasses_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_validate", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + assert request.text == text + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + seed_client = GlinerSeedClient() + validation_client = ValidationClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=validation_client, + gliner_seed_client=seed_client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ] + assert len(validation_client.prompts) == 1 + assert all("Find additional sensitive entities" not in prompt for prompt in validation_client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-gliner-native-validate-no-augment" + assert record["observed_total_requests"] == 2 + assert record["observed_total_tokens"] == 73 + assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] + assert record["model_usage"]["gliner-direct"]["model_name"] == "nvidia/gliner-pii" + assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 + assert record["model_usage"]["native-direct"]["model_name"] == "nvidia/nemotron-3-super" + assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 48 + + +def test_gliner_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_parallel", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + value = str(request.text).split()[0] + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": value, + "label": "first_name", + "start": 0, + "end": len(value), + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" + return SimpleNamespace( + content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + collector = MeasurementCollector(record_hash_key="test-key") + dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]) + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=ValidationClient(), + gliner_seed_client=GlinerSeedClient(), + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + dataframe, + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + assert list(result.dataframe.index) == [10, 4] + entities_by_row = [ + [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] + for raw in result.dataframe[COL_DETECTED_ENTITIES] + ] + assert entities_by_row == [ + [("first_name", "Alice", "detector")], + [("first_name", "Bob", "detector")], + ] + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert [record["observed_total_requests"] for record in records] == [2, 2] + assert [record["observed_total_tokens"] for record in records] == [73, 73] + + +def test_gliner_native_validate_no_augment_parallel_rows_keep_failed_records() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_parallel_failure", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + if str(request.text).startswith("Broken"): + raise RuntimeError("seed unavailable") + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class ValidationClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Broken row"]}, index=[0, 1]) + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + native_client=ValidationClient(), + gliner_seed_client=GlinerSeedClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + dataframe, + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name"], + ) + + assert list(result.dataframe.index) == [0] + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("first_name", "Alice")] + assert [(failed.record_id, failed.step) for failed in result.failed_records] == [ + ("1", "entity-detection-gliner-native-validate-no-augment") + ] + assert "seed unavailable" in result.failed_records[0].reason + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + assert records[0]["observed_total_requests"] == 2 + + +def test_gliner_native_validate_native_augment_strategy_bypasses_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_gliner_native_augment", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice works at NVIDIA." + + class GlinerSeedClient: + def detect(self, request): # type: ignore[no-untyped-def] + assert request.text == text + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "first_name", + "start": 0, + "end": 5, + "score": 0.99, + } + ] + } + ), + elapsed_sec=0.2, + usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + ) + + class NativeClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + self.outputs = [ + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=self.outputs.pop(0), + elapsed_sec=0.1, + usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, + ) + + adapter = Mock() + seed_client = GlinerSeedClient() + native_client = NativeClient() + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + native_client=native_client, + gliner_seed_client=seed_client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=True, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "detector"), + ("organization_name", "NVIDIA", "augmenter"), + ] + assert any("Find additional sensitive entities" in prompt for prompt in native_client.prompts) + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-gliner-native-validate-native-augment" + assert record["observed_total_requests"] == 3 + assert record["observed_total_tokens"] == 121 + assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] + assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 + assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 96 + + +def test_native_single_pass_strategy_runs_one_direct_call_without_data_designer() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content=json.dumps( + { + "entities": [ + {"value": "Alice", "label": "first_name", "start": 0, "end": 5}, + {"value": "NVIDIA", "label": "organization_name", "start": 15, "end": 21}, + ] + } + ), + elapsed_sec=0.1, + usage={}, + ) + + adapter = Mock() + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name", "organization_name"], + ) + + adapter.run_workflow.assert_not_called() + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("first_name", "Alice", "direct_single_pass"), + ("organization_name", "NVIDIA", "direct_single_pass"), + ] + assert len(client.prompts) == 1 + assert '"start"' in client.prompts[0] + assert '"end"' in client.prompts[0] + + +def test_native_single_pass_recall_strategy_uses_label_examples() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_recall", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={}, + ) + + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass_recall, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person", "email"], + ) + + assert len(client.prompts) == 1 + assert "- person" in client.prompts[0] + assert "- email:" in client.prompts[0] + assert "Bias toward high recall" in client.prompts[0] + + +def test_native_single_pass_values_strategy_uses_value_only_prompt() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_values", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def __init__(self) -> None: + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person"}]}', + elapsed_sec=0.1, + usage={}, + ) + + client = SinglePassClient() + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass_values, + native_client=client, + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice met Alice."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.value, entity.start_position, entity.end_position) for entity in entities] == [ + ("Alice", 0, 5), + ("Alice", 10, 15), + ] + assert len(client.prompts) == 1 + assert '"start"' not in client.prompts[0] + assert '"end"' not in client.prompts[0] + assert '{"entities": [{"value": "exact substring", "label": "one_allowed_label"' in client.prompts[0] + + +def test_native_single_pass_strategy_records_direct_model_usage() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_usage", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class SinglePassClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "first_name", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={"prompt_tokens": 20, "completion_tokens": 7, "total_tokens": 27}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=SinglePassClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-single-pass" + assert record["observed_total_requests"] == 1 + assert record["observed_successful_requests"] == 1 + assert record["observed_failed_requests"] == 0 + assert record["observed_input_tokens"] == 20 + assert record["observed_output_tokens"] == 7 + assert record["observed_total_tokens"] == 27 + + +def test_native_single_pass_strategy_adds_non_overlapping_rule_spans() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_rule_guardrail", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + text = "Alice logged in.\nPassword: SuperSecret123!\n" + + class SinglePassClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', + elapsed_sec=0.1, + usage={}, + ) + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=SinglePassClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person", "password"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("person", "Alice", "direct_single_pass"), + ("password", "SuperSecret123!", "rule"), + ] + + +def test_native_single_pass_strategy_records_parser_errors_as_failures() -> None: + tool = load_tool( + "measurement_detection_strategies_native_single_pass_parser_error", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + + class InvalidJsonClient: + def complete(self, _request): # type: ignore[no-untyped-def] + return SimpleNamespace( + content="not json", + elapsed_sec=0.1, + usage={"prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11}, + ) + + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.native_single_pass, + native_client=InvalidJsonClient(), + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + + assert result.dataframe.empty + assert len(result.failed_records) == 1 + assert result.failed_records[0].step == "entity-detection-native-single-pass" + records = [record for record in collector.records if record["record_type"] == "model_workflow"] + assert len(records) == 1 + record = records[0] + assert record["workflow_name"] == "entity-detection-native-single-pass" + assert record["status"] == "error" + assert record["failed_record_count"] == 1 + assert record["observed_successful_requests"] == 1 + assert record["observed_failed_requests"] == 0 + + +def test_rules_covered_or_default_uses_default_pipeline_for_contextual_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_covered_default_fallback", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["api_key", "person"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert len(calls) == 1 + assert calls[0]["entity_labels"] == ["api_key", "person"] + assert calls[0]["validation_single_chunk_full_text"] is False + assert EntityDetectionWorkflow.detect_and_validate_entities is original + assert EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities == [] + + +def test_rules_covered_or_default_uses_default_pipeline_for_narrow_prose_rule_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_covered_prose_rule_fallback", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Jordan worked at Acme Research Center and lived on Maple Street."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + validation_single_chunk_full_text=False, + entity_labels=["organization_name", "street_address"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert len(calls) == 1 + assert calls[0]["entity_labels"] == ["organization_name", "street_address"] + assert EntityDetectionWorkflow.detect_and_validate_entities is original + assert EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities == [] + + +def test_detector_only_strategy_finalizes_gliner_spans_without_validation_or_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_only", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + text = "Alice met Alice at the lab." + start = text.index("Alice") + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: + assert [column.name for column in columns] == [ + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_DETECTED_ENTITIES, + ] + assert kwargs["workflow_name"] == "entity-detection-detector-only" + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": start, + "end": start + len("Alice"), + "score": 0.99, + } + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = json.loads(row[COL_SEED_ENTITIES_JSON]) + assert [(entity["label"], entity["value"]) for entity in seed_entities] == [("person", "Alice")] + row = columns[3].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.detector_only): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("person", "Alice"), ("person", "Alice")] + assert "Alice" in result.dataframe[COL_TAGGED_TEXT].iloc[0] + + +def test_rules_guardrail_detector_only_adds_rule_spans_without_validation_or_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_guardrail_detector_only", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + text = f"Alice exported token={token}" + alice_start = text.index("Alice") + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: + assert [column.name for column in columns] == [ + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_DETECTED_ENTITIES, + ] + assert kwargs["workflow_name"] == "entity-detection-rules-guardrail-detector-only" + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": alice_start, + "end": alice_start + len("Alice"), + "score": 0.99, + } + ] + } + ), + } + for column in columns[1:]: + row = column.generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail_detector_only): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["api_key", "person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("person", "Alice", "detector"), + ("api_key", token, "rule"), + ] + + +def test_rules_guardrail_keeps_default_pipeline_and_adds_rule_spans() -> None: + tool = load_tool( + "measurement_detection_strategies_default_rules_guardrail", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + text = "The applicant was born in 1978 and later moved to Berlin." + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **_: object, + ) -> object: + calls.append(self) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["date_of_birth"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert len(calls) == 1 + assert EntityDetectionWorkflow.detect_and_validate_entities is original + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [("date_of_birth", "1978", "rule")] + + +def test_compact_validation_strategy_disables_full_text_single_chunk_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_compact_validation", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.compact_validation): + workflow = EntityDetectionWorkflow(adapter=Mock()) + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at Acme."]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert EntityDetectionWorkflow.detect_and_validate_entities is original + assert calls[0]["validation_single_chunk_full_text"] is False + + +def test_rules_guardrail_compact_validation_combines_rule_guardrail_and_compact_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_guardrail_compact_validation", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + text = "The applicant was born in 1978 and later moved to Berlin." + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.rules_guardrail_compact_validation + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["date_of_birth"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert calls[0]["validation_single_chunk_full_text"] is False + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [("date_of_birth", "1978", "rule")] + + +def test_rules_guardrail_prefers_rule_label_for_exact_span_overlap() -> None: + tool = load_tool( + "measurement_detection_strategies_default_rules_guardrail_exact_overlap", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + text = "Idilio describes himself as secular and leans progressive on most political issues." + start = text.index("secular") + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **_: object, + ) -> object: + calls.append(self) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema( + entities=[ + { + "id": "political_view_0", + "value": "secular", + "label": "political_view", + "start_position": start, + "end_position": start + len("secular"), + "score": 1.0, + "source": "detector", + } + ] + ).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["political_view", "religious_belief"], + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert len(calls) == 1 + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("religious_belief", "secular", "rule") + ] + + +def test_rules_guardrail_can_apply_explicit_rule_labels_outside_model_labels() -> None: + tool = load_tool( + "measurement_detection_strategies_default_rules_guardrail_rule_labels", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + original = EntityDetectionWorkflow.detect_and_validate_entities + text = "Outside the lab, Idilio shares a modest house on West Roberts Drive with his wife." + calls = [] + + def fake_original( + self: EntityDetectionWorkflow, + dataframe: pd.DataFrame, + **kwargs: object, + ) -> object: + calls.append(kwargs["entity_labels"]) + return SimpleNamespace( + dataframe=pd.DataFrame( + [ + { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + } + ] + ), + failed_records=[], + ) + + EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] + try: + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.rules_guardrail, + rule_labels=["street_address"], + ): + workflow = EntityDetectionWorkflow(adapter=Mock()) + result = workflow.run( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["first_name"], + privacy_goal=None, + tag_latent_entities=False, + ) + finally: + EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] + + assert calls == [["first_name"]] + detected_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + final_entities = EntitiesSchema.from_raw(result.dataframe[COL_FINAL_ENTITIES].iloc[0]).entities + expected = [("street_address", "West Roberts Drive", "rule")] + assert [(entity.label, entity.value, entity.source) for entity in detected_entities] == expected + assert [(entity.label, entity.value, entity.source) for entity in final_entities] == expected + + +def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: + tool = load_tool( + "measurement_detection_strategies_prose_augment_focus", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + before = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.prose_augment_focus): + inside = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + + after = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + assert "Contextual prose recall focus" not in before + assert "Contextual prose recall focus" in inside + assert "organization and institution names" in inside + assert after == before + + +def test_rules_guardrail_no_augment_adds_rule_spans_after_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_guardrail", REPO_ROOT / "tools/measurement/detection_strategies.py" + ) + adapter = Mock() + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + assert [column.name for column in columns][-3:] == [ + COL_AUGMENTED_ENTITIES, + COL_MERGED_ENTITIES, + COL_DETECTED_ENTITIES, + ] + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_MERGED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), + COL_VALIDATED_ENTITIES: {"decisions": []}, + } + row = columns[-1].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail_no_augment): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["api_key", "password"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [ + ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), + ("password", "fakePass123!"), + ] + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-guardrail-no-augment" + + +def test_rules_filter_guardrail_no_augment_filters_rule_spans_before_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_filter_guardrail", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + text = f"Alice used token={token}" + alice_start = text.index("Alice") + token_start = text.index(token) + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": alice_start, + "end": alice_start + len("Alice"), + "score": 0.99, + }, + { + "text": token, + "label": "api_key", + "start": token_start, + "end": token_start + len(token), + "score": 0.99, + }, + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities + assert [(entity.label, entity.value) for entity in seed_entities] == [("person", "Alice")] + assert [candidate["label"] for candidate in row[COL_SEED_VALIDATION_CANDIDATES]["candidates"]] == ["person"] + row[COL_MERGED_ENTITIES] = EntitiesSchema(entities=[]).model_dump(mode="json") + row[COL_VALIDATED_ENTITIES] = {"decisions": []} + row = columns[-1].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context( + tool.ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment + ): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["api_key", "person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("api_key", token)] + assert ( + adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-filter-guardrail-no-augment" + ) + + +def test_rules_filter_guardrail_keeps_augmentation_but_skips_rule_validation() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_filter_guardrail_with_augmentation", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + text = f"Alice used token={token}" + alice_start = text.index("Alice") + token_start = text.index(token) + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + assert [column.name for column in columns] == [ + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_VALIDATION_CANDIDATES, + COL_VALIDATION_DECISIONS, + COL_VALIDATED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_AUGMENTED_ENTITIES, + COL_MERGED_ENTITIES, + COL_DETECTED_ENTITIES, + ] + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": alice_start, + "end": alice_start + len("Alice"), + "score": 0.99, + }, + { + "text": token, + "label": "api_key", + "start": token_start, + "end": token_start + len(token), + "score": 0.99, + }, + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities + assert [(entity.label, entity.value) for entity in seed_entities] == [("person", "Alice")] + assert [candidate["label"] for candidate in row[COL_SEED_VALIDATION_CANDIDATES]["candidates"]] == ["person"] + + row[COL_VALIDATION_DECISIONS] = { + "decisions": [ + { + "id": "person_0_5", + "decision": "keep", + "proposed_label": "", + "reason": "test keep", + } + ] + } + row = columns[4].generator_function(row) + row = columns[5].generator_function(row) + validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities + assert [(entity.label, entity.value, entity.source) for entity in validated_seed] == [ + ("person", "Alice", "detector"), + ("api_key", token, "rule"), + ] + assert f"{token}" in row[COL_INITIAL_TAGGED_TEXT] + + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + row = columns[7].generator_function(row) + row = columns[8].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["api_key", "person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value, entity.source) for entity in entities] == [ + ("person", "Alice", "detector"), + ("api_key", token, "rule"), + ] + assert adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-filter-guardrail" + + +def test_rules_filter_guardrail_preserves_different_label_rule_overlap() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_filter_guardrail_preserve_contextual_overlap", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + phrase = "Christian Democrat" + rule_value = "Christian" + text = f"He identifies as a {phrase}." + phrase_start = text.index(phrase) + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": phrase, + "label": "political_view", + "start": phrase_start, + "end": phrase_start + len(phrase), + "score": 0.99, + }, + { + "text": rule_value, + "label": "religious_belief", + "start": phrase_start, + "end": phrase_start + len(rule_value), + "score": 0.99, + }, + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities + assert [(entity.label, entity.value) for entity in seed_entities] == [("political_view", phrase)] + assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["label"] == "political_view" + + row[COL_VALIDATION_DECISIONS] = { + "decisions": [ + { + "id": f"political_view_{phrase_start}_{phrase_start + len(phrase)}", + "decision": "keep", + "proposed_label": "", + "reason": "test keep", + } + ] + } + row = columns[4].generator_function(row) + row = columns[5].generator_function(row) + validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities + assert any(entity.label == "political_view" and entity.value == phrase for entity in validated_seed) + assert f"{phrase}" in row[COL_INITIAL_TAGGED_TEXT] + + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + row = columns[7].generator_function(row) + row = columns[8].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["political_view", "religious_belief"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert any(entity.label == "political_view" and entity.value == phrase for entity in entities) + + +def test_rules_filter_guardrail_preserves_longer_same_label_detector_overlap() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_filter_guardrail_preserve_longer_same_label", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + rule_value = "Great Health" + detector_value = f"{rule_value} and Mountain Timber" + text = f"After apprenticeships with {detector_value}, Darwin started his own shop." + detector_start = text.index(detector_value) + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": detector_value, + "label": "organization_name", + "start": detector_start, + "end": detector_start + len(detector_value), + "score": 0.99, + } + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities + assert [(entity.label, entity.value) for entity in seed_entities] == [("organization_name", detector_value)] + assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["value"] == detector_value + + row[COL_VALIDATION_DECISIONS] = { + "decisions": [ + { + "id": f"organization_name_{detector_start}_{detector_start + len(detector_value)}", + "decision": "keep", + "proposed_label": "", + "reason": "test keep", + } + ] + } + row = columns[4].generator_function(row) + row = columns[5].generator_function(row) + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + row = columns[7].generator_function(row) + row = columns[8].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["organization_name"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("organization_name", detector_value)] + + +def test_rules_filter_guardrail_does_not_shadow_validated_different_label_exact_span() -> None: + tool = load_tool( + "measurement_detection_strategies_rules_filter_guardrail_no_shadow_exact_span", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + value = "Bowdoin College" + text = f"He completed his MLS at {value}, and his early career followed." + start = text.index(value) + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": value, + "label": "university", + "start": start, + "end": start + len(value), + "score": 0.99, + } + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["label"] == "university" + + row[COL_VALIDATION_DECISIONS] = { + "decisions": [ + { + "id": f"university_{start}_{start + len(value)}", + "decision": "keep", + "proposed_label": "", + "reason": "test keep", + } + ] + } + row = columns[4].generator_function(row) + row = columns[5].generator_function(row) + validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities + assert [(entity.label, entity.value) for entity in validated_seed] == [("university", value)] + + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + row = columns[7].generator_function(row) + row = columns[8].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["organization_name", "university"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("university", value)] diff --git a/tests/tools/test_direct_detection_probe.py b/tests/tools/test_direct_detection_probe.py new file mode 100644 index 00000000..c37960d6 --- /dev/null +++ b/tests/tools/test_direct_detection_probe.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +class FakeClient: + def complete(self, request): # type: ignore[no-untyped-def] + module = sys.modules["measurement_direct_detection_probe_case"] + return module.DirectCompletion( + content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + elapsed_sec=1.25, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + +def test_direct_detection_case_uses_direct_client_and_canonical_artifacts() -> None: + tool = load_tool( + "measurement_direct_detection_probe_case", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + result = tool.run_direct_detection_case( + tool.DirectDetectionRequest( + case_id="case-1", + text="Alice met Alice.", + labels=["first_name"], + row_index=0, + prompt_mode=tool.PromptMode.compact, + ), + client=FakeClient(), + ) + + assert result.status == tool.CaseStatus.completed + assert result.elapsed_sec == 1.25 + assert result.usage == {"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14} + assert result.allowed_suggestion_count == 1 + assert result.final_entity_count == 2 + assert result.final_entity_signature_count == 2 + assert result.final_label_counts == {"first_name": 2} + assert result.artifact.final_source_counts == {"direct_llm": 2} + + +def test_finalize_suggestions_filters_labels_and_deduplicates_signature_hashes() -> None: + tool = load_tool( + "measurement_direct_detection_probe_finalize", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + artifact = tool.finalize_direct_suggestions( + text="Alice works at NVIDIA.", + suggestions=[ + {"value": "Alice", "label": "first_name"}, + {"value": "NVIDIA", "label": "organization_name"}, + {"value": "works", "label": "unsupported"}, + ], + labels=["first_name", "organization_name"], + row_index=0, + workflow_name="direct-detection", + ) + + assert artifact.final_entity_count == 2 + assert artifact.final_label_counts == {"first_name": 1, "organization_name": 1} + assert artifact.weak_api_key_shape_count == 0 + assert set(artifact.final_entity_signature_labels.values()) == {"first_name", "organization_name"} + + +def test_signature_comparison_counts_shared_baseline_only_and_direct_only_labels() -> None: + tool = load_tool( + "measurement_direct_detection_probe_comparison", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + comparison = tool.compare_signature_sets( + baseline_hashes={"shared", "baseline-only"}, + baseline_labels={"shared": "person", "baseline-only": "date"}, + direct_hashes={"shared", "direct-only"}, + direct_labels={"shared": "person", "direct-only": "city"}, + ) + + assert comparison.shared_final_entity_signature_count == 1 + assert comparison.baseline_only_final_entity_signature_count == 1 + assert comparison.direct_only_final_entity_signature_count == 1 + assert comparison.baseline_only_label_counts == {"date": 1} + assert comparison.direct_only_label_counts == {"city": 1} + + +def test_baseline_comparison_skips_rows_without_signature_hashes() -> None: + tool = load_tool( + "measurement_direct_detection_probe_missing_signatures", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + + class LocalFakeClient: + def complete(self, request): # type: ignore[no-untyped-def] + return tool.DirectCompletion( + content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + elapsed_sec=1.25, + usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, + ) + + case = tool.run_direct_detection_case( + tool.DirectDetectionRequest( + case_id="case-1", + text="Alice met Alice.", + labels=["first_name"], + row_index=0, + prompt_mode=tool.PromptMode.compact, + ), + client=LocalFakeClient(), + ) + + compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 2}) + + assert compared.comparison is None + + +def test_baseline_artifact_reader_rejects_duplicate_row_indexes(tmp_path: Path) -> None: + tool = load_tool( + "measurement_direct_detection_probe_duplicate_baseline", + REPO_ROOT / "tools/measurement/direct_detection_probe.py", + ) + baseline_path = tmp_path / "baseline.jsonl" + baseline_path.write_text('{"row_index": 0}\n{"row_index": 0}\n', encoding="utf-8") + + with pytest.raises(ValueError, match="multiple rows for row_index=0"): + tool._read_baseline_artifacts(baseline_path) diff --git a/tests/tools/test_extract_signature_deltas.py b/tests/tools/test_extract_signature_deltas.py new file mode 100644 index 00000000..5e442932 --- /dev/null +++ b/tests/tools/test_extract_signature_deltas.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _entity(value: str, label: str, start: int, end: int, source: str = "detector") -> dict[str, object]: + return {"value": value, "label": label, "start_position": start, "end_position": end, "source": source} + + +def _write_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: + parquet_file = root / "entity-detection" / "parquet-files" / "batch_00000.parquet" + parquet_file.parent.mkdir(parents=True) + pd.DataFrame([{COL_TEXT: text, COL_DETECTED_ENTITIES: {"entities": entities}}]).to_parquet(parquet_file) + artifact_path = root / "detection-artifacts.jsonl" + row = tool.build_detection_artifact_row_from_entities( + workflow_name="entity-detection", + batch_file="entity-detection/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) + return artifact_path + + +def _write_seed_only_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: + parquet_file = root / "entity-detection-seed" / "parquet-files" / "batch_00000.parquet" + parquet_file.parent.mkdir(parents=True) + pd.DataFrame([{COL_TEXT: text}]).to_parquet(parquet_file) + artifact_path = root / "detection-artifacts.jsonl" + row = tool.build_detection_artifact_row_from_entities( + workflow_name="entity-detection-seed", + batch_file="entity-detection-seed/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) + return artifact_path + + +def _case_metadata() -> dict[str, object]: + return { + "suite_id": "suite", + "workload_id": "bio", + "config_id": "config", + "case_id": "bio__config__r000", + "run_id": "bio__config__r000", + "repetition": 0, + } + + +def test_extract_signature_deltas_masks_candidate_only_context(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_artifact_case(baseline_root, analyzer, [_entity("Alice", "person", 0, 5)], "Alice met NASA") + candidate = _write_artifact_case( + candidate_root, + analyzer, + [_entity("Alice", "person", 0, 5), _entity("NASA", "organization_name", 10, 14, "augmenter")], + "Alice met NASA", + ) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.side == "candidate_only" + assert row.label == "organization_name" + assert row.source == "augmenter" + assert row.resolution == "parquet" + assert "NASA" not in row.masked_context + assert "[ORGANIZATION_NAME:" in row.masked_context + + +def test_extract_signature_deltas_recovers_guardrail_rule_context(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_rule_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas_rule", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_artifact_case(baseline_root, analyzer, [], "The applicant was born in 1990.") + candidate = _write_artifact_case(candidate_root, analyzer, [], "The applicant was born in 1990.") + rule_entity = analyzer.EntitySchema( + value="1990", label="date_of_birth", start_position=26, end_position=30, source="rule" + ) + rule_row = analyzer.build_detection_artifact_row_from_entities( + workflow_name="entity-detection", + batch_file="entity-detection/parquet-files/batch_00000.parquet", + row_index=0, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=[rule_entity], + ).model_dump() + pd.json_normalize([{**_case_metadata(), **rule_row}], sep=".").to_json(candidate, orient="records", lines=True) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.label == "date_of_birth" + assert row.source == "rule" + assert row.resolution == "rule" + assert "1990" not in row.masked_context + assert "[DATE_OF_BIRTH:" in row.masked_context + + +def test_extract_signature_deltas_uses_signature_details_when_final_parquet_is_unavailable(tmp_path: Path) -> None: + analyzer = load_tool( + "measurement_detection_artifact_detail_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + ) + tool = load_tool( + "measurement_extract_signature_deltas_details", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + ) + baseline_root = tmp_path / "baseline" + candidate_root = tmp_path / "candidate" + baseline = _write_seed_only_artifact_case(baseline_root, analyzer, [], "Alice met NASA") + candidate = _write_seed_only_artifact_case( + candidate_root, + analyzer, + [_entity("NASA", "organization_name", 10, 14, "detector")], + "Alice met NASA", + ) + + result = tool.extract_signature_deltas( + baseline, + candidate, + baseline_artifact_root=baseline_root, + candidate_artifact_root=candidate_root, + ) + + assert result.delta_count == 1 + row = result.rows[0] + assert row.label == "organization_name" + assert row.source == "detector" + assert row.resolution == "artifact_details" + assert row.start_position == 10 + assert row.end_position == 14 + assert "NASA" not in row.masked_context + assert "[ORGANIZATION_NAME:" in row.masked_context diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 2d295a64..80abba54 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib.util +import json import sys from collections.abc import Iterator from contextlib import contextmanager @@ -16,6 +17,7 @@ from pydantic import ValidationError from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT +from anonymizer.engine.constants import COL_FINAL_ENTITIES REPO_ROOT = Path(__file__).resolve().parents[2] @@ -31,6 +33,19 @@ def load_tool(module_name: str, path: Path) -> ModuleType: return module +def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[str, Any]: + return { + "base_dir": tmp_path, + "workloads": {workload.id: workload for workload in spec.workloads}, + "configs": {config.id: config for config in spec.configs}, + "raw_dir": tmp_path / "raw", + "dd_trace": tool.DDTraceMode.none, + "trace_dir": tmp_path / "traces", + "dd_parser_compat": spec.dd_parser_compat, + "artifact_path": tmp_path / "artifacts", + } + + def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") dataframe = pd.DataFrame( @@ -55,6 +70,716 @@ def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: assert (tmp_path / "tables/manifest.json").exists() +def test_benchmark_exports_detection_artifact_analysis(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_analysis", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + parquet_dir = artifact_root / "entity-detection" / "parquet-files" + parquet_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": '[{"value":"Alice","label":"first_name","start_position":0,"end_position":5}]', + "_augmented_entities": '{"entities":[{"value":"Alice","label":"first_name"}]}', + "_detected_entities": ( + '{"entities":[{"value":"Alice","label":"first_name",' + '"start_position":0,"end_position":5,"source":"detector"}]}' + ), + } + ] + ).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + output_path = tmp_path / "detection-artifacts.jsonl" + + result_path = tool.export_detection_artifact_analysis(artifact_root, output_path) + + assert result_path == output_path + rows = [pd.read_json(output_path, lines=True).iloc[0].to_dict()] + assert rows[0]["augmented_duplicate_seed_value_count"] == 1 + assert "Alice" not in output_path.read_text(encoding="utf-8") + + +def test_benchmark_exports_rules_only_synthetic_detection_artifacts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_only_synthetic_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + pd.DataFrame({"text": [f"export API_KEY={secret}"]}).to_csv(input_path, index=False) + config = tool.ConfigSpec( + id="rules-only-redact", + replace="redact", + detect={"entity_labels": ["api_key", "email", "password", "url"]}, + experimental_detection_strategy="rules_only", + ) + case = tool.BenchmarkCase( + suite_id="rules-suite", + workload_id="input", + config_id="rules-only-redact", + repetition=0, + case_id="input__rules-only-redact__r000", + ) + output_path = tmp_path / "raw" / "input__rules-only-redact__r000.detection-artifacts.jsonl" + + result = tool.export_rules_only_case_detection_artifacts( + config, + tool.AnonymizerInput(source=str(input_path), text_column="text"), + output_path, + case=case, + ) + + assert result == output_path + text = output_path.read_text(encoding="utf-8") + assert secret not in text + row = json.loads(text) + assert row["workflow_name"] == "entity-detection-rules-only" + assert row["final_entity_count"] == 1 + assert row["final_entity_signature_count"] == 1 + assert row["final_label_counts.api_key"] == 1 + assert row["final_source_counts.rule"] == 1 + assert any(key.startswith("final_entity_signature_labels.") for key in row) + + +def test_benchmark_exports_rules_covered_or_default_synthetic_artifacts_for_structured_fast_lane_labels( + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_covered_synthetic_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + pd.DataFrame({"text": [f"export API_KEY={secret}"]}).to_csv(input_path, index=False) + config = tool.ConfigSpec( + id="rules-covered-redact", + replace="redact", + detect={"entity_labels": ["api_key", "email", "password", "url"]}, + experimental_detection_strategy="rules_covered_or_default", + ) + case = tool.BenchmarkCase( + suite_id="rules-suite", + workload_id="input", + config_id="rules-covered-redact", + repetition=0, + case_id="input__rules-covered-redact__r000", + ) + output_path = tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl" + + result = tool.export_rules_only_case_detection_artifacts( + config, + tool.AnonymizerInput(source=str(input_path), text_column="text"), + output_path, + case=case, + ) + + assert result == output_path + row = json.loads(output_path.read_text(encoding="utf-8")) + assert row["workflow_name"] == "entity-detection-rules-only" + assert row["final_entity_count"] == 1 + assert row["final_label_counts.api_key"] == 1 + + +def test_benchmark_does_not_export_rules_covered_or_default_artifacts_for_contextual_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_covered_contextual_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + config = tool.ConfigSpec( + id="rules-covered-redact", + replace="redact", + detect={"entity_labels": ["api_key", "person"]}, + experimental_detection_strategy="rules_covered_or_default", + ) + case = tool.BenchmarkCase( + suite_id="rules-suite", + workload_id="input", + config_id="rules-covered-redact", + repetition=0, + case_id="input__rules-covered-redact__r000", + ) + + result = tool.export_rules_only_case_detection_artifacts( + config, + tool.AnonymizerInput(source=str(input_path), text_column="text"), + tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl", + case=case, + ) + + assert result is None + + +def test_benchmark_does_not_export_rules_covered_artifacts_for_narrow_prose_rule_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_covered_prose_rule_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Jordan worked at Acme Research Center and lived on Maple Street."]}).to_csv( + input_path, + index=False, + ) + config = tool.ConfigSpec( + id="rules-covered-redact", + replace="redact", + detect={"entity_labels": ["organization_name", "street_address"]}, + experimental_detection_strategy="rules_covered_or_default", + ) + case = tool.BenchmarkCase( + suite_id="rules-suite", + workload_id="input", + config_id="rules-covered-redact", + repetition=0, + case_id="input__rules-covered-redact__r000", + ) + + result = tool.export_rules_only_case_detection_artifacts( + config, + tool.AnonymizerInput(source=str(input_path), text_column="text"), + tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl", + case=case, + ) + + assert result is None + + +def test_benchmark_detection_artifact_analysis_ignores_stale_artifacts(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_delta", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + stale_dir = artifact_root / "entity-detection-old" / "parquet-files" + stale_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": '{"entities":[]}', + } + ] + ).to_parquet(stale_dir / "batch_00000.parquet", index=False) + snapshot = tool.snapshot_detection_artifacts(artifact_root) + fresh_dir = artifact_root / "entity-detection-new" / "parquet-files" + fresh_dir.mkdir(parents=True) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": ( + '{"entities":[{"value":"sk-test-AAAAAAAAAAAAAAAAAAAAAAAA",' + '"label":"api_key","start_position":0,"end_position":32,"source":"augmenter"}]}' + ), + } + ] + ).to_parquet(fresh_dir / "batch_00000.parquet", index=False) + output_path = tmp_path / "detection-artifacts.jsonl" + + result_path = tool.export_detection_artifact_analysis( + artifact_root, + output_path, + artifact_snapshot=snapshot, + ) + + assert result_path == output_path + rows = pd.read_json(output_path, lines=True) + assert rows["workflow_name"].tolist() == ["entity-detection-new"] + assert rows["final_entity_count"].tolist() == [1] + + +def test_benchmark_case_detection_artifact_analysis_adds_case_metadata(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_artifact_case", REPO_ROOT / "tools/measurement/run_benchmarks.py") + artifact_root = tmp_path / "artifacts" + parquet_dir = artifact_root / "entity-detection" / "parquet-files" + parquet_dir.mkdir(parents=True) + snapshot = tool.snapshot_detection_artifacts(artifact_root) + pd.DataFrame( + [ + { + "_seed_entities_json": "[]", + "_augmented_entities": '{"entities":[]}', + "_detected_entities": ( + '{"entities":[{"value":"sk-test-AAAAAAAAAAAAAAAAAAAAAAAA",' + '"label":"api_key","start_position":0,"end_position":32,"source":"detector"}]}' + ), + } + ] + ).to_parquet(parquet_dir / "batch_00000.parquet", index=False) + case = tool.BenchmarkCase( + suite_id="suite-a", + workload_id="shell", + config_id="rules", + repetition=2, + case_id="shell__rules__r002", + ) + output_path = tmp_path / "raw" / "shell__rules__r002.detection-artifacts.jsonl" + + result_path = tool.export_case_detection_artifact_analysis( + artifact_root, + output_path, + case=case, + artifact_snapshot=snapshot, + ) + + assert result_path == output_path + rows = pd.read_json(output_path, lines=True) + assert rows["suite_id"].tolist() == ["suite-a"] + assert rows["workload_id"].tolist() == ["shell"] + assert rows["config_id"].tolist() == ["rules"] + assert rows["case_id"].tolist() == ["shell__rules__r002"] + assert rows["run_id"].tolist() == ["shell__rules__r002"] + assert rows["repetition"].tolist() == [2] + assert "sk-test" not in output_path.read_text(encoding="utf-8") + + +def _stale_detection_artifact_payload() -> dict[str, Any]: + return { + "workflow_name": "entity-detection", + "batch_file": "entity-detection/parquet-files/batch_00000.parquet", + "row_index": 0, + "seed_entity_count": 1, + "seed_validation_candidate_count": 1, + "merged_validation_candidate_count": 1, + "augmented_entity_count": 0, + "final_entity_count": 1, + "augmented_duplicate_seed_value_count": 0, + "augmented_new_value_count": 0, + "augmented_new_final_value_count": 0, + "weak_api_key_shape_count": 0, + "final_entity_signature_count": 1, + "final_entity_signature_hashes": ["stale"], + "final_entity_signature_labels": {"stale": "person"}, + "weak_api_key_shape_label_counts": {}, + "final_label_counts": {"person": 1}, + "final_source_counts": {"detector": 1}, + } + + +def _final_trace_dataframe_with_rule_entity() -> pd.DataFrame: + return pd.DataFrame( + { + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": "Alice", + "label": "person", + "start_position": 0, + "end_position": 5, + "source": "detector", + }, + { + "value": "1990", + "label": "date_of_birth", + "start_position": 25, + "end_position": 29, + "source": "rule", + }, + ] + } + ] + } + ) + + +def test_benchmark_patches_detection_artifacts_from_final_trace_dataframe(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_patch_result_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + output_path = tmp_path / "raw" / "case.detection-artifacts.jsonl" + tool.write_detection_artifact_payloads([_stale_detection_artifact_payload()], output_path) + + result = tool.patch_case_detection_artifacts_from_trace_dataframe( + output_path, + _final_trace_dataframe_with_rule_entity(), + ) + + assert result == output_path + text = output_path.read_text(encoding="utf-8") + assert "Alice" not in text + assert "1990" not in text + row = json.loads(text) + assert row["final_entity_count"] == 2 + assert row["final_entity_signature_count"] == 2 + assert row["final_label_counts.person"] == 1 + assert row["final_label_counts.date_of_birth"] == 1 + assert row["final_source_counts.detector"] == 1 + assert row["final_source_counts.rule"] == 1 + assert any(key.startswith("final_entity_signature_details.") for key in row) + assert "final_entity_signature_labels.stale" not in row + + +def test_rules_covered_or_default_detection_artifacts_use_final_trace_dataframe(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_covered_trace_artifacts", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + case = tool.BenchmarkCase( + suite_id="suite-a", + workload_id="input", + config_id="rules-covered", + repetition=0, + case_id="input__rules-covered__r000", + ) + config = tool.ConfigSpec( + id="rules-covered", + replace="redact", + detect={"entity_labels": ["api_key", "password"]}, + experimental_detection_strategy="rules_covered_or_default", + ) + trace_dataframe = pd.DataFrame( + { + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", + "label": "api_key", + "start_position": 6, + "end_position": 38, + "source": "rule", + } + ] + }, + { + "entities": [ + { + "value": "SecretNoRule123!", + "label": "password", + "start_position": 14, + "end_position": 30, + "source": "detector", + } + ] + }, + ] + } + ) + paths = tool._CaseRunPaths( + raw_path=tmp_path / "raw" / "case.jsonl", + artifact_output_path=tmp_path / "raw" / "case.detection-artifacts.jsonl", + trace_path=None, + artifact_snapshot={}, + ) + tool.write_detection_artifact_payloads([_stale_detection_artifact_payload()], paths.artifact_output_path) + contexts = {"artifact_path": tmp_path / "artifacts"} + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + execution = tool._CaseExecution( + input_data=tool.AnonymizerInput(source=str(input_path), text_column="text"), + trace_dataframe=trace_dataframe, + ) + + result = tool._case_detection_artifact_path( + contexts, + paths, + case=case, + config=config, + execution=execution, + ) + + assert result == paths.artifact_output_path + rows = [json.loads(line) for line in paths.artifact_output_path.read_text(encoding="utf-8").splitlines()] + assert len(rows) == 2 + assert [row["workflow_name"] for row in rows] == [ + "entity-detection-final-trace", + "entity-detection-final-trace", + ] + assert [row["row_index"] for row in rows] == [0, 1] + assert [row["final_source_counts.rule"] for row in rows] == [1.0, None] + assert [row["final_source_counts.detector"] for row in rows] == [None, 1.0] + assert "sk-test" not in paths.artifact_output_path.read_text(encoding="utf-8") + assert "SecretNoRule123!" not in paths.artifact_output_path.read_text(encoding="utf-8") + + +def test_run_suite_records_detection_artifact_analysis_path( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_run_suite_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec = tool.BenchmarkSpec( + suite_id="artifact-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + output_dir = tmp_path / "output" + output_dir.mkdir() + artifact_path = output_dir / "artifacts" + analysis_path = output_dir / "detection-artifacts.jsonl" + + class FakeAnonymizer: + def __init__(self, **kwargs: Any) -> None: + assert kwargs["artifact_path"] == artifact_path + + def fake_run_case(case: Any, *_args: Any, **_kwargs: Any) -> Any: + assert _kwargs["export_detection_artifacts"] is True + raw_path = output_dir / "raw" / f"{case.case_id}.jsonl" + raw_path.parent.mkdir() + raw_path.write_text('{"record_type":"run","run_id":"case"}\n', encoding="utf-8") + artifact_output_path = output_dir / "raw" / f"{case.case_id}.detection-artifacts.jsonl" + artifact_output_path.write_text( + '{"case_id":"input__redact__r000","workflow_name":"entity-detection"}\n', + encoding="utf-8", + ) + return case.model_copy( + update={ + "status": tool.CaseStatus.completed, + "measurement_path": str(raw_path), + "detection_artifact_path": str(artifact_output_path), + } + ) + + monkeypatch.setattr(tool, "Anonymizer", FakeAnonymizer) + monkeypatch.setattr(tool, "_run_case", fake_run_case) + monkeypatch.setattr(tool, "export_measurement_tables", lambda *_args: output_dir / "tables") + + result = tool.run_suite( + spec, + spec_path=tmp_path / "suite.yaml", + output_dir=output_dir, + export=True, + fail_fast=False, + dd_trace=tool.DDTraceMode.none, + trace_dir=None, + ) + + assert result.detection_artifact_analysis_path == str(analysis_path) + assert analysis_path.read_text(encoding="utf-8") == ( + '{"case_id":"input__redact__r000","workflow_name":"entity-detection"}\n' + ) + summary = (output_dir / "summary.json").read_text(encoding="utf-8") + assert "detection_artifact_analysis_path" in summary + assert "detection_artifact_path" in summary + + +def test_run_suite_skips_detection_artifact_analysis_when_export_disabled( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_run_suite_no_export_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + spec = tool.BenchmarkSpec( + suite_id="artifact-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + output_dir = tmp_path / "output" + output_dir.mkdir() + + class FakeAnonymizer: + def __init__(self, **_kwargs: Any) -> None: + pass + + def fake_run_case(case: Any, *_args: Any, **_kwargs: Any) -> Any: + assert _kwargs["export_detection_artifacts"] is False + raw_path = output_dir / "raw" / f"{case.case_id}.jsonl" + raw_path.parent.mkdir() + raw_path.write_text('{"record_type":"run","run_id":"case"}\n', encoding="utf-8") + return case.model_copy(update={"status": tool.CaseStatus.completed, "measurement_path": str(raw_path)}) + + monkeypatch.setattr(tool, "Anonymizer", FakeAnonymizer) + monkeypatch.setattr(tool, "_run_case", fake_run_case) + monkeypatch.setattr( + tool, + "combine_detection_artifact_analysis", + lambda *_args: pytest.fail("artifact analysis should not be combined"), + ) + + result = tool.run_suite( + spec, + spec_path=tmp_path / "suite.yaml", + output_dir=output_dir, + export=False, + fail_fast=False, + dd_trace=tool.DDTraceMode.none, + trace_dir=None, + ) + + assert result.detection_artifact_analysis_path is None + assert not (output_dir / "detection-artifacts.jsonl").exists() + + +def test_benchmark_case_retries_transient_errors_and_records_attempts( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_case_retry_success", REPO_ROOT / "tools/measurement/run_benchmarks.py") + attempts: list[Path] = [] + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=1, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + pd.DataFrame({"text": ["Alice"]}).to_csv(tmp_path / "input.csv", index=False) + + def fake_execute_case(*_args: Any, raw_path: Path, **_kwargs: Any) -> Any: + attempts.append(raw_path) + if len(attempts) == 1: + raise RuntimeError("transient provider health check failure") + raw_path.parent.mkdir(parents=True, exist_ok=True) + raw_path.write_text('{"record_type":"run"}\n', encoding="utf-8") + return tool._CaseExecution( + input_data=tool.AnonymizerInput(source=str(tmp_path / "input.csv"), text_column="text") + ) + + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + + result = tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=False, + export_detection_artifacts=False, + ) + + assert result.status == tool.CaseStatus.completed + assert result.attempt_count == 2 + assert result.attempt_errors == ["transient provider health check failure"] + assert attempts == [tmp_path / "raw" / "input__redact__r000.jsonl"] * 2 + + +def test_benchmark_case_records_persistent_retry_failures( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_case_retry_failure", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=1, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + + attempts = 0 + errors: list[str] = [] + + def fake_execute_case(*_args: Any, **_kwargs: Any) -> Any: + nonlocal attempts + attempts += 1 + raise RuntimeError(f"provider unavailable #{attempts}") + + def capture_error(case: Any, *, error: Exception, **kwargs: Any) -> Any: + errors.append(str(error)) + return original_run_case_error(case, error=error, **kwargs) + + original_run_case_error = tool._run_case_error + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + monkeypatch.setattr(tool, "_run_case_error", capture_error) + + result = tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=False, + export_detection_artifacts=False, + ) + + assert result.status == tool.CaseStatus.error + assert result.error == "provider unavailable #2" + assert result.attempt_count == 2 + assert result.attempt_errors == ["provider unavailable #1", "provider unavailable #2"] + assert errors == ["provider unavailable #2"] + assert attempts == 2 + + +def test_benchmark_case_fail_fast_skips_retries( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_case_retry_fail_fast", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + spec = tool.BenchmarkSpec( + suite_id="retry-suite", + case_retries=3, + case_retry_backoff_sec=0, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + case = tool.BenchmarkCase( + suite_id="retry-suite", + workload_id="input", + config_id="redact", + repetition=0, + case_id="input__redact__r000", + ) + attempts = 0 + + def fake_execute_case(*_args: Any, **_kwargs: Any) -> Any: + nonlocal attempts + attempts += 1 + raise RuntimeError("fail fast") + + monkeypatch.setattr(tool, "_execute_case", fake_execute_case) + + with pytest.raises(RuntimeError, match="fail fast"): + tool._run_case( + case, + spec, + contexts=_minimal_case_contexts(tool, spec, tmp_path), + anonymizer=object(), + fail_fast=True, + export_detection_artifacts=False, + ) + + assert attempts == 1 + + +def test_combine_detection_artifact_analysis_separates_jsonl_chunks(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_combine_artifact_newlines", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + first = tmp_path / "first.jsonl" + second = tmp_path / "second.jsonl" + first.write_text('{"case_id":"one"}', encoding="utf-8") + second.write_text('{"case_id":"two"}\n', encoding="utf-8") + destination = tmp_path / "combined.jsonl" + cases = [ + tool.BenchmarkCase( + suite_id="suite", + workload_id="input", + config_id="first", + repetition=0, + case_id="first", + detection_artifact_path=str(first), + ), + tool.BenchmarkCase( + suite_id="suite", + workload_id="input", + config_id="second", + repetition=0, + case_id="second", + detection_artifact_path=str(second), + ), + ] + + result = tool.combine_detection_artifact_analysis(cases, destination) + + assert result == destination + assert destination.read_text(encoding="utf-8") == '{"case_id":"one"}\n{"case_id":"two"}\n' + + def test_benchmark_spec_validates_matrix_references(tmp_path: Path) -> None: tool = load_tool("measurement_benchmark_tool", REPO_ROOT / "tools/measurement/run_benchmarks.py") spec_path = tmp_path / "suite.yaml" @@ -163,6 +888,60 @@ def test_benchmark_preflight_rejects_missing_text_column(tmp_path: Path) -> None tool.preflight_suite(spec, spec_path=spec_path) +def test_build_input_materializes_sliced_csv_workload(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_sliced_input", REPO_ROOT / "tools/measurement/run_benchmarks.py") + input_path = tmp_path / "input.csv" + pd.DataFrame({"id": ["a", "b", "c", "d"], "text": ["row-a", "row-b", "row-c", "row-d"]}).to_csv( + input_path, index=False + ) + workload = tool.WorkloadSpec( + id="slice", + source="input.csv", + text_column="text", + id_column="id", + row_offset=1, + row_limit=2, + ) + + anonymizer_input = tool.build_input( + workload, + tmp_path, + slice_dir=tmp_path / "slices", + case_id="slice__redact__r000", + ) + + assert anonymizer_input.text_column == "text" + assert anonymizer_input.id_column == "id" + assert Path(anonymizer_input.source) != input_path + sliced = pd.read_csv(anonymizer_input.source) + assert sliced.to_dict("records") == [ + {"id": "b", "text": "row-b"}, + {"id": "c", "text": "row-c"}, + ] + + +def test_benchmark_preflight_rejects_sliced_remote_workload(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_sliced_remote", REPO_ROOT / "tools/measurement/run_benchmarks.py") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-remote-slice +workloads: + - id: remote + source: s3://bucket/input.csv + row_limit: 2 +configs: + - id: redact + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="row slicing requires a local workload source"): + tool.preflight_suite(spec, spec_path=spec_path) + + def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) -> None: tool = load_tool("measurement_benchmark_tool_preflight_models", REPO_ROOT / "tools/measurement/run_benchmarks.py") input_path = tmp_path / "input.csv" @@ -201,18 +980,75 @@ def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) tool.preflight_suite(spec, spec_path=spec_path) -def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: +def test_benchmark_preflight_rejects_local_structured_substitute_for_contextual_labels(tmp_path: Path) -> None: tool = load_tool( - "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" + "measurement_benchmark_tool_local_substitute_contextual_labels", + REPO_ROOT / "tools/measurement/run_benchmarks.py", ) input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) - provider_path = tmp_path / "providers.yaml" - provider_path.write_text("not_providers: []\n", encoding="utf-8") + pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ -suite_id: bad-provider-suite +suite_id: local-substitute-suite +workloads: + - id: input + source: input.csv +configs: + - id: local-substitute + detect: + entity_labels: [api_key, person] + replace: substitute + experimental_replacement_strategy: local_structured_substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="unsupported local structured substitute labels: person"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_local_structured_substitute_supported_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_local_substitute_supported_labels", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: local-substitute-suite +workloads: + - id: input + source: input.csv +configs: + - id: local-substitute + detect: + entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] + replace: substitute + experimental_replacement_strategy: local_structured_substitute +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + provider_path = tmp_path / "providers.yaml" + provider_path.write_text("not_providers: []\n", encoding="utf-8") + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-provider-suite model_providers: providers.yaml workloads: - id: biography @@ -309,6 +1145,7 @@ def run(self, *, config: Any, data: Any) -> None: spec=spec, base_dir=tmp_path, dd_trace=tool.DDTraceMode.all_messages, + dd_parser_compat=tool.DDParserCompatMode.none, ) assert len(captured) == 1 @@ -316,3 +1153,452 @@ def run(self, *, config: Any, data: Any) -> None: assert captured[0].dd_trace_path == trace_path assert captured[0].streaming is True assert captured[0].keep_records is False + + +def test_benchmark_config_accepts_experimental_detection_strategy() -> None: + tool = load_tool( + "measurement_benchmark_tool_detection_strategy_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + + config = tool.ConfigSpec( + id="rules-only", + replace="redact", + experimental_detection_strategy="rules_only", + ) + + assert config.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.rules_only + anonymizer_config = tool.build_anonymizer_config(config) + assert not hasattr(anonymizer_config.detect, "experimental_detection_strategy") + + detector_only = tool.ConfigSpec( + id="detector-only", + replace="redact", + experimental_detection_strategy="detector_only", + ) + + assert detector_only.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.detector_only + + rules_covered = tool.ConfigSpec( + id="rules-covered", + replace="redact", + experimental_detection_strategy="rules_covered_or_default", + ) + + assert rules_covered.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.rules_covered_or_default + + native_rules_router = tool.ConfigSpec( + id="native-rules-router", + replace="redact", + experimental_detection_strategy="native_rules_router", + ) + + assert native_rules_router.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.native_rules_router + + native_candidate_validate = tool.ConfigSpec( + id="native-candidate-validate", + replace="redact", + experimental_detection_strategy="native_candidate_validate_no_augment", + ) + + assert ( + native_candidate_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment + ) + + detector_native_validate = tool.ConfigSpec( + id="detector-native-validate", + replace="redact", + experimental_detection_strategy="detector_native_validate_no_augment", + ) + + assert ( + detector_native_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment + ) + + detector_native_augment = tool.ConfigSpec( + id="detector-native-augment", + replace="redact", + experimental_detection_strategy="detector_native_validate_native_augment", + ) + + assert ( + detector_native_augment.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment + ) + + gliner_native_validate = tool.ConfigSpec( + id="gliner-native-validate", + replace="redact", + experimental_detection_strategy="gliner_native_validate_no_augment", + ) + + assert ( + gliner_native_validate.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment + ) + + gliner_native_augment = tool.ConfigSpec( + id="gliner-native-augment", + replace="redact", + experimental_detection_strategy="gliner_native_validate_native_augment", + ) + + assert ( + gliner_native_augment.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment + ) + + native_single_pass = tool.ConfigSpec( + id="native-single-pass", + replace="redact", + experimental_detection_strategy="native_single_pass", + ) + + assert native_single_pass.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.native_single_pass + + native_single_pass_recall = tool.ConfigSpec( + id="native-single-pass-recall", + replace="redact", + experimental_detection_strategy="native_single_pass_recall", + ) + + assert ( + native_single_pass_recall.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_recall + ) + + native_single_pass_values = tool.ConfigSpec( + id="native-single-pass-values", + replace="redact", + experimental_detection_strategy="native_single_pass_values", + ) + + assert ( + native_single_pass_values.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_values + ) + + native_single_pass_values_recall = tool.ConfigSpec( + id="native-single-pass-values-recall", + replace="redact", + experimental_detection_strategy="native_single_pass_values_recall", + ) + + assert ( + native_single_pass_values_recall.experimental_detection_strategy + == tool.ExperimentalDetectionStrategy.native_single_pass_values_recall + ) + + +def test_benchmark_config_accepts_experimental_rule_labels() -> None: + tool = load_tool("measurement_benchmark_tool_rule_labels_config", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + config = tool.ConfigSpec( + id="rules-guardrail", + replace="redact", + experimental_detection_strategy="rules_guardrail", + experimental_rule_labels=["street_address"], + ) + + assert config.experimental_rule_labels == ["street_address"] + anonymizer_config = tool.build_anonymizer_config(config) + assert not hasattr(anonymizer_config.detect, "experimental_rule_labels") + + detector_only = tool.ConfigSpec( + id="rules-guardrail-detector-only", + replace="redact", + experimental_detection_strategy="rules_guardrail_detector_only", + experimental_rule_labels=["api_key"], + ) + + assert detector_only.experimental_rule_labels == ["api_key"] + + +def test_benchmark_spec_accepts_dd_parser_compat() -> None: + tool = load_tool( + "measurement_benchmark_tool_dd_parser_compat_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + + spec = tool.BenchmarkSpec( + suite_id="raw-json-suite", + dd_parser_compat="raw_json", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + + assert spec.dd_parser_compat == tool.DDParserCompatMode.raw_json + + +def test_benchmark_preflight_rejects_rules_only_without_explicit_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_only_without_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rules-only-no-labels +workloads: + - id: input + source: input.csv +configs: + - id: rules-only-redact + experimental_detection_strategy: rules_only + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="requires explicit detect.entity_labels"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_rules_only_unsupported_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_only_unsupported_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rules-only-unsupported-labels +workloads: + - id: input + source: input.csv +configs: + - id: rules-only-redact + experimental_detection_strategy: rules_only + detect: + entity_labels: [api_key, person] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_rules_only_supported_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_only_supported_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rules-only-supported-labels +workloads: + - id: input + source: input.csv +configs: + - id: rules-only-redact + experimental_detection_strategy: rules_only + detect: + entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_rules_covered_or_default_contextual_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rules_covered_contextual_labels", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rules-covered-contextual-labels +workloads: + - id: input + source: input.csv +configs: + - id: rules-covered-redact + experimental_detection_strategy: rules_covered_or_default + detect: + entity_labels: [api_key, person] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_experimental_rule_labels_for_non_rule_strategy(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rule_labels_non_rule_strategy", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rule-labels-non-rule-strategy +workloads: + - id: input + source: input.csv +configs: + - id: redact + experimental_detection_strategy: prose_augment_focus + experimental_rule_labels: [street_address] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="experimental_rule_labels requires a rule-backed strategy"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_accepts_experimental_rule_labels_for_compact_rule_guardrail( + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_rule_labels_compact_rule_guardrail", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice lives on West Roberts Drive."]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rule-labels-compact-rule-guardrail +workloads: + - id: input + source: input.csv +configs: + - id: redact + experimental_detection_strategy: rules_guardrail_compact_validation + experimental_rule_labels: [street_address] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_preflight_rejects_unsupported_experimental_rule_labels(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_rule_labels_unsupported", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: rule-labels-unsupported +workloads: + - id: input + source: input.csv +configs: + - id: redact + experimental_detection_strategy: rules_guardrail + experimental_rule_labels: [person] + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="unsupported experimental_rule_labels.*person"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_case_enters_experimental_detection_strategy_context( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_detection_strategy_case", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + captured_measurements: list[Any] = [] + captured_parser_compat: list[Any] = [] + captured_strategies: list[Any] = [] + captured_rule_labels: list[Any] = [] + + @contextmanager + def fake_measurement_session(config: Any) -> Iterator[None]: + captured_measurements.append(config) + yield None + + @contextmanager + def fake_detection_strategy_context(strategy: Any, *, rule_labels: list[str] | None = None) -> Iterator[None]: + captured_strategies.append(strategy) + captured_rule_labels.append(rule_labels) + yield None + + @contextmanager + def fake_dd_parser_compat_context(mode: Any) -> Iterator[None]: + captured_parser_compat.append(mode) + yield None + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> None: + assert config.replace is not None + assert data.text_column == "text" + + monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) + monkeypatch.setattr(tool, "dd_parser_compat_context", fake_dd_parser_compat_context) + monkeypatch.setattr(tool, "experimental_detection_strategy_context", fake_detection_strategy_context) + + spec = tool.BenchmarkSpec( + suite_id="rules-suite", + dd_parser_compat="raw_json", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[ + tool.ConfigSpec( + id="rules-only-redact", + replace="redact", + experimental_detection_strategy="rules_only", + experimental_rule_labels=["api_key"], + ) + ], + ) + pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(tmp_path / "input.csv", index=False) + case = tool.BenchmarkCase( + suite_id="rules-suite", + workload_id="input", + config_id="rules-only-redact", + repetition=0, + case_id="input__rules-only-redact__r000", + ) + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=tmp_path / "raw" / "input__rules-only-redact__r000.jsonl", + trace_path=None, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.none, + dd_parser_compat=spec.dd_parser_compat, + ) + + assert captured_parser_compat == [tool.DDParserCompatMode.raw_json] + assert captured_strategies == [tool.ExperimentalDetectionStrategy.rules_only] + assert captured_rule_labels == [["api_key"]] + assert captured_measurements[0].run_tags["experimental_detection_strategy"] == "rules_only" + assert captured_measurements[0].run_tags["experimental_rule_labels"] == ["api_key"] + assert captured_measurements[0].run_tags["dd_parser_compat"] == "raw_json" diff --git a/tests/tools/test_replacement_strategies.py b/tests/tools/test_replacement_strategies.py new file mode 100644 index 00000000..3b1ce86b --- /dev/null +++ b/tests/tools/test_replacement_strategies.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import Mock + +import pandas as pd + +from anonymizer.config.replace_strategies import Substitute +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_FINAL_ENTITIES, COL_REPLACED_TEXT, COL_TEXT +from anonymizer.engine.ndd.model_loader import load_default_model_selection +from anonymizer.engine.replace.llm_replace_workflow import LlmReplaceWorkflow +from anonymizer.engine.replace.replace_runner import ReplacementWorkflow + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_local_structured_substitute_context_bypasses_dd_and_restores_method() -> None: + tool = load_tool( + "measurement_replacement_strategies_local_substitute", + REPO_ROOT / "tools/measurement/replacement_strategies.py", + ) + original_method = LlmReplaceWorkflow.generate_map_only + secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" + adapter = Mock() + runner = ReplacementWorkflow(llm_workflow=LlmReplaceWorkflow(adapter=adapter)) + dataframe = pd.DataFrame( + { + COL_TEXT: [f"export API_KEY={secret}"], + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": secret, + "label": "api_key", + "start_position": len("export API_KEY="), + "end_position": len("export API_KEY=") + len(secret), + } + ] + } + ], + COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": secret, "labels": ["api_key"]}]}], + } + ) + + with tool.experimental_replacement_strategy_context( + tool.ExperimentalReplacementStrategy.local_structured_substitute + ): + result = runner.run( + dataframe, + replace_method=Substitute(), + model_configs=[], + selected_models=load_default_model_selection().replace, + ) + + adapter.run_workflow.assert_not_called() + assert LlmReplaceWorkflow.generate_map_only is original_method + replaced = result.dataframe[COL_REPLACED_TEXT].iloc[0] + assert secret not in replaced + assert "sk-test-" in replaced diff --git a/tests/tools/test_replay_replacement_strategies.py b/tests/tools/test_replay_replacement_strategies.py new file mode 100644 index 00000000..482d1919 --- /dev/null +++ b/tests/tools/test_replay_replacement_strategies.py @@ -0,0 +1,430 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pandas as pd + +from anonymizer.engine.constants import ( + COL_FINAL_ENTITIES, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, + COL_TEXT, +) + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def test_replacement_row_metrics_counts_sanitized_replay_failures() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_metrics", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + row = pd.Series( + { + COL_FINAL_ENTITIES: { + "entities": [ + {"value": "secret-a", "label": "api_key", "start_position": 0, "end_position": 8}, + {"value": "secret-b", "label": "password", "start_position": 9, "end_position": 17}, + {"value": "1234", "label": "pin", "start_position": 18, "end_position": 22}, + ] + }, + COL_REPLACEMENT_MAP: { + "replacements": [ + {"original": "secret-a", "label": "api_key", "synthetic": "secret-b"}, + {"original": "secret-b", "label": "password", "synthetic": "Synthetic!123"}, + ] + }, + COL_REPLACED_TEXT: "the leaked pin is 1234", + } + ) + + metrics = tool.replacement_row_metrics(row) + + assert metrics.final_entity_count == 3 + assert metrics.replacement_count == 2 + assert metrics.missing_count == 1 + assert metrics.missing_labels == {"pin": 1} + assert metrics.collision_count == 1 + assert metrics.collision_labels == {"api_key": 1} + assert metrics.leak_count == 1 + assert metrics.leak_labels == {"pin": 1} + + +def test_summarize_replacement_dataframe_counts_sources_and_totals() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_summary", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + dataframe = pd.DataFrame( + [ + { + COL_FINAL_ENTITIES: { + "entities": [ + {"value": "alice@example.com", "label": "email", "start_position": 0, "end_position": 17} + ] + }, + COL_REPLACEMENT_MAP: { + "replacements": [ + { + "original": "alice@example.com", + "label": "email", + "synthetic": "user-123@example.invalid", + } + ] + }, + COL_REPLACED_TEXT: "user-123@example.invalid", + COL_REPLACEMENT_MAP_SOURCE: "local_structured", + } + ] + ) + + summary = tool.summarize_replacement_dataframe( + dataframe, + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + elapsed_sec=0.01, + ) + + assert summary.status == tool.ReplayStatus.completed + assert summary.final_entity_count == 1 + assert summary.replacement_count == 1 + assert summary.missing_count == 0 + assert summary.collision_count == 0 + assert summary.leak_count == 0 + assert summary.replacement_map_sources == {"local_structured": 1} + + +def test_aggregate_strategy_summaries_sums_repeated_backend_runs() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_aggregate", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + + summary = tool.aggregate_strategy_summaries( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + summaries=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=1.5, + row_count=5, + final_entity_count=10, + replacement_count=9, + missing_count=1, + leak_count=1, + missing_labels={"api_key": 1}, + leak_labels={"api_key": 1}, + replacement_map_sources={"llm": 5}, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.error, + elapsed_sec=2.5, + row_count=5, + final_entity_count=10, + replacement_count=10, + collision_count=1, + collision_labels={"password": 1}, + replacement_map_sources={"llm": 4}, + error="provider failed", + ), + ], + ) + + assert summary.status == tool.ReplayStatus.error + assert summary.repetition_count == 2 + assert summary.elapsed_sec == 4.0 + assert summary.row_count == 10 + assert summary.final_entity_count == 20 + assert summary.replacement_count == 19 + assert summary.missing_count == 1 + assert summary.collision_count == 1 + assert summary.leak_count == 1 + assert summary.missing_labels == {"api_key": 1} + assert summary.collision_labels == {"password": 1} + assert summary.leak_labels == {"api_key": 1} + assert summary.replacement_map_sources == {"llm": 9} + assert summary.error == "provider failed" + + +def test_replay_comparison_row_marks_fast_complete_local_replay_viable() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_row", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/structured_identifiers.csv", + text_column="text", + labels=["api_key"], + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=8.8, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=6.2, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result) + + assert row.workload_id == "structured_identifiers" + assert row.baseline_replacement_strategy == "default" + assert row.candidate_replacement_strategy == "local_structured_substitute" + assert row.pipeline_elapsed_sec_delta < 0 + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "pass" + assert row.safety_verdict == "pass" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "candidate_viable" + assert row.flags == [] + + +def test_replay_comparison_row_rejects_missing_local_replacements() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_missing", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/structured_identifiers.csv", + text_column="text", + labels=["api_key"], + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=8.8, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=6.2, + row_count=1, + final_entity_count=2, + replacement_count=2, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=1, + final_entity_count=2, + replacement_count=1, + missing_count=1, + ), + ], + ) + + row = tool.replay_comparison_row(result) + + assert row.candidate_replacement_missing_final_entity_count == 1 + assert "candidate_replacement_missing_final_entity" in row.flags + assert "replacement_count_loss" in row.flags + assert row.value_protection_verdict == "fail" + assert row.signature_parity_verdict == "fail" + assert row.safety_verdict == "fail" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "reject" + + +def test_replay_comparison_row_reviews_duplicate_local_synthetic_values() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_duplicates", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/biographies.csv", + text_column="biography", + labels=["organization_name", "religious_belief"], + nrows=5, + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=18.8, + detected_final_entity_count=10, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=7.9, + row_count=5, + final_entity_count=10, + replacement_count=10, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=5, + final_entity_count=10, + replacement_count=10, + duplicate_synthetic_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") + + assert row.candidate_duplicate_synthetic_replacement_count == 2 + assert "candidate_duplicate_synthetic_replacement" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_replay_comparison_row_reviews_candidate_replacement_count_gain_over_flawed_baseline() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_comparison_baseline_flaw", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + result = tool.ReplacementReplayResult( + input_path="/tmp/biographies.csv", + text_column="biography", + labels=["organization_name"], + nrows=5, + dd_parser_compat=tool.DDParserCompatMode.raw_json, + detect_elapsed_sec=12.9, + detected_final_entity_count=2, + strategies=[ + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.dd_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=7.6, + row_count=5, + final_entity_count=2, + replacement_count=1, + missing_count=1, + leak_count=1, + missing_labels={"organization_name": 1}, + leak_labels={"organization_name": 1}, + ), + tool.ReplacementReplaySummary( + strategy=tool.ReplacementReplayStrategy.local_structured_substitute, + status=tool.ReplayStatus.completed, + elapsed_sec=0.003, + row_count=5, + final_entity_count=2, + replacement_count=2, + ), + ], + ) + + row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") + + assert row.replacement_count_delta == 1 + assert row.replacement_missing_final_entity_count_delta == -1 + assert row.baseline_replacement_missing_final_entity_label_counts == {"organization_name": 1} + assert row.candidate_replacement_missing_final_entity_label_counts == {} + assert row.baseline_original_value_leak_label_counts == {"organization_name": 1} + assert row.candidate_original_value_leak_label_counts == {} + assert "baseline_replacement_missing_final_entity" in row.flags + assert "baseline_original_value_leak" in row.flags + assert "candidate_covers_baseline_replacement_missing_final_entity" in row.flags + assert "candidate_covers_baseline_original_value_leak" in row.flags + assert "replacement_count_delta" in row.flags + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.safety_verdict == "review" + assert row.performance_verdict == "improved" + assert row.candidate_verdict == "review" + + +def test_strip_replacement_columns_removes_prior_strategy_outputs() -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_strip", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + dataframe = pd.DataFrame( + { + "text": ["hello"], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACEMENT_MAP_SOURCE: ["redact"], + COL_REPLACED_TEXT: ["hello"], + } + ) + + stripped = tool.strip_replacement_columns(dataframe) + + assert list(stripped.columns) == ["text"] + + +def test_build_replay_dataframe_uses_preview_when_nrows_is_set(tmp_path: Path) -> None: + tool = load_tool( + "measurement_replay_replacement_strategies_nrows", + REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", + ) + trace_dataframe = pd.DataFrame( + { + COL_TEXT: ["born on 1988-11-21"], + COL_FINAL_ENTITIES: [ + { + "entities": [ + { + "value": "1988-11-21", + "label": "date_of_birth", + "start_position": 8, + "end_position": 18, + } + ] + } + ], + COL_REPLACEMENT_MAP: [{"replacements": []}], + COL_REPLACED_TEXT: ["born on [REDACTED_DATE_OF_BIRTH]"], + } + ) + + class StubAnonymizer: + preview_num_records: int | None = None + + def run(self, **_kwargs: object) -> object: + raise AssertionError("run should not be used for row-limited replay") + + def preview(self, **kwargs: object) -> object: + self.preview_num_records = kwargs["num_records"] # type: ignore[assignment] + return SimpleNamespace(trace_dataframe=trace_dataframe, resolved_text_column="text") + + anonymizer = StubAnonymizer() + source = tmp_path / "multiline.csv" + source.write_text('text\n"born on 1988-11-21"\n', encoding="utf-8") + + _elapsed, replay_df = tool.build_replay_dataframe( + anonymizer, + source=source, + text_column="text", + labels=["date_of_birth"], + nrows=1, + dd_parser_compat=tool.DDParserCompatMode.none, + ) + + assert anonymizer.preview_num_records == 1 + assert COL_REPLACEMENT_MAP not in replay_df.columns + assert COL_REPLACED_TEXT not in replay_df.columns + assert replay_df[COL_FINAL_ENTITIES].iloc[0]["entities"][0]["value"] == "1988-11-21" diff --git a/tests/tools/test_screen_strategy_comparisons.py b/tests/tools/test_screen_strategy_comparisons.py new file mode 100644 index 00000000..95d1c3f2 --- /dev/null +++ b/tests/tools/test_screen_strategy_comparisons.py @@ -0,0 +1,1046 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + analysis_dir = tmp_path / "analysis" + analysis_dir.mkdir() + pd.DataFrame( + [ + { + "workload_id": "shell-3", + "baseline_config_id": "default", + "candidate_config_id": "rules-only", + "baseline_strategy": "default", + "candidate_strategy": "rules_only", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "baseline_case_count": 3, + "candidate_case_count": 3, + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.0, + "observed_total_requests_delta": -12, + "observed_total_tokens_delta": -11000, + "augmented_entity_count_delta": -3, + "augmented_new_final_value_count_delta": 0, + "baseline_only_final_entity_signature_count": 0, + "baseline_stable_candidate_unstable_final_entity_signature_count": 0, + "baseline_stable_final_entity_signature_count": 12, + "candidate_stable_final_entity_signature_count": 12, + "shared_stable_final_entity_signature_count": 12, + "flags": '["candidate_skips_llm_validation"]', + }, + { + "workload_id": "legal-2", + "baseline_config_id": "default", + "candidate_config_id": "no-augment", + "baseline_strategy": "default", + "candidate_strategy": "no_augment", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -10.0, + "observed_total_requests_delta": -2, + "observed_total_tokens_delta": -10000, + "augmented_entity_count_delta": -5.5, + "augmented_new_final_value_count_delta": -2, + "baseline_only_final_entity_signature_count": 2, + "baseline_stable_candidate_unstable_final_entity_signature_count": 1, + "baseline_only_final_entity_signature_label_counts.first_name": 2, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, + "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', + }, + ] + ).to_csv(analysis_dir / "default-vs-candidates.csv", index=False) + pd.DataFrame([{"workload_id": "shell-3", "config_id": "default"}]).to_csv( + analysis_dir / "group_analysis.csv", + index=False, + ) + pd.DataFrame( + [ + { + "source_path": "analysis/default-vs-candidates.csv", + "workload_id": "shell-3", + "baseline_config_id": "default", + "candidate_config_id": "rules-only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + } + ] + ).to_csv(analysis_dir / "strategy-screen.csv", index=False) + (analysis_dir / "empty.csv").write_text("", encoding="utf-8") + + result = tool.screen_comparison_paths([analysis_dir]) + + assert result.scanned_file_count == 4 + assert result.comparison_file_count == 1 + assert result.row_count == 2 + assert result.summary.candidate_verdict_counts == {"reject": 1, "review": 1} + assert result.summary.review_count == 1 + assert result.summary.reject_count == 1 + assert result.summary.viable_count == 0 + legal = next(row for row in result.rows if row.workload_id == "legal-2") + assert legal.workload_family == "legal" + assert legal.flags == ["entity_signature_loss", "stable_entity_signature_loss"] + assert legal.baseline_only_label_counts == {"first_name": 2} + assert legal.stable_lost_label_counts == {"date": 1} + assert legal.augmented_new_final_value_count_delta == -2 + shell = next(row for row in result.rows if row.workload_id == "shell-3") + assert shell.baseline_replacement_strategy == "default" + assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.baseline_case_count == 3 + assert shell.candidate_case_count == 3 + assert shell.shared_stable_final_entity_signature_count == 12 + rules_local = next( + group + for group in result.groups + if group.group_key == "strategy:rules_only|replacement:local_structured_substitute" + ) + assert rules_local.candidate_replacement_strategy == "local_structured_substitute" + assert rules_local.row_count == 1 + no_augment = next(group for group in result.groups if group.group_key == "strategy:no_augment") + assert no_augment.row_count == 1 + assert no_augment.reject_count == 1 + assert no_augment.has_conflicting_verdicts is False + + +def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_export", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + rows = [ + tool.ScreenRow( + source_path="analysis/default-vs-rules.csv", + workload_id="shell", + baseline_config_id="default", + candidate_config_id="rules", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + safety_verdict="review", + performance_verdict="improved", + candidate_verdict="review", + flags=["candidate_skips_llm_validation"], + ) + ] + + output = tmp_path / "screen.csv" + tool.write_rows(rows, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["workload_id"].tolist() == ["shell"] + assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] + + +def test_screen_strategy_comparisons_dedupes_exact_rows(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_dedupe", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + comparison = pd.DataFrame( + [ + { + "workload_id": "bio", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -5.0, + } + ] + ) + comparison.to_csv(tmp_path / "a.csv", index=False) + comparison.to_csv(tmp_path / "b.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.scanned_file_count == 2 + assert result.comparison_file_count == 2 + assert result.row_count == 1 + assert result.duplicate_row_count == 1 + assert result.summary.viable_count == 1 + + +def test_screen_strategy_comparisons_surfaces_evidence_level_counts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_evidence_level", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default", + "candidate_config_id": "local-substitute", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + }, + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default", + "candidate_config_id": "local-substitute-legacy", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "shared_stable_final_entity_signature_count": 17, + }, + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} + assert {row.evidence_level for row in result.rows} == {"split_verdicts", "stable_signatures"} + group = result.groups[0] + assert group.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} + assert group.split_verdict_candidate_verdict_counts == {"review": 1} + rendered = tool.render_result(result, json_output=False, limit=10) + assert "evidence=split_verdicts" in rendered + assert "evidence_counts=split_verdicts:1,stable_signatures:1" in rendered + + +def test_screen_strategy_comparisons_surfaces_candidate_original_value_leaks(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_original_value_leaks", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-secrets", + "baseline_config_id": "default", + "candidate_config_id": "rules-covered", + "candidate_strategy": "rules_covered_or_default", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "candidate_original_value_leak_count": 2, + "candidate_original_value_leak_record_count": 1, + "original_value_leak_count_delta": 2, + "candidate_original_value_leak_label_counts.api_key": 1, + "candidate_original_value_leak_label_counts.password": 1, + "candidate_replacement_synthetic_original_collision_count": 1, + "candidate_replacement_synthetic_original_collision_value_count": 1, + "replacement_synthetic_original_collision_count_delta": 1, + "candidate_replacement_synthetic_original_collision_label_counts.date": 1, + "flags": ('["candidate_original_value_leak", "candidate_replacement_synthetic_original_collision"]'), + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.reject_count == 1 + row = result.rows[0] + assert row.candidate_original_value_leak_count == 2 + assert row.candidate_original_value_leak_record_count == 1 + assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert row.candidate_replacement_synthetic_original_collision_count == 1 + assert row.candidate_replacement_synthetic_original_collision_value_count == 1 + assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + group = result.groups[0] + assert group.sum_candidate_original_value_leak_count == 2 + assert group.sum_candidate_original_value_leak_record_count == 1 + assert group.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} + assert group.sum_candidate_replacement_synthetic_original_collision_count == 1 + assert group.sum_candidate_replacement_synthetic_original_collision_value_count == 1 + assert group.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} + rendered = tool.render_result(result, json_output=False, limit=10) + assert "candidate_original_value_leaks=2.0" in rendered + assert "candidate_replacement_collisions=1.0" in rendered + assert "leak_labels=api_key:1,password:1" in rendered + assert "collision_labels=date:1" in rendered + + +def test_screen_strategy_comparisons_marks_rule_fast_lane_review_when_only_provenance_flags_remain( + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_fast_lane_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-secrets", + "baseline_config_id": "default", + "candidate_config_id": "rules-covered", + "candidate_strategy": "rules_covered_or_default", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "candidate_original_value_leak_count": 0, + "candidate_original_value_leak_record_count": 0, + "flags": '["no_candidate_detector_entities", "candidate_uses_rule_entities"]', + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.groups[0].recommendation == "fast_lane_review" + + +def test_screen_strategy_comparisons_treats_covered_boundary_deltas_as_fast_lane_review( + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_fast_lane_boundary_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default", + "candidate_config_id": "rules-local", + "candidate_strategy": "rules_covered_or_default", + "candidate_replacement_strategy": "local_structured_substitute", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "candidate_original_value_leak_count": 0, + "candidate_original_value_leak_record_count": 0, + "flags": ( + '["entity_count_loss", "span_boundary_mismatch", ' + '"no_candidate_detector_entities", "candidate_uses_rule_entities"]' + ), + "baseline_only_final_entity_signature_label_counts.api_key": 2, + "baseline_only_final_entity_signature_label_counts.http_cookie": 3, + "baseline_only_candidate_covered_signature_label_counts.api_key": 2, + "baseline_only_candidate_covered_signature_label_counts.http_cookie": 3, + "baseline_only_candidate_overlapping_signature_label_counts.http_cookie": 1, + "baseline_only_candidate_uncovered_signature_count": 0, + "baseline_stable_candidate_uncovered_signature_count": 0, + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + row = result.rows[0] + assert row.baseline_only_label_counts == {} + assert result.groups[0].baseline_only_label_counts == {} + assert result.groups[0].recommendation == "fast_lane_review" + + +def test_screen_strategy_comparisons_surfaces_label_policy_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_label_policy_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "legal-r1", + "baseline_config_id": "legal-default", + "candidate_config_id": "legal-detector-native-validate", + "candidate_strategy": "detector_native_validate_no_augment", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -26.5, + "observed_total_requests_delta": -1, + "observed_total_tokens_delta": -5366, + "flags": '["covered_label_mismatch"]', + "baseline_only_candidate_label_mismatch_signature_label_counts.date_of_birth": 1, + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + assert result.summary.value_protection_verdict_counts == {"pass": 1} + assert result.summary.signature_parity_verdict_counts == {"review": 1} + row = result.rows[0] + assert row.value_protection_verdict == "pass" + assert row.signature_parity_verdict == "review" + assert row.label_mismatch_label_counts == {"date_of_birth": 1} + group = result.groups[0] + assert group.value_protection_verdict_counts == {"pass": 1} + assert group.signature_parity_verdict_counts == {"review": 1} + assert group.label_mismatch_label_counts == {"date_of_birth": 1} + assert group.recommendation == "label_policy_review" + rendered = tool.render_result(result, json_output=False, limit=10) + assert "value_protection=pass" in rendered + assert "signature_parity=review" in rendered + assert "label_mismatch=date_of_birth:1" in rendered + + +def test_screen_strategy_comparisons_surfaces_reliability_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_reliability_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default-substitute", + "candidate_config_id": "local-substitute", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "pass", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "flags": '["failed_request_increase"]', + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + group = result.groups[0] + assert group.flag_counts == {"failed_request_increase": 1} + assert group.recommendation == "reliability_review" + + +def test_screen_strategy_comparisons_surfaces_replacement_replay_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_replacement_replay_review", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "structured-identifiers", + "baseline_config_id": "default-substitute", + "candidate_config_id": "local-substitute", + "baseline_strategy": "default", + "candidate_strategy": "default", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "flags": '["covered_label_mismatch", "replacement_only_detection_instability"]', + "baseline_only_candidate_label_mismatch_signature_label_counts.api_key": 1, + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + group = result.groups[0] + assert group.flag_counts == { + "covered_label_mismatch": 1, + "replacement_only_detection_instability": 1, + } + assert group.recommendation == "replacement_replay_review" + + +def test_screen_strategy_comparisons_surfaces_baseline_defect_improvement_review(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_baseline_defect_improvement", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + pd.DataFrame( + [ + { + "workload_id": "biography-supported-structured", + "baseline_config_id": "dd-substitute", + "candidate_config_id": "local-substitute", + "baseline_strategy": "default", + "candidate_strategy": "default", + "baseline_replacement_strategy": "default", + "candidate_replacement_strategy": "local_structured_substitute", + "value_protection_verdict": "pass", + "signature_parity_verdict": "review", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "baseline_replacement_missing_final_entity_count": 2, + "candidate_replacement_missing_final_entity_count": 0, + "baseline_original_value_leak_count": 2, + "candidate_original_value_leak_count": 0, + "baseline_duplicate_synthetic_replacement_count": 1, + "candidate_duplicate_synthetic_replacement_count": 0, + "baseline_replacement_missing_final_entity_label_counts.organization_name": 1, + "baseline_replacement_missing_final_entity_label_counts.religious_belief": 1, + "baseline_original_value_leak_label_counts.organization_name": 2, + "flags": ( + '["baseline_replacement_missing_final_entity", "baseline_original_value_leak", ' + '"candidate_covers_baseline_replacement_missing_final_entity", ' + '"candidate_covers_baseline_original_value_leak", "replacement_count_delta"]' + ), + } + ] + ).to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + row = result.rows[0] + assert row.baseline_replacement_missing_final_entity_count == 2 + assert row.candidate_replacement_missing_final_entity_count == 0 + assert row.baseline_original_value_leak_count == 2 + assert row.candidate_original_value_leak_count == 0 + assert row.baseline_duplicate_synthetic_replacement_count == 1 + assert row.candidate_duplicate_synthetic_replacement_count == 0 + assert row.baseline_replacement_missing_final_entity_label_counts == { + "organization_name": 1, + "religious_belief": 1, + } + assert row.baseline_original_value_leak_label_counts == {"organization_name": 2} + group = result.groups[0] + assert group.baseline_defect_improvement_count == 1 + assert group.sum_baseline_replacement_missing_final_entity_count == 2 + assert group.sum_candidate_replacement_missing_final_entity_count == 0 + assert group.sum_baseline_original_value_leak_count == 2 + assert group.sum_candidate_original_value_leak_count == 0 + assert group.sum_baseline_duplicate_synthetic_replacement_count == 1 + assert group.sum_candidate_duplicate_synthetic_replacement_count == 0 + assert group.baseline_replacement_missing_final_entity_label_counts == { + "organization_name": 1, + "religious_belief": 1, + } + assert group.baseline_original_value_leak_label_counts == {"organization_name": 2} + assert group.recommendation == "candidate_covers_baseline_defects" + rendered = tool.render_result(result, json_output=False, limit=10) + assert "baseline_defect_improvements=1" in rendered + assert "baseline_missing_replacements=2.0" in rendered + assert "candidate_missing_replacements=0.0" in rendered + assert "baseline_missing_labels=organization_name:1,religious_belief:1" in rendered + + +def test_screen_strategy_comparisons_groups_default_detection_by_replacement_strategy() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_replacement_group", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + row = tool.ScreenRow( + source_path="comparison.csv", + workload_id="structured-identifiers", + baseline_config_id="substitute-dd", + candidate_config_id="substitute-local", + candidate_strategy="default", + baseline_replacement_strategy="default", + candidate_replacement_strategy="local_structured_substitute", + safety_verdict="pass", + performance_verdict="improved", + candidate_verdict="candidate_viable", + ) + + assert tool.group_base_for_row(row, config_aliases={}) == "replacement:local_structured_substitute" + + +def test_screen_strategy_comparisons_keeps_stale_rule_review_generic_without_leak_metrics() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_fast_lane_review_stale", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + group = tool.ScreenGroup( + group_key="strategy:rules_covered_or_default", + candidate_strategy="rules_covered_or_default", + row_count=1, + review_count=1, + performance_verdict_counts={"improved": 1}, + flag_counts={"candidate_uses_rule_entities": 1, "no_candidate_detector_entities": 1}, + ) + + assert tool.group_recommendation(group) == "review_only" + + +def test_screen_strategy_comparisons_filters_source_paths(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_source_filters", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + current_dir = tmp_path / "analysis-current-csv" + current_dir.mkdir() + stale_dir = tmp_path / "analysis" + stale_dir.mkdir() + pd.DataFrame( + [ + { + "workload_id": "bio-current", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + } + ] + ).to_csv(current_dir / "current-comparison.csv", index=False) + pd.DataFrame( + [ + { + "workload_id": "bio-stale", + "baseline_config_id": "default", + "candidate_config_id": "candidate", + "safety_verdict": "fail", + "performance_verdict": "regressed", + "candidate_verdict": "reject", + } + ] + ).to_csv(stale_dir / "stale-comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], source_includes=["analysis-current-csv"]) + + assert result.scanned_file_count == 1 + assert result.comparison_file_count == 1 + assert [row.workload_id for row in result.rows] == ["bio-current"] + + result = tool.screen_comparison_paths( + [tmp_path], + source_includes=["analysis"], + source_excludes=["analysis-current-csv"], + ) + + assert result.scanned_file_count == 1 + assert result.comparison_file_count == 1 + assert [row.workload_id for row in result.rows] == ["bio-stale"] + + +def test_screen_strategy_comparisons_groups_candidate_strategy_conflicts(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "legal-small", + "baseline_config_id": "default", + "candidate_config_id": "no-augment-small", + "candidate_strategy": "no_augment", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -7.0, + "observed_total_requests_delta": -2, + "observed_total_tokens_delta": -8000, + "baseline_case_count": 3, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 20, + "flags": "[]", + }, + { + "workload_id": "legal-offset", + "baseline_config_id": "default", + "candidate_config_id": "no-augment-offset", + "candidate_strategy": "no_augment", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -10.0, + "observed_total_requests_delta": -1, + "observed_total_tokens_delta": -10000, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 14, + "baseline_only_final_entity_signature_label_counts.first_name": 2, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, + "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', + }, + { + "workload_id": "shell", + "baseline_config_id": "default", + "candidate_config_id": "rules", + "candidate_strategy": "rules_only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.9, + "observed_total_tokens_delta": -11000, + "flags": '["candidate_skips_llm_validation"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path]) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == ["strategy:rules_only", "strategy:no_augment"] + no_augment = groups["strategy:no_augment"] + assert no_augment.row_count == 2 + assert no_augment.viable_count == 1 + assert no_augment.reject_count == 1 + assert no_augment.has_conflicting_verdicts is True + assert no_augment.recommendation == "conflicting_evidence" + assert no_augment.best_pipeline_elapsed_sec_delta_pct == -10.0 + assert no_augment.best_observed_total_tokens_delta == -10000 + assert no_augment.best_observed_total_requests_delta == -2 + assert no_augment.worst_pipeline_elapsed_sec_delta_pct == -7.0 + assert no_augment.worst_observed_total_tokens_delta == -8000 + assert no_augment.worst_observed_total_requests_delta == -1 + assert no_augment.min_baseline_case_count == 2 + assert no_augment.min_candidate_case_count == 2 + assert no_augment.min_shared_stable_final_entity_signature_count == 14 + assert no_augment.flag_counts == {"entity_signature_loss": 1, "stable_entity_signature_loss": 1} + assert no_augment.baseline_only_label_counts == {"first_name": 2} + assert no_augment.stable_lost_label_counts == {"date": 1} + + +def test_screen_strategy_comparisons_groups_default_strategy_by_config(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_default_config_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "biographies-r5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -6.0, + "observed_total_tokens_delta": -325, + "flags": "[]", + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-hybrid", + "candidate_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "regressed", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": 10.0, + "observed_total_tokens_delta": 100, + "baseline_only_final_entity_signature_label_counts.university": 1, + "flags": '["entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == [ + "config:biography-augment-temp07|family:biographies", + "config:biography-hybrid|family:biographies", + ] + assert groups["config:biography-augment-temp07|family:biographies"].recommendation == "single_slice_viable" + assert groups["config:biography-hybrid|family:biographies"].recommendation == "reject" + + +def test_screen_strategy_comparisons_groups_default_strategy_by_config_alias(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_config_alias_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "biographies-r2", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-hybrid-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "pass", + "performance_verdict": "improved", + "candidate_verdict": "candidate_viable", + "pipeline_elapsed_sec_delta_pct": -10.4, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 48, + "flags": "[]", + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "biography-default", + "candidate_config_id": "biography-augment-temp07", + "candidate_strategy": "default", + "safety_verdict": "fail", + "performance_verdict": "mixed", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": 16.0, + "baseline_case_count": 2, + "candidate_case_count": 2, + "shared_stable_final_entity_signature_count": 116, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts.university": 1, + "flags": '["stable_entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths( + [tmp_path], + group_by=tool.GroupBy.strategy_workload_family, + config_aliases={ + "biography-hybrid-augment-temp07": "biography-temp07-routing", + "biography-augment-temp07": "biography-temp07-routing", + }, + ) + + groups = {group.group_key: group for group in result.groups} + group = groups["alias:biography-temp07-routing|family:biographies"] + assert group.row_count == 2 + assert group.candidate_config_ids == ["biography-augment-temp07", "biography-hybrid-augment-temp07"] + assert group.viable_count == 1 + assert group.reject_count == 1 + assert group.recommendation == "conflicting_evidence" + assert group.min_baseline_case_count == 2 + assert group.min_candidate_case_count == 2 + assert group.min_shared_stable_final_entity_signature_count == 48 + assert group.stable_lost_label_counts == {"university": 1} + + +def test_screen_strategy_comparisons_can_group_by_strategy_and_workload_family(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_family_groups", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + table = pd.DataFrame( + [ + { + "workload_id": "shell-secrets-3", + "baseline_config_id": "default", + "candidate_config_id": "rules-only-shell", + "candidate_strategy": "rules_only", + "safety_verdict": "review", + "performance_verdict": "improved", + "candidate_verdict": "review", + "pipeline_elapsed_sec_delta_pct": -99.9, + "observed_total_tokens_delta": -11000, + "flags": '["candidate_uses_rule_entities"]', + }, + { + "workload_id": "biographies-r5-offset5", + "baseline_config_id": "default", + "candidate_config_id": "rules-only-bio", + "candidate_strategy": "rules_only", + "safety_verdict": "fail", + "performance_verdict": "improved", + "candidate_verdict": "reject", + "pipeline_elapsed_sec_delta_pct": -90.0, + "observed_total_tokens_delta": -17000, + "baseline_only_final_entity_signature_label_counts.first_name": 4, + "flags": '["entity_signature_loss"]', + }, + ] + ) + table.to_csv(tmp_path / "comparison.csv", index=False) + + result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) + + groups = {group.group_key: group for group in result.groups} + assert list(groups) == ["strategy:rules_only|family:shell-secrets", "strategy:rules_only|family:biographies"] + assert groups["strategy:rules_only|family:shell-secrets"].recommendation == "review_only" + assert groups["strategy:rules_only|family:shell-secrets"].workload_families == ["shell-secrets"] + assert groups["strategy:rules_only|family:biographies"].recommendation == "reject" + assert groups["strategy:rules_only|family:biographies"].baseline_only_label_counts == {"first_name": 4} + + +def test_workload_family_normalizes_slice_and_offset_suffixes() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_workload_family", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + + assert tool.workload_family("legal-r2-offset1") == "legal" + assert tool.workload_family("legal-slice-2") == "legal" + assert tool.workload_family("biographies-r5-offset5") == "biographies" + assert tool.workload_family("shell-secrets-3") == "shell-secrets" + assert tool.workload_family("support-ticket") == "support-ticket" + + +def test_screen_strategy_comparisons_writes_group_csv(tmp_path: Path) -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_group_export", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + groups = [ + tool.ScreenGroup( + group_key="strategy:no_augment", + candidate_strategy="no_augment", + row_count=2, + viable_count=1, + reject_count=1, + has_conflicting_verdicts=True, + performance_verdict_counts={"improved": 1, "regressed": 1}, + flag_counts={"entity_signature_loss": 1}, + ) + ] + + output = tmp_path / "groups.csv" + tool.write_groups(groups, output, tool.ExportFormat.csv) + + exported = pd.read_csv(output) + assert exported["group_key"].tolist() == ["strategy:no_augment"] + assert exported["has_conflicting_verdicts"].tolist() == [True] + assert exported["recommendation"].tolist() == ["conflicting_evidence"] + assert exported["performance_verdict_counts"].tolist() == ['{"improved": 1, "regressed": 1}'] + assert exported["flag_counts"].tolist() == ['{"entity_signature_loss": 1}'] + + +def test_screen_strategy_group_recommendations() -> None: + tool = load_tool( + "measurement_screen_strategy_comparisons_recommendations", + REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", + ) + + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="viable", row_count=1, viable_count=1), + ) + == "single_slice_viable" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="repeated-viable", row_count=2, viable_count=2), + ) + == "candidate_family_viable" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="promising", row_count=2, viable_count=1, review_count=1), + ) + == "promising_needs_review" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="weak-promising", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"signature_counts": 1, "stable_signatures": 1}, + ), + ) + == "needs_split_verdict_rerun" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="partial-split-review", + row_count=2, + review_count=2, + evidence_level_counts={"split_verdicts": 1, "stable_signatures": 1}, + split_verdict_candidate_verdict_counts={"review": 1}, + ), + ) + == "needs_split_verdict_rerun" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="split-review-with-legacy-viable", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"signature_counts": 1, "split_verdicts": 1}, + split_verdict_candidate_verdict_counts={"review": 1}, + ), + ) + == "needs_viable_split_verdict" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="split-viable-with-review", + row_count=2, + viable_count=1, + review_count=1, + evidence_level_counts={"split_verdicts": 2}, + split_verdict_candidate_verdict_counts={"candidate_viable": 1, "review": 1}, + ), + ) + == "promising_needs_review" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review", + row_count=2, + review_count=2, + performance_verdict_counts={"improved": 2}, + ), + ) + == "review_only" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-mixed", + row_count=2, + review_count=2, + performance_verdict_counts={"mixed": 2}, + ), + ) + == "review_mixed_performance" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-improved-and-mixed", + row_count=2, + review_count=2, + performance_verdict_counts={"improved": 1, "mixed": 1}, + ), + ) + == "review_mixed_performance" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup( + group_key="review-regressed", + row_count=2, + review_count=2, + performance_verdict_counts={"regressed": 2}, + ), + ) + == "no_performance_win" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="reject", row_count=2, reject_count=2), + ) + == "reject" + ) + assert ( + tool.group_recommendation( + tool.ScreenGroup(group_key="conflict", row_count=2, viable_count=1, reject_count=1), + ) + == "conflicting_evidence" + ) diff --git a/tests/tools/test_staged_detection_output_analysis.py b/tests/tools/test_staged_detection_output_analysis.py new file mode 100644 index 00000000..f2ccb7f9 --- /dev/null +++ b/tests/tools/test_staged_detection_output_analysis.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType + +import pandas as pd +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: + path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") + + +def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + output_dir = tmp_path / "staged" + output_dir.mkdir() + _write_jsonl( + output_dir / "staged-detection-cases.jsonl", + [ + { + "record_type": "staged_detection_case", + "case_id": "shell-row-0", + "row_index": 0, + "seed_source": "rules_router", + "status": "completed", + "elapsed_sec": 0.002, + "model_elapsed_sec": 0.0, + "model_phase_count": 0, + "model_request_count": 0, + "rule_covered_label_set": True, + "final_entity_count": 5, + "final_entity_signature_count": 5, + "final_label_counts": {"api_key": 2, "email": 1, "password": 1, "url": 1}, + "total_usage": {}, + "comparison": { + "baseline_final_entity_signature_count": 5, + "shared_final_entity_signature_count": 5, + "baseline_only_final_entity_signature_count": 0, + "direct_only_final_entity_signature_count": 0, + "baseline_only_label_counts": {}, + "direct_only_label_counts": {}, + }, + }, + { + "record_type": "staged_detection_case", + "case_id": "bio-row-0", + "row_index": 0, + "seed_source": "rules_plus_direct_llm", + "status": "completed", + "elapsed_sec": 10.0, + "model_elapsed_sec": 9.5, + "model_phase_count": 3, + "model_request_count": 3, + "rule_covered_label_set": False, + "final_entity_count": 3, + "final_entity_signature_count": 3, + "final_label_counts": {"person": 2, "api_key": 1}, + "total_usage": {"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, + "comparison": { + "baseline_final_entity_signature_count": 4, + "shared_final_entity_signature_count": 2, + "baseline_only_final_entity_signature_count": 2, + "direct_only_final_entity_signature_count": 1, + "baseline_only_label_counts": {"city": 1, "person": 1}, + "direct_only_label_counts": {"api_key": 1}, + }, + }, + { + "record_type": "staged_detection_case", + "case_id": "bio-row-1", + "row_index": 1, + "seed_source": "rules_plus_direct_llm", + "status": "error", + "elapsed_sec": 1.0, + "model_elapsed_sec": 0.8, + "model_phase_count": 1, + "model_request_count": 1, + "rule_covered_label_set": False, + "final_entity_count": 0, + "final_entity_signature_count": 0, + "total_usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}, + "comparison": None, + "error": "provider failed", + }, + ], + ) + + result = tool.analyze_staged_detection_output(output_dir) + + assert result.case_count == 3 + assert result.group_count == 2 + groups = {row.seed_source: row for row in result.groups} + assert groups["rules_router"].case_count == 1 + assert groups["rules_router"].completed_case_count == 1 + assert groups["rules_router"].model_elapsed_sec_sum == 0.0 + assert groups["rules_router"].model_request_count_sum == 0 + assert groups["rules_router"].rule_covered_case_count == 1 + assert groups["rules_router"].baseline_shared_signature_rate == 1.0 + assert groups["rules_router"].fast_lane_verdict == "review" + assert groups["rules_router"].flags == ["too_few_cases"] + assert groups["rules_plus_direct_llm"].case_count == 2 + assert groups["rules_plus_direct_llm"].completed_case_count == 1 + assert groups["rules_plus_direct_llm"].error_case_count == 1 + assert groups["rules_plus_direct_llm"].elapsed_sec_sum == pytest.approx(11.0) + assert groups["rules_plus_direct_llm"].model_elapsed_sec_sum == pytest.approx(10.3) + assert groups["rules_plus_direct_llm"].model_request_count_sum == 4 + assert groups["rules_plus_direct_llm"].total_tokens_sum == 132 + assert groups["rules_plus_direct_llm"].baseline_final_entity_signature_count_sum == 4 + assert groups["rules_plus_direct_llm"].shared_final_entity_signature_count_sum == 2 + assert groups["rules_plus_direct_llm"].baseline_only_final_entity_signature_count_sum == 2 + assert groups["rules_plus_direct_llm"].direct_only_final_entity_signature_count_sum == 1 + assert groups["rules_plus_direct_llm"].baseline_shared_signature_rate == pytest.approx(0.5) + assert groups["rules_plus_direct_llm"].fast_lane_verdict == "reject" + assert groups["rules_plus_direct_llm"].flags == [ + "too_few_cases", + "case_errors", + "baseline_signature_loss", + "uses_model", + "not_fully_rule_covered", + ] + + label_deltas = {(row.seed_source, row.delta_type, row.label): row.count for row in result.label_deltas} + assert label_deltas == { + ("rules_plus_direct_llm", "baseline_only", "city"): 1, + ("rules_plus_direct_llm", "baseline_only", "person"): 1, + ("rules_plus_direct_llm", "direct_only", "api_key"): 1, + } + + +def test_staged_detection_output_analysis_requires_repeated_cases_for_fast_lane(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis_repeated_gate", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + output_dir = tmp_path / "staged" + output_dir.mkdir() + _write_jsonl( + output_dir / "staged-detection-cases.jsonl", + [ + { + "record_type": "staged_detection_case", + "case_id": f"shell-row-{index}", + "row_index": index, + "seed_source": "rules_router", + "status": "completed", + "elapsed_sec": 0.002, + "model_elapsed_sec": 0.0, + "model_phase_count": 0, + "model_request_count": 0, + "rule_covered_label_set": True, + "final_entity_count": 5, + "final_entity_signature_count": 5, + "total_usage": {}, + "comparison": { + "baseline_final_entity_signature_count": 5, + "shared_final_entity_signature_count": 5, + "baseline_only_final_entity_signature_count": 0, + "direct_only_final_entity_signature_count": 0, + "baseline_only_label_counts": {}, + "direct_only_label_counts": {}, + }, + } + for index in range(3) + ], + ) + + result = tool.analyze_staged_detection_output(output_dir) + + group = result.groups[0] + assert group.seed_source == "rules_router" + assert group.case_count == 3 + assert group.fast_lane_verdict == "fast_lane_candidate" + assert group.flags == [] + + +def test_staged_detection_output_analysis_writes_csv_tables(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis_export", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + cases_path = tmp_path / "cases.jsonl" + _write_jsonl( + cases_path, + [ + { + "case_id": "case-0", + "row_index": 0, + "seed_source": "rules_router", + "status": "completed", + "elapsed_sec": 0.01, + "model_elapsed_sec": 0.0, + "model_request_count": 0, + "total_usage": {}, + } + ], + ) + + result = tool.analyze_staged_detection_output(cases_path) + export = tool.write_analysis_tables(result, tmp_path / "analysis", tool.ExportFormat.csv) + + assert Path(export.manifest_path).exists() + assert [table.table for table in export.tables] == [ + "case_analysis", + "group_analysis", + "label_delta_analysis", + ] + case_table = pd.read_csv(tmp_path / "analysis" / "case_analysis.csv") + assert case_table.loc[0, "case_id"] == "case-0" + assert case_table.loc[0, "model_request_count"] == 0 + + +def test_staged_detection_output_analysis_rejects_missing_case_file(tmp_path: Path) -> None: + tool = load_tool( + "measurement_staged_detection_output_analysis_missing_input", + REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", + ) + + with pytest.raises(ValueError, match="input path does not exist"): + tool.analyze_staged_detection_output(tmp_path / "missing") diff --git a/tests/tools/test_staged_detection_probe.py b/tests/tools/test_staged_detection_probe.py new file mode 100644 index 00000000..c380cba2 --- /dev/null +++ b/tests/tools/test_staged_detection_probe.py @@ -0,0 +1,812 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +from anonymizer.engine.schemas import ValidationCandidateSchema + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def load_tool(module_name: str, path: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + sys.path.insert(0, str(path.parent)) + spec.loader.exec_module(module) + return module + + +class SequencedClient: + def __init__(self, tool: ModuleType, outputs: list[str]) -> None: + self._tool = tool + self._outputs = list(outputs) + self.prompts: list[str] = [] + + def complete(self, request): # type: ignore[no-untyped-def] + self.prompts.append(request.prompt) + content = self._outputs.pop(0) + return self._tool.DirectCompletion( + content=content, + elapsed_sec=0.5, + usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + ) + + +class StaticSeedClient: + def __init__(self, tool: ModuleType, content: str) -> None: + self._tool = tool + self.content = content + self.requests = [] + + def detect(self, request): # type: ignore[no-untyped-def] + self.requests.append(request) + return self._tool.DirectCompletion( + content=self.content, + elapsed_sec=0.25, + usage={"prompt_tokens": 1, "completion_tokens": 3, "total_tokens": 4}, + ) + + +def test_staged_detection_reuses_validation_and_augmentation_flow() -> None: + tool = load_tool( + "measurement_staged_detection_probe_case", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_suggestion_count == 1 + assert result.seed_entity_count == 1 + assert result.validation_decision_count == 1 + assert result.augmented_suggestion_count == 1 + assert result.final_entity_count == 2 + assert result.final_label_counts == {"first_name": 1, "organization_name": 1} + assert result.artifact.final_source_counts == {"direct_seed": 1, "augmenter": 1} + assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=True) + assert result.phase_skip_reasons == tool.PhaseSkipReasons() + assert result.model_phase_count == 3 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) + assert result.model_request_count == 3 + assert result.total_usage == {"prompt_tokens": 30, "completion_tokens": 15, "total_tokens": 45} + assert len(client.prompts) == 3 + + +def test_staged_detection_can_disable_augmentation_phase() -> None: + tool = load_tool( + "measurement_staged_detection_probe_disable_augmentation", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_label_counts == {"first_name": 1} + assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=False) + assert result.phase_skip_reasons == tool.PhaseSkipReasons(augmentation="disabled") + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=0) + assert result.model_request_count == 2 + assert len(client.prompts) == 2 + + +def test_staged_detection_execution_exposes_output_row_without_serializing_it() -> None: + tool = load_tool( + "measurement_staged_detection_probe_execution", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": []}', + ], + ) + + execution = tool.execute_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=client, + ) + + assert execution.case.status == tool.CaseStatus.completed + assert execution.row[tool.COL_DETECTED_ENTITIES]["entities"][0]["value"] == "Alice" + assert "row" not in execution.case.model_dump() + + +def test_staged_detection_validation_drop_removes_seed_entity() -> None: + tool = load_tool( + "measurement_staged_detection_probe_drop", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "name", "label": "first_name", "reason": "placeholder"}]}', + '{"decisions": [{"id": "first_name_3_7", "decision": "drop", "reason": "placeholder"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="my name is hidden", + labels=["first_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_entity_count == 1 + assert result.validation_decision_count == 1 + assert result.final_entity_count == 0 + assert result.final_label_counts == {} + + +def test_staged_detection_invalid_reclass_label_keeps_seed_label() -> None: + tool = load_tool( + "measurement_staged_detection_probe_invalid_reclass_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + ( + '{"decisions": [' + '{"id": "first_name_0_5", "decision": "reclass", ' + '"proposed_label": "drop", "reason": "invalid label"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"first_name": 1} + + +def test_staged_detection_discards_invalid_augmentation_labels() -> None: + tool = load_tool( + "measurement_staged_detection_probe_invalid_augmentation_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "drop", "reason": "invalid label"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=client, + ) + + assert result.status == tool.CaseStatus.completed + assert result.augmented_suggestion_count == 0 + assert result.final_entity_count == 1 + assert result.final_label_counts == {"first_name": 1} + + +def test_staged_detection_preserves_more_specific_seed_label_on_native_reclass() -> None: + tool = load_tool( + "measurement_staged_detection_probe_specific_label_reclass", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "birth date"}]}', + ( + '{"decisions": [' + '{"id": "date_of_birth_26_41", "decision": "reclass", ' + '"proposed_label": "date", "reason": "date expression"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="The applicant was born on 23 October 1992.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date_of_birth": 1} + + +def test_staged_detection_allows_date_of_birth_reclass_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_generic_date_reclass", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "date"}]}', + ( + '{"decisions": [' + '{"id": "date_of_birth_3_18", "decision": "reclass", ' + '"proposed_label": "date", "reason": "filing date"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_demotes_native_birth_date_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_generic_date_birth_label", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + '{"entities": [{"value": "23 October 1992", "label": "date", "reason": "date"}]}', + ( + '{"decisions": [' + '{"id": "date_3_18", "decision": "reclass", ' + '"proposed_label": "date_of_birth", "reason": "ambiguous date"}' + "]}" + ), + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=client, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_can_seed_from_direct_gliner_payload_without_llm_seed_prompt() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + '{"entities": [{"text": "Alice", "label": "first_name", "start": 0, "end": 5, "score": 0.99}]}', + ) + llm_client = SequencedClient( + tool, + [ + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works at NVIDIA.", + labels=["first_name", "organization_name"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"first_name": 1, "organization_name": 1} + assert result.total_usage == {"prompt_tokens": 21, "completion_tokens": 13, "total_tokens": 34} + assert len(seed_client.requests) == 1 + assert len(llm_client.prompts) == 2 + + +def test_staged_detection_promotes_gliner_date_seed_in_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_birth_date_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 26, "end": 41, "score": 0.99}]}'), + ) + llm_client = SequencedClient( + tool, + ['{"decisions": [{"id": "date_of_birth_26_41", "decision": "keep", "reason": "birth date"}]}'], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="The applicant was born on 23 October 1992.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"date_of_birth": 1} + + +def test_staged_detection_keeps_gliner_date_seed_without_birth_context() -> None: + tool = load_tool( + "measurement_staged_detection_probe_gliner_generic_date_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + seed_client = StaticSeedClient( + tool, + ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 3, "end": 18, "score": 0.99}]}'), + ) + llm_client = SequencedClient( + tool, + ['{"decisions": [{"id": "date_3_18", "decision": "keep", "reason": "filing date"}]}'], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="On 23 October 1992 the applicant filed an action.", + labels=["date", "date_of_birth"], + row_index=0, + ), + client=llm_client, + seed_client=seed_client, + seed_source=tool.SeedSource.gliner, + skip_augmentation=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.gliner + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"date": 1} + + +def test_staged_detection_validation_prompt_preserves_degree_label_guidance() -> None: + tool = load_tool( + "measurement_staged_detection_probe_degree_guidance", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + request = tool.StagedDetectionRequest( + case_id="case-1", + text="He earned his Bachelor of Science in physics.", + labels=["degree", "education_level", "field_of_study"], + ) + candidates = tool.ValidationCandidatesSchema( + candidates=[ + ValidationCandidateSchema( + id="education_level_14_33", + value="Bachelor of Science", + label="education_level", + context_before="He earned his ", + context_after=" in physics.", + ) + ] + ) + + prompt = tool._validation_prompt(request, candidates) + + assert "degree" in prompt + assert "education_level" in prompt + assert "Bachelor of Science" in prompt + assert "Prefer degree for credential names" in prompt + + +def test_staged_detection_augmentation_prompt_discourages_grouped_person_and_surname_spans() -> None: + tool = load_tool( + "measurement_staged_detection_probe_augmentation_guidance", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + request = tool.StagedDetectionRequest( + case_id="case-1", + text="Her parents, Mark and Linda, live near West Baker Drive and Baker's grocery.", + labels=["first_name", "last_name", "organization_name", "place_name", "company_name"], + ) + + prompt = tool._augmentation_prompt(request, {tool.COL_SEED_ENTITIES_JSON: "[]"}) + + assert "split personal names connected by 'and'" in prompt + assert "Do not label a list of people as organization_name" in prompt + assert "also return the surname substring as last_name" in prompt + + +def test_staged_detection_can_seed_from_rules_without_llm_seed_prompt() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient( + tool, + [ + '{"decisions": [{"id": "email_6_23", "decision": "keep", "reason": "email address"}]}', + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Email alice@example.com at NVIDIA.", + labels=["email", "organization_name"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.rules + assert result.phase_usage.seed == {} + assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=True, augmentation=True) + assert result.phase_skip_reasons.seed == "deterministic_rules" + assert result.phase_skip_reasons.validation is None + assert result.model_phase_count == 2 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=1, augmentation=1) + assert result.model_request_count == 2 + assert result.seed_suggestion_count == 1 + assert result.seed_entity_count == 1 + assert result.final_label_counts == {"email": 1, "organization_name": 1} + assert result.artifact.final_source_counts == {"augmenter": 1, "rule": 1} + assert result.total_usage == {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30} + assert len(llm_client.prompts) == 2 + + +def test_staged_detection_can_add_rules_to_direct_llm_seed_without_validating_rules() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_plus_direct_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient( + tool, + [ + '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', + '{"decisions": [{"id": "organization_name_27_33", "decision": "keep", "reason": "employer"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Email alice@example.com at NVIDIA.", + labels=["email", "organization_name"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules_plus_direct_llm, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.rules_plus_direct_llm + assert result.seed_suggestion_count == 2 + assert result.seed_entity_count == 2 + assert result.validation_candidate_count == 1 + assert result.validation_decision_count == 1 + assert result.final_label_counts == {"email": 1, "organization_name": 1} + assert result.artifact.final_source_counts == {"direct_seed": 1, "rule": 1} + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) + assert result.model_request_count == 3 + assert '"label":"email"' not in llm_client.prompts[1] + assert '"label":"organization_name"' in llm_client.prompts[1] + + +def test_staged_detection_baseline_comparison_skips_rows_without_signature_hashes() -> None: + tool = load_tool( + "measurement_staged_detection_probe_missing_baseline_signatures", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "person name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "person name"}]}', + '{"entities": []}', + ], + ) + case = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works remotely.", + labels=["first_name"], + row_index=0, + ), + client=llm_client, + ) + + compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 1}) + + assert compared.comparison is None + + +def test_staged_detection_can_trust_rules_without_validation_prompt() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_trusted_seed", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient(tool, ['{"entities": []}']) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text=( + "$ docker run -e DATABASE_URL='postgres://app_user:fakeDbPass123!@db.example.test:5432/app' " + "-e API_KEY=ghp_FAKEtoken1234567890abcdef myapp:latest\nPassword: fakeLoginPass!" + ), + labels=["api_key", "password", "email", "url"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules_trusted, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.rules_trusted + assert result.phase_usage.seed == {} + assert result.phase_usage.validation == {} + assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=True) + assert result.phase_skip_reasons.seed == "deterministic_rules" + assert result.phase_skip_reasons.validation == "trusted_rules" + assert result.phase_skip_reasons.augmentation is None + assert result.model_phase_count == 1 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=1) + assert result.model_request_count == 1 + assert result.rule_covered_label_set is True + assert result.validation_decision_count == 3 + assert result.final_label_counts == {"api_key": 1, "password": 1, "url": 1} + assert result.artifact.final_source_counts == {"rule": 3} + assert result.total_usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + assert len(llm_client.prompts) == 1 + + +def test_staged_detection_can_skip_augmentation_when_all_labels_are_rule_covered() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_trusted_no_augment", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient(tool, []) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Email alice@example.com", + labels=["email"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules_trusted, + skip_augmentation_when_rule_covered=True, + ) + + assert result.status == tool.CaseStatus.completed + assert result.phase_usage.augmentation == {} + assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=False) + assert result.phase_skip_reasons == tool.PhaseSkipReasons( + seed="deterministic_rules", + validation="trusted_rules", + augmentation="rule_covered_labels", + ) + assert result.model_phase_count == 0 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=0) + assert result.model_request_count == 0 + assert result.rule_covered_label_set is True + assert result.augmented_suggestion_count == 0 + assert result.final_label_counts == {"email": 1} + assert result.total_usage == {} + assert len(llm_client.prompts) == 0 + + +def test_staged_detection_rules_router_short_circuits_rule_covered_labels() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_router_short_circuit", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient(tool, []) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Email alice@example.com and token ghp_FAKEtoken1234567890abcdef", + labels=["email", "api_key"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules_router, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.rules_router + assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=False) + assert result.phase_skip_reasons == tool.PhaseSkipReasons( + seed="deterministic_rules", + validation="trusted_rules", + augmentation="rule_covered_labels", + ) + assert result.model_phase_count == 0 + assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=0) + assert result.model_request_count == 0 + assert result.elapsed_sec is not None and result.elapsed_sec > 0.0 + assert result.model_elapsed_sec == 0.0 + assert result.rule_covered_label_set is True + assert result.final_label_counts == {"api_key": 1, "email": 1} + assert result.artifact.final_source_counts == {"rule": 2} + assert result.total_usage == {} + assert len(llm_client.prompts) == 0 + + +def test_staged_detection_rules_router_uses_direct_seed_for_contextual_labels() -> None: + tool = load_tool( + "measurement_staged_detection_probe_rules_router_mixed_labels", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + llm_client = SequencedClient( + tool, + [ + '{"entities": [{"value": "Alice", "label": "first_name", "reason": "person name"}]}', + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "person name"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice emails alice@example.com.", + labels=["email", "first_name"], + row_index=0, + ), + client=llm_client, + seed_source=tool.SeedSource.rules_router, + ) + + assert result.status == tool.CaseStatus.completed + assert result.seed_source == tool.SeedSource.rules_router + assert result.rule_covered_label_set is False + assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=True) + assert result.phase_skip_reasons == tool.PhaseSkipReasons() + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) + assert result.model_request_count == 3 + assert result.final_label_counts == {"email": 1, "first_name": 1} + assert result.artifact.final_source_counts == {"direct_seed": 1, "rule": 1} + assert '"label":"email"' not in llm_client.prompts[1] + assert '"label":"first_name"' in llm_client.prompts[1] + + +def test_staged_detection_can_chunk_validation_into_local_excerpts() -> None: + tool = load_tool( + "measurement_staged_detection_probe_chunked_validation", + REPO_ROOT / "tools/measurement/staged_detection_probe.py", + ) + client = SequencedClient( + tool, + [ + ( + '{"entities": [' + '{"value": "Alice", "label": "first_name", "reason": "name"},' + '{"value": "Paris", "label": "city", "reason": "city"}' + "]}" + ), + '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', + '{"decisions": [{"id": "city_61_66", "decision": "keep", "reason": "city"}]}', + '{"entities": []}', + ], + ) + + result = tool.run_staged_detection_case( + tool.StagedDetectionRequest( + case_id="case-1", + text="Alice works in a very long remote biography before moving to Paris.", + labels=["first_name", "city"], + row_index=0, + ), + client=client, + validation_prompt_mode=tool.ValidationPromptMode.chunked_excerpt, + validation_max_entities_per_call=1, + validation_excerpt_window_chars=8, + ) + + assert result.status == tool.CaseStatus.completed + assert result.final_label_counts == {"city": 1, "first_name": 1} + assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=2, augmentation=1) + assert result.model_phase_count == 3 + assert result.model_request_count == 4 + assert len(client.prompts) == 4 + assert "Alice" in client.prompts[1] + assert "Paris" not in client.prompts[1] + assert "Paris" in client.prompts[2] diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 82f7a955..8a6a1d4f 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -19,6 +19,8 @@ By default it writes Parquet files plus `manifest.json`: - `stage.parquet` - `record.parquet` - `ndd_workflow.parquet` when adapter records are present +- `model_workflow.parquet` when non-DataDesigner model workflow records are + present Use `--format csv` or `--format jsonl` for non-Parquet output, and `--overwrite` to replace existing output files. @@ -64,14 +66,20 @@ Example: suite_id: shell-and-biography-smoke model_configs: ./model-configs.yaml model_providers: ./providers.yaml +dd_parser_compat: none +case_retries: 1 +case_retry_backoff_sec: 10 workloads: - id: biographies source: ./data/biographies.csv text_column: text + row_limit: 25 - id: support source: ./data/support.csv text_column: body id_column: ticket_id + row_offset: 100 + row_limit: 50 configs: - id: redact-default replace: redact @@ -93,6 +101,220 @@ matrix: config: hash-agent-labels ``` +Use `row_limit` and `row_offset` to create cheap, repeatable slices of a local +CSV or Parquet workload. The runner materializes a per-case sliced input under +`raw/inputs/` before calling Anonymizer, so each case keeps a stable input file +even when the matrix has multiple configs or repetitions. Slicing is rejected +for URL-like sources because the runner cannot safely materialize a local +subset without downloading the whole dataset first. + +Use `case_retries` and `case_retry_backoff_sec` for long-running suites on +shared model endpoints. Retries are disabled by default. When enabled, a failed +case is retried with the same `case_id` and output paths; the final case still +records `attempt_count` and `attempt_errors` in `summary.json`. `--fail-fast` +remains fail-fast and bypasses retries. Treat retried cases as reliability +signals during analysis, especially when failures come from provider health +checks or rate limits. + +Configs may also set `experimental_detection_strategy` for benchmark-only +pipeline probes: + +```yaml +configs: + - id: shell-rules-only + experimental_detection_strategy: rules_only + detect: + entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] + replace: + strategy: hash + digest_length: 12 +``` + +Supported values: + +- `default`: run the normal Anonymizer detection pipeline. +- `rules_guardrail`: run the normal Anonymizer detection pipeline, then union + deterministic high-confidence rule spans into the final entity set. +- `rules_filter_guardrail`: remove GLiNER candidates that are fully covered by + same-label deterministic high-confidence rule spans before validation, add + non-overlapping rule spans back before augmentation so the augmenter sees them + as already tagged, then add non-overlapping rule spans into the final entity + set. Different-label overlaps and longer detector spans remain validation + candidates so contextual spans such as a multi-token political view, + university, or organization name are not shadowed by a shorter or differently + labeled rule span. +- `no_augment`: run GLiNER detection and validation, but skip LLM augmentation. +- `rules_seed_no_augment`: add deterministic high-confidence secret spans to + the GLiNER seed set, validate those seeds, and skip LLM augmentation. +- `rules_guardrail_no_augment`: run GLiNER detection and validation, skip LLM + augmentation, then union deterministic high-confidence rule spans into the + final entity set. +- `rules_filter_guardrail_no_augment`: remove GLiNER candidates that are fully + covered by same-label deterministic high-confidence rule spans before + validation, skip LLM augmentation, then add non-overlapping rule spans into + the final entity set. +- `rules_guardrail_detector_only`: run only GLiNER detection and local + finalization, then union deterministic high-confidence rule spans into the + final entity set. +- `detector_only`: run only GLiNER detection and local finalization. This skips + LLM validation and LLM augmentation. +- `rules_only`: use only deterministic high-confidence rules for the detection + stage. +- `rules_covered_or_default`: if explicit `detect.entity_labels` are entirely + inside the structured-secret fast lane (`api_key`, `email`, `http_cookie`, + `password`, `pin`, `unique_id`, `url`, `user_name`), use deterministic rules + for rows whose structured assignments are covered and route suspicious + uncovered rows through the normal Anonymizer detection pipeline. Label sets + outside the fast lane always use normal detection. +- `native_rules_router`: run a benchmark-only native staged detector without + DataDesigner. Rule-covered label sets short-circuit through deterministic + rules with no model calls; other label sets use direct OpenAI-compatible + provider calls for seed extraction, validation, and augmentation. +- `native_candidate_validate_no_augment`: run a benchmark-only native staged + detector without DataDesigner using direct OpenAI-compatible calls for seed + extraction and validation, then skip augmentation. This isolates the cost and + recall impact of removing the augmentation phase from the native executor. +- `detector_native_validate_no_augment`: run the normal GLiNER detector seed + through Anonymizer/DataDesigner, then bypass DataDesigner validation and + augmentation with direct OpenAI-compatible validation calls. This isolates + whether native validation can replace DataDesigner validation when candidate + quality is held closer to the default detector path. +- `detector_native_validate_native_augment`: run the normal GLiNER detector + seed through Anonymizer/DataDesigner, then bypass DataDesigner validation and + augmentation with direct OpenAI-compatible validation and augmentation calls. + This keeps the default detector candidate source while testing whether direct + provider calls can replace the two downstream DataDesigner LLM phases. +- `gliner_native_validate_no_augment`: run a direct hosted-GLiNER seed without + DataDesigner, validate those detector candidates with direct + OpenAI-compatible calls, and skip augmentation. This isolates DataDesigner + detector orchestration overhead while keeping a GLiNER-style candidate source. +- `gliner_native_validate_native_augment`: run a direct hosted-GLiNER seed + without DataDesigner, validate those detector candidates with direct + OpenAI-compatible calls, then run direct native augmentation. This is the + fully staged no-DataDesigner detector/validator/augmenter lane for contextual + recall experiments. +- `native_single_pass`: run a benchmark-only native detector without + DataDesigner using one direct OpenAI-compatible provider call per row. The + model must return exact values plus `start`/`end` offsets; local code validates + offsets, unions non-overlapping deterministic rule spans, resolves overlaps, + and records parser/runtime failures as `model_workflow` errors. +- `native_single_pass_recall`: the same one-call native detector with a + recall-oriented prompt that includes Anonymizer's label examples and stronger + high-recall guidance. +- `native_single_pass_values`: the same one-call native detector, but with the + value-only prompt shape from `direct_detection_probe.py`. The model returns + exact values and labels only; local code resolves every occurrence of each + returned value into spans. +- `native_single_pass_values_recall`: the value-only one-call detector with the + recall-oriented prompt from `direct_detection_probe.py`. + +These strategies exist to compare performance options. They are not public +`Detect` config fields, and they should not be treated as safe defaults across +arbitrary data. The rule-backed strategies only cover deterministic +high-confidence spans for `api_key`, `date_of_birth`, `email`, +`http_cookie`, `organization_name`, `password`, `pin`, `religious_belief`, +`street_address`, `unique_id`, `url`, and `user_name`; they will not replace +contextual detection for prose identifiers such as names in biographies or +legal documents. The prose rules (`date_of_birth`, `organization_name`, +`religious_belief`, and `street_address`) are narrow contextual patterns and +are not enough to opt into `rules_covered_or_default`; those labels fall back +to default detection unless `rules_only` is explicitly selected. The structured +identifier rules require keyed or command-style syntax such as +`Cookie:`, `pin=`, `trace-id:`, `user_name=`, or service-principal flags. They +are not general entity recognizers. `detector_only` is also unsafe as a default +because it skips the LLM validation pass that drops false positives and +reclassifies ambiguous spans. `rules_only` requires explicit `entity_labels`, +and every label must be covered by those deterministic rules. Use +`rules_covered_or_default` when a benchmark suite may include both fully +structured-secret scans and contextual workloads; it keeps the no-DataDesigner +short-circuit for the former and falls back to the default pipeline for prose +or legal labels. + +Use `native_rules_router` when you want the same routing shape without +DataDesigner orchestration. It defaults to the local OpenAI-compatible endpoint +used by the staged probe (`http://gpu-dev-pod-serve-svc:8000/v1`) and model +`nvidia/nemotron-3-super`. Treat it as a native-executor prototype: it can prove +that DataDesigner overhead is avoidable, but it must be compared against +baseline signatures and original-value leak metrics before any workload-specific +promotion decision. + +Use `native_candidate_validate_no_augment` when you want a narrower native +executor diagnostic: direct seed candidates plus direct validation, with no +augmentation. It is useful for proving how much speed comes from removing a +phase, but a faster run that loses baseline signatures is still a rejection. + +Use `detector_native_validate_no_augment` when you want to keep the production +detector seed while testing a direct-provider validation path. It is not a +no-DataDesigner strategy because the detector seed still runs through the +adapter, but it tells you whether DataDesigner validation/augmentation is the +load-bearing part of a workload. The native validation shim preserves +`date_of_birth` over broader `date` reclassifications only when the local +candidate context contains birth/DOB language; generic filing or event dates can +still be reclassified to `date`. + +Use `detector_native_validate_native_augment` for the same detector-seed +question when augmentation recall is expected to be load-bearing. This arm still +uses DataDesigner for the detector seed, but direct provider calls own both +validation and augmentation. + +Use `gliner_native_validate_no_augment` or +`gliner_native_validate_native_augment` when the question is specifically +"what if GLiNER did not run through DataDesigner?" These strategies use the +staged direct executor's GLiNER seed client, which defaults to +`https://integrate.api.nvidia.com/v1`, model `nvidia/gliner-pii`, and the +`NVIDIA_API_KEY` environment variable. The no-augmentation arm is a lower-cost +boundary; the native-augmentation arm is the quality-oriented no-DataDesigner +candidate. The integrated benchmark strategies execute staged direct rows with +bounded parallelism so hosted GLiNER and native validation/augmentation latency +is not serialized across records. These arms also normalize direct GLiNER +`date` seeds to `date_of_birth` only when the local seed context contains +birth/DOB language. +Generic filing or event dates remain `date`. Both arms still need repeated +signature, leak, label-mismatch, and reliability gates before any +workload-specific promotion. + +Use `native_single_pass`, `native_single_pass_recall`, +`native_single_pass_values`, or `native_single_pass_values_recall` for the more +aggressive "collapse detection to one call" experiment. The first pair asks the +model for `start`/`end` offsets and validates them before falling back to exact +value matching. The value-only pair uses the standalone direct-probe prompt and +lets local code recover spans from exact returned values. Recall variants spend +more prompt tokens on label examples and high-recall guidance. All one-call +variants are expected to be faster than staged native detection when the prompt +works, but they are also more parser- and recall-sensitive. Any malformed JSON +response becomes a failed case in analysis, and any missed baseline signature +should be treated as a rejection rather than a latency win. + +Replacement-map generation has a separate benchmark-only knob: + +```yaml +configs: + - id: structured-local-substitute + experimental_detection_strategy: rules_covered_or_default + experimental_replacement_strategy: local_structured_substitute + detect: + entity_labels: [api_key, email, password, url] + replace: + strategy: substitute +``` + +Supported replacement strategy values: + +- `default`: run normal replacement behavior. `Substitute` uses the configured + `replacement_generator` role through DataDesigner. +- `local_structured_substitute`: for `replace: substitute`, build deterministic + synthetic replacement maps locally for supported structured labels. Text + replacement still uses Anonymizer's normal replacement-map application code. + +`local_structured_substitute` requires `replace: substitute` and explicit +`detect.entity_labels`. Every label must be one of the structured substitute +labels: `api_key`, `date_of_birth`, `email`, `organization_name`, `password`, +`http_cookie`, `pin`, `religious_belief`, `street_address`, `unique_id`, `url`, +or `user_name`. The preflight rejects contextual labels such as `person`. This +is deliberate. The local substitute map generator does not understand names, +social relations, cultural consistency, or prose semantics; use the default +DataDesigner-backed `Substitute` path for those workloads. + The runner refuses to write into a non-empty output directory unless `--overwrite` is set. By default it also exports Parquet tables into `tables/`; pass `--no-export` when you only want the raw measurement JSONL. @@ -112,6 +334,325 @@ values, replacement values, secrets, and PII. Treat them as debug artifacts: keep them out of shared benchmark bundles unless they have been reviewed or redacted. +To summarize traced calls without copying raw prompts or responses into the +analysis output, run: + +```bash +uv run python tools/measurement/analyze_dd_traces.py \ + benchmark-runs/suite-id/traces \ + --output benchmark-runs/suite-id/trace-analysis \ + --format csv +``` + +This writes `trace_analysis.*` and `trace_group_analysis.*`. The row table +captures run tags, workflow/model metadata, status, elapsed time, prompt and +response lengths, token counts, and response-shape flags such as `raw_json`, +`fenced_json`, `embedded_json`, `text`, and `none`. The grouped table rolls those +fields up by workload, config, workflow, model, provider, status, error type, +and response shape. Use this when diagnosing local provider behavior, parser +compatibility, unexpected thinking text, or retry-heavy workflows. + +Some OpenAI-compatible local endpoints return raw JSON when their model config +uses `response_format: {type: json_object}`. DataDesigner structured recipes +currently prompt for markdown-fenced JSON, so those raw JSON responses can be +valid but still fail parsing. Set top-level `dd_parser_compat: raw_json` when a +benchmark suite needs this provider compatibility mode: + +```yaml +dd_parser_compat: raw_json +``` + +This is benchmark-only behavior. The runner patches DataDesigner structured +parser builders for the duration of a case, restores them afterward, and records +the mode in `run_tags.dd_parser_compat`. The fallback accepts either pure raw +JSON or a JSON object/array embedded after model reasoning text, then still +validates the extracted object against the requested schema. Keep the default +`none` unless a local provider or vLLM endpoint needs raw-JSON structured-output +compatibility. + +## DD-Free Direct Detection Probe + +Use `direct_detection_probe.py` to test a deliberately DD-free extraction path +against an OpenAI-compatible endpoint. This is a benchmark-only diagnostic: it +does not call DataDesigner, does not run GLiNER, and does not execute the +production detection graph. It sends one direct chat-completions request per +input row, then reuses Anonymizer's existing span postprocessing, occurrence +expansion, overlap resolution, and entity signature logic so results can be +compared against normal detection artifacts. + +Example biography probe: + +```bash +uv run python tools/measurement/direct_detection_probe.py \ + docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography \ + --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ + --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/biographies-slice-1__biography-prose-default__r000.detection-artifacts.jsonl \ + --output /tmp/anonymizer-perf-explore/out-direct-detection-probe-biography \ + --overwrite \ + --json +``` + +Example legal probe: + +```bash +uv run python tools/measurement/direct_detection_probe.py \ + docs/data/TAB_legal_sample25.csv \ + --text-column text \ + --labels application_number,city,country,date,date_of_birth,nationality,person \ + --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/legal-slice-2__legal-prose-default__r000.detection-artifacts.jsonl \ + --output /tmp/anonymizer-perf-explore/out-direct-detection-probe-legal \ + --overwrite \ + --json +``` + +The tool writes `direct-detection-cases.jsonl`, +`direct-detection-artifacts.jsonl`, and `summary.json`. Case rows include model +usage, elapsed time, raw/allowed suggestion counts, final label counts, final +signature hashes, and optional baseline comparison counts. Artifact rows use +the same opaque signature fields as `analyze_detection_artifacts.py` and omit +raw entity values. For baseline comparison, pass a per-case sidecar or another +artifact file with one row per `row_index`; duplicate row indexes are rejected +so a combined multi-case artifact cannot silently select the wrong baseline. + +When this probe shape is promising, move it into a normal benchmark suite with +`experimental_detection_strategy: native_single_pass_values` or +`native_single_pass_values_recall`. Those strategies use the same value-only +prompt family but run through `run_benchmarks.py`, measurement collection, case +retries, artifact capture, and pairwise strategy comparison. + +Interpret this probe as a lower-friction model-call experiment, not a safe +replacement for detection. A local one-row smoke against +`nvidia/nemotron-3-super` with vLLM JSON mode and thinking disabled produced: + +- Biography: 4.1s, 906 total tokens, 19 final signatures, 18/22 baseline + signatures shared; misses included `field_of_study` and `place_name`. +- Legal: 4.9s, 1,308 total tokens, 21 final signatures, 19/22 baseline + signatures shared; misses included `date`, `date_of_birth`, and + `nationality`. + +That result makes a DD-free native executor worth exploring, but only if it +preserves the production safety decomposition (`GLiNER/rules -> validate -> +augment -> finalize`). The one-shot direct prompt is useful as a speed/quality +boundary, not as a production candidate. + +## DD-Free Staged Detection Probe + +Use `staged_detection_probe.py` to test a more conservative DD-free route. This +probe still avoids DataDesigner, but it does not collapse detection into one +model response. It can run direct LLM seed extraction, direct GLiNER seeding, +deterministic rule seeding, trusted deterministic rule seeding, or rule-routed +DD-free execution. It then runs direct validation and direct augmentation unless +trusted rules or the rule router short-circuit are selected, where rule spans +bypass validation. It reuses Anonymizer's existing row-level postprocessing +helpers for validation application, augmentation merge, occurrence expansion, +overlap resolution, and artifact signatures. + +Example biography probe: + +```bash +uv run python tools/measurement/staged_detection_probe.py \ + docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography \ + --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ + --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/biographies-slice-1__biography-prose-default__r000.detection-artifacts.jsonl \ + --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography \ + --overwrite \ + --json +``` + +Example legal probe: + +```bash +uv run python tools/measurement/staged_detection_probe.py \ + docs/data/TAB_legal_sample25.csv \ + --text-column text \ + --labels application_number,city,country,date,date_of_birth,nationality,person \ + --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/legal-slice-2__legal-prose-default__r000.detection-artifacts.jsonl \ + --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-legal \ + --overwrite \ + --json +``` + +To replace the LLM seed phase with a direct GLiNER call, add +`--seed-source gliner`. The default GLiNER endpoint is NVIDIA-hosted +`https://integrate.api.nvidia.com/v1` with model `nvidia/gliner-pii`; it reads +the API key from `NVIDIA_API_KEY`. + +To replace the LLM seed phase with deterministic local rules, add +`--seed-source rules`. This still sends rule candidates through the validator. +Use `--seed-source rules-trusted` to bypass validation for high-confidence rule +spans and run only augmentation afterward. The trusted mode is a diagnostic for +rule-covered workloads; it is not a general prose/legal safety default. +Use `--seed-source rules-plus-direct-llm` to add deterministic rule spans to +direct LLM seed spans while validating only the direct LLM seed candidates. This +tests a mixed native path where obvious structured secrets are trusted locally +without giving up contextual model seeding for the rest of the record. +Use `--seed-source rules-router` to make that split explicit: if every requested +label is supported by deterministic rules, the probe runs trusted local rules +with no model calls; otherwise it falls back to `rules-plus-direct-llm`. +When the requested labels are all covered by deterministic rules, add +`--skip-augmentation-when-rule-covered` to measure a fully local short-circuit +with no model calls. +Use `--skip-augmentation` to disable augmentation for any seed source. This is +only a diagnostic for measuring how much recall the augmentation phase carries; +signature loss should reject the candidate even when latency improves. + +To test whether direct validation can preserve the phase boundary with less +prompt text, add `--validation-prompt-mode chunked-excerpt`. This splits seed +validation candidates into chunks of `--validation-max-entities-per-call` and +sends each chunk with a tagged local excerpt bounded by +`--validation-excerpt-window-chars`. The default remains `full-text`, which +keeps the prior one-call behavior. Treat this as a request-count/token tradeoff: +chunked excerpts can reduce prompt payload, but they also create more validator +requests and can remove context needed for labels such as legal roles, +education, demographics, or prose locations. + +The tool writes `staged-detection-cases.jsonl`, +`staged-detection-artifacts.jsonl`, and `summary.json`. Case rows include +per-phase usage for seed extraction, validation, and augmentation, true case +wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, plus +`phase_model_work`, `phase_skip_reasons`, `phase_model_requests`, +`model_phase_count`, `model_request_count`, total usage, and optional baseline +signature deltas. Use these fields to distinguish local work, provider latency, +and a provider that returned no token accounting. +For example, a fully local rule-covered run should show `model_phase_count: 0`, +`model_request_count: 0`, `rule_covered_label_set: true`, and +`phase_skip_reasons.augmentation: "rule_covered_labels"`; `elapsed_sec` should +still capture the local rule/postprocess wall time while `model_elapsed_sec` +remains `0.0`. A chunked-excerpt validation run should usually keep +`model_phase_count` unchanged while raising `phase_model_requests.validation`. + +To summarize those staged probe outputs without hand-written `jq`, run: + +```bash +uv run python tools/measurement/analyze_staged_detection_output.py \ + /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography \ + --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography/analysis \ + --format csv +``` + +The analyzer accepts either the staged output directory or the +`staged-detection-cases.jsonl` file directly. It writes per-case, per-seed +source group, and label-delta tables. Use `group_analysis.csv` for latency, +token, request, and signature-overlap totals; use `label_delta_analysis.csv` to +see which labels account for baseline-only misses or direct-only additions. The +analysis tables still omit raw text and raw entity values. + +The grouped table also includes a conservative `fast_lane_verdict`: + +- `fast_lane_candidate`: every case completed, every case was fully + rule-covered, the seed-source group has at least three cases, model requests + were zero, and baseline comparison found no missing signatures. +- `reject`: at least one case errored or the candidate lost any baseline + signature. +- `review`: baseline comparison is missing, fewer than three cases were + analyzed, the candidate still used model calls, or not every case was fully + rule-covered. + +Use `fast_lane_candidate` only as a workload-scoped promotion signal. It does +not prove that the same no-DataDesigner path is safe for prose/legal labels or +for data shapes outside the sampled suite. + +A refreshed local one-row smoke against `nvidia/nemotron-3-super` with vLLM JSON +mode and thinking disabled produced: + +- Biography: 13.7s, 4,550 total tokens, 24 final signatures, 20/22 baseline + signatures shared. The staged path recovered two signatures missed by the + one-shot direct probe, but still missed an `age` and a `place_name` signature + and added four direct-only signatures. +- Legal: 17.5s, 6,425 total tokens, 21 final signatures, 19/22 baseline + signatures shared. This did not improve signature overlap over the one-shot + direct probe and was materially slower. + +A direct hosted GLiNER seed smoke reached NVIDIA's endpoint but failed before +local validation with `DEGRADED function cannot be invoked` for +`nvidia/gliner-pii`. Keep the `--seed-source gliner` mode as a native executor +option, but do not treat hosted GLiNER availability as stable for local +performance conclusions. + +Rules seeding changed the tradeoff. On biography row 0, `rules` took 6.1s and +1,565 tokens but shared only 17/22 baseline signatures; `rules-trusted` took +5.2s and 1,019 tokens and shared 18/22. On legal row 0, `rules` took 7.1s and +2,213 tokens with 20/22 shared signatures; `rules-trusted` took 6.4s and 1,431 +tokens with the same 20/22 shared signatures. On the three-row shell-secrets +slice, `rules` exposed a validation regression: the validator reclassified a +database URL as a password, leaving row 1 with 2/3 shared baseline signatures. +`rules-trusted` preserved all shell baseline signatures and reduced each row to +one augmentation call, but that no-op augmentation still consumed 398-533 tokens +per row. With `--skip-augmentation-when-rule-covered`, the same trusted-rules +shell run preserved all 12 baseline signatures with zero model usage. Use this +as evidence for a native executor with rule-covered short circuiting, not as +evidence that trusted rules are safe for arbitrary text. + +Interpret this as evidence for native orchestration, not as a ready strategy. +The staged shape is closer to Anonymizer's safety model than one-shot +extraction, but the naive direct prompts spend too many tokens. The next useful +experiment is a native executor that preserves the same phase boundaries while +using compact production-equivalent prompts, direct provider clients, and a +cheap deterministic or detector-backed seed phase instead of LLM-seeded +extraction. + +## No-DataDesigner Strategy Pivot + +The strongest current performance signal comes from not invoking +DataDesigner at all for records whose requested labels and text shape are +covered by deterministic structured-secret extractors. On a local shell/structured-secret slice, +the staged `rules-router` path preserved every compared baseline signature with +zero model requests and millisecond-level elapsed time. In full Anonymizer +benchmarks, `rules_covered_or_default` plus `local_structured_substitute` +reduced structured substitute workloads by 38-99% wall time and removed most or +all observed model tokens, depending on whether the run still fell back to +default detection. + +The benchmark harness now has several integrated native strategies for that +next experiment. `native_rules_router` reuses the staged DD-free executor inside +Anonymizer's detection workflow, so benchmark cases still exercise the normal +replacement and measurement plumbing. `native_candidate_validate_no_augment` +removes augmentation to isolate the recall cost of that phase. +`detector_native_validate_no_augment` keeps the default detector seed and +switches only validation to direct provider calls. `native_single_pass` is the +more radical variant: it asks the local provider for all spans in one JSON +response and then lets Anonymizer validate offsets and finalize entities +locally. Use these arms to compare native provider calls against the +DataDesigner-backed `default` strategy on the same workloads. + +Treat that as a workload router, not a global replacement. The same DD-free +direct LLM approach on biography and legal prose still lost roughly a quarter +to a third of baseline signatures in repeated local probes, even though it +avoided DataDesigner. That is not an anonymization-safe trade by itself. The +current evidence points to three separate lanes: + +- **Structured fast lane:** if the explicit labels are all deterministic-rule + labels and rule extraction covers the workload, skip DataDesigner, skip model + calls, and use local redact/hash/substitute. This is the most promising path + for shell history, secrets, config files, audit logs, and similarly keyed + records. +- **Native model lane:** for prose or mixed records, preserve the production + detection decomposition but call providers directly: seed, validate, augment, + finalize. The prototype exists as `staged_detection_probe.py`, and the + benchmark harness includes detector-seeded and native-seeded variants, but + their current prompts are still research prompts and are too lossy/costly to + promote. +- **Single-pass model lane:** for a sharper boundary test, collapse prose or + mixed detection into one direct JSON span extraction call. This only becomes + interesting if it preserves baseline signatures; parser errors, invalid + offsets, or missed signatures should send the workload back to the default + pipeline. +- **Safety fallback:** route unsupported labels, uncertain text shapes, direct + parser failures, and signature-loss evidence back to the normal + DataDesigner-backed pipeline until a native executor proves equal or better + recall on repeated workload-specific comparisons. + +This changes the performance strategy from "make every DataDesigner phase +faster" to "avoid DataDesigner when the safety case is trivial, and use +DataDesigner as the fallback for hard cases." The benchmark interpretation +should therefore privilege signature coverage, original-value leak checks, +source provenance, and reliability flags over raw latency wins. A no-DD result +that is faster but loses baseline signatures remains a reject; a no-DD result +that is fully rule-covered, leak-free, and stable across repetitions is a +candidate for a production fast lane. + ## Output Layout A benchmark run writes one raw measurement file per case, then combines them: @@ -120,11 +661,13 @@ A benchmark run writes one raw measurement file per case, then combines them: benchmark-runs/suite-id/ raw/ biographies__redact-default__r000.jsonl + biographies__redact-default__r000.detection-artifacts.jsonl support__hash-agent-labels__r000.jsonl traces/ biographies__redact-default__r000.jsonl measurements.jsonl summary.json + detection-artifacts.jsonl tables/ manifest.json run.parquet @@ -138,11 +681,388 @@ long run leaves inspectable partial output before the case exits. The combined `measurements.jsonl` is written after the completed and errored case files are collected. -Use `summary.json` to inspect case status and errors. Use `measurements.jsonl` -when you need the original structured records. Use `tables/` for analysis. +Use `summary.json` to inspect case status, retry attempts, and errors. If a +case succeeds after retry, the combined `measurements.jsonl` contains the final +successful attempt while `summary.json` preserves the earlier failure messages. +Use `measurements.jsonl` when you need the original structured records. Use +`tables/` for analysis. Use `traces/` only when `--dd-trace` was enabled and you need raw DataDesigner message-level debugging. +Detection workflow artifacts can be analyzed separately when you need to know +whether augmentation helped or only added cost. `run_benchmarks.py` writes +`detection-artifacts.jsonl` automatically when export is enabled and detection +artifacts are present. The automatic export analyzes each case immediately after +it runs, then combines per-case sidecars from `raw/`; rows include `suite_id`, +`workload_id`, `config_id`, `repetition`, `case_id`, and `run_id` so they can be +joined to `measurements.jsonl` and exported tables. `rules_only` cases do not +produce DataDesigner parquet artifacts, so the runner writes a synthetic +rules-only sidecar from the same deterministic rules. That sidecar includes +counts, source=`rule`, and opaque entity signatures, but not raw entity values. +Routed strategies whose final entity set can differ from raw DataDesigner +artifacts, including row-aware `rules_covered_or_default`, write sidecars from +the final trace dataframe so rule-routed and fallback-routed rows are both +represented. + +Row-aware routed strategies also emit sanitized route telemetry into +`measurements.jsonl`, and `analyze_benchmark_output.py` surfaces it in +`case_analysis.*` and `group_analysis.*`. Use `route_total_row_count`, +`route_rule_row_count`, and `route_fallback_row_count` to confirm how many rows +used the zero-model rules lane versus the normal detection fallback before +interpreting request, token, or latency deltas. +You can also run the analyzer by hand against an artifact directory: + +```bash +uv run python tools/measurement/analyze_detection_artifacts.py \ + benchmark-runs/suite-id/artifacts \ + --output benchmark-runs/suite-id/detection-artifacts.jsonl +``` + +The analyzer reads `entity-detection*` parquet artifacts and emits one row per +artifact row. It reports seed, augmentation, and final entity counts; duplicate +augmentation suggestions; new augmented values that survived into final +entities; final label/source counts; and weak `api_key` shape warnings. The +output intentionally omits raw entity values. + +Use this alongside the exported measurement tables when comparing +`default` against `no_augment`: + +- High `augmented_duplicate_seed_value_count` with low + `augmented_new_final_value_count` means augmentation probably added cost + without improving that case. +- High `augmented_new_final_value_count` means augmentation found spans that + the detector+validator path missed. +- High `weak_api_key_shape_count` usually means the label set is mismatched to + the workload. For example, legal prose constrained to + `[person, email, api_key, password]` can force dates or case identifiers into + `api_key` because better prose labels are unavailable. + +For a ready-made case and grouped summary that joins `measurements.jsonl` with +`detection-artifacts.jsonl`, use: + +```bash +uv run python tools/measurement/analyze_benchmark_output.py \ + benchmark-runs/suite-id \ + --output benchmark-runs/suite-id/analysis \ + --format csv +``` + +By default this joins `benchmark-runs/suite-id/measurements.jsonl` with +`benchmark-runs/suite-id/detection-artifacts.jsonl`. To use a refreshed or +relocated sidecar that still contains benchmark case metadata, pass it +explicitly: + +```bash +uv run python tools/measurement/analyze_benchmark_output.py \ + benchmark-runs/suite-id \ + --detection-artifacts benchmark-runs/suite-id/current-analysis/detection-artifacts.jsonl \ + --output benchmark-runs/suite-id/current-analysis \ + --format csv +``` + +The override sidecar must include `case_id` or `run_id` values that match the +measurement rows. A raw artifact scan produced from only the DataDesigner +parquet directory can summarize detection artifacts, but it cannot be safely +joined to benchmark measurements unless benchmark case metadata is preserved. + +This writes `case_analysis.*`, `group_analysis.*`, `model_analysis.*`, and +`model_group_analysis.*`. It keeps fully local cases with no model workflow +rows, such as rule-covered `rules_only` or `native_rules_router` cases, in the +comparison with zero observed requests/tokens. Native direct-call strategies +that bypass DataDesigner write `model_workflow` rows, so their provider request +and token counts still contribute to case, group, and model summaries. When the +benchmark was run with current sidecar export, `rules_only` also has +artifact-derived signatures and source counts; older runs may only have +record-level entity counts. The joined case/group tables include +successful/failed request counts, input/output token splits, +`seed_validation_candidate_count`, `estimated_seed_validation_chunk_count`, and +`observed_failed_request_rate`; use these when testing +`detect.validation_max_entities_per_call` so you can distinguish a real chunk +count change from provider retry variance. The model tables split the same +usage by `workflow_name` and `model_name`, which is useful for separating local +detector cost from validator, augmenter, substitute, or rewrite model cost. +The case/group tables also surface incomplete benchmark cases with +`case_failed`, `error_stage_count`, `error_ndd_workflow_count`, +`error_model_workflow_count`, `failed_case_count`, and `failed_case_rate`. +Check these before interpreting a fast candidate as a safe improvement; a +failed repetition can otherwise look like entity instability or a latency win. +The joined case/group tables also expose final entity source counts from +detection artifacts, including `artifact_final_detector_entity_count`, +`artifact_final_rule_entity_count`, and +`artifact_final_augmenter_entity_count`. Use these to verify whether a faster +strategy is still relying on contextual detector/validator spans, or whether it +has shifted a workload entirely onto deterministic rules. +They also include `artifact_final_entity_signature_count` and +`artifact_final_entity_signature_hashes`, which are opaque per-row hashes of the +final entity label, normalized value, and offsets. The companion +`artifact_final_entity_signature_labels` field maps each opaque hash to its +entity label. These fields do not expose raw entity values, but they let +analysis tools detect when two configs report the same entity count while +protecting different spans. + +To compare a baseline and candidate strategy across common workloads, use: + +```bash +uv run python tools/measurement/compare_strategy_pairs.py \ + benchmark-runs/suite-id/analysis/case_analysis.csv \ + --baseline-strategy no_augment \ + --candidate-strategy rules_filter_guardrail_no_augment \ + --output benchmark-runs/suite-id/analysis/strategy_comparison.csv +``` + +If the candidate was run in a separate benchmark directory, pass a second case +analysis file: + +```bash +uv run python tools/measurement/compare_strategy_pairs.py \ + benchmark-runs/baseline-suite/analysis/case_analysis.csv \ + --candidate-case-analysis benchmark-runs/candidate-suite/analysis/case_analysis.csv \ + --baseline-strategy no_augment \ + --candidate-strategy rules_guardrail_no_augment +``` + +The comparison reports latency, request, token, entity-count, validation +candidate-count, augmentation-count, final source-count, and opaque +entity-signature deltas. It also reports original-value leak deltas from +`original_value_leak_count` and `original_value_leak_record_count`. The +`augmented_entity_count_delta` and +`augmented_new_final_value_count_delta` columns are especially useful for +no-augmentation and model-routing ablations: a faster candidate that removes +new final values from augmentation needs signature checks before promotion. +When signature labels are available, it also reports label counts for +baseline-only, candidate-only, and shared signatures. For repeated selector +runs, it also compares signatures that are stable across every repetition, +which catches cases where a candidate finds a sensitive span only +intermittently. It adds conservative flags such as +`baseline_case_failures`, `candidate_case_failures`, `entity_count_loss`, +`entity_signature_loss`, `span_boundary_mismatch`, +`covered_label_mismatch`, +`candidate_original_value_leak`, +`candidate_replacement_missing_final_entity`, +`candidate_duplicate_synthetic_replacement`, +`failed_request_increase`, `bridge_fallback_increase`, +`stable_entity_signature_loss`, `no_candidate_detector_entities`, +`candidate_uses_rule_entities`, `candidate_skips_llm_validation`, and +`replacement_only_detection_instability`, plus five verdict fields: + +- `value_protection_verdict`: `pass`, `review`, or `fail`. This axis focuses on + whether the candidate still protects the sensitive values. Candidate case + failures, candidate original-value leaks, missing replacement-map entries, + replacement collisions, and uncovered baseline signatures fail. Rule + provenance, validation skipping, + provider retry pressure, and covered boundary or label mismatches do not fail + this axis by themselves; they are represented in the semantic and overall + safety verdicts. +- `signature_parity_verdict`: `pass`, `review`, or `fail`. This axis focuses on + exact baseline signature semantics. Covered label or boundary mismatches stay + review-gated even when `value_protection_verdict` passes. +- `safety_verdict`: `pass`, `review`, or `fail`. Candidate case failures and + entity/signature loss fail. Candidate original-value leaks also fail, even + when entity signatures match. Baseline case failures, baseline + original-value leaks, rule-only, rule-heavy, or validation-skipping + candidates require review. Candidate provider failed-request increases or + bridge-fallback increases also require review: they are reliability signals, + not anonymization leaks. +- `performance_verdict`: `improved`, `mixed`, `regressed`, `unchanged`, or + `unknown`, based on available latency, request, and token deltas. +- `candidate_verdict`: `candidate_viable`, `review`, or `reject`. A candidate + is viable only when safety passes and measured performance improves. + +Use verdicts for triage, then inspect the underlying flags and label-count +deltas before promoting a strategy beyond benchmark experiments. +For replacement-only comparisons where the detection strategy is unchanged, +`replacement_only_detection_instability` means the candidate and baseline were +still run through independent detection passes and their detection artifacts +drifted. Treat that as a prompt to consult fixed-trace replacement replay before +blaming or promoting the replacement-map backend. +In fixed-trace replacement replay, +`candidate_duplicate_synthetic_replacement` means the local replacement backend +protected every original value but collapsed at least two replacements in the +same row to the same synthetic value. That is review-gated as a substitute +quality and relational-consistency concern rather than treated as an immediate +privacy leak. +When the replay CSV contains +`candidate_covers_baseline_replacement_missing_final_entity`, +`candidate_covers_baseline_original_value_leak`, or +`candidate_covers_baseline_replacement_synthetic_original_collision`, the +candidate removed a defect observed in the DataDesigner-backed substitute arm +on the same fixed detection trace. In that case `value_protection_verdict` can +pass while `signature_parity_verdict` remains review-gated, because the +candidate covered more of the final-entity set than the flawed baseline. + +For `rules_covered_or_default`, compare rule-covered configs by config ID so +the zero-model lane is checked against the same explicit label set: + +```bash +uv run python tools/measurement/compare_strategy_pairs.py \ + benchmark-runs/suite-id/analysis/case_analysis.csv \ + --baseline-config rule-labels-default \ + --candidate-config rule-labels-covered-or-default \ + --output benchmark-runs/suite-id/analysis/rules-covered-comparison.csv +``` + +Promote the fast path only when +`baseline_only_candidate_uncovered_signature_count` is zero on the target +workload, `candidate_original_value_leak_count` is zero, `candidate_verdict` is +at least `review`, and the review flags are expected rule fast-lane flags such +as `candidate_uses_rule_entities`, `no_candidate_detector_entities`, +`entity_count_loss`, or `span_boundary_mismatch`. Exact +`baseline_only_final_entity_signature_count` can be nonzero when a candidate +protects the same sensitive value with a wider or slightly narrower keyed span; +use the covered/overlapping/uncovered columns to decide whether that is an +acceptable workload policy. A run that has uncovered signatures or leaks +original detected values should reject: +in the June 8, 2026 sudo-password smoke run, the pre-fix comparison rejected the +candidate with `lost_labels=password:1`; after the narrow sudo rule was added, +the same comparison had no baseline-only signatures and remained review-gated +only because the final spans were rule-sourced. + +The command output also includes a rollup summary with verdict counts and the +workloads in each candidate-verdict bucket, which is useful for repeated runs +over larger suites. + +To screen many comparison CSVs from one or more benchmark directories, use: + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py \ + benchmark-runs/ \ + --output benchmark-runs/strategy-screen.csv \ + --group-output benchmark-runs/strategy-groups.csv +``` + +When screening a scratch directory that contains older analysis outputs, filter +by source-path fragments: + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py \ + /tmp/anonymizer-perf-explore \ + --source-include analysis-current-csv \ + --source-include analysis-failure-aware-csv \ + --output current-strategy-screen.csv \ + --group-output current-strategy-groups.csv +``` + +Use `--source-exclude` to omit known stale or exploratory subdirectories. +For example, if a scratch directory contains a pre-fix comparison and a rerun, +screen only current evidence by excluding the stale source-path fragment: + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py \ + /tmp/anonymizer-perf-goal \ + --source-include comparison \ + --source-exclude before-sudo \ + --source-exclude structured-secrets-varied-comparison.csv \ + --output /tmp/anonymizer-perf-goal/strategy-screen-current.csv \ + --group-output /tmp/anonymizer-perf-goal/strategy-screen-current-groups.csv +``` + +The screen walks CSV files recursively, ignores non-comparison tables such as +`case_analysis.csv` and `group_analysis.csv`, and combines rows produced by +`compare_strategy_pairs.py`. It deduplicates exact repeated rows from copied +analysis directories, then sorts viable candidates first, then review and reject +rows, preserving latency/token deltas, flags, lost-label summaries, and +augmentation deltas. It also preserves baseline/candidate case counts, +baseline/candidate detection strategies, baseline/candidate replacement +strategies, stable-signature evidence counts, and candidate original-value leak +counts and labels. For DataDesigner-free experiments, it also preserves +`value_protection_verdict`, `signature_parity_verdict`, and label-mismatch +label counts, so one-off candidate rows are visible as weak evidence even before +opening the source comparison CSV. This is the quickest way to check whether a +benchmark directory contains any candidate worth rerunning on a larger workload +slice. + +Use the `evidence_level` column to separate current safety evidence from older +or weaker comparison rows. `split_verdicts` means the row has separate value +protection and signature-parity verdicts, `stable_signatures` means it has +stable-signature counts but not split verdicts, `signature_counts` means it only +has raw signature counts, and `legacy` means the screen can only use the older +aggregate verdict columns. The group output includes `evidence_level_counts` so +mixed scratch directories do not make a legacy row look as strong as a current +split-verdict rerun. + +The optional group output aggregates rows by candidate strategy when the +candidate used a non-default experimental strategy, or by candidate config +otherwise. This keeps ordinary config experiments, such as model routing or +prompt-parameter changes, from being collapsed under `strategy:default`. When +the same experiment used multiple config IDs, pass a JSON alias map: + +```json +{ + "biography-hybrid-augment-temp07": "biography-temp07-routing", + "biography-augment-temp07": "biography-temp07-routing" +} +``` + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py \ + benchmark-runs/ \ + --group-by strategy_workload_family \ + --config-aliases config-aliases.json \ + --group-output benchmark-runs/strategy-family-groups.csv +``` + +Aliases only affect default-strategy, default-replacement config grouping. +Non-default experimental detection strategies still group by strategy; when a +candidate also uses a non-default replacement strategy, the group key appends +`replacement:`. If detection is default and only replacement changes, +the group key is `replacement:`. Use the group output to find +candidates with conflicting evidence, such as a no-augmentation candidate that +passes one slice and rejects on another. The +group table includes both best and worst latency, token, and request deltas so a +single fast slice does not hide a slower or unsafe repeat. It also includes +minimum baseline/candidate case counts and the minimum shared stable-signature +count observed in the group, plus summed candidate original-value leak counts +and leak labels. The +`recommendation` column is deliberately conservative: +`single_slice_viable` means one viable row exists but needs repeat evidence, +`candidate_family_viable` requires two or more viable rows and no review or +reject rows, `promising_needs_review` means viable rows exist but review-gated +rows remain and at least one split-verdict row is also viable, +`needs_split_verdict_rerun` means viable-looking and review-gated rows exist but +the group has only older signature-count or stable-signature evidence, or a +review-only group mixes current split-verdict rows with older comparison rows +that should be rerun under the current verdict schema, +`needs_viable_split_verdict` means older viable rows exist and split-verdict +evidence exists, but every split-verdict row is still review- or reject-gated, +`replacement_replay_review` means an improved replacement-strategy group is +review-gated by detection artifact drift even though the detection strategy did +not change; use fixed-trace replacement replay to isolate replacement-map +behavior, +`reliability_review` means every row improved performance but one or more rows +are review-gated by provider reliability signals such as failed-request or +sync-bridge fallback increases, +`fast_lane_review` means a `rules_only` or +`rules_covered_or_default` group improved performance, had explicit zero +candidate original-value leaks, had no uncovered baseline signatures, and is +review-gated only by expected rule fast-lane provenance or span-boundary flags, +`label_policy_review` means every row improved performance, passed +`value_protection_verdict`, and was review-gated on `signature_parity_verdict` +because the candidate protected a baseline value under a different label, +`review_only` means the family has no failures, still needs manual review, and +every review-gated row is `improved`, +`review_mixed_performance` +means the family has no failures but has mixed performance evidence, +`no_performance_win` means review-gated rows exist without an improvement +signal, `reject` means no viable rows survived, and `conflicting_evidence` means +at least one viable row and at least one rejected row exist for the same +candidate family. + +When a strategy's safety depends on workload shape, group by workload family: + +```bash +uv run python tools/measurement/screen_strategy_comparisons.py \ + benchmark-runs/ \ + --group-by strategy_workload_family \ + --output benchmark-runs/strategy-screen.csv \ + --group-output benchmark-runs/strategy-family-groups.csv +``` + +This keeps evidence from families such as shell-secret command logs, legal +records, and biographies separate. Use this mode before claiming a broad +performance improvement from a strategy that may only be safe on rule-covered +secret workloads. Use `--group-by strategy_workload` for an even stricter +per-workload grouping. + The exporter groups records by `record_type`: - `run`: one row per Anonymizer run, with sanitized config, workload, model, and @@ -155,6 +1075,12 @@ The exporter groups records by `record_type`: - `ndd_workflow`: one row per DataDesigner adapter call, with model aliases, elapsed time, row counts, failed-record counts, and observed token/request usage when DataDesigner exposes it. +- `model_workflow`: one row per non-DataDesigner model-backed workflow, such as + `native_rules_router`, `native_candidate_validate_no_augment`, + `detector_native_validate_no_augment`, + `detector_native_validate_native_augment`, `native_single_pass`, and the + other `native_single_pass*` strategies, with the same sanitized usage fields + as `ndd_workflow`. The tables never store raw text, prompts, generated outputs, entity values, or replacement maps. `record_hash` is a run-scoped HMAC, so it can join rows within @@ -172,18 +1098,44 @@ Start with these questions: - Are quality metrics worse on one data shape, such as legal text, biographies, support tickets, shell history, or mixed natural-language/code records? -Most analyses join `stage`, `record`, and `ndd_workflow` back to `run` through -`run_id`, then group by run tags: +Most analyses join `stage`, `record`, `ndd_workflow`, and `model_workflow` back +to `run` through `run_id`, then group by run tags: - `run_tags.suite_id` - `run_tags.workload_id` - `run_tags.config_id` +- `run_tags.experimental_detection_strategy` +- `run_tags.experimental_replacement_strategy` +- `run_tags.dd_parser_compat` - `run_tags.repetition` - `run_tags.case_id` Prefer medians and percentiles over averages when comparing latency. LLM calls usually have long tails, and one retry or provider stall can distort a mean. +For staged DD-free detection probes, convert the probe output first: + +```bash +uv run python tools/measurement/analyze_staged_detection_output.py \ + /tmp/anonymizer-perf-goal/no-dd-rules-plus-direct-biography-r5-current \ + --output /tmp/anonymizer-perf-goal/no-dd-rules-plus-direct-biography-r5-current/analysis \ + --format csv +``` + +Then read `analysis/group_analysis.csv` to compare `elapsed_sec_sum`, +`model_elapsed_sec_sum`, `model_request_count_sum`, `total_tokens_sum`, +`baseline_shared_signature_rate`, and +`baseline_only_final_entity_signature_count_sum`. Use `fast_lane_verdict` as +the first gate: `reject` means stop and inspect losses before running larger +slices; `fast_lane_candidate` means the sampled workload is a plausible +zero-model rule-covered lane with repeated evidence; `review` means the output +is incomplete, has too few cases, or still uses model work. The staged analyzer +requires at least three cases in a seed-source group before a clean zero-model +run can become `fast_lane_candidate`; one-row smokes remain `review` even when +they preserve all compared signatures. Read +`analysis/label_delta_analysis.csv` when the shared-signature rate is low; it +shows which labels drove the baseline-only losses or direct-only additions. + ## Pandas Examples Load exported tables: @@ -263,7 +1215,14 @@ print(by_shape) Summarize DataDesigner token and request usage: ```python -workflow_group_cols = ["run_tags.workload_id", "run_tags.config_id", "workflow_name"] +workflow_group_cols = [ + "run_tags.workload_id", + "run_tags.config_id", + "run_tags.experimental_detection_strategy", + "run_tags.experimental_replacement_strategy", + "run_tags.dd_parser_compat", + "workflow_name", +] token_summary = ( ndd @@ -283,6 +1242,34 @@ token_summary = ( print(token_summary) ``` +Summarize provider usage by workflow and model: + +```python +model_usage = pd.read_csv("benchmark-runs/suite-id/analysis/model_group_analysis.csv") + +retry_sources = ( + model_usage + .sort_values( + ["sum_observed_failed_requests", "sum_observed_total_tokens"], + ascending=[False, False], + ) + [ + [ + "workload_id", + "config_id", + "workflow_name", + "model_name", + "sum_observed_total_requests", + "sum_observed_failed_requests", + "observed_failed_request_rate", + "sum_observed_total_tokens", + ] + ] +) + +print(retry_sources) +``` + Join run metadata to stage timing: ```python @@ -313,6 +1300,480 @@ uv run python tools/measurement/export_measurements.py \ --overwrite ``` +## Signature Delta Review + +Use `extract_signature_deltas.py` when a fast candidate has fewer, more, or +different final entity signatures than a higher-recall reference run. The tool +compares two `detection-artifacts.jsonl` files and recovers local context from +the DataDesigner artifact parquet files. Entity values are masked by default: +the output stores label, source, span offsets, value length, value hash, and a +small context window with the entity replaced by a placeholder. + +Example: review spans found by a text/raw-parser reference but missed by a +hybrid candidate for one workload/config pair: + +```bash +uv run python tools/measurement/extract_signature_deltas.py \ + /tmp/reference/detection-artifacts.jsonl \ + /tmp/candidate/detection-artifacts.jsonl \ + --baseline-artifact-root /tmp/reference/artifacts \ + --candidate-artifact-root /tmp/candidate/artifacts \ + --baseline-config legal-default \ + --candidate-config legal-hybrid-rules-guardrail \ + --workload legal-r2 \ + --output /tmp/legal-signature-deltas.csv \ + --format csv +``` + +Interpretation: + +- `baseline_only` rows are spans the candidate missed relative to the + reference. +- `candidate_only` rows are spans the candidate found that the reference did + not. +- `resolution=parquet` means the span was recovered from DataDesigner's final + detection artifacts. +- `resolution=artifact_details` means the span was reconstructed from + sanitized final signature details plus the artifact row's source text. This + is common for benchmark-only strategies that patch final entities from an + in-memory dataframe after a seed-stage artifact is written. +- `resolution=rule` means the span was reconstructed from deterministic + rule-guardrail logic because it was added after DataDesigner wrote parquet. +- `resolution=metadata_only` means only the opaque signature metadata was + available; use this as a signal to rerun with trace/artifact capture if the + delta matters. + +## Current Local Findings + +These findings come from small local vLLM runs against +`nvidia/nemotron-3-super`; treat them as triage signals, not defaults. + +| Strategy | Latest local result | Status | Implication | +| --- | --- | --- | --- | +| `rules_only` on the three-row shell-secrets slice | Preserved all 12 stable signatures; median latency moved from 7.2s to 0.004s, requests from 12 to 0, and tokens from 11,019 to 0. | Review | Viable only for bounded secret scans where every requested label is covered by deterministic rules. | +| `rules_guardrail_detector_only` on the same shell-secrets slice | Preserved stable signatures and reduced model work, but one candidate repetition failed during GLiNER health checks. | Review | Useful as a structured-secret diagnostic, but less attractive than `rules_only` when labels are fully rule-covered. | +| `rules_filter_guardrail` on the same shell-secrets slice | Retry-enabled rerun completed all 6 cases. It preserved all 12 signatures, reduced seed validation candidates from 11 to 0, median pipeline latency from 8.0s to 3.9s, requests from 12.5 to 7.0, and tokens from 10,966 to 3,647. | Review | Useful as a mixed-workload probe; keep it review-gated because final entity provenance is rule-only for this slice. | +| `rules_filter_guardrail` on a mixed biography/legal/shell probe | After changing rule filtering to preserve different-label overlaps, repeated two-row biography, one-row legal, and three-row shell runs had no stable or unstable signature loss. Median pipeline latency moved from 28.5s to 20.0s on biography, 19.4s to 18.2s on legal, and 8.7s to 6.1s on shell. | Review | Historical positive probe only; the larger five-row non-shell repeat below did not preserve this signal. | +| `rules_filter_guardrail` on offset biography/legal slices | After hardening rule filtering so only fully covered same-label spans are skipped and rule reinsertion is additive, the five-row biography offset slice had no signature loss but moved into review because requests increased slightly while tokens decreased. The richer two-row legal offset slice rejected: latency, requests, and tokens regressed and one repetition missed three `court_name` signatures while adding one rule-backed `date_of_birth`. | Mixed | The hardened strategy is safer than the first version, but it still needs per-workload gates and is not a broad legal/prose default. | +| `rules_filter_guardrail` on current five-row biography/legal repeats | Biography preserved stable and unstable signatures but regressed latency from 37.8s to 45.9s and requests from 20.5 to 22.5. Legal improved latency from 60.6s to 51.2s and tokens from 63,072 to 61,568, but lost five stable `date` signatures and made two stable `person` signatures unstable. | Reject | Do not promote this as a prose/legal default; the safety and latency tradeoff is workload-dependent and fails the legal signature gate. | +| `rules_guardrail` on a five-row legal slice | Same-suite repeated comparison against default preserved stable and unstable signatures, but latency regressed from 39.6s to 47.1s, requests rose from 20.0 to 20.5, and tokens were roughly flat at 60,998 to 60,757. | Mixed | Deterministic date guardrails can improve coverage without signature loss, but they are not a legal-prose performance win on this slice. | +| `detector_only` and `rules_guardrail_detector_only` on prose/legal slices | Faster on one-row smoke checks, but lost baseline signatures on biography and legal samples. A current detector-only isolation rerun moved biography 27.3s → 0.9s and 8,416 → 526 tokens, but lost two `first_name` and one `organization_name` signatures. Legal moved 52.0s → 1.0s and 14,095 → 1,078 tokens, kept `date_of_birth`, but still lost one `date` and one `nationality` signature while adding many extra spans. | Reject | Local finalization alone is not a safe replacement for validation and augmentation on contextual text. The legal rerun is useful diagnostically because raw detector output kept `date_of_birth`, so a later native-validation miss likely came from validation behavior rather than detector seeding. | +| One-shot DD-free direct detection on biography/legal row 0 | Biography completed in 5.1s with 902 tokens but shared only 18/22 baseline signatures. Legal completed in 5.8s with 1,308 tokens but shared only 19/22 baseline signatures. | Reject as replacement | This is a useful speed boundary and prompt experiment, but a single extraction prompt drops core detections on non-shell workloads. | +| Standalone direct-detection five-row probe | A fresh local probe compared one direct extraction call per row against the current staged direct reference. On legal, compact direct detection moved from 62.3s and 15 requests to 17.1s and 5 requests, but shared only 75/147 reference signatures and missed 72. Recall prompting improved legal to 31.1s, 109 final entities, 102 shared, and 45 missed. On biographies, compact direct detection moved from 85.7s and 15 requests to 21.2s and 5 requests, with 91/102 shared signatures, 11 missed, and no extras; recall prompting regressed to 62 shared and 40 missed. Outputs: `/tmp/anonymizer-perf-goal/direct-detection-legal-r5-compact-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-legal-r5-recall-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-biography-r5-compact-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-biography-r5-recall-after-guard`. The benchmark harness can now run this value-only prompt shape through `native_single_pass_values` and `native_single_pass_values_recall`. | Mixed diagnostic | The one-call path is the clearest lower-bound latency test, but it is not a general anonymization-safe replacement. Compact one-call extraction may deserve workload-specific follow-up for biographies; legal still needs augmentation or a stronger candidate source. Recall prompting is not monotonic across domains. | +| Staged DD-free detection on biography/legal row 0 | Biography improved to 20/22 shared signatures but took 13.7s and 4,550 tokens. Legal stayed at 19/22 shared signatures while taking 17.5s and 6,425 tokens. Hosted GLiNER seeding was unavailable due a `DEGRADED function cannot be invoked` response for `nvidia/gliner-pii`. | Mixed diagnostic | A native no-DataDesigner executor is still plausible, but only if it preserves phase boundaries with much cheaper seed/validation prompts or deterministic code. Naive direct LLM phases are not enough. | +| Chunked-excerpt validation in staged DD-free detection | On current one-row reruns, biography preserved the same 20/22 shared signatures as full-text validation but moved from 10.8s, 4,527 tokens, and 3 model requests to 13.3s, 5,648 tokens, and 6 requests. Legal preserved the same 19/22 shared signatures but moved from 14.7s, 6,425 tokens, and 3 requests to 17.2s, 7,727 tokens, and 7 requests. | Reject | Splitting direct validation into local excerpts increases repeated instruction overhead and request count on these non-shell rows. Do not pursue validator excerpting as a standalone no-DD speedup unless longer records show a different request/token crossover. | +| Rules-seeded staged DD-free detection | `rules` improved biography/legal latency but still lost baseline signatures; legal row 0 reached 20/22 shared signatures at 7.1s and 2,213 tokens. On shell-secrets, validation reclassified a database URL as a password and lost one baseline URL signature. `rules-trusted` fixed that shell loss and preserved all 12 shell signatures with one augmentation call per row, but still missed biography/legal signatures. With `--skip-augmentation-when-rule-covered`, trusted rules preserved all 12 shell signatures with zero model usage. | Mixed diagnostic | Deterministic seed spans are useful, but rule-covered spans should not always go through LLM validation. A native executor needs workload gates and should short-circuit locally when every requested label is rule-covered. | +| Rules + direct LLM staged DD-free detection | `rules-plus-direct-llm` preserved all 12 shell-secrets signatures while avoiding validation, but still used two model calls per row and 726-938 tokens because the direct seed and augmentation phases still ran. On row-0 smokes it looked like the most plausible mixed no-DD path: biography shared 20/22 signatures at 10.8s and 4,465 tokens, and legal shared 20/22 at 11.5s and 5,929 tokens. The five-row gate rejected it: biography shared only 80/114 baseline signatures, lost 34 signatures, and took 85.7s versus the DD baseline's 32.9-47.8s; legal shared 108/145, lost 37 signatures, and took 62.3s versus the DD baseline's ~39.5s. | Reject for contextual workloads | Trusting deterministic structured spans locally is still useful, but direct LLM seed/validation/augmentation is not a safe or faster replacement for DataDesigner-backed contextual detection on prose/legal slices. Keep no-DD promotion limited to fully rule-covered structured-secret lanes unless a new native executor passes repeated signature gates. | +| Rules router staged DD-free detection | `rules-router` preserved all 12 shell-secrets signatures with no seed, validation, or augmentation model calls. The mixed/contextual fall-through did not generalize: on five biographies it shared 96/114 default signatures and lost 18 baseline signatures; on five legal rows it shared 86/145 default signatures and lost 59 baseline signatures. The benchmark-safe expression of this result is `rules_covered_or_default`, which short-circuits only fully rule-covered label sets and otherwise runs default detection. | Mixed | Keep the router only for the zero-model rule-covered structured-secret lane. Do not use the direct local LLM fall-through as a prose/legal replacement; use default Anonymizer or another signature-gated strategy for contextual rows. | +| Integrated `native_rules_router` benchmark with corrected direct-call metering | A five-row benchmark-harness run on biography/legal confirmed the staged finding. Biography moved from 32.9s to 85.6s, requests from 20 to 15, and tokens from 43,354 to 26,644, but entities fell from 114 to 102 and 34 baseline signatures were uncovered. Legal moved from 54.3s to 62.3s, requests from 21 to 15, and tokens from 60,649 to 31,894, but 37 baseline signatures were uncovered. Both workloads rejected. | Reject for contextual workloads | Direct native calls can reduce request and token counts while still losing safety and wall-time. Treat lower token counts as insufficient evidence; contextual promotion requires signature preservation and latency improvement together. | +| Integrated `native_candidate_validate_no_augment` smoke | One-row biography/legal benchmark-harness smoke proved the no-augmentation native executor is much cheaper but unsafe. Biography moved from 24.8s to 5.9s, requests from 4 to 2, and tokens from 8,092 to 2,000, but entities fell from 15 to 12 and lost `age`, `first_name`, and `organization_name` signatures. Legal moved from 49.8s to 10.9s, requests from 4 to 2, and tokens from 13,791 to 3,823, but entities fell from 23 to 21 and lost `date`, `date_of_birth`, and `nationality` signatures. Both rows had zero original-value leaks. | Reject for contextual workloads | Removing augmentation from the native executor gives the expected speed boundary, but augmentation or a stronger candidate source remains load-bearing for contextual recall. Keep this arm as a diagnostic, not a promotion candidate. | +| Integrated `detector_native_validate_no_augment` smoke | Keeping the detector seed and replacing DataDesigner validation/augmentation with direct validation is much cheaper, but quality remains workload-dependent. Biography still rejects: latest one-row rerun moved 26.6s -> 6.7s and 8,398 -> 2,347 tokens, but entities fell 15 -> 14 and two augmenter-sourced child `first_name` signatures were uncovered. A focused one-row legal repeat improved median latency from 15.0s to 11.0s, requests from 4 to 3, and tokens from 9,516 to 4,150 with zero leaks. After row-parallel direct validation plus deterministic DOB-context label normalization, a wider three-row, two-repeat legal gate moved median elapsed from 40.6s to 21.3s, requests from 12.5 to 6.5, and tokens from 37,972 to 17,902 with zero original-value leaks. The split verdicts were `value_protection=pass`, `signature_parity=review`, and `performance=improved`: a filing-date span that baseline labeled `date_of_birth` was protected as `date`, while separate birth-context years were added as `date_of_birth`. | Mixed: biography reject, legal label-policy review | The promising shape is not "remove DataDesigner everywhere"; it is "keep DD as fallback, use deterministic fast lanes where provably covered, and only replace validation when a native validator preserves both coverage and label semantics across repeated gates." The legal repeats now show a real latency win, but a DD-free candidate may protect values correctly while disagreeing with DataDesigner label semantics. That should stay review-gated until label policy says whether such covered reclassification is acceptable. | +| Integrated `detector_native_validate_no_augment` substitute gate | A one-row legal substitute smoke first showed the same review shape as the redact gate: latency moved from 21.1s to 15.2s, requests from 5 to 3, and tokens from 12,192 to 6,871 with zero original-value leaks. The wider three-row, two-repeat substitute gate still improved performance, but rejected on safety: median pipeline latency moved from 44.0s to 33.4s, requests from 15.0 to 9.0, tokens from 47,958 to 28,465.5, and failed requests from 3.0 to 0.0, while both baseline and candidate leaked two original `date` values across two row-runs. Replacement-map coverage was complete; local replay showed the leak was a substitute collision where one synthetic date reused another protected original date in the same record. The candidate added 11 stable signatures, but had covered label mismatches including a stable `date_of_birth` -> `date` mismatch. | Reject for substitute promotion | Native validation reduces detection cost even when substitute still uses normal replacement-map generation, but speed cannot promote a substitute strategy while original values survive in replaced output. The leak appears in the default substitute arm too, so this is a baseline substitute safety issue separate from the native validator. | +| Integrated `gliner_native_validate_*` no-DataDesigner gate | A one-row biography/legal smoke tested direct hosted GLiNER seeding outside DataDesigner plus direct native validation, with and without direct native augmentation. Biography no-augment rejected despite improving latency/tokens because it lost two `first_name` signatures. Biography with native augmentation passed the one-row gate: latency 13.7s -> 10.2s, requests 4 -> 3, tokens 8,033 -> 5,035, entities 22 -> 24, zero leaks, and only candidate-only additions. After bounded per-row parallelism and targeted label-boundary guidance in the integrated no-DD executor, a repeat-3 five-row biography gate improved median wall time 40.7s -> 25.5s, requests 21 -> 15, and tokens 43,371 -> 27,643 with zero original-value leaks and no case failures. The guidance removed the earlier `first_name` label mismatches, but repeat comparison rejected the candidate: four baseline signatures were only covered with mismatched labels (`degree`: 1, `last_name`: 2, `place_name`: 1), and six stable baseline signatures became unstable (`degree`: 1, `last_name`: 2, `organization_name`: 2, `place_name`: 1). Legal improved latency/tokens in both one-row arms, but stayed review-gated because a generic filing date that baseline labeled `date_of_birth` was protected as `date` by the candidate; the seed guardrail correctly does not promote dates without birth/DOB context. | Reject for contextual biographies | Direct GLiNER outside DataDesigner is a useful performance diagnostic, but repeated stable-signature gates block promotion on this biography slice. Lower requests/tokens plus faster wall time are insufficient if label semantics are unstable. | +| Integrated `native_single_pass` benchmark smoke | One-row benchmark-harness smoke on biography/legal showed the speed boundary for collapsing detection into one direct provider call. Biography improved latency 10.3s → 1.7s, requests 4 → 1, and tokens 5,059 → 597, but found 4 entities versus 7 and lost three `person` signatures, so it rejected. Legal improved latency 19.2s → 1.1s, requests 5 → 1, and tokens 7,107 → 838 while preserving both signatures, so that single row was viable. | Mixed diagnostic | The one-call native extractor is worth keeping as a benchmark arm, but it is not safe for broad contextual use. Promotion needs repeated workload-specific signature gates; a legal-row win does not cancel the biography miss. | +| Integrated `native_single_pass` five-row gate | After adding a local deterministic rule guardrail, the larger biography/legal run still rejected both contextual workloads. Biography moved from 24.1s to 8.3s, requests from 21 to 5, and tokens from 26,759 to 3,078, but entities fell from 36 to 21 and it lost 16 `person` signatures. Legal moved from 35.7s to 6.1s, requests from 21 to 5, and tokens from 38,569 to 5,781, but entities fell from 14 to 12 and it lost three `person` signatures. | Reject for contextual workloads | Local rules can cheaply protect deterministic secret shapes, but they do not fix contextual recall. Collapsing detection to one direct call remains a useful lower-bound latency experiment, not a safe contextual replacement. | +| Integrated `native_single_pass_recall` five-row gate | The recall prompt improved raw recall, especially on legal text, but still rejected both workloads. Biography moved from 23.0s to 10.2s, requests from 21 to 5, and tokens from 26,730 to 4,072, but entities fell from 36 to 26 and it still lost 16 `person` signatures. Legal moved from 32.2s to 8.7s, requests from 21 to 5, and tokens from 38,085 to 6,885; entity count rose from 14 to 20, but two baseline `person` signatures were still uncovered. | Reject for contextual workloads | Prompt recall can improve counts without satisfying anonymization safety. One-call contextual extraction remains below the signature gate even when it is much faster and cheaper than default. | +| Integrated `native_single_pass_values*` value-only five-row gate | Two repetitions on five NVIDIA biography rows and five TAB legal rows confirmed the value-only one-call prompt shape is only a speed boundary. Compact values mode improved latency by 55.6% on biographies and 68.9% on legal, with 15-15.5 fewer requests and 31,770-60,491 fewer tokens, but rejected both workloads after losing 45 biography and 123 legal baseline-only signatures. Recall values mode still rejected: it improved latency by 31.9% on biographies and 38.7% on legal, but lost 40 biography and 96 legal baseline-only signatures. Output: `/tmp/anonymizer-native-values-paired-r5`. | Reject for contextual workloads | Returning values instead of offsets makes parsing cheaper but does not solve contextual recall. Keep this arm in the harness as a lower-bound diagnostic; do not promote one-call extraction on biographies or legal text without a different seed source or repeated signature parity. | +| Structured fast-lane router tightening | `rules_covered_or_default` now short-circuits only the structured-secret labels `api_key`, `email`, `http_cookie`, `password`, `pin`, `unique_id`, `url`, and `user_name`. Narrow prose rule labels such as `date_of_birth`, `organization_name`, `religious_belief`, and `street_address` fall back to default detection unless `rules_only` is explicitly selected. A shell-secret smoke still found 12 entities across 3 records with 0 model rows, 0 requests, and 0 tokens. | Review | This preserves the no-DataDesigner win without assuming all inputs are shell logs. Local prose rules remain useful as explicit experiments or guardrails, but they are not complete enough for automatic contextual anonymization. | +| Narrow prose-label augmentation skip probe | On one synthetic `organization_name` + `street_address` row, `rules_covered_or_default` correctly fell back to model-backed detection instead of using the zero-model fast lane. A repeat-3 comparison then found `rules_guardrail_no_augment` preserved the same two signatures with zero leaks while moving median latency 3.0s → 2.6s, requests 4 → 3, and tokens 3,069 → 2,133. | Candidate | Skipping augmentation can be viable for tightly scoped prose-label slices when detector+validator already recover the needed spans. This is not a broad prose default; promote only through repeated signature gates, especially on biographies and legal text where augmentation may carry recall. | +| Real biography/legal no-augmentation check | On two NVIDIA biography rows, pure `no_augment` rejected: latency regressed 24.1s → 28.8s, entities fell 48 → 46, and two `first_name` signatures were lost. `rules_guardrail_no_augment` improved biography latency/tokens (24.1s → 18.3s, 17,992 → 11,905 tokens) but still rejected after losing the same two `first_name` signatures and using rule-sourced spans. On two TAB legal rows at offset 2, `no_augment` preserved signatures and reduced tokens but regressed latency (27.2s → 38.6s) and increased failed-request rate; `rules_guardrail_no_augment` preserved signatures with modest latency/token gains but remained review-gated because it introduced rule-sourced spans. | Mixed: biography reject, legal review | The synthetic augmentation-skip win does not generalize to biography prose. Augmentation remains load-bearing for contextual name recall, and legal gains need repeated runs plus failed-request scrutiny before promotion. | +| `rules_covered_or_default` mixed benchmark harness run | A two-row synthetic shell-secret run initially exposed a rule hole: default found one sudo-stdin password that the rule-only path missed. After adding a narrow `echo "..." | sudo -S` password rule, the rerun preserved all 9 shell signatures with detection latency 21.4s → 0.004s, requests 8 → 0, and tokens 9,854 → 0. One-row biography and legal contextual configs included `person`, so they fell back to model-backed detection and matched default entity counts. | Review | This is the safest implementation shape for the no-DataDesigner idea: use local rules only where labels and observed signatures prove coverage, and treat every missed signature as a rule-quality bug or a reason to fall back. | +| `rules_covered_or_default` current mixed fallback run | Current-code rerun completed all 6 cases. Shell secrets preserved all 9 signatures with pipeline latency 23.1s → 0.005s, requests 8 → 0, and tokens 10,173 → 0. The biography and legal configs included `person`, so both candidate cases fell back to model-backed detection and matched default entity counts and signatures: biography 7/7, legal 2/2. | Review | The router is behaving as designed after the rule-only `tagged_text` contract fix: structured secret configs can short-circuit locally, while contextual non-shell configs stay on default detection. | +| `rules_covered_or_default` repeated shell-secret run | A three-repetition shell-only suite completed all 6 cases. The candidate preserved all 9 final signatures in every repetition with median detection latency 29.4s → 0.004s, requests 9 → 0, and tokens 10,112 → 0. Default detection was unstable on this tiny slice: one repetition missed one `api_key`, so stable signatures were 8 for default and 9 for the rules path. The comparison remained review-gated, not viable, because all candidate spans were rule-sourced. | Review | Repeated evidence strengthens the structured-secret fast path but also shows why promotion should use stable-signature comparisons rather than treating default as perfectly deterministic on every repetition. | +| `rules_covered_or_default` on non-shell structured secrets | A four-row JSON/env/HTTP-header/YAML-style suite initially rejected after exposing two deterministic-rule gaps: URLs swallowed trailing semicolon separators and `session_id=...` cookie values were not protected. After tightening URL boundaries and adding a narrow `session_id` assignment rule, the rerun preserved all 17 default signatures while moving detection latency 25.8s → 0.010s, requests 16 → 0, and tokens 19,167 → 0. A repeat-3 run then kept all candidate signatures stable: default produced 15, 15, and 16 entities with median 18,822 tokens, while the rules path produced 17 entities every time with zero model requests and zero tokens. | Review | The no-DataDesigner fast lane is not shell-specific, but it must remain rule-coverage and signature-gated. Treat every structured-secret miss as either a narrow rule bug with tests or a reason to fall back to default detection. | +| `local_structured_substitute` on non-shell structured secrets | A four-row JSON/env/HTTP-header/YAML-style substitute suite preserved the same 17 final entities with zero original-value leaks. In a repeat-3 run, DataDesigner-backed substitute had median pipeline latency 38.1s, 4 requests, and 13,967 tokens for replacement-map generation; individual DD-backed runs ranged from 30.7s to 62.4s. `local_structured_substitute` had median latency 0.005s, 0 requests, and 0 tokens while preserving the same 17 replacements. | Review | Replacement-map generation is now another defensible no-DataDesigner lane for structured labels. Keep it benchmark-only until repeated gates and a policy decision define which structured labels deserve public API support. | +| `local_structured_substitute` with model-backed detection fallback | A one-row audit-style structured-identifier suite requested `api_key`, `http_cookie`, `pin`, `unique_id`, and `user_name`, so `rules_covered_or_default` fell back to normal model-backed detection in both arms. Both arms found the same 5 final entities with zero original-value leaks. Local replacement removed the replacement-map workflow, moving pipeline latency 53.6s → 33.0s, requests 5 → 4, and tokens 11,547 → 7,694. The pairwise comparison marked the candidate viable. | Candidate | Local replacement-map generation can help even when detection still needs DataDesigner. This is a cleaner promotion path than rule-only detection because contextual detection provenance is preserved; keep rejecting contextual replacement labels such as `person`. | +| `local_structured_substitute` with default detection on varied audit/config/HTTP identifiers | A four-row repeat-3 suite isolated replacement-map generation by keeping default model-backed detection in both arms. After adding a local synthetic-original collision guard, the guarded rerun kept value protection clean: zero original leaks, zero missing replacement-map entries, and zero synthetic-original collisions. Local substitute moved median pipeline latency 18.8s -> 12.7s, requests 21 -> 17, and tokens 24,324 -> 17,015. A current fixed-trace replay held detection constant at 21 entities and measured replacement only: DataDesigner substitute took 6.15s while local structured substitute took 0.003s, with 21/21 replacements and zero leaks/collisions in both arms. Regenerating the older repeat comparison with split verdicts moved the strategy-screen group out of `needs_split_verdict_rerun`; adding the fixed-trace replay comparison moved it to `promising_needs_review`. All three rows have `value_protection=pass`; the replay row has `signature_parity=pass` and `candidate_verdict=candidate_viable`, while one full-pipeline pairwise row has `signature_parity=review` because two covered signatures used different labels (`api_key`, `unique_id`). The comparison now tags this drift as `replacement_only_detection_instability` because detection strategy did not change. | Promising needs review | This is the cleanest structured-label promotion path because detector provenance stays model-backed in full-pipeline runs and the replacement backend passes fixed-trace replay. It is not fully promoted because normal pairwise runs still need monitoring for provider reliability and detection-run label drift. | +| `local_structured_substitute` fixed-trace replay on biography structured labels | A five-row NVIDIA synthetic biography replay used model-backed detection for `date_of_birth`, `organization_name`, `religious_belief`, and `street_address`, then replayed both substitute backends three times on the same 56 detected entities. After making local replacement-map generation avoid per-record duplicate synthetic values, both arms produced 159 replacements across 15 replacement attempts with zero duplicate synthetics, zero missing replacement-map entries, zero original-value leaks, and zero synthetic-original collisions. DataDesigner substitute took 23.59s for replacement-map generation and local structured substitute took 0.006s. The replay comparison marks `value_protection=pass`, `signature_parity=pass`, `safety=pass`, and `candidate_verdict=candidate_viable`. Output: `/tmp/anonymizer-perf-goal/biography-supported-structured-replacement-replay-repeat3.json`; comparison: `/tmp/anonymizer-perf-goal/biography-supported-structured-replacement-replay-repeat3-comparison.csv`; screen: `/tmp/anonymizer-perf-goal/strategy-screen-local-substitute-with-biography-replay-groups.csv`. | Candidate | This broadens the replacement-only result beyond shell or config logs without claiming DD-free contextual detection. The speed and leak profile are strong, the duplicate-collapse issue is fixed for this slice, and repeated replacement-only evidence shows the local backend can preserve replacement-map coverage when detection is held fixed. The remaining gate is policy: decide which structured labels and text shapes are eligible for deterministic substitute generation in production-facing configuration. | +| Expanded `rules_covered_or_default` + `local_structured_substitute` on an audit-style structured identifier record | After adding narrow keyed coverage for `http_cookie`, `pin`, `unique_id`, and `user_name`, the candidate protected all baseline signatures, found one additional `unique_id`, had zero original-value leaks, and moved pipeline latency 9.2s → 0.005s, requests 5 → 0, and tokens 6,075 → 0. | Review | This extends the no-DataDesigner fast lane beyond shell logs into keyed audit/config/HTTP-style structured records. It remains review-gated because every final span is rule-sourced and this run used one row. | +| Expanded `rules_covered_or_default` + `local_structured_substitute` on varied audit/config/HTTP identifiers | A four-row repeat-3 suite preserved every baseline-only signature through containing or overlapping candidate spans, with zero original-value leaks. Median pipeline latency moved 21.1s → 0.006s, requests 21 → 0, and tokens 24,332 → 0. The comparison records 8 exact baseline-only signatures, 8 candidate-covered signatures, 2 span-boundary mismatches, and 0 uncovered signatures. | Review | This is the strongest no-DataDesigner result so far for non-shell structured records. It is still not a default: all final spans are rule-sourced, and two protected values had different span boundaries such as `token=` versus ``, so promotion needs a workload policy gate. | +| Row-aware `rules_covered_or_default` + local substitute smoke | A four-row JSON/env/HTTP-header/YAML-style suite initially rejected because quoted JSON `user`/`pin` keys were not rule-covered. After adding quoted-key coverage and changing the router to fall back per row on suspicious uncovered structured assignments, the structured candidate moved pipeline latency 9.7s -> 0.0s, requests 20 -> 0, and tokens 20,080 -> 0 while matching entity count 10 -> 10 with zero original-value leaks. One-row biography and legal controls included `person`, used default detection in both arms, and passed comparison gates. | Review | The no-DataDesigner path is now safer: eligible labels are necessary but not sufficient, and rows with uncovered structured fields go through normal detection. The structured candidate still stays review-gated because one `HF_TOKEN` value was protected under a different label/boundary than the default `http_cookie` span. | +| Row-aware `rules_covered_or_default` + local substitute repeat gate | A focused repeat-3 split-verdict suite reran the same four structured rows after the row-aware router change. All 6 cases completed. Default substitute had median pipeline latency 12.4s, 21 requests, and 20,071 tokens; the row-aware rules/local candidate had median latency 0.006s, 0 requests, and 0 tokens. Both arms found 10 entities in every repetition and had zero original-value leaks or synthetic-original collisions. The split-verdict comparison has `value_protection=pass` but remained `safety=review` and `signature_parity=review`: one stable baseline `http_cookie` signature was protected by the candidate under an `api_key` label with a span/boundary mismatch. Output: `/tmp/anonymizer-perf-goal/structured-fastlane-split-r3`. | Needs viable split verdict | This is a large structured fast-lane performance win, but not promotion-ready. The next decision is whether the covered `http_cookie` -> `api_key` mismatch is acceptable value protection for this workload or whether the deterministic rules need to match baseline label semantics more closely. | +| `bio-vmax10-w80` validator window tuning | Rejected on biography rows 6-10: latency, requests, and tokens regressed, and stable `field_of_study` and `state` signatures were lost. | Reject | Smaller validation windows need per-workload proof; prompt-size savings can be outweighed by more calls and lost context. | +| Text augmenter routing at `temperature: 0.3` | A one-row biography smoke test passed, but repeated five-row slices did not: rows 0-4 preserved signatures while latency regressed from 40.4s to 45.9s and requests from 21.0 to 21.5; rows 5-9 rejected after latency regressed from 41.0s to 52.1s and two stable `state` signatures became unstable. | Reject | JSON-validator/text-augmenter routing at the default text temperature is not a reliable prose speedup on these slices. | +| Text augmenter routing at `temperature: 0.7` | Passed the first biography slice, then failed on rows 6-10 by losing a stable `university` signature and regressing latency. | Reject | Do not promote the routing pattern from a single positive slice. | +| `rules_guardrail_no_augment` on legal prose | Improved latency/tokens on legal rows 2-3, but lost two stable `first_name` signatures. | Reject | Augmentation remains load-bearing for contextual names, even when aggregate entity counts look acceptable. | + +No broad replacement for the default prose/legal detection path has passed the +current repeated signature checks. The only strong performance result so far is +workload-scoped: deterministic rules for tightly bounded, rule-covered secret +scans. + +When DataDesigner message traces are enabled, interpret failed request counts +through `observed_non_bridge_*` metrics before drawing provider-reliability +conclusions. Across 13 local trace files, the local-vLLM +`SyncClientUnavailableError` rows were 104 near-zero-latency sync-to-async +bridge fallbacks with zero token usage; they are adapter accounting, not model +work. GLiNER `ProviderError` rows are different: the same trace set had 20 real +detector failures, which can invalidate otherwise faster detector-heavy +candidates. + +Do not expand deterministic rules into contextual names merely to recover the +failed candidates above. The rejected prose and legal runs lost labels such as +`first_name`, `field_of_study`, `state`, and `university`; these require context +and separate precision evidence. The rule layer should stay narrow unless a new +label has high-confidence syntax and false-positive tests. + +## Validator Chunk Tuning + +The detector validator can dominate replace-mode latency on records with many +candidate entities. Tune `Detect.validation_max_entities_per_call` and +`Detect.validation_excerpt_window_chars` together: + +- `validation_max_entities_per_call` controls how many candidate entities go + into each validator call. Lower values create more calls, but Anonymizer can + overlap those calls through the validator pool. +- `validation_excerpt_window_chars` controls how much text surrounds each + validation chunk. Lower values reduce prompt size, but can hide context the + validator needs for labels such as `date_of_birth`, `race_ethnicity`, or + legal roles. + +Run these sweeps per workload. A window that is safe for short biographies may +drop legal identifiers, and a legal-safe window may erase the speedup on short +records. + +Example config fragment: + +```yaml +configs: + - id: legal-vmax10-w160 + detect: + validation_max_entities_per_call: 10 + validation_excerpt_window_chars: 160 + entity_labels: [first_name, last_name, court_name, date, date_of_birth] + replace: + strategy: hash + digest_length: 12 +``` + +Use the aggregate analysis first: + +```bash +uv run python tools/measurement/analyze_benchmark_output.py \ + benchmark-runs/legal-window-sweep \ + --json +``` + +Then compare every faster candidate against a higher-context reference: + +```bash +uv run python tools/measurement/extract_signature_deltas.py \ + /tmp/reference/legal__default-window__r000.detection-artifacts.jsonl \ + /tmp/candidate/legal__vmax10-w160__r000.detection-artifacts.jsonl \ + --baseline-artifact-root /tmp/reference/artifacts \ + --candidate-artifact-root /tmp/candidate/artifacts \ + --baseline-config default-window \ + --candidate-config vmax10-w160 \ + --workload legal \ + --output /tmp/legal-vmax10-w160-deltas.csv \ + --format csv +``` + +Treat a candidate as unsafe until signature deltas are clean on repeated runs. +In one local vLLM check with two repetitions, a biography sample went from +24.6s with the default window to 17.8s with `vmax10/w80`, with all 50 stable +entity signatures preserved. A one-row legal sample went from 21.2s with the +default window to 13.2s with `vmax10/w160`, with all 28 stable signatures +preserved. Both candidates increased request and token counts, so the comparison +tool marks them for review instead of as automatic wins. + +The biography `vmax10/w80` result did not hold on the next five biography rows. +With `row_offset: 5`, median latency regressed from 31.8s to 33.6s, requests +from 20.0 to 43.0, and tokens from 44,367.0 to 68,407.5. The comparison also +lost stable `field_of_study` and `state` signatures, with an additional +unstable `university` loss, so the tool rejected the candidate. Recheck this +tuning on the target dataset because smaller windows can miss sensitive +attributes and because the extra parallel validator calls can overwhelm any +prompt-size savings. + +## Augmentation Ablation + +Use `experimental_detection_strategy: rules_guardrail_no_augment` to measure +what happens when the detector keeps GLiNER, validation, and deterministic rule +guardrails, but skips LLM augmentation. Treat this as an ablation, not as a +replacement for the default pipeline. + +In a local vLLM check with two repetitions, removing augmentation from the +two-row biography sample reduced work but consistently lost two stable +`first_name` signatures. The comparison tool rejected both the default-window +and `vmax10/w80` no-augmentation candidates. This indicates augmentation is +load-bearing for prose records where contextual names and quasi-identifiers +matter. + +The same ablation preserved all 28 stable signatures on a one-row legal sample. +With the default validation window, latency moved from 21.2s to 18.3s, requests +from 5 to 4, and tokens from 11,327.5 to 7,654. With `vmax10/w160`, latency +moved from 13.2s to 9.5s, requests from 8 to 7, and tokens from 16,604 to +12,881. Compared directly against the default-window baseline, the combined +legal candidate is faster but still needs review because validator chunking +increases requests and tokens. + +That legal no-augmentation result also failed to generalize to the next two +legal records. On `row_offset: 1` with two rows and two repetitions, comparing +`legal-noaugment-vmax10-w160` against the same-window full augmentation baseline +improved latency from 23.9s to 21.5s, requests from 28.0 to 26.0, and tokens +from 61,780.5 to 50,905.0, but the candidate lost two stable `first_name` +signatures and one unstable `date` signature. The comparison rejected it despite +the performance improvement. + +Use this ablation when `augmented_new_final_value_count` is near zero for the +target workload and repeated signature deltas are clean. Do not generalize a +single legal row to the rest of a legal dataset, and do not generalize legal +results to biography, support-ticket, shell-history, or mixed prose data without +rerunning the comparison on that workload. + +## Augmenter Routing and Temperature + +The detection validator and augmenter do different jobs. Keep them separable in +model configs when testing local endpoints: + +- validators benefit from deterministic JSON-oriented settings; +- augmenters may work better through a text alias, because DataDesigner + structured parsing can be fragile on local OpenAI-compatible endpoints; +- augmenter temperature changes can alter retry pressure and output shape, so + evaluate them with repeated signature comparison, not only entity counts. + +In one local vLLM biography run with two repetitions, keeping the validator on +`local-nemotron-json` while routing the augmenter to a text alias with +`temperature: 0.7` was the first prose candidate that passed the current safety +gate and improved performance. Median latency moved from 24.2s to 21.6s, +requests from 8 to 6, and tokens from 17,938.5 to 11,921. The comparison had no +baseline-only or unstable-lost signatures across 48 stable signatures, so the +tool marked it `candidate_viable`. + +The same routing/temperature candidate also held on a five-row biography slice +with two repetitions, though the gain was smaller. Median latency moved from +40.4s to 38.0s, requests from 21.0 to 20.5, and tokens from 43,367.5 to +43,043.0. It preserved all 114 stable baseline signatures; one candidate-only +`place_name` appeared in one repetition, so the comparison still marked the +candidate `candidate_viable`. + +This result did not generalize cleanly to the next five biography rows. On a +second slice using `row_offset: 5`, the same candidate was rejected: median +latency moved from 41.0s to 47.5s, requests from 21.0 to 21.5, and tokens were +effectively unchanged at 44,708.0 to 44,670.0. More importantly, the comparison +lost one stable `university` signature and had unstable losses for +`field_of_study` and `university`. Treat this routing as an experiment to +retest on each workload, not as a default candidate yet. +When the two temp-0.7 config IDs are grouped with `--config-aliases`, the +biography family result is `conflicting_evidence`: three comparison rows, two +viable rows, one reject row, best latency -10.4%, worst latency +16.0%, and +stable losses for `field_of_study` and `university`. + +On a two-row legal slice with two repetitions, the same augmenter routing did +not materially improve latency or requests: median latency moved from 27.3s to +27.5s, requests stayed at 8, and tokens moved from 24,460.0 to 24,296.5. It +preserved stable signatures, but the rule-guardrail legal strategy remains +review-gated and this routing should be treated as neutral for that sample. +Also compare it against prompt-only changes such as +`prose_augment_focus`: in the same biography slice, prose-focused augmentation +preserved signatures and reduced requests/tokens, but wall time increased from +24.2s to 26.4s, so the tool kept it in review. + +Parser compatibility is a separate concern. A text-model suite without +`dd_parser_compat: raw_json` produced a failed biography case in local testing; +the raw-parser compatibility mode fixed that failure, but increased latency and +tokens on both biography and legal slices. Treat raw-parser compatibility as an +endpoint interoperability fix, not as a performance optimization. + +## Detector-Only Ablation + +Use `experimental_detection_strategy: detector_only` to measure the lower bound +of the detection phase when GLiNER output is trusted directly and only local +span finalization runs afterward. Use +`experimental_detection_strategy: rules_guardrail_detector_only` to measure the +same path with deterministic high-confidence rule spans unioned into the final +entity set. Both remove LLM validation and LLM augmentation from the detection +phase, so they are diagnostic ablations rather than deployable strategies. + +The comparison tool marks these candidates with +`candidate_skips_llm_validation`, which forces `safety_verdict: review` even +when entity signatures match on the sampled records. The rule-guardrail variant +also gets `candidate_uses_rule_entities` when rule spans survive. Promote either +path only if independent precision checks show false positives are acceptable +for the target workflow and repeated signature deltas remain clean. + +In a one-row cross-workload smoke check, detector-only improved latency and +token counts on biography, legal, and shell-secrets slices, but all three +candidates were rejected by signature comparison. Biography moved from 13.7s to +0.9s and lost two baseline `first_name` signatures; legal moved from 15.3s to +1.0s and lost one `nationality` signature while increasing final entity count +from 22 to 39; shell-secrets moved from 6.6s to 0.8s and still lost one +baseline `api_key` signature. This is a useful lower-bound measurement, but it +shows why validation/augmentation or deterministic rule coverage remain +load-bearing for anonymization. + +The `rules_guardrail_detector_only` variant did not fix prose/legal losses in +the same one-row check: biography still lost two `first_name` signatures and +legal still lost one `nationality` signature. It did preserve all shell-secret +baseline signatures while moving latency from 4.6s to 0.8s, requests from 4 to +1, and tokens from 3,969 to 85. Treat that as a narrow structured-secret +candidate. It remains review-gated because it skips LLM validation and relies +on deterministic rules. + +On the three-row shell-secrets slice with three successful candidate +repetitions, `rules_guardrail_detector_only` preserved all stable baseline +signatures while moving median latency from 7.2s to 3.2s, requests from 12 to 4, +and tokens from 11,019 to 198. The final entity set came from 9 detector spans +and 3 rule spans. It still had local GLiNER `ProviderError` health-check +failures and remains slower than `rules_only`, which used zero model calls and +zero tokens on the same fully rule-covered labels. + +## Deterministic Rules for Structured Secrets + +Use `experimental_detection_strategy: rules_only` only when the workload is a +bounded secret-scanning task and every requested label is covered by the +deterministic rules. Current rule coverage is intentionally narrow: +`api_key`, `date_of_birth`, `email`, `http_cookie`, `organization_name`, +`password`, `pin`, `religious_belief`, `street_address`, `unique_id`, `url`, +and `user_name`. The `http_cookie`, `pin`, `unique_id`, and `user_name` rules +cover keyed or command-style structured patterns only. They do not recognize +free-form names, narrative identifiers, or arbitrary prose mentions. + +The zero-model detector is implemented by +`EntityDetectionWorkflow.detect_with_high_confidence_rules()`. The benchmark +strategy delegates to that internal engine method, but no user-facing config +selects it outside the benchmark harness. + +Use `experimental_detection_strategy: rules_covered_or_default` for mixed +benchmark suites where some configs are structured-secret scans and others +include contextual labels such as `person`, `organization_name`, or +`street_address`. It runs the same zero-model path for structured fast-lane +cases, but does not attempt a DataDesigner-free replacement for contextual +prose or legal records. + +A mixed local-vLLM smoke run on June 8, 2026 used two synthetic shell-secret +rows plus one biography and one legal row. The first shell run found that +`rules_covered_or_default` missed a sudo stdin password that default +augmentation caught; after adding a narrow `echo "..." | sudo -S` rule, the +rerun preserved all 9 shell signatures with zero model requests and zero tokens. +The biography and legal configs requested `person`, so they correctly fell back +to model-backed detection and matched default entity counts. Keep this strategy +signature-gated: a missed default signature is a rule-quality bug or a fallback +signal, not acceptable drift. + +A follow-up three-repetition shell-only run kept all 9 candidate signatures +stable while default detection had 8 stable signatures because one `api_key` +was absent from one default repetition. The comparison still returned +`candidate_verdict=review` because the candidate had no detector-sourced final +spans. This is the intended behavior: repeated clean signatures can justify a +workload-scoped fast lane, but rule-only provenance should remain an explicit +review decision. + +For substitute workloads, use +`experimental_replacement_strategy: local_structured_substitute` to bypass the +DataDesigner replacement-generator call. The local substitute generator writes a +normal replacement map and stamps `_replacement_map_source=local_structured` so +measurement estimates do not count a replacement-map LLM call. It only supports +structured labels. Pair it with `rules_covered_or_default` when all requested +labels are also rule-covered; otherwise detection can still use the default +model-backed path while replacement-map generation stays local. If a config +includes `person` or another contextual label, preflight fails instead of +silently producing poor local substitutes. + +On the current four-row non-shell structured-secret suite, +DataDesigner-backed substitute preserved 17 entities with zero original-value +leaks but had repeat-3 median latency 38.1s, 4 requests, and 13,967 tokens in +replacement-map generation. The local structured substitute arm preserved the +same 17 entities, had zero original-value leaks, and had repeat-3 median latency +0.005s with 0 requests and 0 tokens. The repeat output used for this result is +`/tmp/anonymizer-perf-goal/structured-secrets-local-substitute-repeat3`. + +The local substitute backend can also combine with model-backed detection. In +the first one-row audit-style structured-identifier smoke, `api_key`, +`http_cookie`, `pin`, `unique_id`, and `user_name` were not all rule-covered, so +detection fell back to the default model path in both arms. The local substitute +arm still removed the replacement-map DataDesigner workflow, moving pipeline +latency from 53.6s to 33.0s, requests from 5 to 4, and tokens from 11,547 to +7,694 while preserving the same 5 final entities and zero original-value leaks. +The output used for that result is +`/tmp/anonymizer-perf-goal/structured-identifiers-local-substitute`. + +Use `replay_replacement_strategies.py` when you need to hold detection fixed and +isolate replacement-map generation: + +```bash +uv run python tools/measurement/replay_replacement_strategies.py \ + /tmp/anonymizer-perf-goal/structured_identifiers_varied.csv \ + --text-column text \ + --labels api_key,http_cookie,password,pin,unique_id,user_name \ + --nrows 5 \ + --replacement-repetitions 3 \ + --model-configs /stable-cache/anonymizer/local-vllm-json-models.yaml \ + --model-providers /stable-cache/anonymizer/local-vllm-providers.yaml \ + --dd-parser-compat raw_json \ + --comparison-output /tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay-comparison.csv \ + --json +``` + +The current fixed-trace replay detected 21 entities once, then ran both +substitute backends on that same trace. DataDesigner substitute took 6.04s for +replacement-map generation; local structured substitute took 0.003s. Both arms +produced 21 replacements, zero missing replacement-map entries, zero +original-value leaks, and zero synthetic-original collisions. The JSON output +used for this result is +`/tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay.json`. +A rerun after adding an LLM replacement-map collision guard produced the same +21/21 complete, leak-free, collision-free result. In that rerun, DataDesigner +substitute took 6.22s and local structured substitute took 0.003s; the updated +JSON output is +`/tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay-after-llm-guard.json`. +When `--replacement-repetitions` is greater than one, detection still runs once +and only the substitute backends repeat. The summary rows aggregate replacement +latency, missing-map counts, leaks, collisions, duplicate synthetics, and source +counts across those repeated backend passes. When `--comparison-output` is set, +the replay tool also writes a one-row comparison CSV with +`value_protection_verdict`, `signature_parity_verdict`, `safety_verdict`, +`performance_verdict`, and `candidate_verdict`. This lets +`screen_strategy_comparisons.py` include fixed-trace replacement evidence +alongside normal pairwise benchmark comparisons. Missing local replacement-map +entries, original-value leaks, and synthetic-original collisions fail the replay +candidate even if the elapsed-time delta is large. +If the DD substitute baseline misses replacement-map entries or leaks original +values while the local backend covers them, the replay comparison emits +candidate-covers-baseline flags and the strategy screen recommends +`candidate_covers_baseline_defects` for all-review groups of that shape. Treat +that as a baseline-independent safety-rule prompt: inspect the candidate's +missing, leak, collision, duplicate-synthetic, and supported-label columns +rather than requiring exact parity with a known-flawed substitute baseline. + +After adding narrow keyed rules for `http_cookie`, `pin`, `unique_id`, and +`user_name`, the same audit-style label set can now short-circuit both +detection and local replacement for a structured record. In a one-row local +vLLM check, default detection plus DataDesigner substitute found 4 entities and +missed the `unique_id`; the rules/local arm found 5 entities, had zero +original-value leaks, and moved pipeline latency from 9.2s to 0.005s, requests +from 5 to 0, and tokens from 6,075 to 0. The pairwise comparison remains +`review`, not `candidate_viable`, because the candidate has rule-only +provenance and the evidence is a single row. The output used for this result is +`/tmp/anonymizer-perf-goal/structured-identifiers-expanded-rules`. + +On a three-row shell-secrets slice with labels `[api_key, password, email, url]`, +`rules_only` preserved all 12 stable signatures across three repetitions while +moving median latency from 7.2s to 0.004s, requests from 12 to 0, and tokens +from 11,019 to 0 in the refreshed failure-aware comparison. The comparison tool +still marks the candidate for review because it has no contextual detector spans +and skips LLM validation. That is the right gate: a pure rule strategy is +acceptable only when missing contextual spans is part of the test contract. + +`rules_seed_no_augment` preserved the same 12 signatures and reduced median +tokens from 11,017 to 7,732, but median latency moved from 8.0s to 8.5s on the +same slice. In this run, seeding rules into the validator path reduced token +work but did not improve end-to-end latency. Prefer `rules_only` for tightly +scoped secret scans; prefer rule guardrails plus contextual detection for prose, +legal text, support tickets, and mixed records. + +Use `rules_filter_guardrail` as the mixed-workload version of that idea. It +keeps LLM augmentation, but rule-covered spans are not sent to the seed +validator. The rule spans are reinserted before augmentation so the augmenter +does not waste work rediscovering them. This is a candidate for datasets that +combine structured secrets with contextual prose; it still needs repeated +signature comparison because filtered detector spans no longer receive the +LLM validator's reclassification/drop pass. In a local shell-secrets smoke run, +the completed candidate repetition reduced seed validation candidates to zero +and preserved all stable signatures, but the repeated comparison rejected it +because a later candidate case hit a GLiNER health-check rate limit. + ## Metric Interpretation Use metrics as signals, not as a single score. @@ -320,6 +1781,7 @@ Use metrics as signals, not as a single score. Latency and throughput: - `elapsed_sec`: wall time for a measured stage or DataDesigner workflow. + Staged DD-free detection cases report end-to-end case wall time here. - `rows_per_sec`: completed output rows per second for the measured block. - `tokens_per_sec`: observed total tokens per second when token usage exists. - `text_length_tokens_bucket`: a coarse text-size bucket for comparing similar @@ -332,13 +1794,96 @@ LLM usage: zero values mean the provider path did not expose usage, not necessarily that no tokens were consumed. - `observed_total_requests`, `observed_successful_requests`, and - `observed_failed_requests`: request counts when DataDesigner exposes them. + `observed_failed_requests`: request counts when DataDesigner or a native + benchmark model workflow exposes them. +- `observed_failed_request_rate`: failed requests divided by total requests. + Case and group tables expose this as the end-to-end retry pressure for a + strategy; model usage tables expose it per workflow/model. Sort by this + together with total token count to find retry-heavy workflow/model pairs. +- `observed_bridge_fallback_requests`: DataDesigner sync-to-async bridge + fallbacks, derived from message traces when `--dd-trace` is enabled. Treat + these as adapter accounting, not provider/model failures. +- `model_elapsed_sec`: staged DD-free detection only; sum of direct model-call + durations for seed, validation, and augmentation. This stays `0.0` for fully + local rule-covered runs even when `elapsed_sec` records nonzero local work. +- `observed_non_bridge_total_requests`, + `observed_non_bridge_failed_requests`, and + `observed_non_bridge_failed_request_rate`: request metrics after subtracting + sync-to-async bridge fallbacks. Prefer these fields over raw failed-request + counts when diagnosing provider reliability from traced runs. - `nominal_llm_call_count`: an internal estimate based on the Anonymizer pipeline shape. Treat it as expected work, not observed provider traffic. +- `seed_validation_candidate_count`: number of detector candidates sent to the + seed validator, derived from detection artifacts without storing values. +- `estimated_seed_validation_chunk_count`: estimated validator chunk count after + applying `detect.validation_max_entities_per_call`. If this does not change + between benchmark configs, chunk-size experiments are not expected to reduce + successful validator calls. Entity and quality metrics: - `final_entity_count`: entities that survive detection and validation. +- `original_value_leak_count`: number of final entity original values that + still appear verbatim in the replaced or rewritten output text. This is a + conservative replace/rewrite safety signal and stores only counts, not raw + values. +- `original_value_leak_label_counts`: per-label counts for those surviving + original values. The analysis tables aggregate these as + `original_value_leak_record_count`, `sum_original_value_leak_count`, + `leaking_case_count`, and `median_original_value_leak_count`. +- `replacement_missing_final_entity_count`: number of final entity occurrences + whose original value has no entry in the replacement map. This is sanitized + replacement-map coverage, not raw leakage text. +- `replacement_missing_final_value_count`: number of unique final entity values + with no replacement-map entry. Compare it with + `original_value_leak_count` to distinguish omitted replacement-map entries + from replacement-application or metric issues. +- `replacement_missing_final_entity_label_counts`: per-label counts for missing + replacement-map coverage. +- `replacement_synthetic_original_collision_count`: number of final entity + occurrences whose original value was reused as a synthetic replacement value + elsewhere in the same record. This is a substitute safety signal; map + coverage can be complete while this is nonzero. +- `replacement_synthetic_original_collision_value_count`: number of unique + protected original values reused as synthetic replacement values. +- `replacement_synthetic_original_collision_label_counts`: per-label counts for + synthetic-original collisions. +- `artifact_final_detector_entity_count`, + `artifact_final_rule_entity_count`, and + `artifact_final_augmenter_entity_count`: final entity source counts derived + from detection artifact sidecars. These are useful safety signals for + rule-backed benchmark strategies. +- `artifact_final_entity_signature_count` and + `artifact_final_entity_signature_hashes`: opaque final-span signatures derived + from detection artifacts. `artifact_final_entity_signature_labels` maps each + hash to a label, but still does not include raw entity values. Use these to + catch and triage safety regressions where total entity count is unchanged but + the candidate lost a baseline-protected span. +- `baseline_only_candidate_covered_signature_count`, + `baseline_only_candidate_overlapping_signature_count`, and + `baseline_only_candidate_uncovered_signature_count`: comparison-only fields + from `compare_strategy_pairs.py`. These split exact signature deltas into + baseline spans protected by a containing candidate span, protected by a + high-overlap or small keyed-boundary candidate span, or not protected by any + candidate span metadata. Overlapping coverage sets `span_boundary_mismatch` + and keeps the candidate in review; uncovered signatures set + `entity_signature_loss` and fail the safety verdict. +- `baseline_only_candidate_label_mismatch_signature_count`: comparison-only + field for baseline signatures whose raw span is covered by the candidate, but + under a different label. This sets `covered_label_mismatch` and keeps the + candidate in review because the value is protected but label semantics may no + longer match replacement/audit expectations. +- `value_protection_verdict`: comparison-only pass/review/fail verdict focused + on whether candidate output still protects baseline values. Covered + label-mismatch spans can still pass this axis because the sensitive value is + protected, while uncovered signatures, candidate leaks, and candidate case + failures fail it. +- `signature_parity_verdict`: comparison-only pass/review/fail verdict focused + on exact baseline signature semantics. Covered label mismatches and boundary + mismatches review-gate this axis even when `value_protection_verdict` passes. + This split is useful for DataDesigner-free experiments: a candidate can be a + plausible protection backend while still requiring label-policy review before + it can replace a DataDesigner-backed baseline. - `final_entity_label_counts`: per-label entity counts serialized as JSON in exported tabular files. - `ground_truth_*`: precision, recall, F1, false positives, and false negatives @@ -351,6 +1896,12 @@ Error and reliability metrics: - `failed_record_count`: records dropped by a DataDesigner workflow. - `status`: completion state for a stage or workflow. +- `case_failed`: true when a benchmark case has any error-status stage or + DataDesigner workflow measurement. +- `error_stage_count`, `error_ndd_workflow_count`, and + `error_model_workflow_count`: error-status measurement rows counted per case. +- `failed_case_count` and `failed_case_rate`: group-level failed-case count and + rate for a workload/config/strategy. - `summary.json` case errors: runner-level failures, such as invalid inputs or model endpoint failures. @@ -366,6 +1917,6 @@ time. Stage and workflow rows carry timing. To explain a slow run, first find the slow stage, then inspect the records in that run for text length, entity count, nominal call count, and rewrite repair signals. -When token or request fields are missing, check `ndd_workflow.model_usage`. -The measurement layer records deeper provider usage only when DataDesigner -returns it. +When token or request fields are missing, check `ndd_workflow.model_usage` and +`model_workflow.model_usage`. The measurement layer records deeper provider +usage only when the underlying executor returns it. diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py new file mode 100644 index 00000000..b5844523 --- /dev/null +++ b/tools/measurement/analyze_benchmark_output.py @@ -0,0 +1,1180 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze joined benchmark measurements and detection artifact sidecars. + +Usage: + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --output analysis + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --detection-artifacts current.jsonl + uv run python tools/measurement/analyze_benchmark_output.py benchmark-runs/suite-id --json +""" + +from __future__ import annotations + +import json +import logging +import math +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.benchmark_output") + +_SYNC_CLIENT_UNAVAILABLE_ERROR = "SyncClientUnavailableError" + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class CaseAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + repetition: int | None = None + case_id: str + run_id: str + case_failed: bool = False + error_stage_count: int = 0 + error_ndd_workflow_count: int = 0 + error_model_workflow_count: int = 0 + pipeline_elapsed_sec: float | None = None + ndd_workflow_count: int = 0 + ndd_elapsed_sec_total: float = 0.0 + observed_total_requests: int = 0 + observed_successful_requests: int = 0 + observed_input_tokens: int = 0 + observed_output_tokens: int = 0 + observed_total_tokens: int = 0 + observed_failed_requests: int = 0 + observed_failed_request_rate: float | None = None + dd_trace_record_count: int = 0 + dd_trace_error_count: int = 0 + dd_trace_sync_client_unavailable_count: int = 0 + observed_bridge_fallback_requests: int | None = None + observed_non_bridge_total_requests: int | None = None + observed_non_bridge_failed_requests: int | None = None + observed_non_bridge_failed_request_rate: float | None = None + route_total_row_count: float | None = None + route_rule_row_count: float | None = None + route_fallback_row_count: float | None = None + final_entity_count: float | None = None + replacement_count: float | None = None + replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + replacement_missing_final_value_count: float | None = None + replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + replacement_synthetic_original_collision_value_count: float | None = None + original_value_leak_count: float | None = None + original_value_leak_record_count: int = 0 + original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + validation_max_entities_per_call: int | None = None + detection_artifact_rows: int = 0 + seed_entity_count: float | None = None + seed_validation_candidate_count: float | None = None + estimated_seed_validation_chunk_count: float | None = None + augmented_entity_count: float | None = None + augmented_new_final_value_count: float | None = None + artifact_final_entity_count: float | None = None + artifact_final_detector_entity_count: float | None = None + artifact_final_rule_entity_count: float | None = None + artifact_final_augmenter_entity_count: float | None = None + artifact_final_entity_signature_count: float | None = None + artifact_final_entity_signature_hashes: list[str] = Field(default_factory=list) + artifact_final_entity_signature_labels: dict[str, str] = Field(default_factory=dict) + artifact_final_entity_signature_details: dict[str, dict[str, Any]] = Field(default_factory=dict) + + +class GroupAnalysisRow(BaseModel): + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + case_count: int + failed_case_count: int = 0 + failed_case_rate: float | None = None + error_stage_count: int = 0 + error_ndd_workflow_count: int = 0 + error_model_workflow_count: int = 0 + median_pipeline_elapsed_sec: float | None = None + median_ndd_elapsed_sec_total: float | None = None + median_observed_total_requests: float | None = None + median_observed_successful_requests: float | None = None + median_observed_input_tokens: float | None = None + median_observed_output_tokens: float | None = None + median_observed_total_tokens: float | None = None + median_observed_failed_requests: float | None = None + median_observed_failed_request_rate: float | None = None + median_observed_bridge_fallback_requests: float | None = None + median_observed_non_bridge_total_requests: float | None = None + median_observed_non_bridge_failed_requests: float | None = None + median_observed_non_bridge_failed_request_rate: float | None = None + median_route_total_row_count: float | None = None + median_route_rule_row_count: float | None = None + median_route_fallback_row_count: float | None = None + median_final_entity_count: float | None = None + median_replacement_missing_final_entity_count: float | None = None + median_replacement_missing_final_value_count: float | None = None + replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + median_replacement_synthetic_original_collision_count: float | None = None + median_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + sum_original_value_leak_count: float | None = None + leaking_case_count: int = 0 + median_original_value_leak_count: float | None = None + median_seed_entity_count: float | None = None + median_seed_validation_candidate_count: float | None = None + median_estimated_seed_validation_chunk_count: float | None = None + median_augmented_entity_count: float | None = None + median_augmented_new_final_value_count: float | None = None + median_artifact_final_entity_count: float | None = None + median_artifact_final_detector_entity_count: float | None = None + median_artifact_final_rule_entity_count: float | None = None + median_artifact_final_augmenter_entity_count: float | None = None + median_artifact_final_entity_signature_count: float | None = None + + +class ModelUsageAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + repetition: int | None = None + case_id: str + run_id: str + workflow_name: str | None = None + model_alias: str | None = None + model_name: str + model_provider_name: str | None = None + ndd_elapsed_sec: float | None = None + observed_total_requests: int = 0 + observed_successful_requests: int = 0 + observed_failed_requests: int = 0 + observed_input_tokens: int = 0 + observed_output_tokens: int = 0 + observed_total_tokens: int = 0 + observed_reasoning_tokens: int | None = None + observed_failed_request_rate: float | None = None + + +class ModelUsageGroupAnalysisRow(BaseModel): + workload_id: str | None = None + config_id: str | None = None + experimental_detection_strategy: str | None = None + experimental_replacement_strategy: str | None = None + dd_parser_compat: str | None = None + workflow_name: str | None = None + model_alias: str | None = None + model_name: str + model_provider_name: str | None = None + case_count: int + workflow_count: int + sum_observed_total_requests: int = 0 + sum_observed_successful_requests: int = 0 + sum_observed_failed_requests: int = 0 + sum_observed_input_tokens: int = 0 + sum_observed_output_tokens: int = 0 + sum_observed_total_tokens: int = 0 + sum_observed_reasoning_tokens: int | None = None + observed_failed_request_rate: float | None = None + median_observed_total_requests: float | None = None + median_observed_failed_requests: float | None = None + median_observed_total_tokens: float | None = None + + +class BenchmarkOutputAnalysis(BaseModel): + benchmark_dir: str + detection_artifacts_path: str | None = None + cases: list[CaseAnalysisRow] = Field(default_factory=list) + groups: list[GroupAnalysisRow] = Field(default_factory=list) + model_usage: list[ModelUsageAnalysisRow] = Field(default_factory=list) + model_usage_groups: list[ModelUsageGroupAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def case_count(self) -> int: + return len(self.cases) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + @computed_field + @property + def model_usage_count(self) -> int: + return len(self.model_usage) + + @computed_field + @property + def model_usage_group_count(self) -> int: + return len(self.model_usage_groups) + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_benchmark_output( + benchmark_dir: Path, + *, + detection_artifacts: Path | None = None, +) -> BenchmarkOutputAnalysis: + measurements = read_jsonl_table(benchmark_dir / "measurements.jsonl", required=True) + artifacts_path = detection_artifacts or benchmark_dir / "detection-artifacts.jsonl" + artifacts = read_jsonl_table(artifacts_path, required=detection_artifacts is not None) + traces = read_trace_summary_table(benchmark_dir / "traces") + cases = [ + _build_case_row(case_id, measurements, artifacts, traces) + for case_id in _case_ids(measurements, artifacts, traces) + ] + model_usage = build_model_usage_rows(measurements) + return BenchmarkOutputAnalysis( + benchmark_dir=str(benchmark_dir), + detection_artifacts_path=str(artifacts_path) if not artifacts.empty else None, + cases=cases, + groups=build_group_rows(cases), + model_usage=model_usage, + model_usage_groups=build_model_usage_group_rows(model_usage), + ) + + +def read_jsonl_table(path: Path, *, required: bool) -> pd.DataFrame: + if not path.exists(): + if required: + raise ValueError(f"input path does not exist: {path}") + return pd.DataFrame() + if path.is_dir(): + raise ValueError(f"input path is a directory: {path}") + raw = pd.read_json(path, lines=True) + if raw.empty: + return raw + return pd.json_normalize(raw.to_dict("records"), sep=".") + + +def read_trace_summary_table(trace_path: Path) -> pd.DataFrame: + """Read DD trace files into a sanitized table with no prompt/response text.""" + if not trace_path.exists(): + return pd.DataFrame() + if trace_path.is_file(): + paths = [trace_path] + elif trace_path.is_dir(): + paths = sorted(trace_path.rglob("*.jsonl")) + else: + raise ValueError(f"trace path is not a file or directory: {trace_path}") + + rows: list[dict[str, Any]] = [] + for path in paths: + for line in path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + record = json.loads(line) + if not isinstance(record, dict) or record.get("record_type") != "dd_message_trace": + continue + run_tags = record.get("run_tags") if isinstance(record.get("run_tags"), dict) else {} + rows.append( + { + "record_type": "dd_message_trace", + "run_id": record.get("run_id"), + "run_tags.case_id": run_tags.get("case_id"), + "run_tags.workload_id": run_tags.get("workload_id"), + "run_tags.config_id": run_tags.get("config_id"), + "run_tags.experimental_detection_strategy": run_tags.get("experimental_detection_strategy"), + "run_tags.experimental_replacement_strategy": run_tags.get("experimental_replacement_strategy"), + "run_tags.dd_parser_compat": run_tags.get("dd_parser_compat"), + "run_tags.repetition": run_tags.get("repetition"), + "workflow_name": record.get("workflow_name"), + "model_alias": record.get("model_alias"), + "status": record.get("status"), + "error_type": record.get("error_type"), + "is_async": record.get("is_async"), + } + ) + return pd.DataFrame(rows) + + +def _case_ids(*frames: pd.DataFrame) -> list[str]: + values: set[str] = set() + for dataframe in frames: + for column in ("run_tags.case_id", "case_id", "run_id"): + if column in dataframe.columns: + values.update(str(value) for value in dataframe[column].dropna().tolist()) + return sorted(values) + + +def _build_case_row( + case_id: str, + measurements: pd.DataFrame, + artifacts: pd.DataFrame, + traces: pd.DataFrame, +) -> CaseAnalysisRow: + measurement_rows = _rows_for_case(measurements, case_id) + artifact_rows = _rows_for_case(artifacts, case_id) + trace_rows = _rows_for_case(traces, case_id) + record_rows = _records_of_type(measurement_rows, "record") + ndd_rows = _records_of_type(measurement_rows, "ndd_workflow") + model_rows = _model_workflow_rows(measurement_rows) + stage_rows = _records_of_type(measurement_rows, "stage") + pipeline_rows = _pipeline_stage_rows(measurement_rows) + validation_max_entities_per_call = _first_int([measurement_rows], ["detect.validation_max_entities_per_call"]) + request_metrics = _case_request_metrics(model_rows) + return CaseAnalysisRow( + suite_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.suite_id", "suite_id"]), + workload_id=_first_value( + [measurement_rows, artifact_rows, trace_rows], ["run_tags.workload_id", "workload_id"] + ), + config_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.config_id", "config_id"]), + experimental_detection_strategy=_first_value([measurement_rows], ["run_tags.experimental_detection_strategy"]), + experimental_replacement_strategy=_first_value( + [measurement_rows, trace_rows], + ["run_tags.experimental_replacement_strategy"], + ), + dd_parser_compat=_first_value([measurement_rows], ["run_tags.dd_parser_compat"]), + repetition=_first_int([measurement_rows, artifact_rows, trace_rows], ["run_tags.repetition", "repetition"]), + case_id=case_id, + run_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_id"]) or case_id, + **_case_failure_metrics(stage_rows=stage_rows, ndd_rows=ndd_rows, model_rows=model_rows), + pipeline_elapsed_sec=_sum_or_none(pipeline_rows, "elapsed_sec"), + ndd_workflow_count=len(ndd_rows), + ndd_elapsed_sec_total=_sum_or_zero(ndd_rows, "elapsed_sec"), + **request_metrics, + **_case_trace_metrics(trace_rows, request_metrics=request_metrics), + route_total_row_count=_sum_or_none(model_rows, "route_total_row_count"), + route_rule_row_count=_sum_or_none(model_rows, "route_rule_row_count"), + route_fallback_row_count=_sum_or_none(model_rows, "route_fallback_row_count"), + final_entity_count=_coalesce_number( + _sum_or_none(record_rows, "final_entity_count"), + _sum_or_none(artifact_rows, "final_entity_count"), + ), + replacement_count=_sum_or_none(record_rows, "replacement_count"), + replacement_missing_final_entity_count=_sum_or_none(record_rows, "replacement_missing_final_entity_count"), + replacement_missing_final_entity_label_counts=_sum_prefixed_ints( + record_rows, + "replacement_missing_final_entity_label_counts.", + ), + replacement_missing_final_value_count=_sum_or_none(record_rows, "replacement_missing_final_value_count"), + replacement_synthetic_original_collision_count=_sum_or_none( + record_rows, + "replacement_synthetic_original_collision_count", + ), + replacement_synthetic_original_collision_label_counts=_sum_prefixed_ints( + record_rows, + "replacement_synthetic_original_collision_label_counts.", + ), + replacement_synthetic_original_collision_value_count=_sum_or_none( + record_rows, + "replacement_synthetic_original_collision_value_count", + ), + original_value_leak_count=_sum_or_none(record_rows, "original_value_leak_count"), + original_value_leak_record_count=_positive_count(record_rows, "original_value_leak_count"), + original_value_leak_label_counts=_sum_prefixed_ints(record_rows, "original_value_leak_label_counts."), + validation_max_entities_per_call=validation_max_entities_per_call, + **_case_artifact_metrics( + artifact_rows, + validation_max_entities_per_call=validation_max_entities_per_call, + ), + ) + + +def _case_request_metrics(model_rows: pd.DataFrame) -> dict[str, int | float | None]: + observed_total_requests = int(_sum_or_zero(model_rows, "observed_total_requests")) + observed_failed_requests = int(_sum_or_zero(model_rows, "observed_failed_requests")) + return { + "observed_total_requests": observed_total_requests, + "observed_successful_requests": int(_sum_or_zero(model_rows, "observed_successful_requests")), + "observed_input_tokens": int(_sum_or_zero(model_rows, "observed_input_tokens")), + "observed_output_tokens": int(_sum_or_zero(model_rows, "observed_output_tokens")), + "observed_total_tokens": int(_sum_or_zero(model_rows, "observed_total_tokens")), + "observed_failed_requests": observed_failed_requests, + "observed_failed_request_rate": _request_failure_rate( + failed=observed_failed_requests, + total=observed_total_requests, + ), + } + + +def _case_trace_metrics( + trace_rows: pd.DataFrame, + *, + request_metrics: dict[str, int | float | None], +) -> dict[str, int | float | None]: + trace_record_count = len(trace_rows) + if trace_record_count == 0: + return { + "dd_trace_record_count": 0, + "dd_trace_error_count": 0, + "dd_trace_sync_client_unavailable_count": 0, + "observed_bridge_fallback_requests": None, + "observed_non_bridge_total_requests": None, + "observed_non_bridge_failed_requests": None, + "observed_non_bridge_failed_request_rate": None, + } + + status = trace_rows["status"].astype(str) if "status" in trace_rows.columns else pd.Series(dtype=str) + error_type = trace_rows["error_type"].astype(str) if "error_type" in trace_rows.columns else pd.Series(dtype=str) + error_count = int((status == "error").sum()) + bridge_fallbacks = int(((status == "error") & (error_type == _SYNC_CLIENT_UNAVAILABLE_ERROR)).sum()) + observed_total = int(request_metrics["observed_total_requests"] or 0) + observed_failed = int(request_metrics["observed_failed_requests"] or 0) + non_bridge_total = max(observed_total - bridge_fallbacks, 0) + non_bridge_failed = max(observed_failed - bridge_fallbacks, 0) + return { + "dd_trace_record_count": trace_record_count, + "dd_trace_error_count": error_count, + "dd_trace_sync_client_unavailable_count": bridge_fallbacks, + "observed_bridge_fallback_requests": bridge_fallbacks, + "observed_non_bridge_total_requests": non_bridge_total, + "observed_non_bridge_failed_requests": non_bridge_failed, + "observed_non_bridge_failed_request_rate": _request_failure_rate( + failed=non_bridge_failed, + total=non_bridge_total, + ), + } + + +def _case_failure_metrics( + *, + stage_rows: pd.DataFrame, + ndd_rows: pd.DataFrame, + model_rows: pd.DataFrame, +) -> dict[str, bool | int]: + error_stage_count = _error_status_count(stage_rows) + error_ndd_workflow_count = _error_status_count(ndd_rows) + error_model_workflow_count = _error_status_count(model_rows) + return { + "case_failed": error_stage_count > 0 or error_ndd_workflow_count > 0 or error_model_workflow_count > 0, + "error_stage_count": error_stage_count, + "error_ndd_workflow_count": error_ndd_workflow_count, + "error_model_workflow_count": error_model_workflow_count, + } + + +def _error_status_count(rows: pd.DataFrame) -> int: + if "status" not in rows.columns: + return 0 + statuses = rows["status"].astype(str).str.lower() + return int(statuses.isin({"error", "failed"}).sum()) + + +def _case_artifact_metrics( + artifact_rows: pd.DataFrame, + *, + validation_max_entities_per_call: int | None, +) -> dict[str, int | float | list[str] | dict[str, str] | None]: + signature_hashes = _artifact_signature_hashes(artifact_rows) + return { + "detection_artifact_rows": len(artifact_rows), + "seed_entity_count": _sum_or_none(artifact_rows, "seed_entity_count"), + "seed_validation_candidate_count": _sum_or_none(artifact_rows, "seed_validation_candidate_count"), + "estimated_seed_validation_chunk_count": _estimated_validation_chunk_count( + artifact_rows, + validation_max_entities_per_call=validation_max_entities_per_call, + ), + "augmented_entity_count": _sum_or_none(artifact_rows, "augmented_entity_count"), + "augmented_new_final_value_count": _sum_or_none(artifact_rows, "augmented_new_final_value_count"), + "artifact_final_entity_count": _sum_or_none(artifact_rows, "final_entity_count"), + "artifact_final_detector_entity_count": _sum_or_none(artifact_rows, "final_source_counts.detector"), + "artifact_final_rule_entity_count": _sum_or_none(artifact_rows, "final_source_counts.rule"), + "artifact_final_augmenter_entity_count": _sum_or_none(artifact_rows, "final_source_counts.augmenter"), + "artifact_final_entity_signature_count": _signature_count(artifact_rows, signature_hashes=signature_hashes), + "artifact_final_entity_signature_hashes": signature_hashes, + "artifact_final_entity_signature_labels": _artifact_signature_labels(artifact_rows), + "artifact_final_entity_signature_details": _artifact_signature_details(artifact_rows), + } + + +def _rows_for_case(dataframe: pd.DataFrame, case_id: str) -> pd.DataFrame: + if dataframe.empty: + return dataframe + masks = [ + dataframe[column].astype(str) == case_id + for column in ("run_tags.case_id", "case_id", "run_id") + if column in dataframe.columns + ] + if not masks: + return dataframe.iloc[0:0] + mask = masks[0] + for next_mask in masks[1:]: + mask = mask | next_mask + return dataframe[mask] + + +def _records_of_type(dataframe: pd.DataFrame, record_type: str) -> pd.DataFrame: + if "record_type" not in dataframe.columns: + return dataframe.iloc[0:0] + return dataframe[dataframe["record_type"] == record_type] + + +def _records_of_types(dataframe: pd.DataFrame, record_types: set[str]) -> pd.DataFrame: + if "record_type" not in dataframe.columns: + return dataframe.iloc[0:0] + return dataframe[dataframe["record_type"].isin(record_types)] + + +def _model_workflow_rows(dataframe: pd.DataFrame) -> pd.DataFrame: + return _records_of_types(dataframe, {"ndd_workflow", "model_workflow"}) + + +def _pipeline_stage_rows(dataframe: pd.DataFrame) -> pd.DataFrame: + stages = _records_of_type(dataframe, "stage") + if "stage" not in stages.columns: + return stages.iloc[0:0] + return stages[stages["stage"] == "Anonymizer._run_internal"] + + +_MODEL_USAGE_SUFFIXES = { + ".request_usage.total_requests": "observed_total_requests", + ".request_usage.successful_requests": "observed_successful_requests", + ".request_usage.failed_requests": "observed_failed_requests", + ".token_usage.input_tokens": "observed_input_tokens", + ".token_usage.output_tokens": "observed_output_tokens", + ".token_usage.total_tokens": "observed_total_tokens", + ".token_usage.reasoning_tokens": "observed_reasoning_tokens", +} + +_MODEL_USAGE_METADATA_SUFFIXES = { + ".model_alias": "model_alias", + ".model_name": "model_name", + ".model_provider_name": "model_provider_name", +} + + +def build_model_usage_rows(measurements: pd.DataFrame) -> list[ModelUsageAnalysisRow]: + model_rows = _model_workflow_rows(measurements) + if model_rows.empty: + return [] + model_usage_keys = _model_usage_keys(model_rows.columns) + rows: list[ModelUsageAnalysisRow] = [] + for _, measurement in model_rows.iterrows(): + data = measurement.to_dict() + case_id = _string_from_row(data, ["run_tags.case_id", "run_id"]) + run_id = _string_from_row(data, ["run_id", "run_tags.case_id"]) + if case_id is None or run_id is None: + continue + for model_usage_key in model_usage_keys: + usage = _model_usage_metrics(data, model_usage_key) + if not _has_observed_model_usage(usage): + continue + metadata = _model_usage_metadata(data, model_usage_key) + rows.append( + ModelUsageAnalysisRow( + suite_id=_string_from_row(data, ["run_tags.suite_id"]), + workload_id=_string_from_row(data, ["run_tags.workload_id"]), + config_id=_string_from_row(data, ["run_tags.config_id"]), + experimental_detection_strategy=_string_from_row( + data, ["run_tags.experimental_detection_strategy"] + ), + experimental_replacement_strategy=_string_from_row( + data, ["run_tags.experimental_replacement_strategy"] + ), + dd_parser_compat=_string_from_row(data, ["run_tags.dd_parser_compat"]), + repetition=_int_from_row(data, ["run_tags.repetition"]), + case_id=case_id, + run_id=run_id, + workflow_name=_string_from_row(data, ["workflow_name"]), + model_alias=metadata.get("model_alias"), + model_name=metadata.get("model_name") or model_usage_key, + model_provider_name=metadata.get("model_provider_name"), + ndd_elapsed_sec=_float_from_row(data, ["elapsed_sec"]), + **usage, + ) + ) + return rows + + +def _model_usage_keys(columns: pd.Index) -> list[str]: + keys: set[str] = set() + for column in columns: + parsed = _model_usage_column_parts(str(column)) + if parsed is not None: + keys.add(parsed[0]) + return sorted(keys) + + +def _model_usage_column_parts(column: str) -> tuple[str, str] | None: + prefix = "model_usage." + if not column.startswith(prefix): + return None + for suffix, metric in {**_MODEL_USAGE_SUFFIXES, **_MODEL_USAGE_METADATA_SUFFIXES}.items(): + if column.endswith(suffix): + return column[len(prefix) : -len(suffix)], metric + return None + + +def _model_usage_metrics(data: dict[str, Any], model_usage_key: str) -> dict[str, int | float | None]: + values: dict[str, int | float | None] = { + "observed_total_requests": 0, + "observed_successful_requests": 0, + "observed_failed_requests": 0, + "observed_input_tokens": 0, + "observed_output_tokens": 0, + "observed_total_tokens": 0, + "observed_reasoning_tokens": None, + "observed_failed_request_rate": None, + } + for suffix, metric in _MODEL_USAGE_SUFFIXES.items(): + value = data.get(f"model_usage.{model_usage_key}{suffix}") + if value is None or pd.isna(value): + continue + values[metric] = _coerce_int(value) + values["observed_failed_request_rate"] = _request_failure_rate( + failed=values["observed_failed_requests"], + total=values["observed_total_requests"], + ) + return values + + +def _model_usage_metadata(data: dict[str, Any], model_usage_key: str) -> dict[str, str | None]: + values: dict[str, str | None] = { + "model_alias": None, + "model_name": None, + "model_provider_name": None, + } + for suffix, field_name in _MODEL_USAGE_METADATA_SUFFIXES.items(): + value = data.get(f"model_usage.{model_usage_key}{suffix}") + if value is None or pd.isna(value): + continue + values[field_name] = str(value) + return values + + +def _has_observed_model_usage(usage: dict[str, int | float | None]) -> bool: + return any(value not in (None, 0) for value in usage.values()) + + +def _string_from_row(data: dict[str, Any], columns: list[str]) -> str | None: + for column in columns: + value = data.get(column) + if value is not None and not pd.isna(value): + return str(value) + return None + + +def _int_from_row(data: dict[str, Any], columns: list[str]) -> int | None: + value = _string_from_row(data, columns) + return int(float(value)) if value is not None else None + + +def _float_from_row(data: dict[str, Any], columns: list[str]) -> float | None: + value = _string_from_row(data, columns) + return float(value) if value is not None else None + + +def _coerce_int(value: Any) -> int: + return int(float(value)) + + +def _first_value(frames: list[pd.DataFrame], columns: list[str]) -> str | None: + for frame in frames: + for column in columns: + if column not in frame.columns: + continue + values = frame[column].dropna() + if not values.empty: + return str(values.iloc[0]) + return None + + +def _first_int(frames: list[pd.DataFrame], columns: list[str]) -> int | None: + value = _first_value(frames, columns) + return int(float(value)) if value is not None else None + + +def _coalesce_number(*values: float | None) -> float | None: + for value in values: + if value is not None: + return value + return None + + +def _artifact_signature_hashes(artifact_rows: pd.DataFrame) -> list[str]: + if "final_entity_signature_hashes" not in artifact_rows.columns: + return [] + values: set[str] = set() + for raw in artifact_rows["final_entity_signature_hashes"].dropna(): + values.update(_coerce_string_list(raw)) + return sorted(values) + + +def _artifact_signature_labels(artifact_rows: pd.DataFrame) -> dict[str, str]: + labels: dict[str, str] = {} + if "final_entity_signature_labels" in artifact_rows.columns: + for raw in artifact_rows["final_entity_signature_labels"].dropna(): + labels.update(_coerce_string_dict(raw)) + for column in artifact_rows.columns: + prefix = "final_entity_signature_labels." + if not column.startswith(prefix): + continue + signature_hash = column.removeprefix(prefix) + for value in artifact_rows[column].dropna(): + labels[signature_hash] = str(value) + return dict(sorted(labels.items())) + + +def _artifact_signature_details(artifact_rows: pd.DataFrame) -> dict[str, dict[str, Any]]: + details: dict[str, dict[str, Any]] = {} + if "final_entity_signature_details" in artifact_rows.columns: + for raw in artifact_rows["final_entity_signature_details"].dropna(): + details.update(_coerce_detail_map(raw)) + prefix = "final_entity_signature_details." + for column in artifact_rows.columns: + if not column.startswith(prefix): + continue + remainder = column.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + for value in artifact_rows[column].dropna(): + details.setdefault(signature_hash, {})[field] = _json_scalar(value) + return dict(sorted(details.items())) + + +def _coerce_detail_map(raw: object) -> dict[str, dict[str, Any]]: + if isinstance(raw, str): + try: + raw = json.loads(raw) + except json.JSONDecodeError: + return {} + if not isinstance(raw, dict): + return {} + details: dict[str, dict[str, Any]] = {} + for signature_hash, value in raw.items(): + if isinstance(value, dict): + details[str(signature_hash)] = {str(key): _json_scalar(item) for key, item in value.items()} + return details + + +def _json_scalar(value: object) -> Any: + if hasattr(value, "item"): + try: + return value.item() + except ValueError: + return value + return value + + +def _coerce_string_list(raw: object) -> list[str]: + if isinstance(raw, list): + return [str(item) for item in raw] + return [] + + +def _coerce_string_dict(raw: object) -> dict[str, str]: + if isinstance(raw, dict): + return {str(key): str(value) for key, value in raw.items()} + return {} + + +def _signature_count(artifact_rows: pd.DataFrame, *, signature_hashes: list[str]) -> float | None: + if signature_hashes: + return float(len(signature_hashes)) + return _sum_or_none(artifact_rows, "final_entity_signature_count") + + +def _sum_or_zero(dataframe: pd.DataFrame, column: str) -> float: + if column not in dataframe.columns: + return 0.0 + return float(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) + + +def _sum_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + if column not in dataframe.columns: + return None + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.sum()) + + +def _positive_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + values = pd.to_numeric(dataframe[column], errors="coerce").fillna(0) + return int((values > 0).sum()) + + +def _sum_prefixed_ints(dataframe: pd.DataFrame, prefix: str) -> dict[str, int]: + totals: dict[str, int] = {} + base_column = prefix.removesuffix(".") + if base_column in dataframe.columns: + for value in dataframe[base_column].dropna().tolist(): + for key, count in _coerce_count_mapping(value).items(): + totals[key] = totals.get(key, 0) + count + for column in sorted(col for col in dataframe.columns if col.startswith(prefix)): + value = _sum_int_or_zero(dataframe, column) + if value: + totals[column.removeprefix(prefix)] = value + return totals + + +def _coerce_count_mapping(value: object) -> dict[str, int]: + payload = value + if isinstance(value, str): + try: + payload = json.loads(value) + except json.JSONDecodeError: + return {} + if not isinstance(payload, dict): + return {} + counts: dict[str, int] = {} + for key, count in payload.items(): + numeric = pd.to_numeric(pd.Series([count]), errors="coerce").dropna() + if not numeric.empty and numeric.iloc[0]: + counts[str(key)] = int(numeric.iloc[0]) + return counts + + +def build_group_rows(cases: list[CaseAnalysisRow]) -> list[GroupAnalysisRow]: + if not cases: + return [] + table = pd.DataFrame([case.model_dump() for case in cases]) + rows: list[GroupAnalysisRow] = [] + group_columns = [ + "workload_id", + "config_id", + "experimental_detection_strategy", + "experimental_replacement_strategy", + ] + for keys, group in table.groupby(group_columns, dropna=False): + rows.append(_build_group_row(keys, group)) + return rows + + +def build_model_usage_group_rows(model_usage: list[ModelUsageAnalysisRow]) -> list[ModelUsageGroupAnalysisRow]: + if not model_usage: + return [] + table = pd.DataFrame([row.model_dump() for row in model_usage]) + rows: list[ModelUsageGroupAnalysisRow] = [] + group_columns = [ + "workload_id", + "config_id", + "experimental_detection_strategy", + "experimental_replacement_strategy", + "dd_parser_compat", + "workflow_name", + "model_alias", + "model_name", + "model_provider_name", + ] + for keys, group in table.groupby(group_columns, dropna=False): + rows.append(_build_model_usage_group_row(keys, group)) + return rows + + +def _build_model_usage_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> ModelUsageGroupAnalysisRow: + ( + workload_id, + config_id, + detection_strategy, + replacement_strategy, + dd_parser_compat, + workflow_name, + model_alias, + model_name, + provider_name, + ) = keys + reasoning_sum = _sum_int_or_none(group, "observed_reasoning_tokens") + total_requests = _sum_int_or_zero(group, "observed_total_requests") + failed_requests = _sum_int_or_zero(group, "observed_failed_requests") + return ModelUsageGroupAnalysisRow( + workload_id=_none_if_nan(workload_id), + config_id=_none_if_nan(config_id), + experimental_detection_strategy=_none_if_nan(detection_strategy), + experimental_replacement_strategy=_none_if_nan(replacement_strategy), + dd_parser_compat=_none_if_nan(dd_parser_compat), + workflow_name=_none_if_nan(workflow_name), + model_alias=_none_if_nan(model_alias), + model_name=str(model_name), + model_provider_name=_none_if_nan(provider_name), + case_count=int(group["case_id"].nunique()), + workflow_count=len(group), + sum_observed_total_requests=total_requests, + sum_observed_successful_requests=_sum_int_or_zero(group, "observed_successful_requests"), + sum_observed_failed_requests=failed_requests, + sum_observed_input_tokens=_sum_int_or_zero(group, "observed_input_tokens"), + sum_observed_output_tokens=_sum_int_or_zero(group, "observed_output_tokens"), + sum_observed_total_tokens=_sum_int_or_zero(group, "observed_total_tokens"), + sum_observed_reasoning_tokens=reasoning_sum, + observed_failed_request_rate=_request_failure_rate(failed=failed_requests, total=total_requests), + median_observed_total_requests=_median_or_none(group, "observed_total_requests"), + median_observed_failed_requests=_median_or_none(group, "observed_failed_requests"), + median_observed_total_tokens=_median_or_none(group, "observed_total_tokens"), + ) + + +def _build_group_row(keys: tuple[Any, Any, Any, Any], group: pd.DataFrame) -> GroupAnalysisRow: + workload_id, config_id, detection_strategy, replacement_strategy = keys + case_count = int(group["case_id"].nunique()) + failed_case_count = _sum_bool_or_zero(group, "case_failed") + return GroupAnalysisRow( + workload_id=_none_if_nan(workload_id), + config_id=_none_if_nan(config_id), + experimental_detection_strategy=_none_if_nan(detection_strategy), + experimental_replacement_strategy=_none_if_nan(replacement_strategy), + case_count=case_count, + failed_case_count=failed_case_count, + failed_case_rate=_request_failure_rate(failed=failed_case_count, total=case_count), + error_stage_count=_sum_int_or_zero(group, "error_stage_count"), + error_ndd_workflow_count=_sum_int_or_zero(group, "error_ndd_workflow_count"), + error_model_workflow_count=_sum_int_or_zero(group, "error_model_workflow_count"), + median_pipeline_elapsed_sec=_median_or_none(group, "pipeline_elapsed_sec"), + median_ndd_elapsed_sec_total=_median_or_none(group, "ndd_elapsed_sec_total"), + median_observed_total_requests=_median_or_none(group, "observed_total_requests"), + median_observed_successful_requests=_median_or_none(group, "observed_successful_requests"), + median_observed_input_tokens=_median_or_none(group, "observed_input_tokens"), + median_observed_output_tokens=_median_or_none(group, "observed_output_tokens"), + median_observed_total_tokens=_median_or_none(group, "observed_total_tokens"), + median_observed_failed_requests=_median_or_none(group, "observed_failed_requests"), + median_observed_failed_request_rate=_median_or_none(group, "observed_failed_request_rate"), + median_observed_bridge_fallback_requests=_median_or_none(group, "observed_bridge_fallback_requests"), + median_observed_non_bridge_total_requests=_median_or_none(group, "observed_non_bridge_total_requests"), + median_observed_non_bridge_failed_requests=_median_or_none(group, "observed_non_bridge_failed_requests"), + median_observed_non_bridge_failed_request_rate=_median_or_none( + group, + "observed_non_bridge_failed_request_rate", + ), + median_route_total_row_count=_median_or_none(group, "route_total_row_count"), + median_route_rule_row_count=_median_or_none(group, "route_rule_row_count"), + median_route_fallback_row_count=_median_or_none(group, "route_fallback_row_count"), + median_final_entity_count=_median_or_none(group, "final_entity_count"), + median_replacement_missing_final_entity_count=_median_or_none( + group, + "replacement_missing_final_entity_count", + ), + median_replacement_missing_final_value_count=_median_or_none(group, "replacement_missing_final_value_count"), + replacement_missing_final_entity_label_counts=_sum_prefixed_ints( + group, + "replacement_missing_final_entity_label_counts.", + ), + median_replacement_synthetic_original_collision_count=_median_or_none( + group, + "replacement_synthetic_original_collision_count", + ), + median_replacement_synthetic_original_collision_value_count=_median_or_none( + group, + "replacement_synthetic_original_collision_value_count", + ), + replacement_synthetic_original_collision_label_counts=_sum_prefixed_ints( + group, + "replacement_synthetic_original_collision_label_counts.", + ), + sum_original_value_leak_count=_sum_or_none(group, "original_value_leak_count"), + leaking_case_count=_positive_count(group, "original_value_leak_count"), + median_original_value_leak_count=_median_or_none(group, "original_value_leak_count"), + median_seed_entity_count=_median_or_none(group, "seed_entity_count"), + median_seed_validation_candidate_count=_median_or_none(group, "seed_validation_candidate_count"), + median_estimated_seed_validation_chunk_count=_median_or_none(group, "estimated_seed_validation_chunk_count"), + median_augmented_entity_count=_median_or_none(group, "augmented_entity_count"), + median_augmented_new_final_value_count=_median_or_none(group, "augmented_new_final_value_count"), + median_artifact_final_entity_count=_median_or_none(group, "artifact_final_entity_count"), + median_artifact_final_detector_entity_count=_median_or_none(group, "artifact_final_detector_entity_count"), + median_artifact_final_rule_entity_count=_median_or_none(group, "artifact_final_rule_entity_count"), + median_artifact_final_augmenter_entity_count=_median_or_none(group, "artifact_final_augmenter_entity_count"), + median_artifact_final_entity_signature_count=_median_or_none(group, "artifact_final_entity_signature_count"), + ) + + +def _none_if_nan(value: object) -> str | None: + if pd.isna(value): + return None + return str(value) + + +def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.median()) + + +def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: + return int(_sum_or_zero(dataframe, column)) + + +def _sum_bool_or_zero(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(dataframe[column].fillna(False).astype(bool).sum()) + + +def _sum_int_or_none(dataframe: pd.DataFrame, column: str) -> int | None: + value = _sum_or_none(dataframe, column) + return int(value) if value is not None else None + + +def _request_failure_rate(*, failed: object, total: object) -> float | None: + failed_value = _optional_number(failed) + total_value = _optional_number(total) + if failed_value is None or total_value is None or total_value <= 0: + return None + return failed_value / total_value + + +def _optional_number(value: object) -> float | None: + if value is None or pd.isna(value): + return None + return float(value) + + +def _estimated_validation_chunk_count( + artifact_rows: pd.DataFrame, + *, + validation_max_entities_per_call: int | None, +) -> float | None: + if validation_max_entities_per_call is None or validation_max_entities_per_call <= 0: + return None + if "seed_validation_candidate_count" not in artifact_rows.columns: + return None + counts = pd.to_numeric(artifact_rows["seed_validation_candidate_count"], errors="coerce").dropna() + if counts.empty: + return None + return float(sum(math.ceil(count / validation_max_entities_per_call) for count in counts if count > 0)) + + +def write_analysis_tables( + result: BenchmarkOutputAnalysis, + output_dir: Path, + export_format: ExportFormat, +) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows( + result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, CaseAnalysisRow + ), + _write_model_rows( + result.groups, output_dir / f"group_analysis.{export_format.value}", export_format, GroupAnalysisRow + ), + _write_model_rows( + result.model_usage, + output_dir / f"model_analysis.{export_format.value}", + export_format, + ModelUsageAnalysisRow, + ), + _write_model_rows( + result.model_usage_groups, + output_dir / f"model_group_analysis.{export_format.value}", + export_format, + ModelUsageGroupAnalysisRow, + ), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows( + rows: list[BaseModel], + path: Path, + export_format: ExportFormat, + row_model: type[BaseModel], +) -> TableSummary: + if rows: + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + else: + table = pd.DataFrame(columns=list(row_model.model_fields)) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def render_result(result: BenchmarkOutputAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Analyzed {result.case_count} case(s) across {result.group_count} group(s); " + f"model rows={result.model_usage_count}" + ] + for group in result.groups: + label = ( + f"{group.workload_id}/{group.config_id}/" + f"{group.experimental_detection_strategy}/{group.experimental_replacement_strategy}" + ) + lines.append( + f"- {label}: cases={group.case_count}, median_entities={group.median_final_entity_count}, " + f"failed_cases={group.failed_case_count}/{group.case_count}, " + f"median_requests={group.median_observed_total_requests}, median_tokens={group.median_observed_total_tokens}, " + f"median_failed_request_rate={group.median_observed_failed_request_rate}, " + f"median_aug_new_final={group.median_augmented_new_final_value_count}" + ) + return "\n".join(lines) + + +@app.default +def main( + benchmark_dir: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + detection_artifacts: Annotated[Path | None, cyclopts.Parameter("--detection-artifacts")] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_benchmark_output(benchmark_dir, detection_artifacts=detection_artifacts) + if output is not None: + write_analysis_tables(result, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_dd_traces.py b/tools/measurement/analyze_dd_traces.py new file mode 100644 index 00000000..157beec9 --- /dev/null +++ b/tools/measurement/analyze_dd_traces.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze DataDesigner message traces without emitting raw prompt or response text. + +Usage: + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --output analysis --format csv + uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --json +""" + +from __future__ import annotations + +import json +import logging +import re +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.dd_traces") + +JSON_FENCE_RE = re.compile(r"```\s*json\b", re.IGNORECASE) + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class TraceAnalysisRow(BaseModel): + trace_file: str + line_number: int + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + case_id: str | None = None + run_id: str + workflow_name: str | None = None + model_alias: str | None = None + model_name: str | None = None + model_provider_name: str | None = None + status: str | None = None + error_type: str | None = None + elapsed_sec: float | None = None + message_count: int = 0 + prompt_chars: int = 0 + response_shape: str + response_chars: int = 0 + response_has_thinking: bool = False + response_has_json_fence: bool = False + response_has_embedded_json: bool = False + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +class TraceGroupAnalysisRow(BaseModel): + suite_id: str | None = None + workload_id: str | None = None + config_id: str | None = None + workflow_name: str | None = None + model_alias: str | None = None + model_name: str | None = None + model_provider_name: str | None = None + status: str | None = None + error_type: str | None = None + response_shape: str + trace_record_count: int + error_count: int = 0 + thinking_count: int = 0 + json_fence_count: int = 0 + embedded_json_count: int = 0 + median_elapsed_sec: float | None = None + median_prompt_chars: float | None = None + median_response_chars: float | None = None + sum_input_tokens: int = 0 + sum_output_tokens: int = 0 + sum_total_tokens: int = 0 + + +class TraceAnalysis(BaseModel): + trace_path: str + rows: list[TraceAnalysisRow] = Field(default_factory=list) + groups: list[TraceGroupAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def trace_record_count(self) -> int: + return len(self.rows) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_trace_path(trace_path: Path) -> TraceAnalysis: + rows = [ + _trace_row(record, trace_file=path, line_number=line_number) + for path in _iter_trace_files(trace_path) + for line_number, record in _iter_trace_records(path) + ] + return TraceAnalysis( + trace_path=str(trace_path), + rows=rows, + groups=build_trace_group_rows(rows), + ) + + +def _iter_trace_files(trace_path: Path) -> list[Path]: + if not trace_path.exists(): + raise ValueError(f"trace path does not exist: {trace_path}") + if trace_path.is_file(): + return [trace_path] + if not trace_path.is_dir(): + raise ValueError(f"trace path is not a file or directory: {trace_path}") + return sorted(trace_path.rglob("*.jsonl")) + + +def _iter_trace_records(path: Path) -> list[tuple[int, dict[str, Any]]]: + records: list[tuple[int, dict[str, Any]]] = [] + for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): + if not line.strip(): + continue + payload = json.loads(line) + if isinstance(payload, dict) and payload.get("record_type") == "dd_message_trace": + records.append((line_number, payload)) + return records + + +def _trace_row(record: dict[str, Any], *, trace_file: Path, line_number: int) -> TraceAnalysisRow: + response = record.get("response") if isinstance(record.get("response"), dict) else None + response_content = _string_or_none(response.get("content")) if response is not None else None + reasoning_content = _string_or_none(response.get("reasoning_content")) if response is not None else None + run_tags = record.get("run_tags") if isinstance(record.get("run_tags"), dict) else {} + usage = record.get("usage") if isinstance(record.get("usage"), dict) else {} + return TraceAnalysisRow( + trace_file=str(trace_file), + line_number=line_number, + suite_id=_string_or_none(run_tags.get("suite_id")), + workload_id=_string_or_none(run_tags.get("workload_id")), + config_id=_string_or_none(run_tags.get("config_id")), + case_id=_string_or_none(run_tags.get("case_id")), + run_id=str(record.get("run_id") or ""), + workflow_name=_string_or_none(record.get("workflow_name")), + model_alias=_string_or_none(record.get("model_alias")), + model_name=_string_or_none(record.get("model_name")), + model_provider_name=_string_or_none(record.get("model_provider_name")), + status=_string_or_none(record.get("status")), + error_type=_string_or_none(record.get("error_type")), + elapsed_sec=_float_or_none(record.get("elapsed_sec")), + message_count=_message_count(record.get("messages")), + prompt_chars=_message_text_chars(record.get("messages")), + response_shape=_response_shape(response_content), + response_chars=len(response_content or ""), + response_has_thinking=_has_thinking(response_content, reasoning_content), + response_has_json_fence=_has_json_fence(response_content), + response_has_embedded_json=_has_embedded_json(response_content), + input_tokens=_int_or_none(usage.get("input_tokens")), + output_tokens=_int_or_none(usage.get("output_tokens")), + total_tokens=_int_or_none(usage.get("total_tokens")), + ) + + +def build_trace_group_rows(rows: list[TraceAnalysisRow]) -> list[TraceGroupAnalysisRow]: + if not rows: + return [] + table = pd.DataFrame([row.model_dump() for row in rows]) + group_columns = [ + "suite_id", + "workload_id", + "config_id", + "workflow_name", + "model_alias", + "model_name", + "model_provider_name", + "status", + "error_type", + "response_shape", + ] + return [ + _build_trace_group_row(keys, group) for keys, group in table.groupby(group_columns, dropna=False, sort=True) + ] + + +def _build_trace_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> TraceGroupAnalysisRow: + ( + suite_id, + workload_id, + config_id, + workflow_name, + model_alias, + model_name, + provider_name, + status, + error_type, + response_shape, + ) = keys + return TraceGroupAnalysisRow( + suite_id=_none_if_nan(suite_id), + workload_id=_none_if_nan(workload_id), + config_id=_none_if_nan(config_id), + workflow_name=_none_if_nan(workflow_name), + model_alias=_none_if_nan(model_alias), + model_name=_none_if_nan(model_name), + model_provider_name=_none_if_nan(provider_name), + status=_none_if_nan(status), + error_type=_none_if_nan(error_type), + response_shape=str(response_shape), + trace_record_count=len(group), + error_count=int((group["status"] == "error").sum()), + thinking_count=int(group["response_has_thinking"].sum()), + json_fence_count=int(group["response_has_json_fence"].sum()), + embedded_json_count=int(group["response_has_embedded_json"].sum()), + median_elapsed_sec=_median_or_none(group, "elapsed_sec"), + median_prompt_chars=_median_or_none(group, "prompt_chars"), + median_response_chars=_median_or_none(group, "response_chars"), + sum_input_tokens=_sum_int_or_zero(group, "input_tokens"), + sum_output_tokens=_sum_int_or_zero(group, "output_tokens"), + sum_total_tokens=_sum_int_or_zero(group, "total_tokens"), + ) + + +def _response_shape(content: str | None) -> str: + if not content: + return "none" + if _has_json_fence(content): + return "fenced_json" + stripped = content.strip() + try: + json.loads(stripped) + except json.JSONDecodeError: + return "embedded_json" if _has_embedded_json(content) else "text" + return "raw_json" + + +def _has_thinking(content: str | None, reasoning_content: str | None) -> bool: + return bool(reasoning_content) or "" in (content or "") + + +def _has_json_fence(content: str | None) -> bool: + return bool(content and JSON_FENCE_RE.search(content)) + + +def _has_embedded_json(content: str | None) -> bool: + if not content: + return False + decoder = json.JSONDecoder() + for start, char in enumerate(content): + if char not in "{[": + continue + try: + parsed, _ = decoder.raw_decode(content, start) + except json.JSONDecodeError: + continue + if isinstance(parsed, dict | list): + return True + return False + + +def _message_count(messages: object) -> int: + return len(messages) if isinstance(messages, list) else 0 + + +def _message_text_chars(messages: object) -> int: + if not isinstance(messages, list): + return 0 + return sum(_message_content_chars(message) for message in messages) + + +def _message_content_chars(message: object) -> int: + if not isinstance(message, dict): + return 0 + return _content_chars(message.get("content")) + + +def _content_chars(content: object) -> int: + if isinstance(content, str): + return len(content) + if isinstance(content, list): + return sum(_content_item_chars(item) for item in content) + return 0 + + +def _content_item_chars(item: object) -> int: + if isinstance(item, str): + return len(item) + if isinstance(item, dict): + return len(str(item.get("text") or "")) + return 0 + + +def _string_or_none(value: object) -> str | None: + return str(value) if value is not None and not pd.isna(value) else None + + +def _float_or_none(value: object) -> float | None: + return float(value) if value is not None and not pd.isna(value) else None + + +def _int_or_none(value: object) -> int | None: + return int(value) if value is not None and not pd.isna(value) else None + + +def _none_if_nan(value: object) -> str | None: + if pd.isna(value): + return None + return str(value) + + +def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.median()) + + +def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) + + +def write_analysis_tables(result: TraceAnalysis, output_dir: Path, export_format: ExportFormat) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows(result.rows, output_dir / f"trace_analysis.{export_format.value}", export_format), + _write_model_rows(result.groups, output_dir / f"trace_group_analysis.{export_format.value}", export_format), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows(rows: list[BaseModel], path: Path, export_format: ExportFormat) -> TableSummary: + table = pd.DataFrame([row.model_dump() for row in rows]) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def render_result(result: TraceAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Analyzed {result.trace_record_count} trace record(s) across {result.group_count} group(s)"] + for group in result.groups: + model_label = group.model_alias or group.model_name + label = f"{group.workload_id}/{group.config_id}/{group.workflow_name}/{model_label}/{group.response_shape}" + lines.append( + f"- {label}: traces={group.trace_record_count}, errors={group.error_count}, " + f"thinking={group.thinking_count}, json_fence={group.json_fence_count}, " + f"tokens={group.sum_total_tokens}" + ) + return "\n".join(lines) + + +@app.default +def main( + trace_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_trace_path(trace_path) + if output is not None: + write_analysis_tables(result, output, format) + except (OSError, ValueError, json.JSONDecodeError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_detection_artifacts.py b/tools/measurement/analyze_detection_artifacts.py new file mode 100644 index 00000000..583a2218 --- /dev/null +++ b/tools/measurement/analyze_detection_artifacts.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze detection artifacts for augmentation contribution and label-shape risks. + +Usage: + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts --output detection.jsonl + uv run python tools/measurement/analyze_detection_artifacts.py benchmark/artifacts --json +""" + +from __future__ import annotations + +import ast +import hashlib +import json +import logging +import math +import re +import sys +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Iterable + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_VALIDATION_CANDIDATES, +) +from anonymizer.engine.schemas import ( + AugmentedEntitiesSchema, + EntitiesSchema, + EntitySchema, + ValidationCandidatesSchema, +) + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.detection_artifacts") + +API_KEY_PREFIX_RE = re.compile(r"^(sk-|sk_|sk-ant-|sk-proj-|ghp_|pat-|hf_|xox[a-z]-|ya29\.|aiza|akia|bearer\s+)", re.I) + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +class DetectionArtifactRow(BaseModel): + workflow_name: str + batch_file: str + row_index: int + seed_entity_count: int + seed_validation_candidate_count: int + merged_validation_candidate_count: int + augmented_entity_count: int + final_entity_count: int + augmented_duplicate_seed_value_count: int + augmented_new_value_count: int + augmented_new_final_value_count: int + weak_api_key_shape_count: int + final_entity_signature_count: int + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_entity_signature_labels: dict[str, str] = Field(default_factory=dict) + final_entity_signature_details: dict[str, dict[str, object]] = Field(default_factory=dict) + weak_api_key_shape_label_counts: dict[str, int] = Field(default_factory=dict) + final_label_counts: dict[str, int] = Field(default_factory=dict) + final_source_counts: dict[str, int] = Field(default_factory=dict) + + +class DetectionArtifactAnalysis(BaseModel): + artifact_path: str + rows: list[DetectionArtifactRow] = Field(default_factory=list) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_artifacts( + artifact_path: Path, + *, + parquet_files: Iterable[Path] | None = None, +) -> DetectionArtifactAnalysis: + if not artifact_path.exists() or not artifact_path.is_dir(): + raise ValueError(f"artifact path is not a directory: {artifact_path}") + rows: list[DetectionArtifactRow] = [] + for parquet_file in parquet_files if parquet_files is not None else iter_detection_parquet_files(artifact_path): + rows.extend(_analyze_parquet_file(parquet_file, artifact_root=artifact_path)) + return DetectionArtifactAnalysis(artifact_path=str(artifact_path), rows=rows) + + +def iter_detection_parquet_files(artifact_path: Path) -> list[Path]: + files: list[Path] = [] + for workflow_dir in sorted(path for path in artifact_path.iterdir() if path.is_dir()): + if not workflow_dir.name.startswith("entity-detection"): + continue + files.extend(sorted((workflow_dir / "parquet-files").glob("*.parquet"))) + return files + + +def _analyze_parquet_file(parquet_file: Path, *, artifact_root: Path) -> list[DetectionArtifactRow]: + dataframe = pd.read_parquet(parquet_file) + workflow_name = parquet_file.parents[1].name + batch_file = str(parquet_file.relative_to(artifact_root)) + return [ + _analyze_dataframe_row(row, workflow_name=workflow_name, batch_file=batch_file, row_index=row_index) + for row_index, row in dataframe.iterrows() + ] + + +def _analyze_dataframe_row( + row: pd.Series, + *, + workflow_name: str, + batch_file: str, + row_index: int, +) -> DetectionArtifactRow: + seed_entities = _parse_entities(row.get(COL_SEED_ENTITIES_JSON)) + augmented_entities = _parse_augmented_entities(row.get(COL_AUGMENTED_ENTITIES)) + final_entities = _parse_entities(row.get(COL_DETECTED_ENTITIES)) + return build_detection_artifact_row_from_entities( + workflow_name=workflow_name, + batch_file=batch_file, + row_index=row_index, + seed_entities=seed_entities, + seed_validation_candidate_count=_parse_validation_candidate_count(row.get(COL_SEED_VALIDATION_CANDIDATES)), + merged_validation_candidate_count=_parse_validation_candidate_count(row.get(COL_VALIDATION_CANDIDATES)), + augmented_entities=augmented_entities, + final_entities=final_entities, + ) + + +def build_detection_artifact_row_from_entities( + *, + workflow_name: str, + batch_file: str, + row_index: int, + seed_entities: list[EntitySchema], + seed_validation_candidate_count: int, + merged_validation_candidate_count: int, + augmented_entities: list[EntitySchema], + final_entities: list[EntitySchema], +) -> DetectionArtifactRow: + seed_values = {_value_key(entity.value) for entity in seed_entities} + final_values = {_value_key(entity.value) for entity in final_entities} + augmented_new = [entity for entity in augmented_entities if _value_key(entity.value) not in seed_values] + weak_counts = _weak_api_key_shape_counts(final_entities) + final_entity_signatures = _entity_signature_hashes(final_entities, row_index=int(row_index)) + final_entity_signature_labels = _entity_signature_labels(final_entities, row_index=int(row_index)) + final_entity_signature_details = _entity_signature_details(final_entities, row_index=int(row_index)) + return DetectionArtifactRow( + workflow_name=workflow_name, + batch_file=batch_file, + row_index=int(row_index), + seed_entity_count=len(seed_entities), + seed_validation_candidate_count=seed_validation_candidate_count, + merged_validation_candidate_count=merged_validation_candidate_count, + augmented_entity_count=len(augmented_entities), + final_entity_count=len(final_entities), + augmented_duplicate_seed_value_count=len(augmented_entities) - len(augmented_new), + augmented_new_value_count=len(augmented_new), + augmented_new_final_value_count=sum(1 for entity in augmented_new if _value_key(entity.value) in final_values), + weak_api_key_shape_count=sum(weak_counts.values()), + final_entity_signature_count=len(final_entity_signatures), + final_entity_signature_hashes=final_entity_signatures, + final_entity_signature_labels=final_entity_signature_labels, + final_entity_signature_details=final_entity_signature_details, + weak_api_key_shape_label_counts=dict(weak_counts), + final_label_counts=_count_by(final_entities, "label"), + final_source_counts=_count_by(final_entities, "source"), + ) + + +def _parse_entities(raw: object) -> list[EntitySchema]: + values = _extract_payload_list(raw, key="entities") + parsed = EntitiesSchema.model_validate({"entities": values}) + return [entity for entity in parsed.entities if entity.value and entity.label] + + +def _parse_augmented_entities(raw: object) -> list[EntitySchema]: + values = _extract_payload_list(raw, key="entities") + parsed = AugmentedEntitiesSchema.model_validate({"entities": values}) + return [ + EntitySchema(value=entity.value, label=entity.label, source="augmenter") + for entity in parsed.entities + if entity.value and entity.label + ] + + +def _parse_validation_candidate_count(raw: object) -> int: + values = _extract_payload_list(raw, key="candidates") + parsed = ValidationCandidatesSchema.model_validate({"candidates": values}) + return len(parsed.candidates) + + +def _extract_payload_list(raw: object, *, key: str) -> list[object]: + payload = _coerce_payload(raw) + if isinstance(payload, dict): + return _coerce_list(payload.get(key)) + return _coerce_list(payload) + + +def _coerce_payload(raw: object) -> object: + if _is_missing(raw): + return {} + if hasattr(raw, "tolist"): + raw = raw.tolist() + if not isinstance(raw, str): + return raw + text = raw.strip() + if not text: + return {} + try: + return json.loads(text) + except json.JSONDecodeError: + try: + return ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + + +def _coerce_list(value: object) -> list[object]: + value = _coerce_payload(value) + if isinstance(value, list): + return value + return [] + + +def _is_missing(value: object) -> bool: + return value is None or (isinstance(value, float) and math.isnan(value)) + + +def _value_key(value: str) -> str: + return " ".join(value.casefold().split()) + + +def _entity_signature_hashes(entities: list[EntitySchema], *, row_index: int) -> list[str]: + signatures = {_entity_signature_hash(entity, row_index=row_index) for entity in entities} + return sorted(signatures) + + +def _entity_signature_labels(entities: list[EntitySchema], *, row_index: int) -> dict[str, str]: + labels = {_entity_signature_hash(entity, row_index=row_index): entity.label for entity in entities} + return dict(sorted(labels.items())) + + +def _entity_signature_details(entities: list[EntitySchema], *, row_index: int) -> dict[str, dict[str, object]]: + details = { + _entity_signature_hash(entity, row_index=row_index): { + "label": entity.label, + "source": entity.source, + "row_index": int(row_index), + "start_position": entity.start_position, + "end_position": entity.end_position, + "value_hash": _entity_value_hash(entity.value), + "value_length": len(entity.value), + } + for entity in entities + } + return dict(sorted(details.items())) + + +def _entity_signature_hash(entity: EntitySchema, *, row_index: int) -> str: + payload = json.dumps( + { + "row": row_index, + "label": entity.label, + "value": _value_key(entity.value), + "start": entity.start_position, + "end": entity.end_position, + }, + ensure_ascii=True, + sort_keys=True, + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] + + +def _entity_value_hash(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + + +def _weak_api_key_shape_counts(entities: list[EntitySchema]) -> Counter[str]: + counts: Counter[str] = Counter() + for entity in entities: + if entity.label == "api_key" and not _looks_like_api_key(entity.value): + counts[entity.label] += 1 + return counts + + +def _looks_like_api_key(value: str) -> bool: + stripped = value.strip() + if API_KEY_PREFIX_RE.search(stripped): + return True + compact = re.sub(r"[\s'\";:,/]+", "", stripped) + if len(compact) < 20: + return False + return bool(re.search(r"[A-Za-z]", compact)) and bool(re.search(r"\d", compact)) + + +def _count_by(entities: list[EntitySchema], field: str) -> dict[str, int]: + counts = Counter(str(getattr(entity, field)) for entity in entities if getattr(entity, field)) + return dict(sorted(counts.items())) + + +def write_rows(rows: list[DetectionArtifactRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def render_result(result: DetectionArtifactAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + total_warnings = sum(row.weak_api_key_shape_count for row in result.rows) + workflows = Counter(row.workflow_name for row in result.rows) + lines = [f"Analyzed {len(result.rows)} detection artifact row(s) from {result.artifact_path}"] + for workflow_name, count in sorted(workflows.items()): + lines.append(f"- {workflow_name}: {count} row(s)") + lines.append(f"Weak api_key shape warnings: {total_warnings}") + return "\n".join(lines) + + +@app.default +def main( + artifact_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.jsonl, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_artifacts(artifact_path) + if output is not None: + write_rows(result.rows, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py new file mode 100644 index 00000000..225cc045 --- /dev/null +++ b/tools/measurement/analyze_staged_detection_output.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyze benchmark-only DD-free staged detection probe outputs. + +Usage: + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe/staged-detection-cases.jsonl + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --output analysis --format csv + uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --json +""" + +from __future__ import annotations + +import json +import logging +import sys +from collections import Counter, defaultdict +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, computed_field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.staged_detection_output") + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain +_FAST_LANE_MIN_CASES = 3 + + +class StagedCaseAnalysisRow(BaseModel): + source_path: str + case_id: str + row_index: int | None = None + seed_source: str | None = None + status: str | None = None + case_failed: bool = False + elapsed_sec: float | None = None + model_elapsed_sec: float | None = None + model_phase_count: int = 0 + model_request_count: int = 0 + rule_covered_label_set: bool = False + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + seed_entity_count: int = 0 + validation_candidate_count: int = 0 + validation_decision_count: int = 0 + augmented_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_final_entity_signature_count: int | None = None + shared_final_entity_signature_count: int | None = None + baseline_only_final_entity_signature_count: int | None = None + direct_only_final_entity_signature_count: int | None = None + baseline_shared_signature_rate: float | None = None + baseline_loss_signature_rate: float | None = None + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + direct_only_label_counts: dict[str, int] = Field(default_factory=dict) + error: str | None = None + + +class StagedGroupAnalysisRow(BaseModel): + seed_source: str | None = None + case_count: int + completed_case_count: int = 0 + error_case_count: int = 0 + failed_case_rate: float | None = None + rule_covered_case_count: int = 0 + elapsed_sec_sum: float | None = None + elapsed_sec_mean: float | None = None + model_elapsed_sec_sum: float | None = None + model_elapsed_sec_mean: float | None = None + model_phase_count_sum: int = 0 + model_request_count_sum: int = 0 + prompt_tokens_sum: int = 0 + completion_tokens_sum: int = 0 + total_tokens_sum: int = 0 + final_entity_count_sum: int = 0 + final_entity_signature_count_sum: int = 0 + baseline_final_entity_signature_count_sum: int = 0 + shared_final_entity_signature_count_sum: int = 0 + baseline_only_final_entity_signature_count_sum: int = 0 + direct_only_final_entity_signature_count_sum: int = 0 + baseline_shared_signature_rate: float | None = None + baseline_loss_signature_rate: float | None = None + fast_lane_verdict: str = "review" + flags: list[str] = Field(default_factory=list) + + +class LabelDeltaAnalysisRow(BaseModel): + seed_source: str | None = None + delta_type: str + label: str + count: int + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] + manifest_path: str + + +class StagedDetectionOutputAnalysis(BaseModel): + source_path: str + cases: list[StagedCaseAnalysisRow] = Field(default_factory=list) + groups: list[StagedGroupAnalysisRow] = Field(default_factory=list) + label_deltas: list[LabelDeltaAnalysisRow] = Field(default_factory=list) + + @computed_field + @property + def case_count(self) -> int: + return len(self.cases) + + @computed_field + @property + def group_count(self) -> int: + return len(self.groups) + + @computed_field + @property + def label_delta_count(self) -> int: + return len(self.label_deltas) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def analyze_staged_detection_output(input_path: Path) -> StagedDetectionOutputAnalysis: + case_path = resolve_case_path(input_path) + case_rows = [_build_case_row(row, source_path=case_path) for row in read_case_records(case_path)] + return StagedDetectionOutputAnalysis( + source_path=str(case_path), + cases=case_rows, + groups=build_group_rows(case_rows), + label_deltas=build_label_delta_rows(case_rows), + ) + + +def resolve_case_path(input_path: Path) -> Path: + if not input_path.exists(): + raise ValueError(f"input path does not exist: {input_path}") + if input_path.is_dir(): + input_path = input_path / "staged-detection-cases.jsonl" + if not input_path.exists(): + raise ValueError(f"staged detection case file does not exist: {input_path}") + if input_path.is_dir(): + raise ValueError(f"input path is a directory: {input_path}") + return input_path + + +def read_case_records(case_path: Path) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + for line in case_path.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + record = json.loads(line) + if not isinstance(record, dict): + raise ValueError(f"JSONL row is not an object in {case_path}") + if record.get("record_type") in (None, "staged_detection_case"): + records.append(record) + return records + + +def _build_case_row(record: dict[str, Any], *, source_path: Path) -> StagedCaseAnalysisRow: + comparison = _dict_value(record.get("comparison")) + baseline_count = _optional_int(comparison.get("baseline_final_entity_signature_count")) + shared_count = _optional_int(comparison.get("shared_final_entity_signature_count")) + baseline_only_count = _optional_int(comparison.get("baseline_only_final_entity_signature_count")) + return StagedCaseAnalysisRow( + source_path=str(source_path), + case_id=str(record.get("case_id") or ""), + row_index=_optional_int(record.get("row_index")), + seed_source=_optional_str(record.get("seed_source")), + status=_optional_str(record.get("status")), + case_failed=str(record.get("status")).lower() == "error", + elapsed_sec=_optional_float(record.get("elapsed_sec")), + model_elapsed_sec=_optional_float(record.get("model_elapsed_sec")), + model_phase_count=_int_value(record.get("model_phase_count")), + model_request_count=_int_value(record.get("model_request_count")), + rule_covered_label_set=bool(record.get("rule_covered_label_set")), + **_usage_fields(_dict_value(record.get("total_usage"))), + **_entity_count_fields(record), + baseline_final_entity_signature_count=baseline_count, + shared_final_entity_signature_count=shared_count, + baseline_only_final_entity_signature_count=baseline_only_count, + direct_only_final_entity_signature_count=_optional_int( + comparison.get("direct_only_final_entity_signature_count") + ), + baseline_shared_signature_rate=_rate(shared_count, baseline_count), + baseline_loss_signature_rate=_rate(baseline_only_count, baseline_count), + baseline_only_label_counts=_counter_dict(comparison.get("baseline_only_label_counts")), + direct_only_label_counts=_counter_dict(comparison.get("direct_only_label_counts")), + error=_optional_str(record.get("error")), + ) + + +def _usage_fields(usage: dict[str, Any]) -> dict[str, int]: + return { + "prompt_tokens": _int_value(usage.get("prompt_tokens")), + "completion_tokens": _int_value(usage.get("completion_tokens")), + "total_tokens": _int_value(usage.get("total_tokens")), + } + + +def _entity_count_fields(record: dict[str, Any]) -> dict[str, int | dict[str, int]]: + return { + "seed_entity_count": _int_value(record.get("seed_entity_count")), + "validation_candidate_count": _int_value(record.get("validation_candidate_count")), + "validation_decision_count": _int_value(record.get("validation_decision_count")), + "augmented_suggestion_count": _int_value(record.get("augmented_suggestion_count")), + "final_entity_count": _int_value(record.get("final_entity_count")), + "final_entity_signature_count": _int_value(record.get("final_entity_signature_count")), + "final_label_counts": _counter_dict(record.get("final_label_counts")), + } + + +def build_group_rows(cases: list[StagedCaseAnalysisRow]) -> list[StagedGroupAnalysisRow]: + groups: defaultdict[str | None, list[StagedCaseAnalysisRow]] = defaultdict(list) + for case in cases: + groups[case.seed_source].append(case) + return [_build_group_row(seed_source, rows) for seed_source, rows in sorted(groups.items(), key=_group_sort_key)] + + +def _build_group_row(seed_source: str | None, rows: list[StagedCaseAnalysisRow]) -> StagedGroupAnalysisRow: + case_count = len(rows) + error_count = sum(1 for row in rows if row.case_failed) + baseline_total = _sum_optional_int(rows, "baseline_final_entity_signature_count") + shared_total = _sum_optional_int(rows, "shared_final_entity_signature_count") + baseline_only_total = _sum_optional_int(rows, "baseline_only_final_entity_signature_count") + model_request_count = sum(row.model_request_count for row in rows) + rule_covered_count = sum(1 for row in rows if row.rule_covered_label_set) + flags = _fast_lane_flags( + case_count=case_count, + error_count=error_count, + baseline_total=baseline_total, + baseline_only_total=baseline_only_total, + model_request_count=model_request_count, + rule_covered_count=rule_covered_count, + ) + return StagedGroupAnalysisRow( + seed_source=seed_source, + case_count=case_count, + completed_case_count=case_count - error_count, + error_case_count=error_count, + failed_case_rate=_rate(error_count, case_count), + rule_covered_case_count=rule_covered_count, + elapsed_sec_sum=_sum_optional_float(rows, "elapsed_sec"), + elapsed_sec_mean=_mean_optional_float(rows, "elapsed_sec"), + model_elapsed_sec_sum=_sum_optional_float(rows, "model_elapsed_sec"), + model_elapsed_sec_mean=_mean_optional_float(rows, "model_elapsed_sec"), + model_phase_count_sum=sum(row.model_phase_count for row in rows), + model_request_count_sum=model_request_count, + prompt_tokens_sum=sum(row.prompt_tokens for row in rows), + completion_tokens_sum=sum(row.completion_tokens for row in rows), + total_tokens_sum=sum(row.total_tokens for row in rows), + final_entity_count_sum=sum(row.final_entity_count for row in rows), + final_entity_signature_count_sum=sum(row.final_entity_signature_count for row in rows), + baseline_final_entity_signature_count_sum=baseline_total, + shared_final_entity_signature_count_sum=shared_total, + baseline_only_final_entity_signature_count_sum=baseline_only_total, + direct_only_final_entity_signature_count_sum=_sum_optional_int( + rows, "direct_only_final_entity_signature_count" + ), + baseline_shared_signature_rate=_rate(shared_total, baseline_total), + baseline_loss_signature_rate=_rate(baseline_only_total, baseline_total), + fast_lane_verdict=_fast_lane_verdict(flags), + flags=flags, + ) + + +def _group_sort_key(item: tuple[str | None, list[StagedCaseAnalysisRow]]) -> str: + return item[0] or "" + + +def _fast_lane_flags( + *, + case_count: int, + error_count: int, + baseline_total: int, + baseline_only_total: int, + model_request_count: int, + rule_covered_count: int, +) -> list[str]: + flags: list[str] = [] + if case_count < _FAST_LANE_MIN_CASES: + flags.append("too_few_cases") + if error_count: + flags.append("case_errors") + if baseline_total == 0: + flags.append("missing_baseline_comparison") + if baseline_only_total: + flags.append("baseline_signature_loss") + if model_request_count: + flags.append("uses_model") + if rule_covered_count != case_count: + flags.append("not_fully_rule_covered") + return flags + + +def _fast_lane_verdict(flags: list[str]) -> str: + if "case_errors" in flags or "baseline_signature_loss" in flags: + return "reject" + if not flags: + return "fast_lane_candidate" + return "review" + + +def build_label_delta_rows(cases: list[StagedCaseAnalysisRow]) -> list[LabelDeltaAnalysisRow]: + counts: Counter[tuple[str | None, str, str]] = Counter() + for case in cases: + for label, count in case.baseline_only_label_counts.items(): + counts[(case.seed_source, "baseline_only", label)] += count + for label, count in case.direct_only_label_counts.items(): + counts[(case.seed_source, "direct_only", label)] += count + return [ + LabelDeltaAnalysisRow(seed_source=seed_source, delta_type=delta_type, label=label, count=count) + for (seed_source, delta_type, label), count in sorted(counts.items(), key=_label_delta_sort_key) + ] + + +def _label_delta_sort_key(item: tuple[tuple[str | None, str, str], int]) -> tuple[str, str, str]: + (seed_source, delta_type, label), _ = item + return (seed_source or "", delta_type, label) + + +def _sum_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: + values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] + return sum(values) if values else None + + +def _mean_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: + values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] + return sum(values) / len(values) if values else None + + +def _sum_optional_int(rows: list[StagedCaseAnalysisRow], field_name: str) -> int: + return sum(value for value in (_optional_int(getattr(row, field_name)) for row in rows) if value is not None) + + +def _rate(numerator: object, denominator: object) -> float | None: + numerator_value = _optional_float(numerator) + denominator_value = _optional_float(denominator) + if numerator_value is None or denominator_value is None or denominator_value <= 0: + return None + return numerator_value / denominator_value + + +def _dict_value(raw: object) -> dict[str, Any]: + return raw if isinstance(raw, dict) else {} + + +def _counter_dict(raw: object) -> dict[str, int]: + return {str(key): _int_value(value) for key, value in _dict_value(raw).items()} + + +def _optional_str(raw: object) -> str | None: + return str(raw) if raw is not None else None + + +def _optional_float(raw: object) -> float | None: + if raw is None: + return None + return float(raw) + + +def _optional_int(raw: object) -> int | None: + if raw is None: + return None + return int(float(raw)) + + +def _int_value(raw: object) -> int: + return _optional_int(raw) or 0 + + +def write_analysis_tables( + result: StagedDetectionOutputAnalysis, + output_dir: Path, + export_format: ExportFormat, +) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + _write_model_rows( + result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, StagedCaseAnalysisRow + ), + _write_model_rows( + result.groups, + output_dir / f"group_analysis.{export_format.value}", + export_format, + StagedGroupAnalysisRow, + ), + _write_model_rows( + result.label_deltas, + output_dir / f"label_delta_analysis.{export_format.value}", + export_format, + LabelDeltaAnalysisRow, + ), + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def _write_model_rows( + rows: list[BaseModel], + path: Path, + export_format: ExportFormat, + row_model: type[BaseModel], +) -> TableSummary: + table = _rows_to_table(rows, row_model) + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def _rows_to_table(rows: list[BaseModel], row_model: type[BaseModel]) -> pd.DataFrame: + if rows: + return pd.json_normalize([row.model_dump() for row in rows], sep=".") + return pd.DataFrame(columns=list(row_model.model_fields)) + + +def render_result(result: StagedDetectionOutputAnalysis, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [f"Analyzed {result.case_count} staged detection case(s) across {result.group_count} group(s)"] + lines.extend(_render_group_line(group, result.label_deltas) for group in result.groups) + return "\n".join(lines) + + +def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDeltaAnalysisRow]) -> str: + label = group.seed_source or "" + lost = _top_labels(label_deltas, seed_source=group.seed_source, delta_type="baseline_only") + return ( + f"- {label}: cases={group.case_count}, errors={group.error_case_count}, " + f"verdict={group.fast_lane_verdict}, flags={_label_count_summary(group.flags)}, " + f"elapsed_sum={_fmt_float(group.elapsed_sec_sum)}s, " + f"model_elapsed_sum={_fmt_float(group.model_elapsed_sec_sum)}s, " + f"requests={group.model_request_count_sum}, tokens={group.total_tokens_sum}, " + f"shared={group.shared_final_entity_signature_count_sum}/" + f"{group.baseline_final_entity_signature_count_sum}, " + f"baseline_only={group.baseline_only_final_entity_signature_count_sum}, " + f"direct_only={group.direct_only_final_entity_signature_count_sum}, lost_labels={lost}" + ) + + +def _label_count_summary(items: list[str]) -> str: + return "[]" if not items else "[" + ", ".join(items) + "]" + + +def _top_labels(label_deltas: list[LabelDeltaAnalysisRow], *, seed_source: str | None, delta_type: str) -> str: + matches = [delta for delta in label_deltas if delta.seed_source == seed_source and delta.delta_type == delta_type] + if not matches: + return "{}" + items = sorted(matches, key=lambda delta: (-delta.count, delta.label))[:8] + return "{" + ", ".join(f"{item.label}:{item.count}" for item in items) + "}" + + +def _fmt_float(value: float | None) -> str: + return "n/a" if value is None else f"{value:.3f}" + + +@app.default +def main( + input_path: Path, + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = analyze_staged_detection_output(input_path) + if output is not None: + write_analysis_tables(result, output, format) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py new file mode 100644 index 00000000..42faf61f --- /dev/null +++ b/tools/measurement/compare_strategy_pairs.py @@ -0,0 +1,1459 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Compare benchmark case-analysis rows for baseline/candidate strategy pairs. + +Usage: + uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.csv \ + --baseline-strategy no_augment --candidate-strategy rules_filter_guardrail_no_augment + uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.parquet \ + --baseline-config default --candidate-config no-augment --output comparisons.csv + uv run python tools/measurement/compare_strategy_pairs.py baseline/case_analysis.csv \ + --candidate-case-analysis candidate/case_analysis.csv \ + --baseline-strategy no_augment --candidate-strategy rules_guardrail_no_augment +""" + +from __future__ import annotations + +import ast +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.strategy_pairs") + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class SafetyVerdict(StrEnum): + passed = "pass" + review = "review" + fail = "fail" + + +class PerformanceVerdict(StrEnum): + improved = "improved" + mixed = "mixed" + regressed = "regressed" + unchanged = "unchanged" + unknown = "unknown" + + +class CandidateVerdict(StrEnum): + candidate_viable = "candidate_viable" + review = "review" + reject = "reject" + + +_log_format = LogFormat.plain +_MIN_CANDIDATE_OVERLAP_RATIO = 0.8 +_MAX_CANDIDATE_BOUNDARY_GAP_CHARS = 8 +_MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS = 8 +_BOUNDARY_NORMALIZED_BASELINE_LABELS = {"api_key", "http_cookie", "unique_id"} + + +class ComparisonSummary(BaseModel): + comparison_count: int = 0 + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + safety_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + candidate_viable_workloads: list[str] = Field(default_factory=list) + review_workloads: list[str] = Field(default_factory=list) + rejected_workloads: list[str] = Field(default_factory=list) + + +class ComparisonRow(BaseModel): + workload_id: str + baseline_config_id: str + candidate_config_id: str + baseline_strategy: str | None = None + candidate_strategy: str | None = None + baseline_replacement_strategy: str | None = None + candidate_replacement_strategy: str | None = None + baseline_case_count: int + candidate_case_count: int + baseline_failed_case_count: float | None = None + candidate_failed_case_count: float | None = None + failed_case_count_delta: float | None = None + baseline_failed_case_rate: float | None = None + candidate_failed_case_rate: float | None = None + failed_case_rate_delta: float | None = None + baseline_pipeline_elapsed_sec: float | None = None + candidate_pipeline_elapsed_sec: float | None = None + pipeline_elapsed_sec_delta: float | None = None + pipeline_elapsed_sec_delta_pct: float | None = None + baseline_observed_total_requests: float | None = None + candidate_observed_total_requests: float | None = None + observed_total_requests_delta: float | None = None + baseline_observed_total_tokens: float | None = None + candidate_observed_total_tokens: float | None = None + observed_total_tokens_delta: float | None = None + baseline_observed_failed_requests: float | None = None + candidate_observed_failed_requests: float | None = None + observed_failed_requests_delta: float | None = None + baseline_observed_bridge_fallback_requests: float | None = None + candidate_observed_bridge_fallback_requests: float | None = None + observed_bridge_fallback_requests_delta: float | None = None + baseline_observed_non_bridge_total_requests: float | None = None + candidate_observed_non_bridge_total_requests: float | None = None + observed_non_bridge_total_requests_delta: float | None = None + baseline_observed_non_bridge_failed_requests: float | None = None + candidate_observed_non_bridge_failed_requests: float | None = None + observed_non_bridge_failed_requests_delta: float | None = None + baseline_original_value_leak_count: float | None = None + candidate_original_value_leak_count: float | None = None + original_value_leak_count_delta: float | None = None + baseline_original_value_leak_record_count: float | None = None + candidate_original_value_leak_record_count: float | None = None + original_value_leak_record_count_delta: float | None = None + baseline_replacement_missing_final_entity_count: float | None = None + candidate_replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_count_delta: float | None = None + baseline_replacement_missing_final_value_count: float | None = None + candidate_replacement_missing_final_value_count: float | None = None + replacement_missing_final_value_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_count: float | None = None + candidate_replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_value_count: float | None = None + candidate_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_value_count_delta: float | None = None + value_protection_verdict: SafetyVerdict | None = None + signature_parity_verdict: SafetyVerdict | None = None + safety_verdict: SafetyVerdict | None = None + performance_verdict: PerformanceVerdict | None = None + candidate_verdict: CandidateVerdict | None = None + baseline_final_entity_count: float | None = None + candidate_final_entity_count: float | None = None + final_entity_count_delta: float | None = None + baseline_seed_validation_candidate_count: float | None = None + candidate_seed_validation_candidate_count: float | None = None + seed_validation_candidate_count_delta: float | None = None + baseline_augmented_entity_count: float | None = None + candidate_augmented_entity_count: float | None = None + augmented_entity_count_delta: float | None = None + baseline_augmented_new_final_value_count: float | None = None + candidate_augmented_new_final_value_count: float | None = None + augmented_new_final_value_count_delta: float | None = None + baseline_detector_entity_count: float | None = None + candidate_detector_entity_count: float | None = None + baseline_rule_entity_count: float | None = None + candidate_rule_entity_count: float | None = None + baseline_augmenter_entity_count: float | None = None + candidate_augmenter_entity_count: float | None = None + baseline_only_final_entity_signature_count: int | None = None + candidate_only_final_entity_signature_count: int | None = None + shared_final_entity_signature_count: int | None = None + baseline_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + shared_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_covered_signature_count: int | None = None + baseline_only_candidate_overlapping_signature_count: int | None = None + baseline_only_candidate_uncovered_signature_count: int | None = None + baseline_only_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_candidate_label_mismatch_signature_count: int | None = None + baseline_only_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_final_entity_signature_count: int | None = None + candidate_stable_final_entity_signature_count: int | None = None + stable_final_entity_signature_count_delta: int | None = None + baseline_stable_candidate_unstable_final_entity_signature_count: int | None = None + candidate_stable_baseline_unstable_final_entity_signature_count: int | None = None + shared_stable_final_entity_signature_count: int | None = None + baseline_stable_candidate_covered_signature_count: int | None = None + baseline_stable_candidate_overlapping_signature_count: int | None = None + baseline_stable_candidate_uncovered_signature_count: int | None = None + baseline_stable_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_label_mismatch_signature_count: int | None = None + baseline_stable_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_stable_candidate_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_stable_baseline_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + shared_stable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + flags: list[str] = Field(default_factory=list) + + +class ComparisonResult(BaseModel): + input_path: str + candidate_input_path: str | None = None + baseline_selector: str + candidate_selector: str + summary: ComparisonSummary = Field(default_factory=ComparisonSummary) + comparisons: list[ComparisonRow] = Field(default_factory=list) + + @property + def comparison_count(self) -> int: + return len(self.comparisons) + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def read_case_analysis(path: Path) -> pd.DataFrame: + if not path.exists() or path.is_dir(): + raise ValueError(f"case analysis path is not a file: {path}") + if path.suffix == ".parquet": + table = pd.read_parquet(path) + elif path.suffix == ".csv": + table = pd.read_csv(path) + elif path.suffix == ".jsonl": + table = pd.read_json(path, lines=True) + else: + raise ValueError(f"unsupported case analysis format: {path.suffix}") + _validate_case_analysis_columns(table) + return table + + +def _validate_case_analysis_columns(table: pd.DataFrame) -> None: + required = {"workload_id", "config_id", "case_id"} + missing = sorted(required - set(table.columns)) + if missing: + raise ValueError(f"case analysis is missing required column(s): {', '.join(missing)}") + + +def compare_case_analysis( + table: pd.DataFrame, + *, + baseline_config: str | None = None, + candidate_config: str | None = None, + baseline_strategy: str | None = None, + candidate_strategy: str | None = None, +) -> list[ComparisonRow]: + return compare_case_tables( + table, + table, + baseline_config=baseline_config, + candidate_config=candidate_config, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + ) + + +def compare_case_tables( + baseline_table: pd.DataFrame, + candidate_table: pd.DataFrame, + *, + baseline_config: str | None = None, + candidate_config: str | None = None, + baseline_strategy: str | None = None, + candidate_strategy: str | None = None, +) -> list[ComparisonRow]: + baseline = _select_rows( + baseline_table, config_id=baseline_config, strategy=baseline_strategy, selector_name="baseline" + ) + candidate = _select_rows( + candidate_table, config_id=candidate_config, strategy=candidate_strategy, selector_name="candidate" + ) + return [ + _compare_workload(workload_id, baseline, candidate) for workload_id in _common_workloads(baseline, candidate) + ] + + +def summarize_comparisons(rows: list[ComparisonRow]) -> ComparisonSummary: + return ComparisonSummary( + comparison_count=len(rows), + value_protection_verdict_counts=_verdict_counts(rows, "value_protection_verdict", list(SafetyVerdict)), + signature_parity_verdict_counts=_verdict_counts(rows, "signature_parity_verdict", list(SafetyVerdict)), + safety_verdict_counts=_verdict_counts(rows, "safety_verdict", list(SafetyVerdict)), + performance_verdict_counts=_verdict_counts(rows, "performance_verdict", list(PerformanceVerdict)), + candidate_verdict_counts=_verdict_counts(rows, "candidate_verdict", list(CandidateVerdict)), + candidate_viable_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.candidate_viable), + review_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.review), + rejected_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.reject), + ) + + +def _verdict_counts(rows: list[ComparisonRow], field: str, values: list[StrEnum]) -> dict[str, int]: + counts = {value.value: 0 for value in values} + for row in rows: + verdict = _verdict_value(getattr(row, field)) + if verdict is not None: + counts[verdict] = counts.get(verdict, 0) + 1 + return counts + + +def _workloads_by_candidate_verdict(rows: list[ComparisonRow], verdict: CandidateVerdict) -> list[str]: + return sorted(row.workload_id for row in rows if _verdict_value(row.candidate_verdict) == verdict.value) + + +def _verdict_value(value: object) -> str | None: + if value is None: + return None + if isinstance(value, StrEnum): + return value.value + return str(value) + + +def _select_rows( + table: pd.DataFrame, + *, + config_id: str | None, + strategy: str | None, + selector_name: str, +) -> pd.DataFrame: + if (config_id is None) == (strategy is None): + raise ValueError(f"{selector_name} selector must specify exactly one of config or strategy") + column, value = ("config_id", config_id) if config_id is not None else ("experimental_detection_strategy", strategy) + if column not in table.columns: + raise ValueError(f"case analysis is missing selector column: {column}") + selected = table[table[column].astype(str) == str(value)] + if selected.empty: + raise ValueError(f"{selector_name} selector matched no rows: {column}={value}") + _validate_unique_config_per_workload(selected, selector_name=selector_name) + return selected + + +def _validate_unique_config_per_workload(rows: pd.DataFrame, *, selector_name: str) -> None: + counts = rows.groupby("workload_id")["config_id"].nunique() + ambiguous = sorted(str(index) for index, count in counts.items() if count > 1) + if ambiguous: + raise ValueError(f"{selector_name} selector matched multiple configs for workload(s): {', '.join(ambiguous)}") + + +def _common_workloads(baseline: pd.DataFrame, candidate: pd.DataFrame) -> list[str]: + baseline_workloads = set(str(value) for value in baseline["workload_id"].dropna()) + candidate_workloads = set(str(value) for value in candidate["workload_id"].dropna()) + common = sorted(baseline_workloads & candidate_workloads) + if not common: + raise ValueError("baseline and candidate selectors have no workloads in common") + return common + + +def _compare_workload(workload_id: str, baseline: pd.DataFrame, candidate: pd.DataFrame) -> ComparisonRow: + baseline_summary = _summarize_selector(baseline[baseline["workload_id"].astype(str) == workload_id]) + candidate_summary = _summarize_selector(candidate[candidate["workload_id"].astype(str) == workload_id]) + return ComparisonRow( + workload_id=workload_id, + baseline_config_id=str(baseline_summary["config_id"]), + candidate_config_id=str(candidate_summary["config_id"]), + baseline_strategy=_optional_string(baseline_summary.get("experimental_detection_strategy")), + candidate_strategy=_optional_string(candidate_summary.get("experimental_detection_strategy")), + baseline_replacement_strategy=_optional_string(baseline_summary.get("experimental_replacement_strategy")), + candidate_replacement_strategy=_optional_string(candidate_summary.get("experimental_replacement_strategy")), + baseline_case_count=int(baseline_summary["case_count"]), + candidate_case_count=int(candidate_summary["case_count"]), + **_comparison_metrics(baseline_summary, candidate_summary), + ) + + +def _summarize_selector(rows: pd.DataFrame) -> dict[str, object]: + if rows.empty: + raise ValueError("selector has no rows for workload") + case_count = int(rows["case_id"].nunique()) + summary: dict[str, object] = { + "config_id": _single_string(rows, "config_id"), + "experimental_detection_strategy": _single_string(rows, "experimental_detection_strategy"), + "experimental_replacement_strategy": _single_string(rows, "experimental_replacement_strategy"), + "case_count": case_count, + } + for column in _NUMERIC_COLUMNS: + summary[column] = _median_or_none(rows, column) + summary["replacement_missing_final_entity_count"] = _sum_or_none(rows, "replacement_missing_final_entity_count") + summary["replacement_missing_final_value_count"] = _sum_or_none(rows, "replacement_missing_final_value_count") + summary["replacement_synthetic_original_collision_count"] = _sum_or_none( + rows, + "replacement_synthetic_original_collision_count", + ) + summary["replacement_synthetic_original_collision_value_count"] = _sum_or_none( + rows, + "replacement_synthetic_original_collision_value_count", + ) + summary["original_value_leak_count"] = _sum_or_none(rows, "original_value_leak_count") + summary["original_value_leak_record_count"] = _sum_or_none(rows, "original_value_leak_record_count") + summary["original_value_leak_label_counts"] = _sum_label_counts(rows, "original_value_leak_label_counts") + summary["replacement_synthetic_original_collision_label_counts"] = _sum_label_counts( + rows, + "replacement_synthetic_original_collision_label_counts", + ) + summary["failed_case_count"] = _failed_case_count(rows) + summary["failed_case_rate"] = _rate(summary["failed_case_count"], case_count) + summary["artifact_final_entity_signature_hashes"] = _signature_hashes(rows) + summary["stable_artifact_final_entity_signature_hashes"] = _stable_signature_hashes(rows) if case_count > 1 else [] + summary["artifact_final_entity_signature_labels"] = _signature_labels(rows) + summary["artifact_final_entity_signature_details"] = _signature_details(rows) + return summary + + +def _single_string(rows: pd.DataFrame, column: str) -> str | None: + if column not in rows.columns: + return None + values = sorted({str(value) for value in rows[column].dropna()}) + return values[0] if values else None + + +_NUMERIC_COLUMNS = [ + "pipeline_elapsed_sec", + "observed_total_requests", + "observed_total_tokens", + "observed_failed_requests", + "observed_bridge_fallback_requests", + "observed_non_bridge_total_requests", + "observed_non_bridge_failed_requests", + "replacement_missing_final_entity_count", + "replacement_missing_final_value_count", + "replacement_synthetic_original_collision_count", + "replacement_synthetic_original_collision_value_count", + "final_entity_count", + "seed_validation_candidate_count", + "augmented_entity_count", + "augmented_new_final_value_count", + "artifact_final_detector_entity_count", + "artifact_final_rule_entity_count", + "artifact_final_augmenter_entity_count", +] + + +def _median_or_none(rows: pd.DataFrame, column: str) -> float | None: + if column not in rows.columns: + return None + values = pd.to_numeric(rows[column], errors="coerce").dropna() + return float(values.median()) if not values.empty else None + + +def _sum_or_none(rows: pd.DataFrame, column: str) -> float | None: + if column not in rows.columns: + return None + values = pd.to_numeric(rows[column], errors="coerce").dropna() + return float(values.sum()) if not values.empty else None + + +def _comparison_metrics(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + metrics = _named_metric_deltas(baseline, candidate) + metrics.update(_source_counts(baseline, candidate)) + metrics.update(_entity_signature_deltas(baseline, candidate)) + metrics["flags"] = _comparison_flags( + metrics, + baseline_strategy=_optional_string(baseline.get("experimental_detection_strategy")), + candidate_strategy=_optional_string(candidate.get("experimental_detection_strategy")), + baseline_replacement_strategy=_optional_string(baseline.get("experimental_replacement_strategy")), + candidate_replacement_strategy=_optional_string(candidate.get("experimental_replacement_strategy")), + ) + metrics.update(_comparison_verdicts(metrics)) + return metrics + + +def _named_metric_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + return { + **_metric_delta("failed_case_count", baseline, candidate), + **_metric_delta("failed_case_rate", baseline, candidate), + **_metric_delta("pipeline_elapsed_sec", baseline, candidate, include_pct=True), + **_metric_delta("observed_total_requests", baseline, candidate), + **_metric_delta("observed_total_tokens", baseline, candidate), + **_metric_delta("observed_failed_requests", baseline, candidate), + **_metric_delta("observed_bridge_fallback_requests", baseline, candidate), + **_metric_delta("observed_non_bridge_total_requests", baseline, candidate), + **_metric_delta("observed_non_bridge_failed_requests", baseline, candidate), + **_metric_delta("original_value_leak_count", baseline, candidate), + **_metric_delta("original_value_leak_record_count", baseline, candidate), + **_metric_delta("replacement_missing_final_entity_count", baseline, candidate), + **_metric_delta("replacement_missing_final_value_count", baseline, candidate), + **_metric_delta("replacement_synthetic_original_collision_count", baseline, candidate), + **_metric_delta("replacement_synthetic_original_collision_value_count", baseline, candidate), + **_metric_delta("final_entity_count", baseline, candidate), + **_metric_delta("seed_validation_candidate_count", baseline, candidate), + **_metric_delta("augmented_entity_count", baseline, candidate), + **_metric_delta("augmented_new_final_value_count", baseline, candidate), + } + + +def _metric_delta( + name: str, + baseline: dict[str, object], + candidate: dict[str, object], + *, + include_pct: bool = False, +) -> dict[str, float | None]: + base = _optional_float(baseline.get(name)) + cand = _optional_float(candidate.get(name)) + values = {f"baseline_{name}": base, f"candidate_{name}": cand, f"{name}_delta": _delta(base, cand)} + if include_pct: + values[f"{name}_delta_pct"] = _delta_pct(base, cand) + return values + + +def _source_counts(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + return { + "baseline_detector_entity_count": _optional_float(baseline.get("artifact_final_detector_entity_count")), + "candidate_detector_entity_count": _optional_float(candidate.get("artifact_final_detector_entity_count")), + "baseline_rule_entity_count": _optional_float(baseline.get("artifact_final_rule_entity_count")), + "candidate_rule_entity_count": _optional_float(candidate.get("artifact_final_rule_entity_count")), + "baseline_augmenter_entity_count": _optional_float(baseline.get("artifact_final_augmenter_entity_count")), + "candidate_augmenter_entity_count": _optional_float(candidate.get("artifact_final_augmenter_entity_count")), + "baseline_original_value_leak_label_counts": _coerce_count_map( + baseline.get("original_value_leak_label_counts") + ), + "candidate_original_value_leak_label_counts": _coerce_count_map( + candidate.get("original_value_leak_label_counts") + ), + "baseline_replacement_synthetic_original_collision_label_counts": _coerce_count_map( + baseline.get("replacement_synthetic_original_collision_label_counts") + ), + "candidate_replacement_synthetic_original_collision_label_counts": _coerce_count_map( + candidate.get("replacement_synthetic_original_collision_label_counts") + ), + } + + +def _comparison_flags( + metrics: dict[str, object], + *, + baseline_strategy: str | None, + candidate_strategy: str | None, + baseline_replacement_strategy: str | None, + candidate_replacement_strategy: str | None, +) -> list[str]: + flags: list[str] = [] + _append_if_positive(flags, metrics, "baseline_failed_case_count", "baseline_case_failures") + _append_if_positive(flags, metrics, "candidate_failed_case_count", "candidate_case_failures") + _append_if_positive(flags, metrics, "baseline_original_value_leak_count", "baseline_original_value_leak") + _append_if_positive(flags, metrics, "candidate_original_value_leak_count", "candidate_original_value_leak") + _append_if_positive( + flags, + metrics, + "baseline_replacement_synthetic_original_collision_count", + "baseline_replacement_synthetic_original_collision", + ) + _append_if_positive( + flags, + metrics, + "candidate_replacement_synthetic_original_collision_count", + "candidate_replacement_synthetic_original_collision", + ) + _append_if_positive( + flags, + metrics, + "baseline_replacement_missing_final_entity_count", + "baseline_replacement_missing_final_entity", + ) + _append_if_positive( + flags, + metrics, + "candidate_replacement_missing_final_entity_count", + "candidate_replacement_missing_final_entity", + ) + _append_if_negative(flags, metrics, "final_entity_count_delta", "entity_count_loss") + _append_if_positive(flags, metrics, _signature_loss_metric(metrics), "entity_signature_loss") + _append_if_positive( + flags, + metrics, + "baseline_only_candidate_overlapping_signature_count", + "span_boundary_mismatch", + ) + _append_if_positive( + flags, + metrics, + "baseline_only_candidate_label_mismatch_signature_count", + "covered_label_mismatch", + ) + _append_if_positive( + flags, + metrics, + _stable_signature_loss_metric(metrics), + "stable_entity_signature_loss", + ) + failed_request_delta = ( + "observed_non_bridge_failed_requests_delta" + if _has_metric_pair(metrics, "observed_non_bridge_failed_requests") + else "observed_failed_requests_delta" + ) + _append_if_positive(flags, metrics, failed_request_delta, "failed_request_increase") + _append_if_positive(flags, metrics, "observed_bridge_fallback_requests_delta", "bridge_fallback_increase") + _append_if_positive(flags, metrics, "observed_total_tokens_delta", "token_increase") + _append_if_positive(flags, metrics, "observed_total_requests_delta", "request_increase") + if _candidate_lacks_detector_entities(metrics): + flags.append("no_candidate_detector_entities") + if _optional_float(metrics.get("candidate_rule_entity_count")): + flags.append("candidate_uses_rule_entities") + if candidate_strategy in _SKIPS_LLM_VALIDATION_STRATEGIES: + flags.append("candidate_skips_llm_validation") + if _replacement_only_detection_instability( + flags, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + baseline_replacement_strategy=baseline_replacement_strategy, + candidate_replacement_strategy=candidate_replacement_strategy, + ): + flags.append("replacement_only_detection_instability") + return flags + + +_DETECTION_INSTABILITY_FLAGS = { + "covered_label_mismatch", + "entity_count_loss", + "entity_signature_loss", + "span_boundary_mismatch", + "stable_entity_signature_loss", +} + + +def _replacement_only_detection_instability( + flags: list[str], + *, + baseline_strategy: str | None, + candidate_strategy: str | None, + baseline_replacement_strategy: str | None, + candidate_replacement_strategy: str | None, +) -> bool: + if baseline_strategy != candidate_strategy: + return False + if not baseline_replacement_strategy or not candidate_replacement_strategy: + return False + if baseline_replacement_strategy == candidate_replacement_strategy: + return False + return bool(set(flags) & _DETECTION_INSTABILITY_FLAGS) + + +def _signature_loss_metric(metrics: dict[str, object]) -> str: + if metrics.get("baseline_only_candidate_uncovered_signature_count") is not None: + return "baseline_only_candidate_uncovered_signature_count" + return "baseline_only_final_entity_signature_count" + + +def _stable_signature_loss_metric(metrics: dict[str, object]) -> str: + if metrics.get("baseline_stable_candidate_uncovered_signature_count") is not None: + return "baseline_stable_candidate_uncovered_signature_count" + return "baseline_stable_candidate_unstable_final_entity_signature_count" + + +_SKIPS_LLM_VALIDATION_STRATEGIES = {"detector_only", "rules_guardrail_detector_only", "rules_only"} + + +def _has_metric_pair(metrics: dict[str, object], name: str) -> bool: + return ( + _optional_float(metrics.get(f"baseline_{name}")) is not None + and _optional_float(metrics.get(f"candidate_{name}")) is not None + ) + + +def _comparison_verdicts(metrics: dict[str, object]) -> dict[str, str]: + value_protection_verdict = _value_protection_verdict(metrics) + signature_parity_verdict = _signature_parity_verdict(metrics) + safety_verdict = _safety_verdict(metrics) + performance_verdict = _performance_verdict(metrics) + return { + "value_protection_verdict": value_protection_verdict.value, + "signature_parity_verdict": signature_parity_verdict.value, + "safety_verdict": safety_verdict.value, + "performance_verdict": performance_verdict.value, + "candidate_verdict": _candidate_verdict(safety_verdict, performance_verdict).value, + } + + +def _value_protection_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_signature_loss", + "stable_entity_signature_loss", + }: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "baseline_case_failures", + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _signature_parity_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & {"candidate_case_failures", "entity_signature_loss", "stable_entity_signature_loss"}: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "span_boundary_mismatch", + "covered_label_mismatch", + "entity_count_loss", + "baseline_case_failures", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _safety_verdict(metrics: dict[str, object]) -> SafetyVerdict: + flags = set(_coerce_flag_list(metrics.get("flags"))) + if flags & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_signature_loss", + "stable_entity_signature_loss", + }: + return SafetyVerdict.fail + if _entity_count_loss_without_signature_artifacts(flags, metrics): + return SafetyVerdict.fail + if flags & { + "no_candidate_detector_entities", + "candidate_uses_rule_entities", + "candidate_skips_llm_validation", + "failed_request_increase", + "bridge_fallback_increase", + "span_boundary_mismatch", + "covered_label_mismatch", + }: + return SafetyVerdict.review + if flags & { + "baseline_case_failures", + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + "entity_count_loss", + }: + return SafetyVerdict.review + return SafetyVerdict.passed + + +def _entity_count_loss_without_signature_artifacts(flags: set[str], metrics: dict[str, object]) -> bool: + return "entity_count_loss" in flags and metrics.get("baseline_only_final_entity_signature_count") is None + + +def _performance_verdict(metrics: dict[str, object]) -> PerformanceVerdict: + deltas = [ + _optional_float(metrics.get("pipeline_elapsed_sec_delta")), + _optional_float(metrics.get("observed_total_requests_delta")), + _optional_float(metrics.get("observed_total_tokens_delta")), + ] + known = [value for value in deltas if value is not None] + if not known: + return PerformanceVerdict.unknown + improved = any(value < 0 for value in known) + regressed = any(value > 0 for value in known) + if improved and regressed: + return PerformanceVerdict.mixed + if improved: + return PerformanceVerdict.improved + if regressed: + return PerformanceVerdict.regressed + return PerformanceVerdict.unchanged + + +def _candidate_verdict( + safety_verdict: SafetyVerdict, + performance_verdict: PerformanceVerdict, +) -> CandidateVerdict: + if safety_verdict == SafetyVerdict.fail: + return CandidateVerdict.reject + if safety_verdict == SafetyVerdict.passed and performance_verdict == PerformanceVerdict.improved: + return CandidateVerdict.candidate_viable + return CandidateVerdict.review + + +def _coerce_flag_list(value: object) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value] + return [] + + +def _candidate_lacks_detector_entities(metrics: dict[str, object]) -> bool: + final_count = _optional_float(metrics.get("candidate_final_entity_count")) + if final_count is None or final_count <= 0: + return False + detector_count = _optional_float(metrics.get("candidate_detector_entity_count")) + if detector_count is not None: + return detector_count == 0 + non_detector_count = _known_non_detector_candidate_count(metrics) + return non_detector_count is not None and non_detector_count >= final_count + + +def _known_non_detector_candidate_count(metrics: dict[str, object]) -> float | None: + known_counts = [ + _optional_float(metrics.get("candidate_rule_entity_count")), + _optional_float(metrics.get("candidate_augmenter_entity_count")), + ] + if all(value is None for value in known_counts): + return None + return sum(value or 0.0 for value in known_counts) + + +def _entity_signature_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: + baseline_signatures = set(_coerce_signature_list(baseline.get("artifact_final_entity_signature_hashes"))) + candidate_signatures = set(_coerce_signature_list(candidate.get("artifact_final_entity_signature_hashes"))) + baseline_labels = _coerce_signature_labels(baseline.get("artifact_final_entity_signature_labels")) + candidate_labels = _coerce_signature_labels(candidate.get("artifact_final_entity_signature_labels")) + baseline_details = _coerce_signature_details(baseline.get("artifact_final_entity_signature_details")) + candidate_details = _coerce_signature_details(candidate.get("artifact_final_entity_signature_details")) + baseline_signatures.update(baseline_labels) + candidate_signatures.update(candidate_labels) + if not baseline_signatures and not candidate_signatures: + return { + "baseline_only_final_entity_signature_count": None, + "candidate_only_final_entity_signature_count": None, + "shared_final_entity_signature_count": None, + "baseline_only_final_entity_signature_label_counts": {}, + "candidate_only_final_entity_signature_label_counts": {}, + "shared_final_entity_signature_label_counts": {}, + "baseline_only_candidate_covered_signature_count": None, + "baseline_only_candidate_overlapping_signature_count": None, + "baseline_only_candidate_uncovered_signature_count": None, + "baseline_only_candidate_covered_signature_label_counts": {}, + "baseline_only_candidate_overlapping_signature_label_counts": {}, + "baseline_only_candidate_uncovered_signature_label_counts": {}, + "baseline_only_candidate_label_mismatch_signature_count": None, + "baseline_only_candidate_label_mismatch_signature_label_counts": {}, + } + baseline_only = baseline_signatures - candidate_signatures + candidate_only = candidate_signatures - baseline_signatures + shared = baseline_signatures & candidate_signatures + coverage = _candidate_span_coverage( + baseline_only, + baseline_details=baseline_details, + candidate_details=candidate_details, + baseline_labels=baseline_labels, + ) + return { + "baseline_only_final_entity_signature_count": len(baseline_only), + "candidate_only_final_entity_signature_count": len(candidate_only), + "shared_final_entity_signature_count": len(shared), + "baseline_only_final_entity_signature_label_counts": _signature_label_counts( + baseline_only, + baseline_labels, + ), + "candidate_only_final_entity_signature_label_counts": _signature_label_counts( + candidate_only, + candidate_labels, + ), + "shared_final_entity_signature_label_counts": _signature_label_counts(shared, baseline_labels), + **coverage, + **_stable_entity_signature_deltas( + baseline, + candidate, + baseline_labels, + candidate_labels, + baseline_details, + candidate_details, + ), + } + + +def _stable_entity_signature_deltas( + baseline: dict[str, object], + candidate: dict[str, object], + baseline_labels: dict[str, str], + candidate_labels: dict[str, str], + baseline_details: dict[str, dict[str, object]], + candidate_details: dict[str, dict[str, object]], +) -> dict[str, object]: + baseline_case_count = _optional_float(baseline.get("case_count")) + candidate_case_count = _optional_float(candidate.get("case_count")) + if ( + baseline_case_count is None + or candidate_case_count is None + or baseline_case_count < 2 + or candidate_case_count < 2 + ): + return _empty_stable_signature_deltas() + baseline_stable = set(_coerce_signature_list(baseline.get("stable_artifact_final_entity_signature_hashes"))) + candidate_stable = set(_coerce_signature_list(candidate.get("stable_artifact_final_entity_signature_hashes"))) + if not baseline_stable and not candidate_stable: + return _empty_stable_signature_deltas() + baseline_stable_candidate_unstable = baseline_stable - candidate_stable + candidate_stable_baseline_unstable = candidate_stable - baseline_stable + shared_stable = baseline_stable & candidate_stable + stable_candidate_details = { + signature: detail for signature, detail in candidate_details.items() if signature in candidate_stable + } + coverage = _candidate_span_coverage( + baseline_stable_candidate_unstable, + baseline_details=baseline_details, + candidate_details=stable_candidate_details, + baseline_labels=baseline_labels, + prefix="baseline_stable_candidate", + ) + return { + "baseline_stable_final_entity_signature_count": len(baseline_stable), + "candidate_stable_final_entity_signature_count": len(candidate_stable), + "stable_final_entity_signature_count_delta": len(candidate_stable) - len(baseline_stable), + "baseline_stable_candidate_unstable_final_entity_signature_count": len(baseline_stable_candidate_unstable), + "candidate_stable_baseline_unstable_final_entity_signature_count": len(candidate_stable_baseline_unstable), + "shared_stable_final_entity_signature_count": len(shared_stable), + "baseline_stable_candidate_unstable_final_entity_signature_label_counts": _signature_label_counts( + baseline_stable_candidate_unstable, + baseline_labels, + ), + "candidate_stable_baseline_unstable_final_entity_signature_label_counts": _signature_label_counts( + candidate_stable_baseline_unstable, + candidate_labels, + ), + "shared_stable_final_entity_signature_label_counts": _signature_label_counts(shared_stable, baseline_labels), + **coverage, + } + + +def _candidate_span_coverage( + baseline_signatures: set[str], + *, + baseline_details: dict[str, dict[str, object]], + candidate_details: dict[str, dict[str, object]], + baseline_labels: dict[str, str], + prefix: str = "baseline_only_candidate", +) -> dict[str, object]: + contained: set[str] = set() + overlapping: set[str] = set() + label_mismatch: set[str] = set() + for signature in baseline_signatures: + match, labels_mismatch = _candidate_span_match_kind(baseline_details.get(signature), candidate_details.values()) + if match == "contained": + contained.add(signature) + elif match == "overlapping": + overlapping.add(signature) + if labels_mismatch: + label_mismatch.add(signature) + covered = contained | overlapping + uncovered = baseline_signatures - covered + return { + f"{prefix}_covered_signature_count": len(covered), + f"{prefix}_overlapping_signature_count": len(overlapping), + f"{prefix}_uncovered_signature_count": len(uncovered), + f"{prefix}_covered_signature_label_counts": _signature_label_counts(covered, baseline_labels), + f"{prefix}_overlapping_signature_label_counts": _signature_label_counts(overlapping, baseline_labels), + f"{prefix}_uncovered_signature_label_counts": _signature_label_counts(uncovered, baseline_labels), + f"{prefix}_label_mismatch_signature_count": len(label_mismatch), + f"{prefix}_label_mismatch_signature_label_counts": _signature_label_counts(label_mismatch, baseline_labels), + } + + +def _candidate_span_match_kind( + baseline_detail: dict[str, object] | None, + candidate_details: object, +) -> tuple[str | None, bool]: + baseline_row = _optional_int(_detail_value(baseline_detail, "row_index")) + baseline_start = _optional_int(_detail_value(baseline_detail, "start_position")) + baseline_end = _optional_int(_detail_value(baseline_detail, "end_position")) + baseline_label = _optional_string(_detail_value(baseline_detail, "label")) + if baseline_row is None or baseline_start is None or baseline_end is None: + return None, False + baseline_length = baseline_end - baseline_start + if baseline_length <= 0: + return None, False + first_mismatched_match: tuple[str, bool] | None = None + for candidate_detail in candidate_details: + candidate_row = _optional_int(_detail_value(candidate_detail, "row_index")) + candidate_start = _optional_int(_detail_value(candidate_detail, "start_position")) + candidate_end = _optional_int(_detail_value(candidate_detail, "end_position")) + if candidate_row != baseline_row or candidate_start is None or candidate_end is None: + continue + candidate_label = _optional_string(_detail_value(candidate_detail, "label")) + labels_mismatch = _labels_mismatch(baseline_label, candidate_label) + if candidate_start <= baseline_start and candidate_end >= baseline_end: + if not labels_mismatch: + return "contained", False + first_mismatched_match = first_mismatched_match or ("contained", True) + continue + overlap = max(0, min(baseline_end, candidate_end) - max(baseline_start, candidate_start)) + if overlap / baseline_length >= _MIN_CANDIDATE_OVERLAP_RATIO: + if not labels_mismatch: + return "overlapping", False + first_mismatched_match = first_mismatched_match or ("overlapping", True) + continue + if _is_small_boundary_gap( + baseline_start=baseline_start, + baseline_end=baseline_end, + baseline_label=baseline_label, + candidate_start=candidate_start, + candidate_end=candidate_end, + overlap=overlap, + ): + if not labels_mismatch: + return "overlapping", False + first_mismatched_match = first_mismatched_match or ("overlapping", True) + return first_mismatched_match or (None, False) + + +def _labels_mismatch(baseline_label: str | None, candidate_label: str | None) -> bool: + return baseline_label is not None and candidate_label is not None and baseline_label != candidate_label + + +def _is_small_boundary_gap( + *, + baseline_start: int, + baseline_end: int, + baseline_label: str | None, + candidate_start: int, + candidate_end: int, + overlap: int, +) -> bool: + if baseline_label not in _BOUNDARY_NORMALIZED_BASELINE_LABELS: + return False + if overlap < _MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS: + return False + omitted_left = max(0, candidate_start - baseline_start) + omitted_right = max(0, baseline_end - candidate_end) + return omitted_left + omitted_right <= _MAX_CANDIDATE_BOUNDARY_GAP_CHARS + + +def _detail_value(detail: object, key: str) -> object: + if not isinstance(detail, dict): + return None + return detail.get(key) + + +def _empty_stable_signature_deltas() -> dict[str, object]: + return { + "baseline_stable_final_entity_signature_count": None, + "candidate_stable_final_entity_signature_count": None, + "stable_final_entity_signature_count_delta": None, + "baseline_stable_candidate_unstable_final_entity_signature_count": None, + "candidate_stable_baseline_unstable_final_entity_signature_count": None, + "shared_stable_final_entity_signature_count": None, + "baseline_stable_candidate_covered_signature_count": None, + "baseline_stable_candidate_overlapping_signature_count": None, + "baseline_stable_candidate_uncovered_signature_count": None, + "baseline_stable_candidate_covered_signature_label_counts": {}, + "baseline_stable_candidate_overlapping_signature_label_counts": {}, + "baseline_stable_candidate_uncovered_signature_label_counts": {}, + "baseline_stable_candidate_label_mismatch_signature_count": None, + "baseline_stable_candidate_label_mismatch_signature_label_counts": {}, + "baseline_stable_candidate_unstable_final_entity_signature_label_counts": {}, + "candidate_stable_baseline_unstable_final_entity_signature_label_counts": {}, + "shared_stable_final_entity_signature_label_counts": {}, + } + + +def _signature_hashes(rows: pd.DataFrame) -> list[str]: + hashes: set[str] = set() + for signature_set in _signature_hash_sets(rows): + hashes.update(signature_set) + return sorted(hashes) + + +def _stable_signature_hashes(rows: pd.DataFrame) -> list[str]: + signature_sets = _signature_hash_sets(rows) + if not signature_sets: + return [] + return sorted(set.intersection(*signature_sets)) + + +def _signature_hash_sets(rows: pd.DataFrame) -> list[set[str]]: + if "artifact_final_entity_signature_hashes" not in rows.columns: + return [] + return [set(_coerce_signature_list(value)) for value in rows["artifact_final_entity_signature_hashes"].tolist()] + + +def _signature_labels(rows: pd.DataFrame) -> dict[str, str]: + labels: dict[str, str] = {} + if "artifact_final_entity_signature_labels" in rows.columns: + for value in rows["artifact_final_entity_signature_labels"].tolist(): + labels.update(_coerce_signature_labels(value)) + prefix = "artifact_final_entity_signature_labels." + for column in rows.columns: + if not column.startswith(prefix): + continue + signature_hash = column.removeprefix(prefix) + for value in rows[column].tolist(): + if not _is_missing_cell(value): + labels[signature_hash] = str(value) + return dict(sorted(labels.items())) + + +def _signature_details(rows: pd.DataFrame) -> dict[str, dict[str, object]]: + details: dict[str, dict[str, object]] = {} + if "artifact_final_entity_signature_details" in rows.columns: + for value in rows["artifact_final_entity_signature_details"].tolist(): + details.update(_coerce_signature_details(value)) + prefix = "artifact_final_entity_signature_details." + for column in rows.columns: + if not column.startswith(prefix): + continue + remainder = column.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + for value in rows[column].tolist(): + if not _is_missing_cell(value): + details.setdefault(signature_hash, {})[field] = _json_scalar(value) + return dict(sorted(details.items())) + + +def _sum_label_counts(rows: pd.DataFrame, field: str) -> dict[str, int]: + counts: dict[str, int] = {} + if field in rows.columns: + for value in rows[field].tolist(): + _merge_count_map(counts, _coerce_count_map(value)) + prefix = f"{field}." + for column in rows.columns: + if not column.startswith(prefix): + continue + total = _sum_or_none(rows, column) + if total: + counts[column.removeprefix(prefix)] = counts.get(column.removeprefix(prefix), 0) + int(total) + return dict(sorted(counts.items())) + + +def _failed_case_count(rows: pd.DataFrame) -> int: + if "case_failed" in rows.columns: + return int(rows["case_failed"].map(_coerce_bool).sum()) + error_columns = [ + column + for column in ("error_stage_count", "error_ndd_workflow_count", "error_model_workflow_count") + if column in rows.columns + ] + if not error_columns: + return 0 + error_counts = rows[error_columns].apply(pd.to_numeric, errors="coerce").fillna(0).sum(axis=1) + return int((error_counts > 0).sum()) + + +def _coerce_bool(value: object) -> bool: + if _is_missing_cell(value): + return False + if isinstance(value, bool): + return value + if isinstance(value, int | float): + return value != 0 + text = str(value).strip().lower() + return text in {"1", "true", "t", "yes", "y"} + + +def _rate(count: object, total: object) -> float | None: + count_value = _optional_float(count) + total_value = _optional_float(total) + if count_value is None or total_value is None or total_value <= 0: + return None + return count_value / total_value + + +def _signature_label_counts(signatures: set[str], labels: dict[str, str]) -> dict[str, int]: + counts: dict[str, int] = {} + for signature_hash in signatures: + label = labels.get(signature_hash, "unknown") + counts[label] = counts.get(label, 0) + 1 + return dict(sorted(counts.items())) + + +def _coerce_signature_list(value: object) -> list[str]: + if _is_missing_cell(value): + return [] + if hasattr(value, "tolist"): + value = value.tolist() + if isinstance(value, list | tuple | set): + return [str(item) for item in value if not _is_missing_cell(item)] + if not isinstance(value, str): + return [] + text = value.strip() + if not text: + return [] + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return [] + return _coerce_signature_list(parsed) + + +def _coerce_signature_labels(value: object) -> dict[str, str]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + return {str(key): str(item) for key, item in value.items() if not _is_missing_cell(item)} + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_signature_labels(parsed) + + +def _coerce_signature_details(value: object) -> dict[str, dict[str, object]]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + details: dict[str, dict[str, object]] = {} + for signature_hash, raw_detail in value.items(): + if not isinstance(raw_detail, dict): + continue + details[str(signature_hash)] = { + str(key): _json_scalar(item) for key, item in raw_detail.items() if not _is_missing_cell(item) + } + return dict(sorted(details.items())) + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_signature_details(parsed) + + +def _json_scalar(value: object) -> object: + if hasattr(value, "item"): + try: + return value.item() + except ValueError: + return value + return value + + +def _coerce_count_map(value: object) -> dict[str, int]: + if _is_missing_cell(value): + return {} + if hasattr(value, "to_dict"): + value = value.to_dict() + if isinstance(value, dict): + counts: dict[str, int] = {} + for key, item in value.items(): + if _is_missing_cell(item): + continue + count = _optional_float(item) + if count: + counts[str(key)] = int(count) + return dict(sorted(counts.items())) + if not isinstance(value, str): + return {} + text = value.strip() + if not text: + return {} + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (SyntaxError, ValueError): + return {} + return _coerce_count_map(parsed) + + +def _merge_count_map(target: dict[str, int], update: dict[str, int]) -> None: + for key, value in update.items(): + target[key] = target.get(key, 0) + value + + +def _is_missing_cell(value: object) -> bool: + return value is None or (isinstance(value, float) and pd.isna(value)) + + +def _append_if_negative(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: + value = _optional_float(metrics.get(field)) + if value is not None and value < 0: + flags.append(flag) + + +def _append_if_positive(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: + value = _optional_float(metrics.get(field)) + if value is not None and value > 0: + flags.append(flag) + + +def _optional_float(value: object) -> float | None: + if value is None or pd.isna(value): + return None + return float(value) + + +def _optional_int(value: object) -> int | None: + number = _optional_float(value) + return int(number) if number is not None else None + + +def _optional_string(value: object) -> str | None: + if value is None or pd.isna(value): + return None + return str(value) + + +def _delta(baseline: float | None, candidate: float | None) -> float | None: + return candidate - baseline if baseline is not None and candidate is not None else None + + +def _delta_pct(baseline: float | None, candidate: float | None) -> float | None: + if baseline is None or candidate is None or baseline == 0: + return None + return ((candidate - baseline) / baseline) * 100 + + +def write_comparisons(rows: list[ComparisonRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + table = _normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def _normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: + normalized = table.copy() + for column in normalized.columns: + if normalized[column].map(_is_nested_cell).any(): + normalized[column] = normalized[column].map(_json_cell) + return normalized + + +def _is_nested_cell(value: object) -> bool: + return isinstance(value, dict | list) + + +def _json_cell(value: object) -> object: + if not _is_nested_cell(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def render_result(result: ComparisonResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Compared {result.comparison_count} workload(s): " + f"viable={result.summary.candidate_verdict_counts.get(CandidateVerdict.candidate_viable.value, 0)}, " + f"review={result.summary.candidate_verdict_counts.get(CandidateVerdict.review.value, 0)}, " + f"reject={result.summary.candidate_verdict_counts.get(CandidateVerdict.reject.value, 0)}" + ] + for row in result.comparisons: + lines.append( + f"- {row.workload_id}: verdict={row.candidate_verdict or 'unknown'} " + f"(safety={row.safety_verdict or 'unknown'}, " + f"value_protection={row.value_protection_verdict or 'unknown'}, " + f"signature_parity={row.signature_parity_verdict or 'unknown'}, " + f"performance={row.performance_verdict or 'unknown'}), " + f"replacement={row.baseline_replacement_strategy or 'unknown'}" + f"->{row.candidate_replacement_strategy or 'unknown'}, " + f"elapsed {_number_pair(row.baseline_pipeline_elapsed_sec, row.candidate_pipeline_elapsed_sec, suffix='s')}, " + f"entities {row.baseline_final_entity_count}->{row.candidate_final_entity_count}, " + f"requests {_number_pair(row.baseline_observed_total_requests, row.candidate_observed_total_requests)}, " + f"tokens {_number_pair(row.baseline_observed_total_tokens, row.candidate_observed_total_tokens)}, " + f"original_value_leaks " + f"{_number_pair(row.baseline_original_value_leak_count, row.candidate_original_value_leak_count)}, " + f"lost_labels={_label_count_summary(row.baseline_only_final_entity_signature_label_counts)}, " + "covered_label_mismatch_labels=" + f"{_label_count_summary(row.baseline_only_candidate_label_mismatch_signature_label_counts)}, " + f"unstable_lost_labels={_label_count_summary(row.baseline_stable_candidate_unstable_final_entity_signature_label_counts)}, " + f"leak_labels={_label_count_summary(row.candidate_original_value_leak_label_counts)}, " + f"flags={','.join(row.flags) if row.flags else 'none'}" + ) + return "\n".join(lines) + + +def _number_pair(baseline: float | None, candidate: float | None, *, suffix: str = "") -> str: + return f"{_format_number(baseline, suffix=suffix)}->{_format_number(candidate, suffix=suffix)}" + + +def _format_number(value: float | None, *, suffix: str = "") -> str: + if value is None: + return "unknown" + return f"{value:.1f}{suffix}" + + +def _label_count_summary(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ",".join(f"{label}:{count}" for label, count in sorted(counts.items())) + + +def _selector_label(*, config: str | None, strategy: str | None) -> str: + if config is not None: + return f"config:{config}" + if strategy is not None: + return f"strategy:{strategy}" + return "" + + +@app.default +def main( + case_analysis: Path, + *, + candidate_case_analysis: Annotated[Path | None, cyclopts.Parameter("--candidate-case-analysis")] = None, + baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, + candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, + baseline_strategy: Annotated[str | None, cyclopts.Parameter("--baseline-strategy")] = None, + candidate_strategy: Annotated[str | None, cyclopts.Parameter("--candidate-strategy")] = None, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + baseline_table = read_case_analysis(case_analysis) + candidate_table = read_case_analysis(candidate_case_analysis) if candidate_case_analysis else baseline_table + comparisons = compare_case_tables( + baseline_table, + candidate_table, + baseline_config=baseline_config, + candidate_config=candidate_config, + baseline_strategy=baseline_strategy, + candidate_strategy=candidate_strategy, + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + result = ComparisonResult( + input_path=str(case_analysis), + candidate_input_path=str(candidate_case_analysis) if candidate_case_analysis else None, + baseline_selector=_selector_label(config=baseline_config, strategy=baseline_strategy), + candidate_selector=_selector_label(config=candidate_config, strategy=candidate_strategy), + summary=summarize_comparisons(comparisons), + comparisons=comparisons, + ) + if output is not None: + write_comparisons(comparisons, output, format) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/dd_parser_compat.py b/tools/measurement/dd_parser_compat.py new file mode 100644 index 00000000..d3aa441c --- /dev/null +++ b/tools/measurement/dd_parser_compat.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark-only DataDesigner structured-output parser compatibility modes.""" + +from __future__ import annotations + +import json +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from enum import StrEnum +from typing import Any + +from data_designer.engine.models.parsers.errors import ParserException +from data_designer.engine.models.recipes import response_recipes as recipes +from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate + + +class DDParserCompatMode(StrEnum): + none = "none" + raw_json = "raw_json" + + +@contextmanager +def dd_parser_compat_context(mode: DDParserCompatMode) -> Iterator[None]: + """Temporarily patch DataDesigner parsers for benchmark-only compatibility.""" + if mode == DDParserCompatMode.none: + yield + return + if mode != DDParserCompatMode.raw_json: + raise ValueError(f"unsupported DataDesigner parser compatibility mode: {mode}") + + original_pydantic = recipes.PydanticResponseRecipe._build_parser_fn + original_structured = recipes.StructuredResponseRecipe._build_parser_fn + recipes.PydanticResponseRecipe._build_parser_fn = _tolerant_pydantic_builder(original_pydantic) # type: ignore[method-assign] + recipes.StructuredResponseRecipe._build_parser_fn = _tolerant_structured_builder(original_structured) # type: ignore[method-assign] + try: + yield + finally: + recipes.PydanticResponseRecipe._build_parser_fn = original_pydantic # type: ignore[method-assign] + recipes.StructuredResponseRecipe._build_parser_fn = original_structured # type: ignore[method-assign] + + +def _tolerant_pydantic_builder(original: Callable[..., Callable[[str], Any]]) -> Callable[..., Callable[[str], Any]]: + def build_parser(self: Any) -> Callable[[str], Any]: + base_parse = original(self) + + def parse(response: str) -> Any: + try: + return base_parse(response) + except ParserException as exc: + try: + return self.data_type.model_validate(_load_embedded_json(response)) + except Exception: + raise exc + + return parse + + return build_parser + + +def _tolerant_structured_builder( + original: Callable[..., Callable[[str], dict]], +) -> Callable[..., Callable[[str], dict]]: + def build_parser(self: Any) -> Callable[[str], dict]: + base_parse = original(self) + + def parse(response: str) -> dict: + try: + return base_parse(response) + except ParserException as exc: + try: + return validate(_load_embedded_json(response), **self._validate_args) + except (json.JSONDecodeError, JSONSchemaValidationError, TypeError, ValueError): + raise exc + + return parse + + return build_parser + + +def _load_embedded_json(response: str) -> Any: + """Return the largest JSON object/array embedded in a model response.""" + decoder = json.JSONDecoder() + stripped = response.strip() + try: + return json.loads(stripped) + except json.JSONDecodeError: + pass + + best: tuple[int, int, Any] | None = None + for start, char in enumerate(response): + if char not in "{[": + continue + try: + parsed, end = decoder.raw_decode(response, start) + except json.JSONDecodeError: + continue + if not isinstance(parsed, dict | list): + continue + # Prefer the candidate that consumes the most response text. That + # selects the outer response object instead of nested item objects. + if best is None or end > best[1] or (end == best[1] and start < best[0]): + best = (start, end, parsed) + + if best is None: + raise json.JSONDecodeError("No embedded JSON object or array found", response, 0) + return best[2] diff --git a/tools/measurement/detection_strategies.py b/tools/measurement/detection_strategies.py new file mode 100644 index 00000000..e55e677a --- /dev/null +++ b/tools/measurement/detection_strategies.py @@ -0,0 +1,2436 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Experimental detection strategies for benchmark-only performance probes.""" + +from __future__ import annotations + +import json +import re +import time +from collections import Counter +from collections.abc import Callable, Iterator +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + +import pandas as pd +from data_designer.config import custom_column_generator +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig +from data_designer.config.models import ModelConfig +from dd_parser_compat import _load_embedded_json +from direct_detection_probe import DirectDetectionRequest, DirectGenerationRequest, PromptMode, build_direct_prompt +from staged_detection_probe import ( + CaseStatus as StagedCaseStatus, +) +from staged_detection_probe import ( + DirectCompletion, + DirectDetectionClient, + GlinerSeedClient, + HttpxDirectDetectionClient, + SeedSource, + StagedDetectionCase, + StagedDetectionRequest, + StagedExecutionConfig, + ValidationPromptMode, + _run_augmentation_phase, + _run_validation_phase, + execute_staged_detection_case, + normalize_birth_context_final_entities, +) + +from anonymizer.config.models import DetectionModelSelection +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_INITIAL_TAGGED_TEXT, + COL_MERGED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, + COL_VALIDATED_ENTITIES, + COL_VALIDATED_SEED_ENTITIES, + COL_VALIDATION_DECISIONS, + _jinja, +) +from anonymizer.engine.detection import detection_workflow as dw +from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams, make_chunked_validation_generator +from anonymizer.engine.detection.custom_columns import ( + apply_validation_and_finalize, + apply_validation_to_seed_entities, + enrich_validation_decisions, + merge_and_build_candidates, + parse_detected_entities, + prepare_validation_inputs, +) +from anonymizer.engine.detection.postprocess import ( + EntitySpan, + build_tagged_text, + expand_entity_occurrences, + get_tag_notation, + parse_raw_entities, + resolve_overlaps, +) +from anonymizer.engine.detection.rules import detect_high_confidence_entities +from anonymizer.engine.ndd.adapter import FailedRecord +from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases +from anonymizer.engine.row_partitioning import merge_and_reorder, split_rows +from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema, ValidationCandidatesSchema +from anonymizer.measurement import record_model_workflow + +_NATIVE_DIRECT_MODEL_ALIAS = "native-direct" +_NATIVE_DIRECT_MODEL_NAME = "nvidia/nemotron-3-super" +_NATIVE_DIRECT_MODEL_PROVIDER = "local-vllm" +_NATIVE_DIRECT_ENDPOINT = "http://gpu-dev-pod-serve-svc:8000/v1" +_NATIVE_DIRECT_MAX_TOKENS = 4096 +_NATIVE_DIRECT_TIMEOUT_SEC = 180.0 +_GLINER_DIRECT_MODEL_ALIAS = "gliner-direct" +_GLINER_DIRECT_MODEL_NAME = "nvidia/gliner-pii" +_GLINER_DIRECT_MODEL_PROVIDER = "nvidia" +_NATIVE_STAGED_MAX_WORKERS = 4 +_STRUCTURED_ASSIGNMENT_RE = re.compile( + r"(?" + r"api[_-]?key|aws[_-]?access[_-]?key[_-]?id|access[_-]?key[_-]?id|hf[_-]?token|" + r"token|auth[_-]?token|session[_-]?id|authorization|" + r"password|pass|secret|aws[_-]?secret[_-]?access[_-]?key|django[_-]?secret|database[_-]?url|" + r"pin|user(?:_?name)?|username|login|account|cookie|" + r"trace[-_]?id|request[-_]?id|req[-_]?id|order[-_]?id|tenant[-_]?id|unique[-_]?id|" + r"url|uri|endpoint|callback|email" + r")['\"]?\s*[:=]\s*" + r"(?:['\"](?P[^'\"\r\n]+)['\"]|(?P[^\s'\",;]+))", + flags=re.IGNORECASE, +) + + +class ExperimentalDetectionStrategy(StrEnum): + default = "default" + prose_augment_focus = "prose_augment_focus" + compact_validation = "compact_validation" + rules_guardrail_compact_validation = "rules_guardrail_compact_validation" + rules_guardrail = "rules_guardrail" + rules_filter_guardrail = "rules_filter_guardrail" + no_augment = "no_augment" + rules_seed_no_augment = "rules_seed_no_augment" + rules_guardrail_no_augment = "rules_guardrail_no_augment" + rules_filter_guardrail_no_augment = "rules_filter_guardrail_no_augment" + rules_guardrail_detector_only = "rules_guardrail_detector_only" + detector_only = "detector_only" + rules_only = "rules_only" + rules_covered_or_default = "rules_covered_or_default" + native_rules_router = "native_rules_router" + native_candidate_validate_no_augment = "native_candidate_validate_no_augment" + detector_native_validate_no_augment = "detector_native_validate_no_augment" + detector_native_validate_native_augment = "detector_native_validate_native_augment" + gliner_native_validate_no_augment = "gliner_native_validate_no_augment" + gliner_native_validate_native_augment = "gliner_native_validate_native_augment" + native_single_pass = "native_single_pass" + native_single_pass_recall = "native_single_pass_recall" + native_single_pass_values = "native_single_pass_values" + native_single_pass_values_recall = "native_single_pass_values_recall" + + +_DetectAndValidate = Callable[..., dw.EntityDetectionResult] +_AugmentPrompt = Callable[..., str] +_MaterializeFinalEntities = Callable[..., dict] +PROSE_AUGMENT_FOCUS_TEXT = """\ +Contextual prose recall focus: +- Re-scan untagged narrative prose for organization and institution names, named facilities, labs, research centers, street or place names, self-described beliefs, occupations, titles, family member names, and other quasi-identifiers that combine with already tagged entities. +- Prefer allowed labels when present, especially organization_name, company_name, place_name, religious_belief, political_view, occupation, first_name, last_name, city, state, university, degree, field_of_study, language, race_ethnicity, and age. +- Do not add generic common nouns, section headings, or labels outside the allowed list when strict labels are required. +""" + + +@dataclass(frozen=True) +class _NoAugmentOptions: + include_rules: bool + final_rule_guardrail: bool = False + filter_rule_overlaps: bool = False + rule_labels: tuple[str, ...] = () + + +@dataclass(frozen=True) +class _NativeStagedTask: + ordinal: int + index: Any + row: pd.Series + + +@dataclass(frozen=True) +class _NativeStagedRunParams: + labels: list[str] + client: DirectDetectionClient + gliner_seed_client: GlinerSeedClient | None + data_summary: str | None + validation_prompt_mode: ValidationPromptMode + validation_max_entities_per_call: int + validation_excerpt_window_chars: int + seed_source: SeedSource + workflow_name: str + skip_augmentation: bool + + +@dataclass(frozen=True) +class _NativeStagedRowResult: + ordinal: int + index: Any + output_row: dict[str, Any] | None + failed_record: FailedRecord | None + case: StagedDetectionCase | None + + +@dataclass(frozen=True) +class _DetectorNativeValidationParams: + labels: list[str] + client: DirectDetectionClient + data_summary: str | None + validation_prompt_mode: ValidationPromptMode + validation_max_entities_per_call: int + validation_excerpt_window_chars: int + workflow_name: str + skip_augmentation: bool + + +@dataclass(frozen=True) +class _DetectorNativeValidationRowResult: + ordinal: int + index: Any + workflow_name: str + output_row: dict[str, Any] | None + failed_record: FailedRecord | None + completion: DirectCompletion | None + elapsed_sec: float + request_count: int + + +@contextmanager +def experimental_detection_strategy_context( + strategy: ExperimentalDetectionStrategy, + *, + rule_labels: list[str] | None = None, + native_client: DirectDetectionClient | None = None, + gliner_seed_client: GlinerSeedClient | None = None, +) -> Iterator[None]: + """Temporarily apply a benchmark-only detection strategy.""" + if strategy == ExperimentalDetectionStrategy.default: + yield + return + + original_method = dw.EntityDetectionWorkflow.detect_and_validate_entities + original_augment_prompt = dw._get_augment_prompt + original_materialize_final_entities = dw._materialize_final_entities + if rule_labels: + dw._materialize_final_entities = _make_rule_label_materializer( # type: ignore[assignment] + original_materialize_final_entities, + rule_labels=rule_labels, + ) + if strategy == ExperimentalDetectionStrategy.prose_augment_focus: + dw._get_augment_prompt = _make_prose_augment_prompt(original_augment_prompt) # type: ignore[assignment] + else: + dw.EntityDetectionWorkflow.detect_and_validate_entities = _method_for_strategy( # type: ignore[method-assign] + strategy, + original=original_method, + rule_labels=rule_labels, + native_client=native_client, + gliner_seed_client=gliner_seed_client, + ) + try: + yield + finally: + dw.EntityDetectionWorkflow.detect_and_validate_entities = original_method # type: ignore[method-assign] + dw._get_augment_prompt = original_augment_prompt # type: ignore[assignment] + dw._materialize_final_entities = original_materialize_final_entities # type: ignore[assignment] + + +def _make_rule_label_materializer( + original: _MaterializeFinalEntities, + *, + rule_labels: list[str], +) -> _MaterializeFinalEntities: + def materialize_final_entities(raw: object, *, allowed_labels: set[str] | None) -> dict: + if allowed_labels is None: + return original(raw, allowed_labels=allowed_labels) + return original(raw, allowed_labels={*allowed_labels, *rule_labels}) + + return materialize_final_entities + + +def _make_prose_augment_prompt(original: _AugmentPrompt) -> _AugmentPrompt: + def get_augment_prompt(*, data_summary: str | None, labels: list[str], strict_labels: bool = False) -> str: + prompt = original(data_summary=data_summary, labels=labels, strict_labels=strict_labels) + return prompt.replace("Rules:\n", f"{PROSE_AUGMENT_FOCUS_TEXT}\nRules:\n", 1) + + return get_augment_prompt + + +def _method_for_strategy( + strategy: ExperimentalDetectionStrategy, + *, + original: _DetectAndValidate | None = None, + rule_labels: list[str] | None = None, + native_client: DirectDetectionClient | None = None, + gliner_seed_client: GlinerSeedClient | None = None, +) -> _DetectAndValidate: + if strategy == ExperimentalDetectionStrategy.compact_validation: + if original is None: + raise ValueError("compact_validation requires the original detection method") + return _make_default_compact_validation_method(original) + if strategy == ExperimentalDetectionStrategy.rules_guardrail: + if original is None: + raise ValueError("rules_guardrail requires the original detection method") + return _make_default_with_rule_guardrail_method(original, rule_labels=rule_labels) + if strategy == ExperimentalDetectionStrategy.rules_filter_guardrail: + return _make_validated_augmented_rule_filter_guardrail_method(rule_labels=rule_labels) + if strategy == ExperimentalDetectionStrategy.rules_guardrail_compact_validation: + if original is None: + raise ValueError("rules_guardrail_compact_validation requires the original detection method") + return _make_default_with_rule_guardrail_method( + original, + rule_labels=rule_labels, + compact_validation=True, + ) + if strategy == ExperimentalDetectionStrategy.no_augment: + return _make_validated_no_augment_method(include_rules=False) + if strategy == ExperimentalDetectionStrategy.rules_seed_no_augment: + return _make_validated_no_augment_method(include_rules=True, rule_labels=rule_labels) + if strategy == ExperimentalDetectionStrategy.rules_guardrail_no_augment: + return _make_validated_no_augment_method( + include_rules=False, + final_rule_guardrail=True, + rule_labels=rule_labels, + ) + if strategy == ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment: + return _make_validated_no_augment_method( + include_rules=False, + final_rule_guardrail=True, + filter_rule_overlaps=True, + rule_labels=rule_labels, + ) + if strategy == ExperimentalDetectionStrategy.rules_guardrail_detector_only: + return _make_detector_only_with_rule_guardrail_method(rule_labels=rule_labels) + if strategy == ExperimentalDetectionStrategy.detector_only: + return _detect_with_detector_only + if strategy == ExperimentalDetectionStrategy.rules_covered_or_default: + if original is None: + raise ValueError("rules_covered_or_default requires the original detection method") + return _make_rules_covered_or_default_method(original) + if strategy == ExperimentalDetectionStrategy.rules_only: + return _detect_with_rules_only + if strategy == ExperimentalDetectionStrategy.native_rules_router: + return _make_native_rules_router_method(native_client=native_client) + if strategy == ExperimentalDetectionStrategy.native_candidate_validate_no_augment: + return _make_native_candidate_validate_no_augment_method(native_client=native_client) + if strategy == ExperimentalDetectionStrategy.detector_native_validate_no_augment: + return _make_detector_native_validate_no_augment_method(native_client=native_client) + if strategy == ExperimentalDetectionStrategy.detector_native_validate_native_augment: + return _make_detector_native_validate_native_augment_method(native_client=native_client) + if strategy == ExperimentalDetectionStrategy.gliner_native_validate_no_augment: + return _make_gliner_native_validate_no_augment_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + ) + if strategy == ExperimentalDetectionStrategy.gliner_native_validate_native_augment: + return _make_gliner_native_validate_native_augment_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + ) + if strategy == ExperimentalDetectionStrategy.native_single_pass: + return _make_native_single_pass_method(native_client=native_client) + if strategy == ExperimentalDetectionStrategy.native_single_pass_recall: + return _make_native_single_pass_method(native_client=native_client, recall_prompt=True) + if strategy == ExperimentalDetectionStrategy.native_single_pass_values: + return _make_native_single_pass_method(native_client=native_client, value_only_prompt=True) + if strategy == ExperimentalDetectionStrategy.native_single_pass_values_recall: + return _make_native_single_pass_method( + native_client=native_client, + recall_prompt=True, + value_only_prompt=True, + ) + raise ValueError(f"unsupported experimental detection strategy: {strategy}") + + +def _make_default_compact_validation_method(original: _DetectAndValidate) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return original( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=False, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + + return detect_and_validate_entities + + +def _make_default_with_rule_guardrail_method( + original: _DetectAndValidate, + *, + rule_labels: list[str] | None = None, + compact_validation: bool = False, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + result = original( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=not compact_validation, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + output = _apply_rule_guardrail( + result.dataframe.copy(), + labels=_rule_labels_for_detection(entity_labels, extra_rule_labels=rule_labels), + ) + return dw.EntityDetectionResult(dataframe=output, failed_records=result.failed_records) + + return detect_and_validate_entities + + +def _make_validated_augmented_rule_filter_guardrail_method( + *, + rule_labels: list[str] | None = None, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return _run_validated_augmented_rule_filter_guardrail_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + entity_labels=entity_labels, + data_summary=data_summary, + rule_labels=rule_labels, + ) + + return detect_and_validate_entities + + +def _run_validated_augmented_rule_filter_guardrail_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + entity_labels: list[str] | None, + data_summary: str | None, + rule_labels: list[str] | None, +) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + detection_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_validated_augmented_rule_filter_guardrail_columns( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + strict_labels=entity_labels is not None, + rule_labels=rule_labels, + ), + workflow_name="entity-detection-rules-filter-guardrail", + preview_num_records=preview_num_records, + ) + return dw.EntityDetectionResult( + dataframe=detection_result.dataframe.copy(), + failed_records=detection_result.failed_records, + ) + + +def _validated_augmented_rule_filter_guardrail_columns( + *, + selected_models: DetectionModelSelection, + labels: list[str], + data_summary: str | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + strict_labels: bool, + rule_labels: list[str] | None, +) -> list[LLMTextColumnConfig | LLMStructuredColumnConfig | CustomColumnConfig]: + validator_params = _validator_params( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ) + rule_detection_labels = _rule_labels_for_detection(labels, extra_rule_labels=rule_labels) + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, + prompt=_jinja(COL_TEXT), + model_alias=_detector_alias(selected_models), + ), + CustomColumnConfig( + name=COL_SEED_ENTITIES, + generator_function=_make_parse_detected_entities_filtering_rules(rule_detection_labels), + ), + CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), + _validation_decisions_column(selected_models, validator_params), + CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), + CustomColumnConfig( + name=COL_SEED_ENTITIES_JSON, + generator_function=_make_apply_validation_to_seed_entities_with_additive_rule_guardrail( + rule_detection_labels + ), + ), + LLMStructuredColumnConfig( + name=COL_AUGMENTED_ENTITIES, + prompt=dw._get_augment_prompt(data_summary=data_summary, labels=labels, strict_labels=strict_labels), + model_alias=resolve_model_alias("entity_augmenter", selected_models), + output_format=AugmentedEntitiesSchema, + ), + CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), + CustomColumnConfig( + name=COL_DETECTED_ENTITIES, + generator_function=_make_apply_validation_and_finalize_with_additive_rule_guardrail(rule_detection_labels), + ), + ] + + +def _rule_labels_for_detection( + entity_labels: list[str] | None, + *, + extra_rule_labels: list[str] | tuple[str, ...] | None = None, +) -> list[str]: + labels = set(dw._resolve_detection_labels(entity_labels)) + labels.update(extra_rule_labels or []) + return sorted(labels) + + +def _apply_rule_guardrail(dataframe: pd.DataFrame, *, labels: list[str]) -> pd.DataFrame: + if COL_TEXT not in dataframe.columns or COL_DETECTED_ENTITIES not in dataframe.columns: + return dataframe + dataframe[COL_DETECTED_ENTITIES] = dataframe[COL_DETECTED_ENTITIES].astype("object") + if COL_TAGGED_TEXT in dataframe.columns: + dataframe[COL_TAGGED_TEXT] = dataframe[COL_TAGGED_TEXT].astype("object") + for index, row in dataframe.iterrows(): + guarded = _guarded_entities( + text=str(row.get(COL_TEXT, "")), raw_entities=row.get(COL_DETECTED_ENTITIES), labels=labels + ) + dataframe.at[index, COL_DETECTED_ENTITIES] = EntitiesSchema( + entities=[entity.as_dict() for entity in guarded] + ).model_dump(mode="json") + if COL_TAGGED_TEXT in dataframe.columns: + dataframe.at[index, COL_TAGGED_TEXT] = build_tagged_text(text=str(row.get(COL_TEXT, "")), entities=guarded) + return dataframe + + +def _guarded_entities(*, text: str, raw_entities: object, labels: list[str]) -> list[EntitySpan]: + final_spans = _entity_spans_from_payload(raw_entities) + rule_spans = detect_high_confidence_entities(text, labels=labels) + return _merge_rule_guardrail_spans(final_spans, rule_spans) + + +def _merge_rule_guardrail_spans(final_spans: list[EntitySpan], rule_spans: list[EntitySpan]) -> list[EntitySpan]: + filtered_final = [ + entity + for entity in final_spans + if not any( + rule.start_position == entity.start_position + and rule.end_position == entity.end_position + and rule.label != entity.label + for rule in rule_spans + ) + ] + return resolve_overlaps([*filtered_final, *rule_spans]) + + +def _make_validated_no_augment_method( + *, + include_rules: bool, + final_rule_guardrail: bool = False, + filter_rule_overlaps: bool = False, + rule_labels: list[str] | None = None, +) -> _DetectAndValidate: + options = _NoAugmentOptions( + include_rules=include_rules, + final_rule_guardrail=final_rule_guardrail, + filter_rule_overlaps=filter_rule_overlaps, + rule_labels=tuple(rule_labels or ()), + ) + + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return _run_validated_no_augment_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + entity_labels=entity_labels, + data_summary=data_summary, + options=options, + ) + + return detect_and_validate_entities + + +def _run_validated_no_augment_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + entity_labels: list[str] | None, + data_summary: str | None, + options: _NoAugmentOptions, +) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + detection_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_validated_no_augment_columns( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + options=options, + ), + workflow_name=_workflow_name_for_no_augment(options), + preview_num_records=preview_num_records, + ) + return dw.EntityDetectionResult( + dataframe=detection_result.dataframe.copy(), + failed_records=detection_result.failed_records, + ) + + +def _validated_no_augment_columns( + *, + selected_models: DetectionModelSelection, + labels: list[str], + data_summary: str | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + options: _NoAugmentOptions, +) -> list[LLMTextColumnConfig | CustomColumnConfig]: + validator_params = _validator_params( + selected_models=selected_models, + labels=labels, + data_summary=data_summary, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ) + parse_generator = _parse_generator( + labels=_rule_labels_for_detection(labels, extra_rule_labels=options.rule_labels), + include_rules=options.include_rules, + filter_rule_overlaps=options.filter_rule_overlaps, + ) + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, prompt=_jinja(COL_TEXT), model_alias=_detector_alias(selected_models) + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_generator), + CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), + _validation_decisions_column(selected_models, validator_params), + CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), + CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=apply_validation_to_seed_entities), + CustomColumnConfig(name=COL_AUGMENTED_ENTITIES, generator_function=_empty_augmentation), + CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), + CustomColumnConfig( + name=COL_DETECTED_ENTITIES, + generator_function=_finalizer( + _rule_labels_for_detection(labels, extra_rule_labels=options.rule_labels), options + ), + ), + ] + + +def _validation_decisions_column( + selected_models: DetectionModelSelection, + validator_params: ChunkedValidationParams, +) -> CustomColumnConfig: + return CustomColumnConfig( + name=COL_VALIDATION_DECISIONS, + generator_function=make_chunked_validation_generator( + resolve_model_aliases("entity_validator", selected_models) + ), + generator_params=validator_params, + drop=True, + ) + + +def _finalizer(labels: list[str], options: _NoAugmentOptions) -> Callable[[dict[str, Any]], dict[str, Any]]: + if options.filter_rule_overlaps: + return _make_apply_validation_and_finalize_with_additive_rule_guardrail(labels) + if options.final_rule_guardrail: + return _make_apply_validation_and_finalize_with_rule_guardrail(labels) + return apply_validation_and_finalize + + +def _workflow_name_for_no_augment(options: _NoAugmentOptions) -> str: + if options.filter_rule_overlaps: + return "entity-detection-rules-filter-guardrail-no-augment" + if options.final_rule_guardrail: + return "entity-detection-rules-guardrail-no-augment" + if options.include_rules: + return "entity-detection-rules-no-augment" + return "entity-detection-no-augment" + + +def _validator_params( + *, + selected_models: DetectionModelSelection, + labels: list[str], + data_summary: str | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, +) -> ChunkedValidationParams: + validator_aliases = resolve_model_aliases("entity_validator", selected_models) + return ChunkedValidationParams( + pool=list(validator_aliases), + max_entities_per_call=validation_max_entities_per_call, + excerpt_window_chars=validation_excerpt_window_chars, + prompt_template=dw._get_validation_prompt(data_summary=data_summary, labels=labels), + ) + + +def _detector_alias(selected_models: DetectionModelSelection) -> str: + return resolve_model_alias("entity_detector", selected_models) + + +def _detect_with_detector_only( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, +) -> dw.EntityDetectionResult: + return _run_detector_only_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + entity_labels=entity_labels, + preview_num_records=preview_num_records, + rule_labels=None, + workflow_name="entity-detection-detector-only", + ) + + +def _make_detector_only_with_rule_guardrail_method(rule_labels: list[str] | None) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + return _run_detector_only_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + entity_labels=entity_labels, + preview_num_records=preview_num_records, + rule_labels=_rule_labels_for_detection(entity_labels, extra_rule_labels=rule_labels), + workflow_name="entity-detection-rules-guardrail-detector-only", + ) + + return detect_and_validate_entities + + +def _run_detector_only_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + entity_labels: list[str] | None, + preview_num_records: int | None, + rule_labels: list[str] | None, + workflow_name: str, +) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + detection_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_detector_only_columns(selected_models, rule_labels=rule_labels), + workflow_name=workflow_name, + preview_num_records=preview_num_records, + ) + return dw.EntityDetectionResult( + dataframe=detection_result.dataframe.copy(), + failed_records=detection_result.failed_records, + ) + + +def _detector_only_columns( + selected_models: DetectionModelSelection, + *, + rule_labels: list[str] | None, +) -> list[LLMTextColumnConfig | CustomColumnConfig]: + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, + prompt=_jinja(COL_TEXT), + model_alias=_detector_alias(selected_models), + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), + CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=_copy_seed_entities_json), + CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=_detector_only_finalizer(rule_labels)), + ] + + +def _detector_only_finalizer(rule_labels: list[str] | None) -> Callable[[dict[str, Any]], dict[str, Any]]: + if rule_labels is None: + return _finalize_detector_only + return _make_finalize_detector_only_with_rule_guardrail(rule_labels) + + +@custom_column_generator(required_columns=[COL_SEED_ENTITIES]) +def _copy_seed_entities_json(row: dict[str, Any]) -> dict[str, Any]: + row[COL_SEED_ENTITIES_JSON] = json.dumps( + [span.as_dict() for span in _entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))] + ) + return row + + +@custom_column_generator( + required_columns=[COL_TEXT, COL_SEED_ENTITIES], + side_effect_columns=[COL_TAGGED_TEXT], +) +def _finalize_detector_only(row: dict[str, Any]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + spans = expand_entity_occurrences(text=text, entities=_entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))) + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump(mode="json") + row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) + return row + + +def _make_finalize_detector_only_with_rule_guardrail(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_SEED_ENTITIES], + side_effect_columns=[COL_TAGGED_TEXT], + ) + def finalize_detector_only_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: + row = _finalize_detector_only(row) + text = str(row.get(COL_TEXT, "")) + final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _merge_rule_guardrail_spans(final_spans, rule_spans) + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in guarded]).model_dump( + mode="json" + ) + row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return row + + return finalize_detector_only_with_rule_guardrail + + +@custom_column_generator(required_columns=[COL_TEXT]) +def _empty_augmentation(row: dict[str, Any]) -> dict[str, Any]: + row[COL_AUGMENTED_ENTITIES] = AugmentedEntitiesSchema().model_dump(mode="json") + return row + + +def _parse_generator( + *, + labels: list[str], + include_rules: bool, + filter_rule_overlaps: bool, +) -> Callable[[dict[str, Any]], dict[str, Any]]: + if filter_rule_overlaps: + return _make_parse_detected_entities_filtering_rules(labels) + if include_rules: + return _make_parse_detected_entities_with_rules(labels) + return parse_detected_entities + + +def _make_parse_detected_entities_with_rules(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_RAW_DETECTED], + side_effect_columns=[COL_TAG_NOTATION], + ) + def parse_detected_entities_with_rules(row: dict[str, Any]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + detected = parse_raw_entities(raw_response=str(row.get(COL_RAW_DETECTED, "")), text=text) + rule_spans = detect_high_confidence_entities(text, labels=labels) + row[COL_SEED_ENTITIES] = EntitiesSchema( + entities=[entity.as_dict() for entity in resolve_overlaps([*detected, *rule_spans])] + ).model_dump(mode="json") + row[COL_TAG_NOTATION] = get_tag_notation(text=text) + return row + + return parse_detected_entities_with_rules + + +def _make_parse_detected_entities_filtering_rules(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_RAW_DETECTED], + side_effect_columns=[COL_TAG_NOTATION], + ) + def parse_detected_entities_filtering_rules(row: dict[str, Any]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + detected = parse_raw_entities(raw_response=str(row.get(COL_RAW_DETECTED, "")), text=text) + rule_spans = detect_high_confidence_entities(text, labels=labels) + filtered = [entity for entity in detected if not _is_rule_covered_detector_span(entity, rule_spans)] + row[COL_SEED_ENTITIES] = EntitiesSchema( + entities=[entity.as_dict() for entity in resolve_overlaps(filtered)] + ).model_dump(mode="json") + row[COL_TAG_NOTATION] = get_tag_notation(text=text) + return row + + return parse_detected_entities_filtering_rules + + +def _make_apply_validation_and_finalize_with_rule_guardrail( + labels: list[str], +) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_MERGED_ENTITIES, COL_VALIDATED_ENTITIES], + side_effect_columns=[COL_TAGGED_TEXT], + ) + def apply_validation_and_finalize_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: + row = apply_validation_and_finalize(row) + text = str(row.get(COL_TEXT, "")) + final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _merge_rule_guardrail_spans(final_spans, rule_spans) + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[entity.as_dict() for entity in guarded]).model_dump( + mode="json" + ) + row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return row + + return apply_validation_and_finalize_with_rule_guardrail + + +def _make_apply_validation_and_finalize_with_additive_rule_guardrail( + labels: list[str], +) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_MERGED_ENTITIES, COL_VALIDATED_ENTITIES], + side_effect_columns=[COL_TAGGED_TEXT], + ) + def apply_validation_and_finalize_with_additive_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: + row = apply_validation_and_finalize(row) + text = str(row.get(COL_TEXT, "")) + final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _add_non_overlapping_rule_spans(final_spans, rule_spans) + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[entity.as_dict() for entity in guarded]).model_dump( + mode="json" + ) + row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return row + + return apply_validation_and_finalize_with_additive_rule_guardrail + + +def _make_apply_validation_to_seed_entities_with_rule_guardrail( + labels: list[str], +) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_SEED_ENTITIES, COL_VALIDATED_ENTITIES], + side_effect_columns=[COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, COL_VALIDATED_SEED_ENTITIES], + ) + def apply_validation_to_seed_entities_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: + row = apply_validation_to_seed_entities(row) + text = str(row.get(COL_TEXT, "")) + validated_seed = _entity_spans_from_payload(row.get(COL_VALIDATED_SEED_ENTITIES, {})) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _merge_rule_guardrail_spans(validated_seed, rule_spans) + seed_entities = [entity.as_dict() for entity in guarded] + row[COL_VALIDATED_SEED_ENTITIES] = EntitiesSchema(entities=seed_entities).model_dump(mode="json") + row[COL_SEED_ENTITIES_JSON] = json.dumps(seed_entities) + row[COL_INITIAL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return row + + return apply_validation_to_seed_entities_with_rule_guardrail + + +def _make_apply_validation_to_seed_entities_with_additive_rule_guardrail( + labels: list[str], +) -> Callable[[dict[str, Any]], dict[str, Any]]: + @custom_column_generator( + required_columns=[COL_TEXT, COL_SEED_ENTITIES, COL_VALIDATED_ENTITIES], + side_effect_columns=[COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, COL_VALIDATED_SEED_ENTITIES], + ) + def apply_validation_to_seed_entities_with_additive_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: + row = apply_validation_to_seed_entities(row) + text = str(row.get(COL_TEXT, "")) + validated_seed = _entity_spans_from_payload(row.get(COL_VALIDATED_SEED_ENTITIES, {})) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _add_non_overlapping_rule_spans(validated_seed, rule_spans) + seed_entities = [entity.as_dict() for entity in guarded] + row[COL_VALIDATED_SEED_ENTITIES] = EntitiesSchema(entities=seed_entities).model_dump(mode="json") + row[COL_SEED_ENTITIES_JSON] = json.dumps(seed_entities) + row[COL_INITIAL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return row + + return apply_validation_to_seed_entities_with_additive_rule_guardrail + + +def _is_rule_covered_detector_span(entity: EntitySpan, spans: list[EntitySpan]) -> bool: + return any( + entity.label == span.label + and span.start_position <= entity.start_position + and span.end_position >= entity.end_position + for span in spans + ) + + +def _add_non_overlapping_rule_spans( + existing_spans: list[EntitySpan], + rule_spans: list[EntitySpan], +) -> list[EntitySpan]: + additions = [rule for rule in rule_spans if not any(_spans_overlap(rule, existing) for existing in existing_spans)] + return resolve_overlaps([*existing_spans, *additions]) + + +def _spans_overlap(left: EntitySpan, right: EntitySpan) -> bool: + return max(left.start_position, right.start_position) < min(left.end_position, right.end_position) + + +def _entity_spans_from_payload(raw_payload: object) -> list[EntitySpan]: + return [ + EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + for entity in EntitiesSchema.from_raw(raw_payload).entities + ] + + +def _make_native_single_pass_method( + native_client: DirectDetectionClient | None, + *, + recall_prompt: bool = False, + value_only_prompt: bool = False, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + client = native_client or HttpxDirectDetectionClient() + return _run_native_single_pass_detection( + dataframe, + labels=labels, + client=client, + data_summary=data_summary, + preview_num_records=preview_num_records, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + + return detect_and_validate_entities + + +def _run_native_single_pass_detection( + dataframe: pd.DataFrame, + *, + labels: list[str], + client: DirectDetectionClient, + data_summary: str | None, + preview_num_records: int | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> dw.EntityDetectionResult: + source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records: list[FailedRecord] = [] + + for index, row in source_df.iterrows(): + output_row, failed_record = _execute_native_single_pass_row( + row, + index=index, + labels=labels, + client=client, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + if failed_record is not None: + failed_records.append(failed_record) + continue + if output_row is None: + failed_records.append(_native_single_pass_failed_record(index, error="native single-pass produced no row")) + continue + output_rows.append(output_row) + output_indices.append(index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), + failed_records=failed_records, + ) + + +def _execute_native_single_pass_row( + row: pd.Series, + *, + index: object, + labels: list[str], + client: DirectDetectionClient, + data_summary: str | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> tuple[dict[str, Any] | None, FailedRecord | None]: + text = str(row.get(COL_TEXT, "")) + started = time.perf_counter() + try: + completion = _complete_native_single_pass( + text=text, + labels=labels, + client=client, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ) + except Exception as exc: # noqa: BLE001 - benchmark experiment records case-local failures + _record_native_single_pass_request_error(elapsed_sec=time.perf_counter() - started) + return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") + try: + spans = _native_single_pass_spans(text, completion.content, labels=labels) + except Exception as exc: # noqa: BLE001 - parser fragility is a benchmark failure mode + _record_native_single_pass_completion(completion, status="error", output_row_count=0, failed_record_count=1) + return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") + _record_native_single_pass_completion(completion, status="completed", output_row_count=1, failed_record_count=0) + return _native_single_pass_result_row(row, spans=spans, labels=labels), None + + +def _complete_native_single_pass( + *, + text: str, + labels: list[str], + client: DirectDetectionClient, + data_summary: str | None, + recall_prompt: bool, + value_only_prompt: bool, +) -> Any: + return client.complete( + DirectGenerationRequest( + endpoint=_NATIVE_DIRECT_ENDPOINT, + model=_NATIVE_DIRECT_MODEL_NAME, + prompt=_native_single_pass_prompt( + text=text, + labels=labels, + data_summary=data_summary, + recall_prompt=recall_prompt, + value_only_prompt=value_only_prompt, + ), + max_tokens=_NATIVE_DIRECT_MAX_TOKENS, + timeout_sec=_NATIVE_DIRECT_TIMEOUT_SEC, + ) + ) + + +def _native_single_pass_prompt( + *, + text: str, + labels: list[str], + data_summary: str | None, + recall_prompt: bool = False, + value_only_prompt: bool = False, +) -> str: + if value_only_prompt: + return build_direct_prompt( + DirectDetectionRequest( + case_id="native-single-pass-values", + text=text, + labels=labels, + prompt_mode=PromptMode.recall if recall_prompt else PromptMode.compact, + data_summary=data_summary, + ) + ) + + label_text = dw._format_label_examples(labels) if recall_prompt else ", ".join(labels) + recall_block = _native_single_pass_recall_block() if recall_prompt else "" + summary = f"\nData context: {data_summary}\n" if data_summary else "" + return f"""Extract privacy-sensitive entities from the input text in one pass. +{summary} +Use only these labels: +{label_text} + +Rules: +- Return exact substrings from the input text. +- Do not invent values. +- `start` and `end` must be zero-based Python character offsets, where text[start:end] equals `value`. +- Missing a sensitive value is worse than returning one extra plausible value. +- Skip generic nouns, syntax, and non-sensitive filler. +{recall_block}- Return only a JSON object with this shape: + {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "start": 0, "end": 0, "reason": "short reason"}}]}} + +Input text: +--- +{text} +---""" + + +def _native_single_pass_recall_block() -> str: + return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. +- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. +""" + + +def _native_single_pass_spans(text: str, content: str, *, labels: list[str]) -> list[EntitySpan]: + payload = _load_embedded_json(content) + if not isinstance(payload, dict): + raise ValueError("native single-pass response must be a JSON object") + entities = payload.get("entities") + if not isinstance(entities, list): + raise ValueError("native single-pass response must contain an entities list") + spans: list[EntitySpan] = [] + allowed = set(labels) + for item in entities: + if not isinstance(item, dict): + continue + value = str(item.get("value", "")).strip() + label = str(item.get("label", "")).strip() + if not value or label not in allowed: + continue + offset_span = _native_single_pass_offset_span(text, item, value=value, label=label) + if offset_span is not None: + spans.append(offset_span) + continue + spans.extend(_native_single_pass_value_spans(text, value=value, label=label)) + return resolve_overlaps(spans) + + +def _native_single_pass_offset_span( + text: str, + item: dict[str, Any], + *, + value: str, + label: str, +) -> EntitySpan | None: + start = _coerce_native_offset(item.get("start")) + end = _coerce_native_offset(item.get("end")) + if start is None or end is None or start < 0 or end <= start or end > len(text): + return None + if text[start:end] != value: + return None + return _native_single_pass_span(value=value, label=label, start=start, end=end) + + +def _coerce_native_offset(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, str) and value.strip().isdigit(): + return int(value.strip()) + return None + + +def _native_single_pass_value_spans(text: str, *, value: str, label: str) -> list[EntitySpan]: + spans: list[EntitySpan] = [] + start = 0 + while True: + match_start = text.find(value, start) + if match_start < 0: + return spans + match_end = match_start + len(value) + spans.append(_native_single_pass_span(value=value, label=label, start=match_start, end=match_end)) + start = match_end + + +def _native_single_pass_span(*, value: str, label: str, start: int, end: int) -> EntitySpan: + return EntitySpan( + entity_id=f"{label}_{start}_{end}", + value=value, + label=label, + start_position=start, + end_position=end, + score=1.0, + source="direct_single_pass", + ) + + +def _native_single_pass_result_row(row: pd.Series, *, spans: list[EntitySpan], labels: list[str]) -> dict[str, Any]: + text = str(row.get(COL_TEXT, "")) + rule_spans = detect_high_confidence_entities(text, labels=labels) + guarded = _add_non_overlapping_rule_spans(spans, rule_spans) + output_row = row.to_dict() + output_row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in guarded]).model_dump( + mode="json" + ) + output_row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + return output_row + + +def _native_single_pass_failed_record(index: object, *, error: str | None) -> FailedRecord: + return FailedRecord( + record_id=str(index), + step="entity-detection-native-single-pass", + reason=error or "native single-pass detection failed", + ) + + +def _record_native_single_pass_completion( + completion: Any, + *, + status: str, + output_row_count: int, + failed_record_count: int, +) -> None: + record_model_workflow( + workflow_name="entity-detection-native-single-pass", + model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + input_row_count=1, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=float(getattr(completion, "elapsed_sec", 0.0) or 0.0), + status=status, + model_usage=_native_single_pass_model_usage( + successful_requests=1, + failed_requests=0, + usage=dict(getattr(completion, "usage", {}) or {}), + ), + ) + + +def _record_native_single_pass_request_error(*, elapsed_sec: float) -> None: + record_model_workflow( + workflow_name="entity-detection-native-single-pass", + model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + input_row_count=1, + output_row_count=0, + failed_record_count=1, + elapsed_sec=elapsed_sec, + status="error", + model_usage=_native_single_pass_model_usage(successful_requests=0, failed_requests=1, usage={}), + ) + + +def _native_single_pass_model_usage( + *, + successful_requests: int, + failed_requests: int, + usage: dict[str, int], +) -> dict[str, dict[str, Any]]: + total_requests = successful_requests + failed_requests + return { + _NATIVE_DIRECT_MODEL_ALIAS: { + "model_alias": _NATIVE_DIRECT_MODEL_ALIAS, + "model_name": _NATIVE_DIRECT_MODEL_NAME, + "model_provider_name": _NATIVE_DIRECT_MODEL_PROVIDER, + "request_usage": { + "successful_requests": successful_requests, + "failed_requests": failed_requests, + "total_requests": total_requests, + }, + "token_usage": _native_token_usage(usage), + } + } + + +def _make_native_rules_router_method(native_client: DirectDetectionClient | None) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=None, + seed_source=SeedSource.rules_router, + workflow_name="entity-detection-native-rules-router", + skip_augmentation=False, + ) + + +def _make_native_candidate_validate_no_augment_method( + native_client: DirectDetectionClient | None, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=None, + seed_source=SeedSource.rules_plus_direct_llm, + workflow_name="entity-detection-native-candidate-validate-no-augment", + skip_augmentation=True, + ) + + +def _make_gliner_native_validate_no_augment_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + seed_source=SeedSource.gliner, + workflow_name="entity-detection-gliner-native-validate-no-augment", + skip_augmentation=True, + ) + + +def _make_gliner_native_validate_native_augment_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, +) -> _DetectAndValidate: + return _make_native_staged_method( + native_client=native_client, + gliner_seed_client=gliner_seed_client, + seed_source=SeedSource.gliner, + workflow_name="entity-detection-gliner-native-validate-native-augment", + skip_augmentation=False, + ) + + +def _make_detector_native_validate_no_augment_method( + native_client: DirectDetectionClient | None, +) -> _DetectAndValidate: + return _make_detector_native_validate_method( + native_client, + workflow_name="entity-detection-detector-native-validate-no-augment", + seed_workflow_name="entity-detection-detector-native-validate-no-augment-seed", + skip_augmentation=True, + ) + + +def _make_detector_native_validate_native_augment_method( + native_client: DirectDetectionClient | None, +) -> _DetectAndValidate: + return _make_detector_native_validate_method( + native_client, + workflow_name="entity-detection-detector-native-validate-native-augment", + seed_workflow_name="entity-detection-detector-native-validate-native-augment-seed", + skip_augmentation=False, + ) + + +def _make_detector_native_validate_method( + native_client: DirectDetectionClient | None, + *, + workflow_name: str, + seed_workflow_name: str, + skip_augmentation: bool, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + client = native_client or HttpxDirectDetectionClient() + return _run_detector_native_validate_detection( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + client=client, + data_summary=data_summary, + workflow_name=workflow_name, + seed_workflow_name=seed_workflow_name, + skip_augmentation=skip_augmentation, + ) + + return detect_and_validate_entities + + +def _run_detector_native_validate_detection( + workflow: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + labels: list[str], + gliner_detection_threshold: float, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + validation_single_chunk_full_text: bool, + client: DirectDetectionClient, + data_summary: str | None, + workflow_name: str, + seed_workflow_name: str, + skip_augmentation: bool, +) -> dw.EntityDetectionResult: + workflow_model_configs = workflow._inject_detector_params( + model_configs=model_configs, + selected_models=selected_models, + labels=labels, + gliner_detection_threshold=gliner_detection_threshold, + ) + seed_result = workflow._adapter.run_workflow( + dataframe, + model_configs=workflow_model_configs, + columns=_detector_native_validate_seed_columns(selected_models), + workflow_name=seed_workflow_name, + preview_num_records=preview_num_records, + ) + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records = list(seed_result.failed_records) + validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) + tasks = [ + _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) + for ordinal, (index, row) in enumerate(seed_result.dataframe.iterrows()) + ] + params = _DetectorNativeValidationParams( + labels=labels, + client=client, + data_summary=data_summary, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + + for result in _execute_detector_native_validate_tasks(tasks, params=params): + _record_detector_native_validation_result(result) + if result.failed_record is not None: + failed_records.append(result.failed_record) + continue + if result.output_row is None: + failed_records.append( + _native_failed_record( + result.index, + workflow_name=workflow_name, + error="native detector-seed validation produced no row", + ) + ) + continue + output_rows.append(result.output_row) + output_indices.append(result.index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe( + seed_result.dataframe, output_rows=output_rows, output_indices=output_indices + ), + failed_records=failed_records, + ) + + +def _execute_detector_native_validate_tasks( + tasks: list[_NativeStagedTask], + *, + params: _DetectorNativeValidationParams, +) -> list[_DetectorNativeValidationRowResult]: + if not tasks: + return [] + worker_count = _native_staged_worker_count(len(tasks)) + if worker_count == 1: + return [_execute_detector_native_validate_task(task, params=params) for task in tasks] + with ThreadPoolExecutor(max_workers=worker_count) as executor: + return list( + executor.map( + lambda task: _execute_detector_native_validate_task(task, params=params), + tasks, + ) + ) + + +def _execute_detector_native_validate_task( + task: _NativeStagedTask, + *, + params: _DetectorNativeValidationParams, +) -> _DetectorNativeValidationRowResult: + return _execute_detector_native_validate_row( + task.row, + index=task.index, + ordinal=task.ordinal, + labels=params.labels, + client=params.client, + data_summary=params.data_summary, + validation_prompt_mode=params.validation_prompt_mode, + validation_max_entities_per_call=params.validation_max_entities_per_call, + validation_excerpt_window_chars=params.validation_excerpt_window_chars, + workflow_name=params.workflow_name, + skip_augmentation=params.skip_augmentation, + ) + + +def _detector_native_validate_seed_columns( + selected_models: DetectionModelSelection, +) -> list[LLMTextColumnConfig | CustomColumnConfig]: + return [ + LLMTextColumnConfig( + name=COL_RAW_DETECTED, + prompt=_jinja(COL_TEXT), + model_alias=_detector_alias(selected_models), + ), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), + CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), + ] + + +def _execute_detector_native_validate_row( + row: pd.Series, + *, + index: object, + ordinal: int, + labels: list[str], + client: DirectDetectionClient, + data_summary: str | None, + validation_prompt_mode: ValidationPromptMode, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + workflow_name: str, + skip_augmentation: bool, +) -> _DetectorNativeValidationRowResult: + output_row = row.to_dict() + request = _native_staged_request(row, index=index, ordinal=ordinal, labels=labels, data_summary=data_summary) + config = StagedExecutionConfig( + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + skip_augmentation=skip_augmentation, + ) + request_count = _native_validation_request_count( + output_row, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + if not skip_augmentation: + request_count += 1 + started = time.perf_counter() + try: + validation_completion = _run_validation_phase(output_row, request, client, config) + augmentation_completion = _run_augmentation_phase(output_row, request, client, config) + completion = _combine_detector_native_completions([validation_completion, augmentation_completion]) + merge_and_build_candidates(output_row) + apply_validation_and_finalize(output_row) + normalize_birth_context_final_entities(output_row, allowed_labels=labels) + except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures + return _DetectorNativeValidationRowResult( + ordinal=ordinal, + index=index, + workflow_name=workflow_name, + output_row=None, + failed_record=_native_failed_record( + index, + workflow_name=workflow_name, + error=f"{type(exc).__name__}: {exc}", + ), + completion=None, + elapsed_sec=time.perf_counter() - started, + request_count=request_count, + ) + return _DetectorNativeValidationRowResult( + ordinal=ordinal, + index=index, + workflow_name=workflow_name, + output_row=output_row, + failed_record=None, + completion=completion, + elapsed_sec=time.perf_counter() - started, + request_count=request_count, + ) + + +def _native_validation_request_count( + row: dict[str, Any], + *, + validation_prompt_mode: ValidationPromptMode, + validation_max_entities_per_call: int, +) -> int: + candidate_count = len(ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})).candidates) + if candidate_count == 0: + return 0 + if validation_prompt_mode == ValidationPromptMode.full_text: + return 1 + return (candidate_count + validation_max_entities_per_call - 1) // validation_max_entities_per_call + + +def _combine_detector_native_completions(completions: list[DirectCompletion]) -> DirectCompletion: + return DirectCompletion( + content=completions[-1].content if completions else "", + elapsed_sec=sum(float(completion.elapsed_sec or 0.0) for completion in completions), + usage=_sum_usage_dicts([dict(completion.usage or {}) for completion in completions]), + ) + + +def _record_detector_native_validation_completion( + completion: DirectCompletion, + *, + request_count: int, + workflow_name: str, +) -> None: + record_model_workflow( + workflow_name=workflow_name, + model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + input_row_count=1, + output_row_count=1, + failed_record_count=0, + elapsed_sec=float(completion.elapsed_sec or 0.0), + model_usage=_native_single_pass_model_usage( + successful_requests=request_count, + failed_requests=0, + usage=dict(completion.usage or {}), + ), + ) + + +def _record_detector_native_validation_result(result: _DetectorNativeValidationRowResult) -> None: + if result.completion is not None: + _record_detector_native_validation_completion( + result.completion, + request_count=result.request_count, + workflow_name=result.workflow_name, + ) + return + if result.failed_record is not None: + _record_detector_native_validation_error( + elapsed_sec=result.elapsed_sec, + request_count=result.request_count, + workflow_name=result.workflow_name, + ) + + +def _record_detector_native_validation_error(*, elapsed_sec: float, request_count: int, workflow_name: str) -> None: + record_model_workflow( + workflow_name=workflow_name, + model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + input_row_count=1, + output_row_count=0, + failed_record_count=1, + elapsed_sec=elapsed_sec, + status="error", + model_usage=_native_single_pass_model_usage( + successful_requests=0, + failed_requests=max(request_count, 1), + usage={}, + ), + ) + + +def _make_native_staged_method( + *, + native_client: DirectDetectionClient | None, + gliner_seed_client: GlinerSeedClient | None, + seed_source: SeedSource, + workflow_name: str, + skip_augmentation: bool, +) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + _ = self, model_configs, selected_models, gliner_detection_threshold + labels = dw._resolve_detection_labels(entity_labels) + client = native_client or HttpxDirectDetectionClient() + return _run_native_staged_detection( + dataframe, + labels=labels, + client=client, + gliner_seed_client=gliner_seed_client, + data_summary=data_summary, + preview_num_records=preview_num_records, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + seed_source=seed_source, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + + return detect_and_validate_entities + + +def _run_native_staged_detection( + dataframe: pd.DataFrame, + *, + labels: list[str], + client: DirectDetectionClient, + gliner_seed_client: GlinerSeedClient | None, + data_summary: str | None, + preview_num_records: int | None, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + validation_single_chunk_full_text: bool, + seed_source: SeedSource, + workflow_name: str, + skip_augmentation: bool, +) -> dw.EntityDetectionResult: + source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() + output_rows: list[dict[str, Any]] = [] + output_indices: list[Any] = [] + failed_records: list[FailedRecord] = [] + validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) + params = _NativeStagedRunParams( + labels=labels, + client=client, + gliner_seed_client=gliner_seed_client, + data_summary=data_summary, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + seed_source=seed_source, + workflow_name=workflow_name, + skip_augmentation=skip_augmentation, + ) + tasks = [ + _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) + for ordinal, (index, row) in enumerate(source_df.iterrows()) + ] + + for result in _execute_native_staged_tasks(tasks, params=params): + if result.failed_record is not None: + failed_records.append(result.failed_record) + continue + if result.output_row is None: + failed_records.append( + _native_failed_record( + result.index, + workflow_name=workflow_name, + error="native staged detection produced no row", + ) + ) + continue + if result.case is not None: + _record_native_direct_usage(result.case, workflow_name=workflow_name) + output_rows.append(result.output_row) + output_indices.append(result.index) + + return dw.EntityDetectionResult( + dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), + failed_records=failed_records, + ) + + +def _execute_native_staged_tasks( + tasks: list[_NativeStagedTask], + *, + params: _NativeStagedRunParams, +) -> list[_NativeStagedRowResult]: + if not tasks: + return [] + worker_count = _native_staged_worker_count(len(tasks)) + if worker_count == 1: + return [_execute_native_staged_task(task, params=params) for task in tasks] + with ThreadPoolExecutor(max_workers=worker_count) as executor: + return list(executor.map(lambda task: _execute_native_staged_task(task, params=params), tasks)) + + +def _native_staged_worker_count(task_count: int) -> int: + return max(1, min(task_count, _NATIVE_STAGED_MAX_WORKERS)) + + +def _execute_native_staged_task( + task: _NativeStagedTask, + *, + params: _NativeStagedRunParams, +) -> _NativeStagedRowResult: + try: + request = _native_staged_request( + task.row, + index=task.index, + ordinal=task.ordinal, + labels=params.labels, + data_summary=params.data_summary, + ) + execution = execute_staged_detection_case( + request, + client=params.client, + seed_client=params.gliner_seed_client, + seed_source=params.seed_source, + skip_augmentation=params.skip_augmentation, + validation_prompt_mode=params.validation_prompt_mode, + validation_max_entities_per_call=params.validation_max_entities_per_call, + validation_excerpt_window_chars=params.validation_excerpt_window_chars, + ) + if execution.case.status != StagedCaseStatus.completed: + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=None, + failed_record=_native_failed_record( + task.index, + workflow_name=params.workflow_name, + error=execution.case.error, + ), + case=execution.case, + ) + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=_native_detection_result_row(task.row, execution_row=execution.row), + failed_record=None, + case=execution.case, + ) + except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures + return _NativeStagedRowResult( + ordinal=task.ordinal, + index=task.index, + output_row=None, + failed_record=_native_failed_record( + task.index, + workflow_name=params.workflow_name, + error=f"{type(exc).__name__}: {exc}", + ), + case=None, + ) + + +def _native_validation_prompt_mode(validation_single_chunk_full_text: bool) -> ValidationPromptMode: + return ValidationPromptMode.full_text if validation_single_chunk_full_text else ValidationPromptMode.chunked_excerpt + + +def _native_failed_record(index: object, *, workflow_name: str, error: str | None) -> FailedRecord: + return FailedRecord( + record_id=str(index), + step=workflow_name, + reason=error or "native staged detection failed", + ) + + +def _record_native_direct_usage(case: StagedDetectionCase, *, workflow_name: str) -> None: + model_usage = _native_staged_model_usage(case) + if not model_usage: + return + record_model_workflow( + workflow_name=workflow_name, + model_aliases=sorted(model_usage), + input_row_count=1, + output_row_count=1 if case.status == StagedCaseStatus.completed else 0, + failed_record_count=0 if case.status == StagedCaseStatus.completed else 1, + elapsed_sec=case.model_elapsed_sec or case.elapsed_sec or 0.0, + model_usage=model_usage, + ) + + +def _native_staged_model_usage(case: StagedDetectionCase) -> dict[str, dict[str, Any]]: + usage: dict[str, dict[str, Any]] = {} + native_requests = 0 + native_usage: list[dict[str, int]] = [] + + if case.phase_model_work.seed: + if case.seed_source == SeedSource.gliner: + usage[_GLINER_DIRECT_MODEL_ALIAS] = _direct_model_usage_entry( + alias=_GLINER_DIRECT_MODEL_ALIAS, + model_name=_GLINER_DIRECT_MODEL_NAME, + provider_name=_GLINER_DIRECT_MODEL_PROVIDER, + successful_requests=case.phase_model_requests.seed, + usage=case.phase_usage.seed, + ) + else: + native_requests += case.phase_model_requests.seed + native_usage.append(case.phase_usage.seed) + if case.phase_model_work.validation: + native_requests += case.phase_model_requests.validation + native_usage.append(case.phase_usage.validation) + if case.phase_model_work.augmentation: + native_requests += case.phase_model_requests.augmentation + native_usage.append(case.phase_usage.augmentation) + if native_requests: + usage[_NATIVE_DIRECT_MODEL_ALIAS] = _direct_model_usage_entry( + alias=_NATIVE_DIRECT_MODEL_ALIAS, + model_name=_NATIVE_DIRECT_MODEL_NAME, + provider_name=_NATIVE_DIRECT_MODEL_PROVIDER, + successful_requests=native_requests, + usage=_sum_usage_dicts(native_usage), + ) + return usage + + +def _direct_model_usage_entry( + *, + alias: str, + model_name: str, + provider_name: str, + successful_requests: int, + usage: dict[str, int], +) -> dict[str, Any]: + return { + "model_alias": alias, + "model_name": model_name, + "model_provider_name": provider_name, + "request_usage": { + "successful_requests": successful_requests, + "failed_requests": 0, + "total_requests": successful_requests, + }, + "token_usage": _native_token_usage(usage), + } + + +def _sum_usage_dicts(usages: list[dict[str, int]]) -> dict[str, int]: + totals: Counter[str] = Counter() + for usage in usages: + for key, value in usage.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def _native_token_usage(usage: dict[str, int]) -> dict[str, int]: + input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0)) + output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0)) + total_tokens = usage.get("total_tokens", input_tokens + output_tokens) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + + +def _native_staged_request( + row: pd.Series, + *, + index: object, + ordinal: int, + labels: list[str], + data_summary: str | None, +) -> StagedDetectionRequest: + return StagedDetectionRequest( + case_id=f"native-rules-router-{ordinal}", + text=str(row.get(COL_TEXT, "")), + labels=labels, + row_index=_safe_row_index(index, fallback=ordinal), + data_summary=data_summary, + ) + + +def _native_detection_result_row(row: pd.Series, *, execution_row: dict[str, Any]) -> dict[str, Any]: + output_row = row.to_dict() + output_row[COL_DETECTED_ENTITIES] = execution_row.get( + COL_DETECTED_ENTITIES, + EntitiesSchema().model_dump(mode="json"), + ) + output_row[COL_TAGGED_TEXT] = execution_row.get(COL_TAGGED_TEXT, str(row.get(COL_TEXT, ""))) + return output_row + + +def _safe_row_index(index: object, *, fallback: int) -> int: + try: + return int(index) # type: ignore[arg-type] + except (TypeError, ValueError): + return fallback + + +def _native_output_dataframe( + source_df: pd.DataFrame, + *, + output_rows: list[dict[str, Any]], + output_indices: list[Any], +) -> pd.DataFrame: + if output_rows: + return pd.DataFrame(output_rows, index=output_indices) + output = source_df.iloc[0:0].copy() + output[COL_DETECTED_ENTITIES] = pd.Series(dtype="object") + output[COL_TAGGED_TEXT] = pd.Series(dtype="object") + return output + + +def _detect_with_rules_only( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, +) -> dw.EntityDetectionResult: + return self.detect_with_high_confidence_rules(dataframe, entity_labels=entity_labels) + + +def _make_rules_covered_or_default_method(original: _DetectAndValidate) -> _DetectAndValidate: + def detect_and_validate_entities( + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, + validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, + entity_labels: list[str] | None = None, + data_summary: str | None = None, + preview_num_records: int | None = None, + ) -> dw.EntityDetectionResult: + labels = dw._resolve_detection_labels(entity_labels) + if _labels_are_rules_only(labels): + return _detect_rules_covered_rows_or_default( + original, + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + labels=labels, + ) + return original( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + + return detect_and_validate_entities + + +def _detect_rules_covered_rows_or_default( + original: _DetectAndValidate, + self: dw.EntityDetectionWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: DetectionModelSelection, + gliner_detection_threshold: float, + validation_max_entities_per_call: int, + validation_excerpt_window_chars: int, + validation_single_chunk_full_text: bool, + entity_labels: list[str] | None, + data_summary: str | None, + preview_num_records: int | None, + labels: list[str], +) -> dw.EntityDetectionResult: + started = time.perf_counter() + if dataframe.empty: + result = _detect_with_rules_only( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + _record_rules_covered_route( + started=started, + total_row_count=0, + rule_row_count=0, + fallback_row_count=0, + result=result, + ) + return result + + coverage_mask = dataframe[COL_TEXT].apply( + lambda text: _structured_assignments_are_rule_covered(str(text), labels=labels) + ) + if bool(coverage_mask.all()): + result = _detect_with_rules_only( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + _record_rules_covered_route( + started=started, + total_row_count=len(dataframe), + rule_row_count=len(dataframe), + fallback_row_count=0, + result=result, + ) + return result + if not bool(coverage_mask.any()): + result = original( + self, + dataframe, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + _record_rules_covered_route( + started=started, + total_row_count=len(dataframe), + rule_row_count=0, + fallback_row_count=len(dataframe), + result=result, + ) + return result + + rule_rows, default_rows = split_rows( + dataframe, + column=COL_TEXT, + predicate=lambda text: _structured_assignments_are_rule_covered(str(text), labels=labels), + ) + + rule_result = _detect_with_rules_only( + self, + rule_rows, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + default_result = original( + self, + default_rows, + model_configs=model_configs, + selected_models=selected_models, + gliner_detection_threshold=gliner_detection_threshold, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, + entity_labels=entity_labels, + data_summary=data_summary, + preview_num_records=preview_num_records, + ) + result = dw.EntityDetectionResult( + dataframe=merge_and_reorder(rule_result.dataframe, default_result.dataframe), + failed_records=[*rule_result.failed_records, *default_result.failed_records], + ) + _record_rules_covered_route( + started=started, + total_row_count=len(dataframe), + rule_row_count=len(rule_rows), + fallback_row_count=len(default_rows), + result=result, + ) + return result + + +def _record_rules_covered_route( + *, + started: float, + total_row_count: int, + rule_row_count: int, + fallback_row_count: int, + result: dw.EntityDetectionResult, +) -> None: + record_model_workflow( + workflow_name="entity-detection-rules-covered-router", + model_aliases=[], + input_row_count=total_row_count, + output_row_count=len(result.dataframe), + failed_record_count=len(result.failed_records), + elapsed_sec=time.perf_counter() - started, + status="completed" if not result.failed_records else "partial", + extra_fields={ + "route_total_row_count": total_row_count, + "route_rule_row_count": rule_row_count, + "route_fallback_row_count": fallback_row_count, + }, + ) + + +def _structured_assignments_are_rule_covered(text: str, *, labels: list[str]) -> bool: + allowed_labels = set(labels) + rule_spans = detect_high_confidence_entities(text, labels=labels) + covered_ranges = [(span.start_position, span.end_position) for span in rule_spans] + for match in _STRUCTURED_ASSIGNMENT_RE.finditer(text): + label = _structured_assignment_label(match.group("key")) + if label not in allowed_labels: + continue + start, end = _structured_assignment_value_span(match) + if not _range_overlaps_any(start, end, covered_ranges): + return False + return True + + +def _structured_assignment_value_span(match: re.Match[str]) -> tuple[int, int]: + if match.group("quoted") is not None: + return match.span("quoted") + return match.span("bare") + + +def _range_overlaps_any(start: int, end: int, ranges: list[tuple[int, int]]) -> bool: + return any(start < range_end and end > range_start for range_start, range_end in ranges) + + +def _structured_assignment_label(key: str) -> str: + normalized = key.lower().replace("-", "_") + if normalized in {"api_key", "aws_access_key_id", "access_key_id", "hf_token", "token", "auth_token", "session_id"}: + return "api_key" + if normalized == "authorization": + return "api_key" + if normalized in {"password", "pass", "secret", "aws_secret_access_key", "django_secret"}: + return "password" + if normalized == "database_url": + return "url" + if normalized == "pin": + return "pin" + if normalized in {"user", "username", "user_name", "login", "account"}: + return "user_name" + if normalized == "cookie": + return "http_cookie" + if normalized in {"trace_id", "request_id", "req_id", "order_id", "tenant_id", "unique_id"}: + return "unique_id" + if normalized in {"url", "uri", "endpoint", "callback"}: + return "url" + if normalized == "email": + return "email" + return "" + + +def _labels_are_rules_only(labels: list[str]) -> bool: + return dw.labels_are_supported_by_structured_rule_fast_lane(labels) diff --git a/tools/measurement/direct_detection_probe.py b/tools/measurement/direct_detection_probe.py new file mode 100644 index 00000000..aefb09e5 --- /dev/null +++ b/tools/measurement/direct_detection_probe.py @@ -0,0 +1,559 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run benchmark-only DD-free direct entity extraction probes. + +Usage: + uv run python tools/measurement/direct_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography --labels age,city,first_name,last_name,occupation \ + --output /tmp/direct-probe --overwrite +""" + +from __future__ import annotations + +import json +import logging +import shutil +import sys +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any, Protocol + +import cyclopts +import httpx +import pandas as pd +from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities +from dd_parser_compat import _load_embedded_json +from pydantic import BaseModel, Field, ValidationError, model_validator + +from anonymizer.engine.detection.detection_workflow import _format_label_examples +from anonymizer.engine.detection.postprocess import EntitySpan, apply_augmented_entities, expand_entity_occurrences +from anonymizer.engine.schemas import EntitySchema + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.direct_detection_probe") + + +class CaseStatus(StrEnum): + completed = "completed" + error = "error" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class PromptMode(StrEnum): + compact = "compact" + recall = "recall" + + +class DirectCompletion(BaseModel): + content: str + elapsed_sec: float + usage: dict[str, Any] = Field(default_factory=dict) + + +class DirectDetectionRequest(BaseModel): + case_id: str + text: str + labels: list[str] = Field(min_length=1) + row_index: int = 0 + prompt_mode: PromptMode = PromptMode.compact + data_summary: str | None = None + + @model_validator(mode="after") + def normalize_labels(self) -> DirectDetectionRequest: + self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) + if not self.labels: + raise ValueError("labels must contain at least one non-empty label") + return self + + +class DirectGenerationRequest(BaseModel): + endpoint: str + model: str + prompt: str + max_tokens: int = Field(gt=0) + temperature: float = 0.0 + top_p: float = 1.0 + timeout_sec: float = Field(gt=0) + json_mode: bool = True + disable_thinking: bool = True + + +class SignatureComparison(BaseModel): + baseline_final_entity_signature_count: int + shared_final_entity_signature_count: int + baseline_only_final_entity_signature_count: int + direct_only_final_entity_signature_count: int + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + direct_only_label_counts: dict[str, int] = Field(default_factory=dict) + + +class DirectDetectionCase(BaseModel): + case_id: str + row_index: int + status: CaseStatus + elapsed_sec: float | None = None + usage: dict[str, Any] = Field(default_factory=dict) + raw_suggestion_count: int = 0 + allowed_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_label_counts: dict[str, int] = Field(default_factory=dict) + comparison: SignatureComparison | None = None + artifact: DetectionArtifactRow | None = None + error: str | None = None + + +class DirectDetectionRun(BaseModel): + input_path: str + text_column: str + endpoint: str + model: str + prompt_mode: PromptMode + labels: list[str] + rows: list[DirectDetectionCase] = Field(default_factory=list) + + @property + def error_count(self) -> int: + return sum(1 for row in self.rows if row.status == CaseStatus.error) + + +class DirectDetectionClient(Protocol): + def complete(self, request: DirectGenerationRequest) -> DirectCompletion: ... + + +class HttpxDirectDetectionClient: + def complete(self, request: DirectGenerationRequest) -> DirectCompletion: + payload = build_chat_payload(request) + response = httpx.post( + f"{request.endpoint.rstrip('/')}/chat/completions", + json=payload, + timeout=request.timeout_sec, + ) + response.raise_for_status() + data = response.json() + return DirectCompletion( + content=str(data["choices"][0]["message"].get("content") or ""), + elapsed_sec=float(response.elapsed.total_seconds()), + usage=data.get("usage") or {}, + ) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + return + logger.error("bad_input error=%s", error) + + +def build_chat_payload(request: DirectGenerationRequest) -> dict[str, Any]: + payload: dict[str, Any] = { + "model": request.model, + "messages": [ + {"role": "system", "content": "You are a precise entity extraction engine. Return JSON only."}, + {"role": "user", "content": request.prompt}, + ], + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": request.max_tokens, + } + if request.json_mode: + payload["response_format"] = {"type": "json_object"} + if request.disable_thinking: + payload["chat_template_kwargs"] = {"enable_thinking": False} + return payload + + +def build_direct_prompt(request: DirectDetectionRequest) -> str: + label_text = ( + _format_label_examples(request.labels) + if request.prompt_mode == PromptMode.recall + else ", ".join(request.labels) + ) + recall_block = _recall_block() if request.prompt_mode == PromptMode.recall else "" + summary = f"\nData context: {request.data_summary}\n" if request.data_summary else "" + return f"""Extract privacy-sensitive entities from the input text. +{summary} +Use only these labels: +{label_text} + +Rules: +- Return exact substrings from the input text. +- Do not invent values. +- Prefer specific labels from the allowed list. +- Skip generic nouns, syntax, and non-sensitive filler. +{recall_block}- Return only a JSON object with this shape: + {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} + +Input text: +--- +{request.text} +---""" + + +def _recall_block() -> str: + return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. +- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. +""" + + +def run_direct_detection_case( + request: DirectDetectionRequest, + *, + client: DirectDetectionClient, + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", + model: str = "nvidia/nemotron-3-super", + max_tokens: int = 4096, + timeout_sec: float = 180.0, +) -> DirectDetectionCase: + try: + completion = client.complete( + DirectGenerationRequest( + endpoint=endpoint, + model=model, + prompt=build_direct_prompt(request), + max_tokens=max_tokens, + timeout_sec=timeout_sec, + ) + ) + suggestions = _extract_suggestions(completion.content) + allowed_suggestions = filter_direct_suggestions(suggestions, request.labels) + artifact = finalize_direct_suggestions( + text=request.text, + suggestions=allowed_suggestions, + labels=request.labels, + row_index=request.row_index, + workflow_name="direct-detection", + ) + return _completed_case(request, completion, suggestions, allowed_suggestions, artifact) + except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures + return DirectDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + status=CaseStatus.error, + error=f"{type(exc).__name__}: {exc}", + ) + + +def _extract_suggestions(content: str) -> list[dict[str, Any]]: + payload = _load_embedded_json(content) + if not isinstance(payload, dict): + return [] + suggestions = payload.get("entities") + return suggestions if isinstance(suggestions, list) else [] + + +def finalize_direct_suggestions( + *, + text: str, + suggestions: list[dict[str, Any]], + labels: list[str], + row_index: int, + workflow_name: str, +) -> DetectionArtifactRow: + cleaned = filter_direct_suggestions(suggestions, labels) + direct_spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) + entities = [ + _span_to_entity_schema(span, source="direct_llm") for span in expand_entity_occurrences(text, direct_spans) + ] + return build_detection_artifact_row_from_entities( + workflow_name=workflow_name, + batch_file="direct-detection", + row_index=row_index, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=entities, + final_entities=entities, + ) + + +def filter_direct_suggestions(suggestions: list[dict[str, Any]], labels: list[str]) -> list[dict[str, str]]: + allowed = set(labels) + cleaned = [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in suggestions + if isinstance(item, dict) + ] + return [item for item in cleaned if item["value"] and item["label"] in allowed] + + +def _span_to_entity_schema(span: EntitySpan, *, source: str) -> EntitySchema: + span_source = source if span.source == "augmenter" else span.source + return EntitySchema( + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=span_source, + ) + + +def _completed_case( + request: DirectDetectionRequest, + completion: DirectCompletion, + suggestions: list[dict[str, Any]], + allowed_suggestions: list[dict[str, str]], + artifact: DetectionArtifactRow, +) -> DirectDetectionCase: + return DirectDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + status=CaseStatus.completed, + elapsed_sec=completion.elapsed_sec, + usage=completion.usage, + raw_suggestion_count=len(suggestions), + allowed_suggestion_count=len(allowed_suggestions), + final_entity_count=artifact.final_entity_count, + final_entity_signature_count=artifact.final_entity_signature_count, + final_entity_signature_hashes=artifact.final_entity_signature_hashes, + final_label_counts=artifact.final_label_counts, + artifact=artifact, + ) + + +def compare_signature_sets( + *, + baseline_hashes: set[str], + baseline_labels: dict[str, str], + direct_hashes: set[str], + direct_labels: dict[str, str], +) -> SignatureComparison: + baseline_only = baseline_hashes - direct_hashes + direct_only = direct_hashes - baseline_hashes + return SignatureComparison( + baseline_final_entity_signature_count=len(baseline_hashes), + shared_final_entity_signature_count=len(baseline_hashes & direct_hashes), + baseline_only_final_entity_signature_count=len(baseline_only), + direct_only_final_entity_signature_count=len(direct_only), + baseline_only_label_counts=_label_counts(baseline_only, baseline_labels), + direct_only_label_counts=_label_counts(direct_only, direct_labels), + ) + + +def _label_counts(hashes: set[str], labels: dict[str, str]) -> dict[str, int]: + return dict(sorted(Counter(labels.get(item, "unknown") for item in hashes).items())) + + +def apply_baseline_comparisons( + cases: list[DirectDetectionCase], + baseline_artifacts: Path, +) -> list[DirectDetectionCase]: + baseline = _read_baseline_artifacts(baseline_artifacts) + compared: list[DirectDetectionCase] = [] + for case in cases: + if case.status != CaseStatus.completed or case.artifact is None: + compared.append(case) + continue + baseline_row = baseline.get(case.row_index) + if baseline_row is None: + compared.append(case) + continue + compared.append(_case_with_comparison(case, baseline_row)) + return compared + + +def _case_with_comparison(case: DirectDetectionCase, baseline_row: dict[str, Any]) -> DirectDetectionCase: + baseline_hashes = _baseline_signature_hashes(baseline_row) + if baseline_hashes is None: + return case + comparison = compare_signature_sets( + baseline_hashes=baseline_hashes, + baseline_labels=_signature_labels(baseline_row), + direct_hashes=set(case.artifact.final_entity_signature_hashes if case.artifact else []), + direct_labels=case.artifact.final_entity_signature_labels if case.artifact else {}, + ) + return case.model_copy(update={"comparison": comparison}) + + +def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: + hashes = row.get("final_entity_signature_hashes") + if not isinstance(hashes, list): + return None + return {str(item) for item in hashes} + + +def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: + baseline: dict[int, dict[str, Any]] = {} + with path.open(encoding="utf-8") as source: + for line in source: + if not line.strip(): + continue + row = json.loads(line) + row_index = int(row.get("row_index", 0)) + if row_index in baseline: + raise ValueError( + f"baseline artifacts has multiple rows for row_index={row_index}; " + "pass a per-case sidecar or pre-filter the artifact file" + ) + baseline[row_index] = row + return baseline + + +def _signature_labels(row: dict[str, Any]) -> dict[str, str]: + return { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") and value is not None + } + + +def run_probe( + input_path: Path, + *, + text_column: str, + labels: list[str], + output: Path | None = None, + overwrite: bool = False, + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", + model: str = "nvidia/nemotron-3-super", + limit: int = 1, + offset: int = 0, + prompt_mode: PromptMode = PromptMode.compact, + baseline_artifacts: Path | None = None, +) -> DirectDetectionRun: + requests = _load_requests( + input_path, text_column=text_column, labels=labels, limit=limit, offset=offset, prompt_mode=prompt_mode + ) + client = HttpxDirectDetectionClient() + cases = [run_direct_detection_case(request, client=client, endpoint=endpoint, model=model) for request in requests] + if baseline_artifacts is not None: + cases = apply_baseline_comparisons(cases, baseline_artifacts) + result = DirectDetectionRun( + input_path=str(input_path), + text_column=text_column, + endpoint=endpoint, + model=model, + prompt_mode=prompt_mode, + labels=labels, + rows=cases, + ) + if output is not None: + write_outputs(result, output, overwrite=overwrite) + return result + + +def _load_requests( + input_path: Path, + *, + text_column: str, + labels: list[str], + limit: int, + offset: int, + prompt_mode: PromptMode, +) -> list[DirectDetectionRequest]: + dataframe = pd.read_csv(input_path) + if text_column not in dataframe.columns: + raise ValueError(f"text column {text_column!r} not found in {input_path}") + selected = dataframe.iloc[offset : offset + limit] + return [ + DirectDetectionRequest( + case_id=f"{input_path.stem}-row-{int(index)}", + text=str(row[text_column]), + labels=labels, + row_index=int(index), + prompt_mode=prompt_mode, + ) + for index, row in selected.iterrows() + ] + + +def write_outputs(result: DirectDetectionRun, output_dir: Path, *, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise ValueError(f"output directory already exists: {output_dir}") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True) + _write_jsonl(output_dir / "direct-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) + _write_jsonl(output_dir / "direct-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) + (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def _case_payload(case: DirectDetectionCase) -> dict[str, Any]: + payload = case.model_dump(exclude={"artifact"}) + payload["record_type"] = "direct_detection_case" + return payload + + +def _artifact_payload(case: DirectDetectionCase) -> dict[str, Any]: + payload = case.artifact.model_dump() if case.artifact is not None else {} + payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "direct_detection_artifact"}) + return payload + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as target: + for row in rows: + target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") + + +def render_result(result: DirectDetectionRun, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = len(result.rows) - result.error_count + return f"Ran {completed}/{len(result.rows)} direct detection case(s); errors={result.error_count}" + + +def parse_labels(raw: str) -> list[str]: + return [label.strip() for label in raw.split(",") if label.strip()] + + +@app.default +def main( + input_path: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")], + labels: Annotated[str, cyclopts.Parameter("--labels")], + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + endpoint: Annotated[str, cyclopts.Parameter("--endpoint")] = "http://gpu-dev-pod-serve-svc:8000/v1", + model: Annotated[str, cyclopts.Parameter("--model")] = "nvidia/nemotron-3-super", + limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, + offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, + prompt_mode: Annotated[PromptMode, cyclopts.Parameter("--prompt-mode")] = PromptMode.compact, + baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = run_probe( + input_path, + text_column=text_column, + labels=parse_labels(labels), + output=output, + overwrite=overwrite, + endpoint=endpoint, + model=model, + limit=limit, + offset=offset, + prompt_mode=prompt_mode, + baseline_artifacts=baseline_artifacts, + ) + except (ValueError, ValidationError, httpx.HTTPError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if result.error_count: + raise SystemExit(1) + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py new file mode 100644 index 00000000..7a4bbaa2 --- /dev/null +++ b/tools/measurement/extract_signature_deltas.py @@ -0,0 +1,556 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Extract masked context for entity signature differences between benchmark artifacts. + +Usage: + uv run python tools/measurement/extract_signature_deltas.py \ + baseline/detection-artifacts.jsonl candidate/detection-artifacts.jsonl \ + --baseline-artifact-root baseline/artifacts --candidate-artifact-root candidate/artifacts \ + --baseline-config legal-default --candidate-config legal-rules-guardrail \ + --workload legal-r2 --output deltas.csv --format csv +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from analyze_detection_artifacts import _entity_signature_hash +from pydantic import BaseModel, Field, ValidationError + +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT +from anonymizer.engine.detection.rules import detect_high_confidence_entities +from anonymizer.engine.schemas import EntitiesSchema, EntitySchema + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.signature_deltas") + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class DeltaSide(StrEnum): + baseline_only = "baseline_only" + candidate_only = "candidate_only" + + +class ContextResolution(StrEnum): + parquet = "parquet" + artifact_details = "artifact_details" + rule = "rule" + metadata_only = "metadata_only" + + +_log_format = LogFormat.plain + + +class SignatureDeltaRow(BaseModel): + workload_id: str + row_index: int + side: DeltaSide + config_id: str + signature_hash: str + label: str | None = None + source: str | None = None + start_position: int | None = None + end_position: int | None = None + value_hash: str | None = None + value_length: int | None = None + masked_context: str | None = None + resolution: ContextResolution = ContextResolution.metadata_only + batch_file: str | None = None + + +class SignatureDeltaResult(BaseModel): + baseline_artifacts: str + candidate_artifacts: str + workload_id: str | None = None + baseline_config: str | None = None + candidate_config: str | None = None + delta_count: int + rows: list[SignatureDeltaRow] = Field(default_factory=list) + + +class _ArtifactSide(BaseModel): + artifacts_path: str + artifact_root: str + config_id: str | None = None + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def extract_signature_deltas( + baseline_artifacts: Path, + candidate_artifacts: Path, + *, + baseline_artifact_root: Path, + candidate_artifact_root: Path, + baseline_config: str | None = None, + candidate_config: str | None = None, + workload: str | None = None, + context_window: int = 40, +) -> SignatureDeltaResult: + baseline = _select_artifact_rows( + _read_artifact_rows(baseline_artifacts), config_id=baseline_config, workload=workload + ) + candidate = _select_artifact_rows( + _read_artifact_rows(candidate_artifacts), + config_id=candidate_config, + workload=workload, + ) + rows = _compare_artifact_rows( + baseline, + candidate, + baseline_side=_artifact_side(baseline_artifacts, baseline_artifact_root, baseline_config), + candidate_side=_artifact_side(candidate_artifacts, candidate_artifact_root, candidate_config), + context_window=context_window, + ) + return SignatureDeltaResult( + baseline_artifacts=str(baseline_artifacts), + candidate_artifacts=str(candidate_artifacts), + workload_id=workload, + baseline_config=baseline_config, + candidate_config=candidate_config, + delta_count=len(rows), + rows=rows, + ) + + +def _artifact_side(artifacts_path: Path, artifact_root: Path, config_id: str | None) -> _ArtifactSide: + return _ArtifactSide( + artifacts_path=str(artifacts_path), + artifact_root=str(artifact_root), + config_id=config_id, + ) + + +def _read_artifact_rows(path: Path) -> list[dict[str, object]]: + if not path.exists() or path.is_dir(): + raise ValueError(f"detection artifact path is not a file: {path}") + with path.open(encoding="utf-8") as source: + return [json.loads(line) for line in source if line.strip()] + + +def _select_artifact_rows( + rows: list[dict[str, object]], + *, + config_id: str | None, + workload: str | None, +) -> dict[tuple[str, int], dict[str, object]]: + selected = [_row for _row in rows if _matches(_row, config_id=config_id, workload=workload)] + if not selected: + raise ValueError("artifact selector matched no rows") + return {_artifact_key(row): row for row in selected} + + +def _matches(row: dict[str, object], *, config_id: str | None, workload: str | None) -> bool: + if config_id is not None and str(row.get("config_id")) != config_id: + return False + if workload is not None and str(row.get("workload_id")) != workload: + return False + return True + + +def _artifact_key(row: dict[str, object]) -> tuple[str, int]: + return str(row.get("workload_id")), int(row.get("row_index", 0)) + + +def _compare_artifact_rows( + baseline: dict[tuple[str, int], dict[str, object]], + candidate: dict[tuple[str, int], dict[str, object]], + *, + baseline_side: _ArtifactSide, + candidate_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + rows: list[SignatureDeltaRow] = [] + for key in sorted(set(baseline) & set(candidate)): + rows.extend( + _row_signature_deltas(key, baseline[key], candidate[key], baseline_side, candidate_side, context_window) + ) + return rows + + +def _row_signature_deltas( + key: tuple[str, int], + baseline_row: dict[str, object], + candidate_row: dict[str, object], + baseline_side: _ArtifactSide, + candidate_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + baseline_signatures = _signature_set(baseline_row) + candidate_signatures = _signature_set(candidate_row) + return [ + *_delta_rows( + key, + baseline_row, + baseline_signatures - candidate_signatures, + DeltaSide.baseline_only, + baseline_side, + context_window, + ), + *_delta_rows( + key, + candidate_row, + candidate_signatures - baseline_signatures, + DeltaSide.candidate_only, + candidate_side, + context_window, + ), + ] + + +def _signature_set(row: dict[str, object]) -> set[str]: + raw = row.get("final_entity_signature_hashes", []) + if isinstance(raw, str): + raw = json.loads(raw) + return {str(item) for item in raw} if isinstance(raw, list) else set() + + +def _delta_rows( + key: tuple[str, int], + artifact_row: dict[str, object], + signatures: set[str], + side: DeltaSide, + artifact_side: _ArtifactSide, + context_window: int, +) -> list[SignatureDeltaRow]: + labels = _signature_labels(artifact_row) + return [ + _contextual_delta_row(key, artifact_row, signature, labels.get(signature), side, artifact_side, context_window) + for signature in sorted(signatures) + ] + + +def _signature_labels(row: dict[str, object]) -> dict[str, str]: + labels = row.get("final_entity_signature_labels", {}) + if isinstance(labels, str): + labels = json.loads(labels) + flattened = { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") + } + return {**(labels if isinstance(labels, dict) else {}), **flattened} + + +def _contextual_delta_row( + key: tuple[str, int], + artifact_row: dict[str, object], + signature: str, + label: str | None, + side: DeltaSide, + artifact_side: _ArtifactSide, + context_window: int, +) -> SignatureDeltaRow: + context = _resolve_signature_context( + artifact_row, signature, label, Path(artifact_side.artifact_root), context_window + ) + return SignatureDeltaRow( + workload_id=key[0], + row_index=key[1], + side=side, + config_id=str(artifact_row.get("config_id") or artifact_side.config_id or ""), + signature_hash=signature, + label=label, + batch_file=_optional_string(artifact_row.get("batch_file")), + **context, + ) + + +def _resolve_signature_context( + artifact_row: dict[str, object], + signature: str, + label: str | None, + artifact_root: Path, + context_window: int, +) -> dict[str, object]: + parquet_context = _parquet_entity_context(artifact_row, signature, artifact_root, context_window) + if parquet_context is not None: + return parquet_context + rule_context = _rule_entity_context(artifact_row, signature, label, artifact_root, context_window) + if rule_context is not None: + return rule_context + detail_context = _artifact_detail_context(artifact_row, signature, label, artifact_root, context_window) + return detail_context or {"resolution": ContextResolution.metadata_only} + + +def _parquet_entity_context( + artifact_row: dict[str, object], + signature: str, + artifact_root: Path, + context_window: int, +) -> dict[str, object] | None: + record = _artifact_record(artifact_row, artifact_root) + if record is None: + return None + text, row_index, row = record + for entity in EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES)).entities: + if _entity_signature_hash(entity, row_index=row_index) == signature: + return _entity_context(entity, text, context_window, ContextResolution.parquet) + return None + + +def _rule_entity_context( + artifact_row: dict[str, object], + signature: str, + label: str | None, + artifact_root: Path, + context_window: int, +) -> dict[str, object] | None: + record = _artifact_record(artifact_row, artifact_root) + if record is None or label is None: + return None + text, row_index, _row = record + for span in detect_high_confidence_entities(text, labels=[label]): + entity = EntitySchema.model_validate(span.as_dict()) + if _entity_signature_hash(entity, row_index=row_index) == signature: + return _entity_context(entity, text, context_window, ContextResolution.rule) + return None + + +def _artifact_detail_context( + artifact_row: dict[str, object], + signature: str, + label: str | None, + artifact_root: Path, + context_window: int, +) -> dict[str, object] | None: + details = _signature_details(artifact_row).get(signature) + if details is None: + return None + start_position = _optional_int(details.get("start_position")) + end_position = _optional_int(details.get("end_position")) + value_hash = _optional_string(details.get("value_hash")) + resolved_label = _optional_string(details.get("label")) or label + if start_position is None or end_position is None or value_hash is None or resolved_label is None: + return _metadata_context_from_details(details) + record = _artifact_record(artifact_row, artifact_root) + masked_context = None + if record is not None: + text, _row_index, _row = record + masked_context = _masked_context_from_details( + text, + label=resolved_label, + value_hash=value_hash, + start_position=start_position, + end_position=end_position, + window=context_window, + ) + return { + "source": _optional_string(details.get("source")), + "start_position": start_position, + "end_position": end_position, + "value_hash": value_hash, + "value_length": _optional_int(details.get("value_length")), + "masked_context": masked_context, + "resolution": ContextResolution.artifact_details + if masked_context is not None + else ContextResolution.metadata_only, + } + + +def _signature_details(row: dict[str, object]) -> dict[str, dict[str, object]]: + details = _coerce_detail_map(row.get("final_entity_signature_details", {})) + prefix = "final_entity_signature_details." + for key, value in row.items(): + if not key.startswith(prefix): + continue + remainder = key.removeprefix(prefix) + signature_hash, _, field = remainder.partition(".") + if not signature_hash or not field: + continue + details.setdefault(signature_hash, {})[field] = value + return details + + +def _coerce_detail_map(raw: object) -> dict[str, dict[str, object]]: + if isinstance(raw, str): + try: + raw = json.loads(raw) + except json.JSONDecodeError: + return {} + if not isinstance(raw, dict): + return {} + return {str(key): value for key, value in raw.items() if isinstance(value, dict)} + + +def _metadata_context_from_details(details: dict[str, object]) -> dict[str, object]: + return { + "source": _optional_string(details.get("source")), + "start_position": _optional_int(details.get("start_position")), + "end_position": _optional_int(details.get("end_position")), + "value_hash": _optional_string(details.get("value_hash")), + "value_length": _optional_int(details.get("value_length")), + "resolution": ContextResolution.metadata_only, + } + + +def _artifact_record(artifact_row: dict[str, object], artifact_root: Path) -> tuple[str, int, pd.Series] | None: + batch_file = _optional_string(artifact_row.get("batch_file")) + if batch_file is None: + return None + parquet_file = artifact_root / batch_file + if not parquet_file.exists(): + return None + row_index = int(artifact_row.get("row_index", 0)) + row = pd.read_parquet(parquet_file).iloc[row_index] + return str(row.get(COL_TEXT, "")), row_index, row + + +def _entity_context( + entity: EntitySchema, + text: str, + context_window: int, + resolution: ContextResolution, +) -> dict[str, object]: + return { + "source": entity.source, + "start_position": entity.start_position, + "end_position": entity.end_position, + "value_hash": _value_hash(entity.value), + "value_length": len(entity.value), + "masked_context": _masked_context(text, entity, context_window), + "resolution": resolution, + } + + +def _masked_context(text: str, entity: EntitySchema, window: int) -> str: + before = text[max(0, entity.start_position - window) : entity.start_position] + after = text[entity.end_position : entity.end_position + window] + placeholder = f"[{entity.label.upper()}:{_value_hash(entity.value)}]" + return (before + placeholder + after).replace("\n", " ") + + +def _masked_context_from_details( + text: str, + *, + label: str, + value_hash: str, + start_position: int, + end_position: int, + window: int, +) -> str: + before = text[max(0, start_position - window) : start_position] + after = text[end_position : end_position + window] + placeholder = f"[{label.upper()}:{value_hash}]" + return (before + placeholder + after).replace("\n", " ") + + +def _value_hash(value: str) -> str: + return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] + + +def _optional_int(value: object) -> int | None: + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _optional_string(value: object) -> str | None: + if value is None: + return None + try: + if pd.isna(value): + return None + except (TypeError, ValueError): + pass + return str(value) + + +def write_rows(rows: list[SignatureDeltaRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def render_result(result: SignatureDeltaResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + counts = pd.Series([row.side.value for row in result.rows]).value_counts().to_dict() if result.rows else {} + return ( + f"Extracted {result.delta_count} signature delta(s)" + f"; baseline_only={counts.get('baseline_only', 0)}" + f"; candidate_only={counts.get('candidate_only', 0)}" + ) + + +@app.default +def main( + baseline_artifacts: Path, + candidate_artifacts: Path, + *, + baseline_artifact_root: Annotated[Path, cyclopts.Parameter("--baseline-artifact-root")], + candidate_artifact_root: Annotated[Path, cyclopts.Parameter("--candidate-artifact-root")], + baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, + candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, + workload: Annotated[str | None, cyclopts.Parameter("--workload")] = None, + context_window: Annotated[int, cyclopts.Parameter("--context-window")] = 40, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.jsonl, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = extract_signature_deltas( + baseline_artifacts, + candidate_artifacts, + baseline_artifact_root=baseline_artifact_root, + candidate_artifact_root=candidate_artifact_root, + baseline_config=baseline_config, + candidate_config=candidate_config, + workload=workload, + context_window=context_window, + ) + except (ValueError, ValidationError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + if output is not None: + write_rows(result.rows, output, format) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/replacement_strategies.py b/tools/measurement/replacement_strategies.py new file mode 100644 index 00000000..0089ac72 --- /dev/null +++ b/tools/measurement/replacement_strategies.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Experimental replacement strategies for benchmark-only performance probes.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from enum import StrEnum + +import pandas as pd +from data_designer.config.models import ModelConfig + +from anonymizer.config.models import ReplaceModelSelection +from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE +from anonymizer.engine.replace import llm_replace_workflow as lrw +from anonymizer.engine.replace.structured_substitute import apply_structured_substitution_maps + + +class ExperimentalReplacementStrategy(StrEnum): + default = "default" + local_structured_substitute = "local_structured_substitute" + + +@contextmanager +def experimental_replacement_strategy_context(strategy: ExperimentalReplacementStrategy) -> Iterator[None]: + """Temporarily apply a benchmark-only replacement strategy.""" + if strategy == ExperimentalReplacementStrategy.default: + yield + return + + original_method = lrw.LlmReplaceWorkflow.generate_map_only + if strategy == ExperimentalReplacementStrategy.local_structured_substitute: + lrw.LlmReplaceWorkflow.generate_map_only = _local_structured_generate_map_only # type: ignore[method-assign] + else: + raise ValueError(f"unsupported experimental replacement strategy: {strategy}") + try: + yield + finally: + lrw.LlmReplaceWorkflow.generate_map_only = original_method # type: ignore[method-assign] + + +def _local_structured_generate_map_only( + self: lrw.LlmReplaceWorkflow, + dataframe: pd.DataFrame, + *, + model_configs: list[ModelConfig], + selected_models: ReplaceModelSelection, + instructions: str | None = None, + entities_column: str = COL_ENTITIES_BY_VALUE, + preview_num_records: int | None = None, +) -> lrw.LlmReplaceResult: + _ = self, model_configs, selected_models, instructions, preview_num_records + return lrw.LlmReplaceResult( + dataframe=apply_structured_substitution_maps(dataframe, entities_column=entities_column), + failed_records=[], + ) diff --git a/tools/measurement/replay_replacement_strategies.py b/tools/measurement/replay_replacement_strategies.py new file mode 100644 index 00000000..f5512aa5 --- /dev/null +++ b/tools/measurement/replay_replacement_strategies.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Replay substitute strategies on one fixed detection trace. + +Usage: + uv run python tools/measurement/replay_replacement_strategies.py data.csv \ + --text-column text --labels api_key,http_cookie,password,pin,unique_id,user_name \ + --model-configs /stable-cache/anonymizer/local-vllm-json-models.yaml \ + --model-providers /stable-cache/anonymizer/local-vllm-providers.yaml \ + --dd-parser-compat raw_json --json +""" + +from __future__ import annotations + +import json +import logging +import sys +import time +from collections import Counter +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any + +import cyclopts +import pandas as pd +from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context +from pydantic import BaseModel, Field, ValidationError +from replacement_strategies import ( + ExperimentalReplacementStrategy, + experimental_replacement_strategy_context, +) + +from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect +from anonymizer.config.replace_strategies import Redact, Substitute +from anonymizer.engine.constants import ( + COL_FINAL_ENTITIES, + COL_REPLACED_TEXT, + COL_REPLACEMENT_MAP, + COL_REPLACEMENT_MAP_SOURCE, +) +from anonymizer.engine.schemas import EntitiesSchema, EntityReplacementMapSchema +from anonymizer.interface.anonymizer import Anonymizer, _unrename_output_columns +from anonymizer.measurement import _output_contains_original_value + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.replay_replacement_strategies") + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class ReplayStatus(StrEnum): + completed = "completed" + error = "error" + + +class ReplacementReplayStrategy(StrEnum): + dd_substitute = "dd_substitute" + local_structured_substitute = "local_structured_substitute" + + +class ReplacementRowMetrics(BaseModel): + final_entity_count: int = 0 + replacement_count: int = 0 + missing_count: int = 0 + collision_count: int = 0 + leak_count: int = 0 + duplicate_synthetic_count: int = 0 + missing_labels: dict[str, int] = Field(default_factory=dict) + collision_labels: dict[str, int] = Field(default_factory=dict) + leak_labels: dict[str, int] = Field(default_factory=dict) + + +class ReplacementReplaySummary(BaseModel): + strategy: ReplacementReplayStrategy + status: ReplayStatus + repetition_count: int = 1 + elapsed_sec: float | None = None + row_count: int = 0 + final_entity_count: int = 0 + replacement_count: int = 0 + missing_count: int = 0 + collision_count: int = 0 + leak_count: int = 0 + duplicate_synthetic_count: int = 0 + missing_labels: dict[str, int] = Field(default_factory=dict) + collision_labels: dict[str, int] = Field(default_factory=dict) + leak_labels: dict[str, int] = Field(default_factory=dict) + replacement_map_sources: dict[str, int] = Field(default_factory=dict) + error: str | None = None + + +class ReplacementReplayResult(BaseModel): + input_path: str + text_column: str + labels: list[str] + nrows: int | None = None + replacement_repetitions: int = 1 + dd_parser_compat: DDParserCompatMode + detect_elapsed_sec: float + detected_final_entity_count: int + strategies: list[ReplacementReplaySummary] + + +class ReplacementReplayComparisonRow(BaseModel): + workload_id: str + baseline_config_id: str = "dd_substitute_replay" + candidate_config_id: str = "local_structured_substitute_replay" + baseline_strategy: str = "default" + candidate_strategy: str = "default" + baseline_replacement_strategy: str = "default" + candidate_replacement_strategy: str = "local_structured_substitute" + baseline_case_count: int + candidate_case_count: int + baseline_pipeline_elapsed_sec: float | None = None + candidate_pipeline_elapsed_sec: float | None = None + pipeline_elapsed_sec_delta: float | None = None + pipeline_elapsed_sec_delta_pct: float | None = None + baseline_final_entity_count: int + candidate_final_entity_count: int + final_entity_count_delta: int + baseline_replacement_count: int + candidate_replacement_count: int + replacement_count_delta: int + baseline_replacement_missing_final_entity_count: int + candidate_replacement_missing_final_entity_count: int + replacement_missing_final_entity_count_delta: int + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_count: int + candidate_replacement_synthetic_original_collision_count: int + replacement_synthetic_original_collision_count_delta: int + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_duplicate_synthetic_replacement_count: int + candidate_duplicate_synthetic_replacement_count: int + duplicate_synthetic_replacement_count_delta: int + baseline_original_value_leak_count: int + candidate_original_value_leak_count: int + original_value_leak_count_delta: int + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + value_protection_verdict: str + signature_parity_verdict: str + safety_verdict: str + performance_verdict: str + candidate_verdict: str + flags: list[str] = Field(default_factory=list) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + else: + sys.stderr.write(f"ERROR: bad_input error={error}\n") + + +def parse_labels(raw: str) -> list[str]: + labels = [label.strip() for label in raw.split(",") if label.strip()] + if not labels: + raise ValueError("--labels must contain at least one non-empty label") + return list(dict.fromkeys(labels)) + + +def run_replacement_replay( + *, + source: Path, + text_column: str, + labels: list[str], + nrows: int | None, + model_configs: Path | None, + model_providers: Path | None, + artifact_path: Path, + dd_parser_compat: DDParserCompatMode, + replacement_repetitions: int, +) -> ReplacementReplayResult: + anonymizer = Anonymizer( + model_configs=model_configs, + model_providers=model_providers, + artifact_path=artifact_path, + ) + detect_elapsed, replay_df = build_replay_dataframe( + anonymizer, + source=source, + text_column=text_column, + labels=labels, + nrows=nrows, + dd_parser_compat=dd_parser_compat, + ) + strategies = [ + run_strategy_repetitions( + anonymizer, + replay_df, + ReplacementReplayStrategy.dd_substitute, + dd_parser_compat, + repetitions=replacement_repetitions, + ), + run_strategy_repetitions( + anonymizer, + replay_df, + ReplacementReplayStrategy.local_structured_substitute, + dd_parser_compat, + repetitions=replacement_repetitions, + ), + ] + return ReplacementReplayResult( + input_path=str(source), + text_column=text_column, + labels=labels, + nrows=nrows, + replacement_repetitions=replacement_repetitions, + dd_parser_compat=dd_parser_compat, + detect_elapsed_sec=detect_elapsed, + detected_final_entity_count=count_final_entities(replay_df), + strategies=strategies, + ) + + +def build_replay_dataframe( + anonymizer: Anonymizer, + *, + source: Path, + text_column: str, + labels: list[str], + nrows: int | None, + dd_parser_compat: DDParserCompatMode, +) -> tuple[float, pd.DataFrame]: + input_data = AnonymizerInput(source=str(source), text_column=text_column) + config = AnonymizerConfig(detect=Detect(entity_labels=labels), replace=Redact()) + with dd_parser_compat_context(dd_parser_compat): + start = time.perf_counter() + if nrows is None: + detected = anonymizer.run(config=config, data=input_data) + else: + detected = anonymizer.preview(config=config, data=input_data, num_records=nrows) + elapsed = time.perf_counter() - start + internal_df = _unrename_output_columns(detected.trace_dataframe, resolved_text_column=detected.resolved_text_column) + return elapsed, strip_replacement_columns(internal_df) + + +def strip_replacement_columns(dataframe: pd.DataFrame) -> pd.DataFrame: + return dataframe.drop(columns=[COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE, COL_REPLACED_TEXT], errors="ignore") + + +def run_strategy_repetitions( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, + *, + repetitions: int, +) -> ReplacementReplaySummary: + summaries = [run_strategy(anonymizer, dataframe, strategy, dd_parser_compat) for _ in range(repetitions)] + return aggregate_strategy_summaries(strategy=strategy, summaries=summaries) + + +def aggregate_strategy_summaries( + *, + strategy: ReplacementReplayStrategy, + summaries: list[ReplacementReplaySummary], +) -> ReplacementReplaySummary: + sources: Counter[str] = Counter() + missing_labels: Counter[str] = Counter() + collision_labels: Counter[str] = Counter() + leak_labels: Counter[str] = Counter() + errors: list[str] = [] + elapsed_values: list[float] = [] + status = ReplayStatus.completed + for summary in summaries: + sources.update(summary.replacement_map_sources) + missing_labels.update(summary.missing_labels) + collision_labels.update(summary.collision_labels) + leak_labels.update(summary.leak_labels) + if summary.elapsed_sec is not None: + elapsed_values.append(summary.elapsed_sec) + if summary.status == ReplayStatus.error: + status = ReplayStatus.error + if summary.error: + errors.append(summary.error) + return ReplacementReplaySummary( + strategy=strategy, + status=status, + repetition_count=len(summaries), + elapsed_sec=sum(elapsed_values) if elapsed_values else None, + row_count=sum(summary.row_count for summary in summaries), + final_entity_count=sum(summary.final_entity_count for summary in summaries), + replacement_count=sum(summary.replacement_count for summary in summaries), + missing_count=sum(summary.missing_count for summary in summaries), + collision_count=sum(summary.collision_count for summary in summaries), + leak_count=sum(summary.leak_count for summary in summaries), + duplicate_synthetic_count=sum(summary.duplicate_synthetic_count for summary in summaries), + missing_labels=dict(sorted(missing_labels.items())), + collision_labels=dict(sorted(collision_labels.items())), + leak_labels=dict(sorted(leak_labels.items())), + replacement_map_sources=dict(sorted(sources.items())), + error="; ".join(errors) if errors else None, + ) + + +def run_strategy( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, +) -> ReplacementReplaySummary: + try: + start = time.perf_counter() + result_df = execute_strategy(anonymizer, dataframe.copy(), strategy, dd_parser_compat) + elapsed = time.perf_counter() - start + except Exception as exc: + logger.exception("replacement replay strategy failed: %s", strategy) + return ReplacementReplaySummary(strategy=strategy, status=ReplayStatus.error, error=str(exc)) + return summarize_replacement_dataframe(result_df, strategy=strategy, elapsed_sec=elapsed) + + +def execute_strategy( + anonymizer: Anonymizer, + dataframe: pd.DataFrame, + strategy: ReplacementReplayStrategy, + dd_parser_compat: DDParserCompatMode, +) -> pd.DataFrame: + if strategy == ReplacementReplayStrategy.local_structured_substitute: + with experimental_replacement_strategy_context(ExperimentalReplacementStrategy.local_structured_substitute): + return execute_substitute(anonymizer, dataframe) + with dd_parser_compat_context(dd_parser_compat): + return execute_substitute(anonymizer, dataframe) + + +def execute_substitute(anonymizer: Anonymizer, dataframe: pd.DataFrame) -> pd.DataFrame: + result = anonymizer._replace_runner.run( # benchmark probe against the configured runner + dataframe, + replace_method=Substitute(), + model_configs=anonymizer._model_configs, + selected_models=anonymizer._selected_models.replace, + ) + return result.dataframe + + +def summarize_replacement_dataframe( + dataframe: pd.DataFrame, + *, + strategy: ReplacementReplayStrategy, + elapsed_sec: float, +) -> ReplacementReplaySummary: + rows = [replacement_row_metrics(row) for _, row in dataframe.iterrows()] + return ReplacementReplaySummary( + strategy=strategy, + status=ReplayStatus.completed, + elapsed_sec=elapsed_sec, + row_count=len(rows), + final_entity_count=sum(row.final_entity_count for row in rows), + replacement_count=sum(row.replacement_count for row in rows), + missing_count=sum(row.missing_count for row in rows), + collision_count=sum(row.collision_count for row in rows), + leak_count=sum(row.leak_count for row in rows), + duplicate_synthetic_count=sum(row.duplicate_synthetic_count for row in rows), + missing_labels=sum_label_counts(row.missing_labels for row in rows), + collision_labels=sum_label_counts(row.collision_labels for row in rows), + leak_labels=sum_label_counts(row.leak_labels for row in rows), + replacement_map_sources=count_sources(dataframe), + ) + + +def replacement_row_metrics(row: Any) -> ReplacementRowMetrics: + entities = [entity.model_dump() for entity in EntitiesSchema.from_raw(row[COL_FINAL_ENTITIES]).entities] + replacements = parse_replacements(row[COL_REPLACEMENT_MAP]) + missing = missing_replacements(entities, replacements) + collisions = synthetic_original_collisions(entities, replacements) + leaks = leaked_entities(entities, str(row[COL_REPLACED_TEXT])) + synthetic_values = [item["synthetic"] for item in replacements if item.get("synthetic")] + return ReplacementRowMetrics( + final_entity_count=len(entities), + replacement_count=len(replacements), + missing_count=len(missing), + collision_count=len(collisions), + leak_count=len(leaks), + duplicate_synthetic_count=max(0, len(synthetic_values) - len(set(synthetic_values))), + missing_labels=count_labels(missing), + collision_labels=count_labels(collisions), + leak_labels=count_labels(leaks), + ) + + +def parse_replacements(raw: object) -> list[dict[str, str]]: + if hasattr(raw, "model_dump"): + raw = raw.model_dump(mode="python") + if isinstance(raw, str): + raw = json.loads(raw) + return [item.model_dump() for item in EntityReplacementMapSchema.model_validate(raw).replacements] + + +def missing_replacements(entities: list[dict[str, Any]], replacements: list[dict[str, str]]) -> list[dict[str, Any]]: + replacement_pairs = {(item["original"], item["label"]) for item in replacements} + return [entity for entity in entities if (entity.get("value"), entity.get("label")) not in replacement_pairs] + + +def synthetic_original_collisions( + entities: list[dict[str, Any]], + replacements: list[dict[str, str]], +) -> list[dict[str, str]]: + original_values = {entity["value"] for entity in entities if entity.get("value")} + return [item for item in replacements if item.get("synthetic") in original_values] + + +def leaked_entities(entities: list[dict[str, Any]], output_text: str) -> list[dict[str, Any]]: + return [ + entity + for entity in entities + if entity.get("value") and _output_contains_original_value(output_text, str(entity["value"])) + ] + + +def count_labels(items: list[dict[str, Any]]) -> dict[str, int]: + return dict(sorted(Counter(str(item.get("label") or "") for item in items if item.get("label")).items())) + + +def sum_label_counts(values: object) -> dict[str, int]: + counts: Counter[str] = Counter() + for mapping in values: + counts.update(mapping) + return dict(sorted(counts.items())) + + +def count_sources(dataframe: pd.DataFrame) -> dict[str, int]: + if COL_REPLACEMENT_MAP_SOURCE not in dataframe.columns: + return {} + return dict(sorted(Counter(str(source) for source in dataframe[COL_REPLACEMENT_MAP_SOURCE]).items())) + + +def count_final_entities(dataframe: pd.DataFrame) -> int: + return sum(len(EntitiesSchema.from_raw(raw).entities) for raw in dataframe[COL_FINAL_ENTITIES]) + + +def render_result(result: ReplacementReplayResult, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"input={result.input_path}", + f"labels={','.join(result.labels)}", + f"nrows={result.nrows if result.nrows is not None else 'all'}", + f"replacement_repetitions={result.replacement_repetitions}", + f"detected_final_entities={result.detected_final_entity_count}", + f"detection_elapsed_sec={result.detect_elapsed_sec:.3f}", + ] + lines.extend(render_strategy(summary) for summary in result.strategies) + return "\n".join(lines) + + +def render_strategy(summary: ReplacementReplaySummary) -> str: + if summary.status == ReplayStatus.error: + return f"{summary.strategy}: error={summary.error}" + return ( + f"{summary.strategy}: repetitions={summary.repetition_count} elapsed_sec={summary.elapsed_sec:.3f} " + f"replacements={summary.replacement_count} missing={summary.missing_count} " + f"leaks={summary.leak_count} collisions={summary.collision_count}" + ) + + +def replay_comparison_row( + result: ReplacementReplayResult, *, workload_id: str | None = None +) -> ReplacementReplayComparisonRow: + summaries = {summary.strategy: summary for summary in result.strategies} + baseline = summaries.get(ReplacementReplayStrategy.dd_substitute) + candidate = summaries.get(ReplacementReplayStrategy.local_structured_substitute) + if baseline is None or candidate is None: + raise ValueError( + "replacement replay result must include dd_substitute and local_structured_substitute summaries" + ) + flags = replay_comparison_flags(baseline, candidate) + value_protection = replay_value_protection_verdict(flags) + signature_parity = replay_signature_parity_verdict(baseline, candidate, flags) + safety = replay_safety_verdict(value_protection, signature_parity) + performance = replay_performance_verdict(baseline.elapsed_sec, candidate.elapsed_sec) + return ReplacementReplayComparisonRow( + workload_id=workload_id or Path(result.input_path).stem, + baseline_case_count=baseline.row_count, + candidate_case_count=candidate.row_count, + baseline_pipeline_elapsed_sec=baseline.elapsed_sec, + candidate_pipeline_elapsed_sec=candidate.elapsed_sec, + pipeline_elapsed_sec_delta=delta(baseline.elapsed_sec, candidate.elapsed_sec), + pipeline_elapsed_sec_delta_pct=delta_pct(baseline.elapsed_sec, candidate.elapsed_sec), + baseline_final_entity_count=baseline.final_entity_count, + candidate_final_entity_count=candidate.final_entity_count, + final_entity_count_delta=candidate.final_entity_count - baseline.final_entity_count, + baseline_replacement_count=baseline.replacement_count, + candidate_replacement_count=candidate.replacement_count, + replacement_count_delta=candidate.replacement_count - baseline.replacement_count, + baseline_replacement_missing_final_entity_count=baseline.missing_count, + candidate_replacement_missing_final_entity_count=candidate.missing_count, + replacement_missing_final_entity_count_delta=candidate.missing_count - baseline.missing_count, + baseline_replacement_missing_final_entity_label_counts=baseline.missing_labels, + candidate_replacement_missing_final_entity_label_counts=candidate.missing_labels, + baseline_replacement_synthetic_original_collision_count=baseline.collision_count, + candidate_replacement_synthetic_original_collision_count=candidate.collision_count, + replacement_synthetic_original_collision_count_delta=candidate.collision_count - baseline.collision_count, + baseline_replacement_synthetic_original_collision_label_counts=baseline.collision_labels, + candidate_replacement_synthetic_original_collision_label_counts=candidate.collision_labels, + baseline_duplicate_synthetic_replacement_count=baseline.duplicate_synthetic_count, + candidate_duplicate_synthetic_replacement_count=candidate.duplicate_synthetic_count, + duplicate_synthetic_replacement_count_delta=( + candidate.duplicate_synthetic_count - baseline.duplicate_synthetic_count + ), + baseline_original_value_leak_count=baseline.leak_count, + candidate_original_value_leak_count=candidate.leak_count, + original_value_leak_count_delta=candidate.leak_count - baseline.leak_count, + baseline_original_value_leak_label_counts=baseline.leak_labels, + candidate_original_value_leak_label_counts=candidate.leak_labels, + value_protection_verdict=value_protection, + signature_parity_verdict=signature_parity, + safety_verdict=safety, + performance_verdict=performance, + candidate_verdict=replay_candidate_verdict(safety, performance), + flags=flags, + ) + + +def replay_comparison_flags( + baseline: ReplacementReplaySummary, + candidate: ReplacementReplaySummary, +) -> list[str]: + flags: list[str] = [] + if baseline.status == ReplayStatus.error: + flags.append("baseline_case_failures") + if candidate.status == ReplayStatus.error: + flags.append("candidate_case_failures") + if baseline.missing_count: + flags.append("baseline_replacement_missing_final_entity") + if candidate.missing_count: + flags.append("candidate_replacement_missing_final_entity") + if baseline.missing_count and candidate.missing_count < baseline.missing_count: + flags.append("candidate_reduces_baseline_replacement_missing_final_entity") + if baseline.missing_count and candidate.missing_count == 0: + flags.append("candidate_covers_baseline_replacement_missing_final_entity") + if baseline.collision_count: + flags.append("baseline_replacement_synthetic_original_collision") + if candidate.collision_count: + flags.append("candidate_replacement_synthetic_original_collision") + if baseline.collision_count and candidate.collision_count < baseline.collision_count: + flags.append("candidate_reduces_baseline_replacement_synthetic_original_collision") + if baseline.collision_count and candidate.collision_count == 0: + flags.append("candidate_covers_baseline_replacement_synthetic_original_collision") + if baseline.duplicate_synthetic_count: + flags.append("baseline_duplicate_synthetic_replacement") + if candidate.duplicate_synthetic_count: + flags.append("candidate_duplicate_synthetic_replacement") + if baseline.leak_count: + flags.append("baseline_original_value_leak") + if candidate.leak_count: + flags.append("candidate_original_value_leak") + if baseline.leak_count and candidate.leak_count < baseline.leak_count: + flags.append("candidate_reduces_baseline_original_value_leak") + if baseline.leak_count and candidate.leak_count == 0: + flags.append("candidate_covers_baseline_original_value_leak") + if baseline.final_entity_count != candidate.final_entity_count: + flags.append( + "entity_count_loss" if candidate.final_entity_count < baseline.final_entity_count else "entity_count_delta" + ) + if baseline.replacement_count != candidate.replacement_count: + flags.append( + "replacement_count_loss" + if candidate.replacement_count < baseline.replacement_count + else "replacement_count_delta" + ) + return flags + + +def replay_value_protection_verdict(flags: list[str]) -> str: + flag_set = set(flags) + if flag_set & { + "candidate_case_failures", + "candidate_original_value_leak", + "candidate_replacement_missing_final_entity", + "candidate_replacement_synthetic_original_collision", + "entity_count_loss", + "replacement_count_loss", + }: + return "fail" + if "baseline_case_failures" in flag_set: + return "review" + baseline_defects = { + "baseline_original_value_leak", + "baseline_replacement_missing_final_entity", + "baseline_replacement_synthetic_original_collision", + } + corrected_baseline_defects = { + "baseline_original_value_leak": "candidate_covers_baseline_original_value_leak", + "baseline_replacement_missing_final_entity": ("candidate_covers_baseline_replacement_missing_final_entity"), + "baseline_replacement_synthetic_original_collision": ( + "candidate_covers_baseline_replacement_synthetic_original_collision" + ), + } + for defect in baseline_defects & flag_set: + if corrected_baseline_defects[defect] not in flag_set: + return "review" + return "pass" + + +def replay_signature_parity_verdict( + baseline: ReplacementReplaySummary, + candidate: ReplacementReplaySummary, + flags: list[str], +) -> str: + if candidate.status == ReplayStatus.error: + return "fail" + if baseline.status == ReplayStatus.error: + return "review" + if candidate.final_entity_count < baseline.final_entity_count: + return "fail" + if candidate.final_entity_count != baseline.final_entity_count: + return "review" + if candidate.replacement_count < baseline.replacement_count: + return "fail" + if candidate.replacement_count != baseline.replacement_count: + return "review" + if flags: + return "review" + return "pass" + + +def replay_safety_verdict(value_protection: str, signature_parity: str) -> str: + if "fail" in {value_protection, signature_parity}: + return "fail" + if "review" in {value_protection, signature_parity}: + return "review" + return "pass" + + +def replay_performance_verdict(baseline_elapsed: float | None, candidate_elapsed: float | None) -> str: + elapsed_delta = delta(baseline_elapsed, candidate_elapsed) + if elapsed_delta is None: + return "unknown" + if elapsed_delta < 0: + return "improved" + if elapsed_delta > 0: + return "regressed" + return "unchanged" + + +def replay_candidate_verdict(safety: str, performance: str) -> str: + if safety == "fail": + return "reject" + if safety == "pass" and performance == "improved": + return "candidate_viable" + return "review" + + +def delta(baseline: float | None, candidate: float | None) -> float | None: + if baseline is None or candidate is None: + return None + return candidate - baseline + + +def delta_pct(baseline: float | None, candidate: float | None) -> float | None: + value = delta(baseline, candidate) + if value is None or baseline in {None, 0}: + return None + return value / baseline * 100 + + +def write_replay_comparison( + result: ReplacementReplayResult, + output: Path, + *, + workload_id: str | None = None, +) -> None: + row = replay_comparison_row(result, workload_id=workload_id).model_dump(mode="json") + row["flags"] = json.dumps(row["flags"], ensure_ascii=True) + pd.DataFrame([row]).to_csv(output, index=False) + + +@app.default +def main( + source: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")] = "text", + labels: Annotated[str, cyclopts.Parameter("--labels")], + nrows: Annotated[int | None, cyclopts.Parameter("--nrows")] = None, + replacement_repetitions: Annotated[int, cyclopts.Parameter("--replacement-repetitions")] = 1, + model_configs: Annotated[Path | None, cyclopts.Parameter("--model-configs")] = None, + model_providers: Annotated[Path | None, cyclopts.Parameter("--model-providers")] = None, + artifact_path: Annotated[Path, cyclopts.Parameter("--artifact-path")] = Path( + "/tmp/anonymizer-replacement-replay-artifacts" + ), + dd_parser_compat: Annotated[DDParserCompatMode, cyclopts.Parameter("--dd-parser-compat")] = ( + DDParserCompatMode.none + ), + output: Annotated[Path | None, cyclopts.Parameter("--output")] = None, + comparison_output: Annotated[Path | None, cyclopts.Parameter("--comparison-output")] = None, + workload_id: Annotated[str | None, cyclopts.Parameter("--workload-id")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + if nrows is not None and nrows <= 0: + raise ValueError("--nrows must be greater than zero") + if replacement_repetitions <= 0: + raise ValueError("--replacement-repetitions must be greater than zero") + result = run_replacement_replay( + source=source, + text_column=text_column, + labels=parse_labels(labels), + nrows=nrows, + model_configs=model_configs, + model_providers=model_providers, + artifact_path=artifact_path, + dd_parser_compat=dd_parser_compat, + replacement_repetitions=replacement_repetitions, + ) + except (ValidationError, ValueError, FileNotFoundError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + rendered = render_result(result, json_output=json_output) + if output is not None: + output.write_text(rendered + "\n", encoding="utf-8") + else: + sys.stdout.write(rendered + "\n") + if comparison_output is not None: + write_replay_comparison(result, comparison_output, workload_id=workload_id) + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 13843940..acb967d8 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -13,6 +13,7 @@ import shutil import sys import time +from dataclasses import dataclass from enum import StrEnum from pathlib import Path from typing import Annotated, Any @@ -21,10 +22,24 @@ import pandas as pd import pyarrow.parquet as pq import yaml +from analyze_detection_artifacts import ( + analyze_artifacts, + build_detection_artifact_row_from_entities, + iter_detection_parquet_files, +) from data_designer.config.models import ModelProvider from data_designer.config.utils.io_helpers import load_config_file +from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context +from detection_strategies import ( + ExperimentalDetectionStrategy, + experimental_detection_strategy_context, +) from export_measurements import ExportFormat, export_tables, read_measurements from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator +from replacement_strategies import ( + ExperimentalReplacementStrategy, + experimental_replacement_strategy_context, +) from anonymizer.config.anonymizer_config import ( AnonymizerConfig, @@ -36,8 +51,16 @@ ) from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance +from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_FINAL_ENTITIES +from anonymizer.engine.detection.rules import ( + STRUCTURED_RULE_FAST_LANE_LABELS, + SUPPORTED_RULE_LABELS, + detect_high_confidence_entities, +) from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references +from anonymizer.engine.replace.structured_substitute import SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS +from anonymizer.engine.schemas import EntitiesSchema, EntitySchema from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import MeasurementConfig, configured_measurement_session @@ -80,6 +103,8 @@ class WorkloadSpec(BaseModel): text_column: str = "text" id_column: str | None = None data_summary: str | None = None + row_limit: int | None = Field(default=None, ge=1) + row_offset: int = Field(default=0, ge=0) class ReplaceSpec(BaseModel): @@ -112,6 +137,9 @@ class ConfigSpec(BaseModel): replace: str | ReplaceSpec | None = None rewrite: RewriteSpec | None = None emit_telemetry: bool = False + experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default + experimental_replacement_strategy: ExperimentalReplacementStrategy = ExperimentalReplacementStrategy.default + experimental_rule_labels: list[str] | None = None @model_validator(mode="after") def validate_mode(self) -> "ConfigSpec": @@ -147,6 +175,9 @@ class BenchmarkSpec(BaseModel): model_configs: str | None = None model_providers: str | None = None artifact_path: str | None = None + dd_parser_compat: DDParserCompatMode = DDParserCompatMode.none + case_retries: int = Field(default=0, ge=0) + case_retry_backoff_sec: float = Field(default=0.0, ge=0.0) workloads: list[WorkloadSpec] = Field(min_length=1) configs: list[ConfigSpec] = Field(min_length=1) matrix: list[MatrixEntry] | None = Field(default=None, min_length=1) @@ -182,8 +213,11 @@ class BenchmarkCase(BaseModel): status: CaseStatus = CaseStatus.planned elapsed_sec: float | None = None measurement_path: str | None = None + detection_artifact_path: str | None = None trace_path: str | None = None error: str | None = None + attempt_count: int = 0 + attempt_errors: list[str] = Field(default_factory=list) class BenchmarkResult(BaseModel): @@ -192,9 +226,66 @@ class BenchmarkResult(BaseModel): measurement_path: str summary_path: str table_dir: str | None + detection_artifact_analysis_path: str | None = None cases: list[BenchmarkCase] +@dataclass(frozen=True) +class _CaseRunPaths: + raw_path: Path + artifact_output_path: Path + trace_path: Path | None + artifact_snapshot: dict[str, int] | None + + +@dataclass(frozen=True) +class _CaseExecution: + input_data: AnonymizerInput + trace_dataframe: pd.DataFrame | None = None + + +_TRACE_FINAL_ARTIFACT_STRATEGIES = { + ExperimentalDetectionStrategy.rules_guardrail, + ExperimentalDetectionStrategy.rules_covered_or_default, + ExperimentalDetectionStrategy.rules_guardrail_compact_validation, + ExperimentalDetectionStrategy.rules_filter_guardrail, + ExperimentalDetectionStrategy.native_rules_router, + ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_native_augment, + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + ExperimentalDetectionStrategy.native_single_pass, + ExperimentalDetectionStrategy.native_single_pass_recall, + ExperimentalDetectionStrategy.native_single_pass_values, + ExperimentalDetectionStrategy.native_single_pass_values_recall, +} +_RULE_BACKED_STRATEGIES = { + ExperimentalDetectionStrategy.rules_guardrail, + ExperimentalDetectionStrategy.rules_guardrail_compact_validation, + ExperimentalDetectionStrategy.rules_filter_guardrail, + ExperimentalDetectionStrategy.rules_seed_no_augment, + ExperimentalDetectionStrategy.rules_guardrail_no_augment, + ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment, + ExperimentalDetectionStrategy.rules_guardrail_detector_only, + ExperimentalDetectionStrategy.rules_only, +} + +_FINAL_ARTIFACT_KEYS = { + "final_entity_count", + "weak_api_key_shape_count", + "final_entity_signature_count", + "final_entity_signature_hashes", + "final_entity_signature_labels", + "final_entity_signature_details", + "weak_api_key_shape_label_counts", + "final_label_counts", + "final_source_counts", +} + +_FINAL_ARTIFACT_PREFIXES = tuple(f"{key}." for key in _FINAL_ARTIFACT_KEYS if key != "final_entity_signature_hashes") + + def configure_logging(log_format: LogFormat) -> None: global _log_format @@ -260,30 +351,61 @@ def preflight_suite(spec: BenchmarkSpec, *, spec_path: Path) -> None: """Validate cheap suite inputs before any benchmark case consumes model time.""" base_dir = spec_path.parent errors: list[str] = [] - parsed_models = None + parsed_models = _preflight_model_configs(spec, base_dir=base_dir, errors=errors) + + _preflight_model_providers_with_errors(spec, base_dir=base_dir, errors=errors) + errors.extend(_preflight_workload_errors(spec, base_dir=base_dir)) + errors.extend(_preflight_config_errors(spec, parsed_models=parsed_models)) + if errors: + raise ValueError("Benchmark preflight failed:\n- " + "\n- ".join(errors)) + +def _preflight_model_configs(spec: BenchmarkSpec, *, base_dir: Path, errors: list[str]) -> Any | None: try: - parsed_models = parse_model_configs(_resolve_config_source(spec.model_configs, base_dir)) + return parse_model_configs(_resolve_config_source(spec.model_configs, base_dir)) except Exception as exc: errors.append(f"model_configs invalid: {exc}") + return None + +def _preflight_model_providers_with_errors( + spec: BenchmarkSpec, + *, + base_dir: Path, + errors: list[str], +) -> None: try: _preflight_model_providers(spec, base_dir=base_dir) except Exception as exc: errors.append(f"model_providers invalid: {exc}") + +def _preflight_workload_errors(spec: BenchmarkSpec, *, base_dir: Path) -> list[str]: + errors: list[str] = [] for workload in spec.workloads: try: _preflight_workload(workload, base_dir=base_dir) except Exception as exc: errors.append(str(exc)) + return errors + +def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) -> list[str]: + errors: list[str] = [] for config in spec.configs: try: anonymizer_config = build_anonymizer_config(config) except Exception as exc: errors.append(f"config '{config.id}' invalid: {exc}") continue + try: + _preflight_experimental_detection_strategy(config, anonymizer_config) + except Exception as exc: + errors.append(f"config '{config.id}' experimental_detection_strategy invalid: {exc}") + try: + _preflight_experimental_replacement_strategy(config, anonymizer_config) + except Exception as exc: + errors.append(f"config '{config.id}' experimental_replacement_strategy invalid: {exc}") if parsed_models is None: continue try: @@ -296,9 +418,64 @@ def preflight_suite(spec: BenchmarkSpec, *, spec_path: Path) -> None: ) except ValueError as exc: errors.append(f"config '{config.id}' model aliases invalid: {exc}") + return errors - if errors: - raise ValueError("Benchmark preflight failed:\n- " + "\n- ".join(errors)) + +def _preflight_experimental_detection_strategy(config: ConfigSpec, anonymizer_config: AnonymizerConfig) -> None: + _preflight_experimental_rule_labels(config) + if config.experimental_detection_strategy != ExperimentalDetectionStrategy.rules_only: + return + entity_labels = anonymizer_config.detect.entity_labels + supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) + if entity_labels is None: + raise ValueError( + f"`rules_only` requires explicit detect.entity_labels limited to deterministic rule labels: {supported}" + ) + unsupported = sorted(set(entity_labels) - SUPPORTED_RULE_LABELS) + if unsupported: + raise ValueError( + f"unsupported high-confidence rule labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + +def _preflight_experimental_rule_labels(config: ConfigSpec) -> None: + if not config.experimental_rule_labels: + return + supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) + if config.experimental_detection_strategy not in _RULE_BACKED_STRATEGIES: + raise ValueError( + "experimental_rule_labels requires a rule-backed strategy: " + + ", ".join(sorted(strategy.value for strategy in _RULE_BACKED_STRATEGIES)) + ) + unsupported = sorted(set(config.experimental_rule_labels) - SUPPORTED_RULE_LABELS) + if unsupported: + raise ValueError( + f"unsupported experimental_rule_labels: {', '.join(unsupported)}; supported labels: {supported}" + ) + + +def _preflight_experimental_replacement_strategy( + config: ConfigSpec, + anonymizer_config: AnonymizerConfig, +) -> None: + if config.experimental_replacement_strategy == ExperimentalReplacementStrategy.default: + return + if config.experimental_replacement_strategy != ExperimentalReplacementStrategy.local_structured_substitute: + raise ValueError(f"unsupported strategy: {config.experimental_replacement_strategy}") + if not isinstance(anonymizer_config.replace, Substitute): + raise ValueError("local_structured_substitute requires replace: substitute") + entity_labels = anonymizer_config.detect.entity_labels + supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) + if entity_labels is None: + raise ValueError( + "`local_structured_substitute` requires explicit detect.entity_labels limited to " + f"structured substitute labels: {supported}" + ) + unsupported = sorted(set(entity_labels) - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) + if unsupported: + raise ValueError( + f"unsupported local structured substitute labels: {', '.join(unsupported)}; supported labels: {supported}" + ) def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: @@ -322,6 +499,8 @@ def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: def _preflight_workload(workload: WorkloadSpec, *, base_dir: Path) -> None: resolved_source = _resolve_input_source(workload.source, base_dir) + if _workload_has_row_slice(workload) and not _is_local_input_source(str(resolved_source)): + raise ValueError(f"workload '{workload.id}' row slicing requires a local workload source") input_data = AnonymizerInput( source=str(resolved_source), text_column=workload.text_column, @@ -373,23 +552,85 @@ def run_suite( trace_dir=trace_dir, ) anonymizer = Anonymizer(**contexts["anonymizer_kwargs"]) - cases = [ - _run_case(case, spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast) + cases = _run_cases(spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast, export=export) + measurement_path = combine_measurements(cases, output_dir / "measurements.jsonl") + should_export = _should_export_measurements(export=export, measurement_path=measurement_path) + table_dir = _export_suite_tables(measurement_path, output_dir=output_dir, should_export=should_export) + artifact_analysis_path = _combine_suite_detection_artifacts( + cases, output_dir=output_dir, should_export=should_export + ) + result = _benchmark_result( + spec, + output_dir=output_dir, + measurement_path=measurement_path, + table_dir=table_dir, + artifact_analysis_path=artifact_analysis_path, + cases=cases, + ) + write_summary(result) + return result + + +def _run_cases( + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + fail_fast: bool, + export: bool, +) -> list[BenchmarkCase]: + return [ + _run_case( + case, + spec, + contexts=contexts, + anonymizer=anonymizer, + fail_fast=fail_fast, + export_detection_artifacts=export, + ) for case in build_cases(spec) ] - measurement_path = combine_measurements(cases, output_dir / "measurements.jsonl") - should_export = export and measurement_path.stat().st_size > 0 - table_dir = export_measurement_tables(measurement_path, output_dir / "tables") if should_export else None - result = BenchmarkResult( + + +def _should_export_measurements(*, export: bool, measurement_path: Path) -> bool: + return export and measurement_path.stat().st_size > 0 + + +def _export_suite_tables(measurement_path: Path, *, output_dir: Path, should_export: bool) -> Path | None: + if not should_export: + return None + return export_measurement_tables(measurement_path, output_dir / "tables") + + +def _combine_suite_detection_artifacts( + cases: list[BenchmarkCase], + *, + output_dir: Path, + should_export: bool, +) -> Path | None: + if not should_export: + return None + return combine_detection_artifact_analysis(cases, output_dir / "detection-artifacts.jsonl") + + +def _benchmark_result( + spec: BenchmarkSpec, + *, + output_dir: Path, + measurement_path: Path, + table_dir: Path | None, + artifact_analysis_path: Path | None, + cases: list[BenchmarkCase], +) -> BenchmarkResult: + return BenchmarkResult( suite_id=spec.suite_id, output_dir=str(output_dir), measurement_path=str(measurement_path), summary_path=str(output_dir / "summary.json"), table_dir=str(table_dir) if table_dir is not None else None, + detection_artifact_analysis_path=str(artifact_analysis_path) if artifact_analysis_path is not None else None, cases=cases, ) - write_summary(result) - return result def _build_contexts( @@ -409,6 +650,8 @@ def _build_contexts( "raw_dir": output_dir / "raw", "dd_trace": dd_trace, "trace_dir": trace_dir or output_dir / "traces", + "dd_parser_compat": spec.dd_parser_compat, + "artifact_path": artifact_path, "anonymizer_kwargs": { "model_configs": _resolve_config_source(spec.model_configs, base_dir), "model_providers": _resolve_config_source(spec.model_providers, base_dir), @@ -424,44 +667,326 @@ def _run_case( contexts: dict[str, Any], anonymizer: Anonymizer, fail_fast: bool, + export_detection_artifacts: bool, ) -> BenchmarkCase: - raw_path = contexts["raw_dir"] / f"{case.case_id}.jsonl" - trace_path = _case_trace_path(case, contexts=contexts) started = time.perf_counter() - try: - workload = _get_item(contexts["workloads"], case.workload_id, "workload") - config = _get_item(contexts["configs"], case.config_id, "config") - _execute_case( - anonymizer, - workload, - config, - raw_path=raw_path, - trace_path=trace_path, - case=case, - spec=spec, - base_dir=contexts["base_dir"], - dd_trace=contexts["dd_trace"], - ) - return case.model_copy( - update={ - "status": CaseStatus.completed, - "elapsed_sec": time.perf_counter() - started, - "measurement_path": str(raw_path), - "trace_path": str(trace_path) if trace_path is not None else None, - } - ) - except Exception as exc: - if fail_fast: - raise - return case.model_copy( - update={ - "status": CaseStatus.error, - "elapsed_sec": time.perf_counter() - started, - "measurement_path": str(raw_path), - "trace_path": str(trace_path) if trace_path is not None else None, - "error": str(exc), - } - ) + attempt_errors: list[str] = [] + max_attempts = 1 if fail_fast else spec.case_retries + 1 + for attempt_number in range(1, max_attempts + 1): + paths = _case_run_paths(case, contexts=contexts, export_detection_artifacts=export_detection_artifacts) + try: + return _run_case_success( + case, + spec, + contexts=contexts, + anonymizer=anonymizer, + paths=paths, + started=started, + attempt_count=attempt_number, + attempt_errors=attempt_errors, + ) + except Exception as exc: + if fail_fast: + raise + attempt_errors.append(str(exc)) + if attempt_number >= max_attempts: + return _run_case_error( + case, + contexts=contexts, + paths=paths, + started=started, + error=exc, + attempt_count=attempt_number, + attempt_errors=attempt_errors, + ) + _sleep_before_case_retry(spec, case=case, attempt_number=attempt_number, error=exc) + + raise RuntimeError("unreachable benchmark retry state") + + +def _sleep_before_case_retry( + spec: BenchmarkSpec, + *, + case: BenchmarkCase, + attempt_number: int, + error: Exception, +) -> None: + logger.warning( + "case %s attempt %d failed; retrying after %.2fs: %s", + case.case_id, + attempt_number, + spec.case_retry_backoff_sec, + error, + ) + if spec.case_retry_backoff_sec > 0: + time.sleep(spec.case_retry_backoff_sec) + + +def _case_run_paths( + case: BenchmarkCase, + *, + contexts: dict[str, Any], + export_detection_artifacts: bool, +) -> _CaseRunPaths: + return _CaseRunPaths( + raw_path=contexts["raw_dir"] / f"{case.case_id}.jsonl", + artifact_output_path=contexts["raw_dir"] / f"{case.case_id}.detection-artifacts.jsonl", + trace_path=_case_trace_path(case, contexts=contexts), + artifact_snapshot=snapshot_detection_artifacts(contexts["artifact_path"]) + if export_detection_artifacts + else None, + ) + + +def _run_case_success( + case: BenchmarkCase, + spec: BenchmarkSpec, + *, + contexts: dict[str, Any], + anonymizer: Anonymizer, + paths: _CaseRunPaths, + started: float, + attempt_count: int, + attempt_errors: list[str], +) -> BenchmarkCase: + workload = _get_item(contexts["workloads"], case.workload_id, "workload") + config = _get_item(contexts["configs"], case.config_id, "config") + execution = _execute_case( + anonymizer, + workload, + config, + raw_path=paths.raw_path, + trace_path=paths.trace_path, + case=case, + spec=spec, + base_dir=contexts["base_dir"], + dd_trace=contexts["dd_trace"], + dd_parser_compat=contexts["dd_parser_compat"], + ) + detection_artifact_path = _case_detection_artifact_path( + contexts, + paths, + case=case, + config=config, + execution=execution, + ) + return _case_with_result( + case, + status=CaseStatus.completed, + started=started, + raw_path=paths.raw_path, + detection_artifact_path=detection_artifact_path, + trace_path=paths.trace_path, + attempt_count=attempt_count, + attempt_errors=attempt_errors, + ) + + +def _run_case_error( + case: BenchmarkCase, + *, + contexts: dict[str, Any], + paths: _CaseRunPaths, + started: float, + error: Exception, + attempt_count: int, + attempt_errors: list[str], +) -> BenchmarkCase: + detection_artifact_path = _export_case_detection_artifacts_if_requested( + contexts, + paths.artifact_output_path, + case=case, + artifact_snapshot=paths.artifact_snapshot, + ) + return _case_with_result( + case, + status=CaseStatus.error, + started=started, + raw_path=paths.raw_path, + detection_artifact_path=detection_artifact_path, + trace_path=paths.trace_path, + error=str(error), + attempt_count=attempt_count, + attempt_errors=attempt_errors, + ) + + +def _case_detection_artifact_path( + contexts: dict[str, Any], + paths: _CaseRunPaths, + *, + case: BenchmarkCase, + config: ConfigSpec, + execution: _CaseExecution, +) -> Path | None: + detection_artifact_path = _export_case_detection_artifacts_if_requested( + contexts, + paths.artifact_output_path, + case=case, + artifact_snapshot=paths.artifact_snapshot, + ) + detection_artifact_path = _trace_final_artifact_path_if_requested( + config, + detection_artifact_path, + paths.artifact_output_path, + case=case, + trace_dataframe=execution.trace_dataframe, + ) + if detection_artifact_path is not None or paths.artifact_snapshot is None: + return detection_artifact_path + return export_rules_only_case_detection_artifacts( + config, + execution.input_data, + paths.artifact_output_path, + case=case, + ) + + +def _trace_final_artifact_path_if_requested( + config: ConfigSpec, + detection_artifact_path: Path | None, + output_path: Path, + *, + case: BenchmarkCase, + trace_dataframe: pd.DataFrame | None, +) -> Path | None: + if config.experimental_detection_strategy not in _TRACE_FINAL_ARTIFACT_STRATEGIES: + return detection_artifact_path + if trace_dataframe is None: + return detection_artifact_path + return patch_case_detection_artifacts_from_trace_dataframe( + detection_artifact_path or output_path, + trace_dataframe, + case=case, + replace_existing=config.experimental_detection_strategy + == ExperimentalDetectionStrategy.rules_covered_or_default, + ) + + +def patch_case_detection_artifacts_from_trace_dataframe( + output_path: Path, + trace_dataframe: pd.DataFrame, + *, + case: BenchmarkCase | None = None, + replace_existing: bool = False, +) -> Path | None: + final_rows = _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe) + if not final_rows: + return None + if replace_existing: + patched = final_rows + else: + rows = _read_detection_artifact_payloads(output_path) if output_path.exists() else [] + patched = _merge_final_entity_artifact_rows(rows, final_rows) + if case is not None: + patched = [_with_case_metadata(row, case=case) for row in patched] + write_detection_artifact_payloads(patched, output_path) + return output_path + + +def _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe: pd.DataFrame) -> list[dict[str, Any]]: + entities_column = _trace_final_entities_column(trace_dataframe) + if entities_column is None: + return [] + return [ + _final_entity_artifact_row(raw_entities, row_index=row_index) + for row_index, raw_entities in enumerate(trace_dataframe[entities_column]) + ] + + +def _trace_final_entities_column(trace_dataframe: pd.DataFrame) -> str | None: + if COL_FINAL_ENTITIES in trace_dataframe.columns: + return COL_FINAL_ENTITIES + if COL_DETECTED_ENTITIES in trace_dataframe.columns: + return COL_DETECTED_ENTITIES + return None + + +def _final_entity_artifact_row(raw_entities: object, *, row_index: int) -> dict[str, Any]: + entities = EntitiesSchema.from_raw(raw_entities).entities + return build_detection_artifact_row_from_entities( + workflow_name="entity-detection-final-trace", + batch_file="trace_dataframe", + row_index=row_index, + seed_entities=[], + seed_validation_candidate_count=0, + merged_validation_candidate_count=0, + augmented_entities=[], + final_entities=entities, + ).model_dump() + + +def _read_detection_artifact_payloads(output_path: Path) -> list[dict[str, Any]]: + with output_path.open(encoding="utf-8") as source: + return [json.loads(line) for line in source if line.strip()] + + +def _merge_final_entity_artifact_rows( + rows: list[dict[str, Any]], + final_rows: list[dict[str, Any]], +) -> list[dict[str, Any]]: + patched = [_patch_final_entity_artifact_row(row, final_row) for row, final_row in zip(rows, final_rows)] + return patched + rows[len(final_rows) :] + final_rows[len(rows) :] + + +def _patch_final_entity_artifact_row(row: dict[str, Any], final_row: dict[str, Any]) -> dict[str, Any]: + clean_row = _without_final_entity_artifact_fields(row) + return {**clean_row, **_final_entity_artifact_fields(final_row)} + + +def _without_final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: + return { + key: value + for key, value in row.items() + if key not in _FINAL_ARTIFACT_KEYS and not key.startswith(_FINAL_ARTIFACT_PREFIXES) + } + + +def _final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: + return {key: row[key] for key in _FINAL_ARTIFACT_KEYS} + + +def _case_with_result( + case: BenchmarkCase, + *, + status: CaseStatus, + started: float, + raw_path: Path, + detection_artifact_path: Path | None, + trace_path: Path | None, + attempt_count: int, + attempt_errors: list[str], + error: str | None = None, +) -> BenchmarkCase: + return case.model_copy( + update={ + "status": status, + "elapsed_sec": time.perf_counter() - started, + "measurement_path": str(raw_path), + "detection_artifact_path": (str(detection_artifact_path) if detection_artifact_path is not None else None), + "trace_path": str(trace_path) if trace_path is not None else None, + "error": error, + "attempt_count": attempt_count, + "attempt_errors": list(attempt_errors), + } + ) + + +def _export_case_detection_artifacts_if_requested( + contexts: dict[str, Any], + output_path: Path, + *, + case: BenchmarkCase, + artifact_snapshot: dict[str, int] | None, +) -> Path | None: + if artifact_snapshot is None: + return None + return export_case_detection_artifact_analysis( + contexts["artifact_path"], + output_path, + case=case, + artifact_snapshot=artifact_snapshot, + ) def _case_trace_path(case: BenchmarkCase, *, contexts: dict[str, Any]) -> Path | None: @@ -481,7 +1006,15 @@ def _execute_case( spec: BenchmarkSpec, base_dir: Path, dd_trace: DDTraceMode, -) -> None: + dd_parser_compat: DDParserCompatMode, +) -> _CaseExecution: + anonymizer_config = build_anonymizer_config(config) + input_data = build_input( + workload, + base_dir, + slice_dir=raw_path.parent / "inputs", + case_id=case.case_id, + ) measurement = MeasurementConfig( output_path=raw_path, run_id=case.case_id, @@ -493,18 +1026,99 @@ def _execute_case( fail_on_write_error=True, ) with configured_measurement_session(measurement): - anonymizer.run(config=build_anonymizer_config(config), data=build_input(workload, base_dir)) - - -def build_input(workload: WorkloadSpec, base_dir: Path) -> AnonymizerInput: + with dd_parser_compat_context(dd_parser_compat): + with experimental_detection_strategy_context( + config.experimental_detection_strategy, + rule_labels=config.experimental_rule_labels, + ): + with experimental_replacement_strategy_context(config.experimental_replacement_strategy): + result = anonymizer.run( + config=anonymizer_config, + data=input_data, + ) + return _CaseExecution(input_data=input_data, trace_dataframe=getattr(result, "trace_dataframe", None)) + + +def build_input( + workload: WorkloadSpec, + base_dir: Path, + *, + slice_dir: Path | None = None, + case_id: str | None = None, +) -> AnonymizerInput: + resolved_source = _resolve_input_source(workload.source, base_dir) + source = ( + _materialize_sliced_source(workload, resolved_source, slice_dir=slice_dir, case_id=case_id) + if _workload_has_row_slice(workload) + else resolved_source + ) return AnonymizerInput( - source=str(_resolve_input_source(workload.source, base_dir)), + source=str(source), text_column=workload.text_column, id_column=workload.id_column, data_summary=workload.data_summary, ) +def _workload_has_row_slice(workload: WorkloadSpec) -> bool: + return workload.row_limit is not None or workload.row_offset > 0 + + +def _is_local_input_source(source: str) -> bool: + return "://" not in source + + +def _materialize_sliced_source( + workload: WorkloadSpec, + source: str | Path, + *, + slice_dir: Path | None, + case_id: str | None, +) -> Path: + if not _is_local_input_source(str(source)): + raise ValueError(f"workload '{workload.id}' row slicing requires a local workload source") + if slice_dir is None or case_id is None: + raise ValueError("row slicing requires slice_dir and case_id") + source_path = Path(source) + suffix = infer_input_source_suffix(str(source_path)) + dataframe = _read_local_input_dataframe(source_path, suffix=suffix) + sliced = dataframe.iloc[_slice_bounds(workload)] + slice_dir.mkdir(parents=True, exist_ok=True) + destination = slice_dir / f"{_safe_case_filename(case_id)}{suffix}" + _write_local_input_dataframe(sliced, destination, suffix=suffix) + return destination + + +def _slice_bounds(workload: WorkloadSpec) -> slice: + start = workload.row_offset + stop = start + workload.row_limit if workload.row_limit is not None else None + return slice(start, stop) + + +def _read_local_input_dataframe(source: Path, *, suffix: str) -> pd.DataFrame: + if suffix == ".csv": + return pd.read_csv(source) + if suffix == ".parquet": + return pd.read_parquet(source) + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + + +def _write_local_input_dataframe(dataframe: pd.DataFrame, destination: Path, *, suffix: str) -> None: + if suffix == ".csv": + dataframe.to_csv(destination, index=False) + return + if suffix == ".parquet": + dataframe.to_parquet(destination, index=False) + return + supported_formats = " or ".join(SUPPORTED_IO_FORMATS) + raise ValueError(f"Unsupported input format: {suffix}. Use {supported_formats}.") + + +def _safe_case_filename(case_id: str) -> str: + return "".join(char if char.isalnum() or char in "._-" else "_" for char in case_id) + + def build_anonymizer_config(config: ConfigSpec) -> AnonymizerConfig: detect = Detect.model_validate(config.detect) if config.replace is not None: @@ -566,6 +1180,26 @@ def combine_measurements(cases: list[BenchmarkCase], destination: Path) -> Path: return destination +def combine_detection_artifact_analysis(cases: list[BenchmarkCase], destination: Path) -> Path | None: + chunks: list[str] = [] + for case in cases: + if case.detection_artifact_path is None: + continue + source = Path(case.detection_artifact_path) + if source.exists(): + chunks.append(_jsonl_chunk(source.read_text(encoding="utf-8"))) + if not chunks: + return None + destination.write_text("".join(chunks), encoding="utf-8") + return destination + + +def _jsonl_chunk(text: str) -> str: + if not text or text.endswith("\n"): + return text + return text + "\n" + + def export_measurement_tables(measurement_path: Path, table_dir: Path) -> Path: dataframe = read_measurements(measurement_path) export_tables( @@ -574,6 +1208,139 @@ def export_measurement_tables(measurement_path: Path, table_dir: Path) -> Path: return table_dir +def snapshot_detection_artifacts(artifact_path: Path) -> dict[str, int]: + if not artifact_path.exists(): + return {} + return { + str(parquet_file.relative_to(artifact_path)): parquet_file.stat().st_mtime_ns + for parquet_file in iter_detection_parquet_files(artifact_path) + } + + +def changed_detection_artifact_files(artifact_path: Path, snapshot: dict[str, int]) -> list[Path]: + if not artifact_path.exists(): + return [] + changed: list[Path] = [] + for parquet_file in iter_detection_parquet_files(artifact_path): + key = str(parquet_file.relative_to(artifact_path)) + if snapshot.get(key) != parquet_file.stat().st_mtime_ns: + changed.append(parquet_file) + return changed + + +def export_detection_artifact_analysis( + artifact_path: Path, + output_path: Path, + *, + artifact_snapshot: dict[str, int] | None = None, +) -> Path | None: + if not artifact_path.exists(): + return None + parquet_files = ( + changed_detection_artifact_files(artifact_path, artifact_snapshot) if artifact_snapshot is not None else None + ) + analysis = analyze_artifacts(artifact_path, parquet_files=parquet_files) + if not analysis.rows: + return None + write_detection_artifact_payloads([row.model_dump() for row in analysis.rows], output_path) + return output_path + + +def export_case_detection_artifact_analysis( + artifact_path: Path, + output_path: Path, + *, + case: BenchmarkCase, + artifact_snapshot: dict[str, int], +) -> Path | None: + if not artifact_path.exists(): + return None + parquet_files = changed_detection_artifact_files(artifact_path, artifact_snapshot) + analysis = analyze_artifacts(artifact_path, parquet_files=parquet_files) + if not analysis.rows: + return None + write_detection_artifact_payloads( + [_with_case_metadata(row.model_dump(), case=case) for row in analysis.rows], + output_path, + ) + return output_path + + +def export_rules_only_case_detection_artifacts( + config: ConfigSpec, + input_data: AnonymizerInput, + output_path: Path, + *, + case: BenchmarkCase, +) -> Path | None: + if not _is_local_input_source(input_data.source): + return None + labels = build_anonymizer_config(config).detect.entity_labels + if not _uses_rules_only_artifact_export(config, labels): + return None + source = Path(input_data.source) + dataframe = _read_local_input_dataframe(source, suffix=infer_input_source_suffix(str(source))) + rows = [ + _with_case_metadata( + _rules_only_artifact_row( + text=record[input_data.text_column], + labels=labels, + row_index=int(row_index), + ), + case=case, + ) + for row_index, record in dataframe.iterrows() + ] + if not rows: + return None + write_detection_artifact_payloads(rows, output_path) + return output_path + + +def _uses_rules_only_artifact_export(config: ConfigSpec, labels: list[str] | None) -> bool: + if labels is None: + return False + if config.experimental_detection_strategy == ExperimentalDetectionStrategy.rules_only: + return True + if config.experimental_detection_strategy != ExperimentalDetectionStrategy.rules_covered_or_default: + return False + return set(labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) + + +def _rules_only_artifact_row(*, text: object, labels: list[str], row_index: int) -> dict[str, Any]: + entities = [ + EntitySchema.model_validate(span.as_dict()) + for span in detect_high_confidence_entities(str(text), labels=labels) + ] + return build_detection_artifact_row_from_entities( + workflow_name="entity-detection-rules-only", + batch_file="synthetic-rules-only", + row_index=row_index, + seed_entities=entities, + seed_validation_candidate_count=len(entities), + merged_validation_candidate_count=len(entities), + augmented_entities=[], + final_entities=entities, + ).model_dump() + + +def _with_case_metadata(row: dict[str, Any], *, case: BenchmarkCase) -> dict[str, Any]: + return { + "suite_id": case.suite_id, + "workload_id": case.workload_id, + "config_id": case.config_id, + "repetition": case.repetition, + "case_id": case.case_id, + "run_id": case.case_id, + **row, + } + + +def write_detection_artifact_payloads(rows: list[dict[str, Any]], output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + pd.json_normalize(rows, sep=".").to_json(output_path, orient="records", lines=True) + + def write_summary(result: BenchmarkResult) -> None: Path(result.summary_path).write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") @@ -590,12 +1357,17 @@ def render_result(result: BenchmarkResult, *, json_output: bool) -> str: def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: + config = next(item for item in spec.configs if item.id == case.config_id) return { "suite_id": spec.suite_id, "workload_id": case.workload_id, "config_id": case.config_id, "repetition": case.repetition, "case_id": case.case_id, + "experimental_detection_strategy": config.experimental_detection_strategy.value, + "experimental_replacement_strategy": config.experimental_replacement_strategy.value, + "experimental_rule_labels": config.experimental_rule_labels, + "dd_parser_compat": spec.dd_parser_compat.value, } @@ -655,6 +1427,7 @@ def dry_run_result( measurement_path=str(output_dir / "measurements.jsonl"), summary_path=str(output_dir / "summary.json"), table_dir=str(output_dir / "tables") if export else None, + detection_artifact_analysis_path=str(output_dir / "detection-artifacts.jsonl") if export else None, cases=cases, ) diff --git a/tools/measurement/screen_strategy_comparisons.py b/tools/measurement/screen_strategy_comparisons.py new file mode 100644 index 00000000..59dc280e --- /dev/null +++ b/tools/measurement/screen_strategy_comparisons.py @@ -0,0 +1,1166 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Screen strategy-comparison CSVs across benchmark analysis directories. + +Usage: + uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ + uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ --output strategy-screen.csv + uv run python tools/measurement/screen_strategy_comparisons.py run-a/analysis run-b/analysis --json +""" + +from __future__ import annotations + +import ast +import json +import logging +import sys +from enum import StrEnum +from pathlib import Path +from typing import Annotated + +import cyclopts +import pandas as pd +from pydantic import BaseModel, Field, model_validator + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.strategy_screen") + +COMPARISON_COLUMNS = { + "workload_id", + "baseline_config_id", + "candidate_config_id", + "safety_verdict", + "performance_verdict", + "candidate_verdict", +} + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +class GroupBy(StrEnum): + strategy = "strategy" + strategy_workload_family = "strategy_workload_family" + strategy_workload = "strategy_workload" + + +class ScreenRow(BaseModel): + source_path: str + workload_id: str + workload_family: str | None = None + baseline_config_id: str + candidate_config_id: str + baseline_strategy: str | None = None + candidate_strategy: str | None = None + baseline_replacement_strategy: str | None = None + candidate_replacement_strategy: str | None = None + baseline_case_count: int | None = None + candidate_case_count: int | None = None + value_protection_verdict: str | None = None + signature_parity_verdict: str | None = None + safety_verdict: str + performance_verdict: str + candidate_verdict: str + evidence_level: str = "legacy" + pipeline_elapsed_sec_delta_pct: float | None = None + observed_total_requests_delta: float | None = None + observed_total_tokens_delta: float | None = None + final_entity_count_delta: float | None = None + augmented_entity_count_delta: float | None = None + augmented_new_final_value_count_delta: float | None = None + baseline_original_value_leak_count: float | None = None + candidate_original_value_leak_count: float | None = None + original_value_leak_count_delta: float | None = None + baseline_original_value_leak_record_count: float | None = None + candidate_original_value_leak_record_count: float | None = None + original_value_leak_record_count_delta: float | None = None + baseline_replacement_missing_final_entity_count: float | None = None + candidate_replacement_missing_final_entity_count: float | None = None + replacement_missing_final_entity_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_count: float | None = None + candidate_replacement_synthetic_original_collision_count: float | None = None + replacement_synthetic_original_collision_count_delta: float | None = None + baseline_replacement_synthetic_original_collision_value_count: float | None = None + candidate_replacement_synthetic_original_collision_value_count: float | None = None + replacement_synthetic_original_collision_value_count_delta: float | None = None + baseline_duplicate_synthetic_replacement_count: float | None = None + candidate_duplicate_synthetic_replacement_count: float | None = None + duplicate_synthetic_replacement_count_delta: float | None = None + baseline_only_final_entity_signature_count: float | None = None + candidate_only_final_entity_signature_count: float | None = None + shared_final_entity_signature_count: float | None = None + baseline_stable_final_entity_signature_count: int | None = None + candidate_stable_final_entity_signature_count: int | None = None + baseline_stable_candidate_unstable_final_entity_signature_count: float | None = None + candidate_stable_baseline_unstable_final_entity_signature_count: float | None = None + shared_stable_final_entity_signature_count: int | None = None + flags: list[str] = Field(default_factory=list) + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) + stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + + @model_validator(mode="after") + def fill_workload_family(self) -> "ScreenRow": + if self.workload_family is None: + self.workload_family = workload_family(self.workload_id) + return self + + +class ScreenSummary(BaseModel): + viable_count: int = 0 + review_count: int = 0 + reject_count: int = 0 + candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + safety_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + evidence_level_counts: dict[str, int] = Field(default_factory=dict) + + +class ScreenGroup(BaseModel): + group_key: str + candidate_strategy: str | None = None + candidate_replacement_strategy: str | None = None + candidate_config_ids: list[str] = Field(default_factory=list) + workload_ids: list[str] = Field(default_factory=list) + workload_families: list[str] = Field(default_factory=list) + row_count: int = 0 + min_baseline_case_count: int | None = None + min_candidate_case_count: int | None = None + viable_count: int = 0 + review_count: int = 0 + reject_count: int = 0 + has_conflicting_verdicts: bool = False + recommendation: str = "unknown" + value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) + signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) + performance_verdict_counts: dict[str, int] = Field(default_factory=dict) + evidence_level_counts: dict[str, int] = Field(default_factory=dict) + split_verdict_candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) + best_pipeline_elapsed_sec_delta_pct: float | None = None + best_observed_total_tokens_delta: float | None = None + best_observed_total_requests_delta: float | None = None + worst_pipeline_elapsed_sec_delta_pct: float | None = None + worst_observed_total_tokens_delta: float | None = None + worst_observed_total_requests_delta: float | None = None + min_shared_stable_final_entity_signature_count: int | None = None + sum_baseline_replacement_missing_final_entity_count: float | None = None + sum_candidate_replacement_missing_final_entity_count: float | None = None + sum_baseline_original_value_leak_count: float | None = None + sum_candidate_original_value_leak_count: float | None = None + sum_baseline_original_value_leak_record_count: float | None = None + sum_candidate_original_value_leak_record_count: float | None = None + sum_baseline_replacement_synthetic_original_collision_count: float | None = None + sum_candidate_replacement_synthetic_original_collision_count: float | None = None + sum_baseline_replacement_synthetic_original_collision_value_count: float | None = None + sum_candidate_replacement_synthetic_original_collision_value_count: float | None = None + sum_baseline_duplicate_synthetic_replacement_count: float | None = None + sum_candidate_duplicate_synthetic_replacement_count: float | None = None + baseline_defect_improvement_count: int = 0 + flag_counts: dict[str, int] = Field(default_factory=dict) + baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) + label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) + stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) + + @model_validator(mode="after") + def fill_recommendation(self) -> "ScreenGroup": + if self.recommendation == "unknown": + self.recommendation = group_recommendation(self) + return self + + +class ScreenResult(BaseModel): + input_paths: list[str] + scanned_file_count: int + comparison_file_count: int + row_count: int + duplicate_row_count: int = 0 + summary: ScreenSummary = Field(default_factory=ScreenSummary) + rows: list[ScreenRow] = Field(default_factory=list) + groups: list[ScreenGroup] = Field(default_factory=list) + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) + + +def screen_comparison_paths( + paths: list[Path], + *, + group_by: GroupBy = GroupBy.strategy, + config_aliases: dict[str, str] | None = None, + source_includes: list[str] | None = None, + source_excludes: list[str] | None = None, +) -> ScreenResult: + files = filter_source_paths( + list(iter_csv_files(paths)), + includes=source_includes or [], + excludes=source_excludes or [], + ) + rows: list[ScreenRow] = [] + comparison_file_count = 0 + for csv_file in files: + table = read_csv_or_empty(csv_file) + if table is None: + continue + if not is_comparison_table(table): + continue + comparison_file_count += 1 + rows.extend(build_rows_from_table(table, source_path=csv_file)) + deduped_rows = sorted(dedupe_rows(rows), key=screen_sort_key) + groups = sorted( + summarize_groups(deduped_rows, group_by=group_by, config_aliases=config_aliases or {}), + key=group_sort_key, + ) + return ScreenResult( + input_paths=[str(path) for path in paths], + scanned_file_count=len(files), + comparison_file_count=comparison_file_count, + row_count=len(deduped_rows), + duplicate_row_count=len(rows) - len(deduped_rows), + summary=summarize_rows(deduped_rows), + rows=deduped_rows, + groups=groups, + ) + + +def iter_csv_files(paths: list[Path]) -> list[Path]: + files: list[Path] = [] + for path in paths: + if not path.exists(): + raise ValueError(f"comparison path does not exist: {path}") + if path.is_file(): + if path.suffix == ".csv": + files.append(path) + continue + files.extend(sorted(path.rglob("*.csv"))) + return sorted(set(files)) + + +def filter_source_paths(paths: list[Path], *, includes: list[str], excludes: list[str]) -> list[Path]: + return [path for path in paths if source_path_matches(path, includes=includes, excludes=excludes)] + + +def source_path_matches(path: Path, *, includes: list[str], excludes: list[str]) -> bool: + text = str(path) + if includes and not any(fragment in text for fragment in includes): + return False + return not any(fragment in text for fragment in excludes) + + +def is_comparison_table(table: pd.DataFrame) -> bool: + if "source_path" in table.columns: + return False + return COMPARISON_COLUMNS.issubset(set(table.columns)) + + +def read_csv_or_empty(csv_file: Path) -> pd.DataFrame | None: + try: + return pd.read_csv(csv_file) + except pd.errors.EmptyDataError: + return None + + +def read_config_aliases(path: Path | None) -> dict[str, str]: + if path is None: + return {} + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"config aliases file is not valid JSON: {path}") from exc + if not isinstance(payload, dict): + raise ValueError("config aliases file must contain a JSON object mapping config IDs to aliases") + return {str(key): str(value) for key, value in payload.items()} + + +def build_rows_from_table(table: pd.DataFrame, *, source_path: Path) -> list[ScreenRow]: + return [build_row(row, source_path=source_path) for _, row in table.iterrows()] + + +def build_row(row: pd.Series, *, source_path: Path) -> ScreenRow: + return ScreenRow( + source_path=str(source_path), + workload_id=required_string(row, "workload_id"), + workload_family=workload_family(required_string(row, "workload_id")), + baseline_config_id=required_string(row, "baseline_config_id"), + candidate_config_id=required_string(row, "candidate_config_id"), + baseline_strategy=optional_string(row.get("baseline_strategy")), + candidate_strategy=optional_string(row.get("candidate_strategy")), + baseline_replacement_strategy=optional_string(row.get("baseline_replacement_strategy")), + candidate_replacement_strategy=optional_string(row.get("candidate_replacement_strategy")), + baseline_case_count=optional_int(row.get("baseline_case_count")), + candidate_case_count=optional_int(row.get("candidate_case_count")), + value_protection_verdict=optional_string(row.get("value_protection_verdict")), + signature_parity_verdict=optional_string(row.get("signature_parity_verdict")), + safety_verdict=required_string(row, "safety_verdict"), + performance_verdict=required_string(row, "performance_verdict"), + candidate_verdict=required_string(row, "candidate_verdict"), + evidence_level=comparison_evidence_level(row), + pipeline_elapsed_sec_delta_pct=optional_float(row.get("pipeline_elapsed_sec_delta_pct")), + observed_total_requests_delta=optional_float(row.get("observed_total_requests_delta")), + observed_total_tokens_delta=optional_float(row.get("observed_total_tokens_delta")), + final_entity_count_delta=optional_float(row.get("final_entity_count_delta")), + augmented_entity_count_delta=optional_float(row.get("augmented_entity_count_delta")), + augmented_new_final_value_count_delta=optional_float(row.get("augmented_new_final_value_count_delta")), + baseline_original_value_leak_count=optional_float(row.get("baseline_original_value_leak_count")), + candidate_original_value_leak_count=optional_float(row.get("candidate_original_value_leak_count")), + original_value_leak_count_delta=optional_float(row.get("original_value_leak_count_delta")), + baseline_original_value_leak_record_count=optional_float(row.get("baseline_original_value_leak_record_count")), + candidate_original_value_leak_record_count=optional_float( + row.get("candidate_original_value_leak_record_count") + ), + original_value_leak_record_count_delta=optional_float(row.get("original_value_leak_record_count_delta")), + baseline_replacement_missing_final_entity_count=optional_float( + row.get("baseline_replacement_missing_final_entity_count") + ), + candidate_replacement_missing_final_entity_count=optional_float( + row.get("candidate_replacement_missing_final_entity_count") + ), + replacement_missing_final_entity_count_delta=optional_float( + row.get("replacement_missing_final_entity_count_delta") + ), + baseline_replacement_synthetic_original_collision_count=optional_float( + row.get("baseline_replacement_synthetic_original_collision_count") + ), + candidate_replacement_synthetic_original_collision_count=optional_float( + row.get("candidate_replacement_synthetic_original_collision_count") + ), + replacement_synthetic_original_collision_count_delta=optional_float( + row.get("replacement_synthetic_original_collision_count_delta") + ), + baseline_replacement_synthetic_original_collision_value_count=optional_float( + row.get("baseline_replacement_synthetic_original_collision_value_count") + ), + candidate_replacement_synthetic_original_collision_value_count=optional_float( + row.get("candidate_replacement_synthetic_original_collision_value_count") + ), + replacement_synthetic_original_collision_value_count_delta=optional_float( + row.get("replacement_synthetic_original_collision_value_count_delta") + ), + baseline_duplicate_synthetic_replacement_count=optional_float( + row.get("baseline_duplicate_synthetic_replacement_count") + ), + candidate_duplicate_synthetic_replacement_count=optional_float( + row.get("candidate_duplicate_synthetic_replacement_count") + ), + duplicate_synthetic_replacement_count_delta=optional_float( + row.get("duplicate_synthetic_replacement_count_delta") + ), + baseline_only_final_entity_signature_count=optional_float( + row.get("baseline_only_final_entity_signature_count") + ), + candidate_only_final_entity_signature_count=optional_float( + row.get("candidate_only_final_entity_signature_count") + ), + shared_final_entity_signature_count=optional_float(row.get("shared_final_entity_signature_count")), + baseline_stable_final_entity_signature_count=optional_int( + row.get("baseline_stable_final_entity_signature_count") + ), + candidate_stable_final_entity_signature_count=optional_int( + row.get("candidate_stable_final_entity_signature_count") + ), + baseline_stable_candidate_unstable_final_entity_signature_count=optional_float( + row.get("baseline_stable_candidate_unstable_final_entity_signature_count") + ), + candidate_stable_baseline_unstable_final_entity_signature_count=optional_float( + row.get("candidate_stable_baseline_unstable_final_entity_signature_count") + ), + shared_stable_final_entity_signature_count=optional_int(row.get("shared_stable_final_entity_signature_count")), + flags=parse_flags(row.get("flags")), + baseline_only_label_counts=preferred_count_columns( + row, + preferred_prefix="baseline_only_candidate_uncovered_signature_label_counts", + fallback_prefix="baseline_only_final_entity_signature_label_counts", + preferred_count_column="baseline_only_candidate_uncovered_signature_count", + ), + label_mismatch_label_counts=count_columns( + row, + "baseline_only_candidate_label_mismatch_signature_label_counts", + ), + baseline_replacement_missing_final_entity_label_counts=count_columns( + row, + "baseline_replacement_missing_final_entity_label_counts", + ), + candidate_replacement_missing_final_entity_label_counts=count_columns( + row, + "candidate_replacement_missing_final_entity_label_counts", + ), + baseline_original_value_leak_label_counts=count_columns(row, "baseline_original_value_leak_label_counts"), + candidate_original_value_leak_label_counts=count_columns(row, "candidate_original_value_leak_label_counts"), + baseline_replacement_synthetic_original_collision_label_counts=count_columns( + row, + "baseline_replacement_synthetic_original_collision_label_counts", + ), + candidate_replacement_synthetic_original_collision_label_counts=count_columns( + row, + "candidate_replacement_synthetic_original_collision_label_counts", + ), + stable_lost_label_counts=preferred_count_columns( + row, + preferred_prefix="baseline_stable_candidate_uncovered_signature_label_counts", + fallback_prefix="baseline_stable_candidate_unstable_final_entity_signature_label_counts", + preferred_count_column="baseline_stable_candidate_uncovered_signature_count", + ), + ) + + +def comparison_evidence_level(row: pd.Series) -> str: + if optional_string(row.get("value_protection_verdict")) and optional_string(row.get("signature_parity_verdict")): + return "split_verdicts" + if ( + not is_missing(row.get("baseline_stable_final_entity_signature_count")) + or not is_missing(row.get("candidate_stable_final_entity_signature_count")) + or not is_missing(row.get("shared_stable_final_entity_signature_count")) + ): + return "stable_signatures" + if ( + not is_missing(row.get("baseline_only_final_entity_signature_count")) + or not is_missing(row.get("candidate_only_final_entity_signature_count")) + or not is_missing(row.get("shared_final_entity_signature_count")) + ): + return "signature_counts" + return "legacy" + + +def summarize_rows(rows: list[ScreenRow]) -> ScreenSummary: + candidate_counts = count_values(row.candidate_verdict for row in rows) + return ScreenSummary( + viable_count=candidate_counts.get("candidate_viable", 0), + review_count=candidate_counts.get("review", 0), + reject_count=candidate_counts.get("reject", 0), + candidate_verdict_counts=candidate_counts, + value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), + signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), + safety_verdict_counts=count_values(row.safety_verdict for row in rows), + performance_verdict_counts=count_values(row.performance_verdict for row in rows), + evidence_level_counts=count_values(row.evidence_level for row in rows), + ) + + +def summarize_groups( + rows: list[ScreenRow], + *, + group_by: GroupBy = GroupBy.strategy, + config_aliases: dict[str, str] | None = None, +) -> list[ScreenGroup]: + return [ + build_group(group_key, group_rows) + for group_key, group_rows in grouped_rows( + rows, + group_by=group_by, + config_aliases=config_aliases or {}, + ).items() + ] + + +def grouped_rows( + rows: list[ScreenRow], + *, + group_by: GroupBy, + config_aliases: dict[str, str], +) -> dict[str, list[ScreenRow]]: + groups: dict[str, list[ScreenRow]] = {} + for row in rows: + groups.setdefault(group_key_for_row(row, group_by=group_by, config_aliases=config_aliases), []).append(row) + return dict(sorted(groups.items())) + + +def group_key_for_row(row: ScreenRow, *, group_by: GroupBy, config_aliases: dict[str, str]) -> str: + base = group_base_for_row(row, config_aliases=config_aliases) + if group_by == GroupBy.strategy_workload_family: + return f"{base}|family:{row.workload_family}" + if group_by == GroupBy.strategy_workload: + return f"{base}|workload:{row.workload_id}" + return base + + +def group_base_for_row(row: ScreenRow, *, config_aliases: dict[str, str]) -> str: + replacement = _non_default_replacement_strategy(row) + if row.candidate_strategy and row.candidate_strategy != "default": + base = f"strategy:{row.candidate_strategy}" + return f"{base}|replacement:{replacement}" if replacement else base + if replacement: + return f"replacement:{replacement}" + if alias := config_aliases.get(row.candidate_config_id): + return f"alias:{alias}" + return f"config:{row.candidate_config_id}" + + +def _non_default_replacement_strategy(row: ScreenRow) -> str | None: + if row.candidate_replacement_strategy and row.candidate_replacement_strategy != "default": + return row.candidate_replacement_strategy + return None + + +def workload_family(workload_id: str) -> str: + parts = [part for part in workload_id.split("-") if part] + while parts and _is_workload_slice_suffix(parts[-1]): + parts.pop() + return "-".join(parts) if parts else workload_id + + +def _is_workload_slice_suffix(value: str) -> bool: + return ( + value.isdigit() + or value == "slice" + or (value.startswith("r") and value[1:].isdigit()) + or (value.startswith("offset") and value[len("offset") :].isdigit()) + ) + + +def build_group(group_key: str, rows: list[ScreenRow]) -> ScreenGroup: + verdict_counts = count_values(row.candidate_verdict for row in rows) + group = ScreenGroup( + group_key=group_key, + candidate_strategy=single_optional_value(row.candidate_strategy for row in rows), + candidate_replacement_strategy=single_optional_value(row.candidate_replacement_strategy for row in rows), + candidate_config_ids=sorted({row.candidate_config_id for row in rows}), + workload_ids=sorted({row.workload_id for row in rows}), + workload_families=sorted({row.workload_family or workload_family(row.workload_id) for row in rows}), + row_count=len(rows), + min_baseline_case_count=min_present_int(row.baseline_case_count for row in rows), + min_candidate_case_count=min_present_int(row.candidate_case_count for row in rows), + viable_count=verdict_counts.get("candidate_viable", 0), + review_count=verdict_counts.get("review", 0), + reject_count=verdict_counts.get("reject", 0), + has_conflicting_verdicts=len(verdict_counts) > 1, + value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), + signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), + performance_verdict_counts=count_values(row.performance_verdict for row in rows), + evidence_level_counts=count_values(row.evidence_level for row in rows), + split_verdict_candidate_verdict_counts=count_values( + row.candidate_verdict for row in rows if row.evidence_level == "split_verdicts" + ), + best_pipeline_elapsed_sec_delta_pct=min_present(row.pipeline_elapsed_sec_delta_pct for row in rows), + best_observed_total_tokens_delta=min_present(row.observed_total_tokens_delta for row in rows), + best_observed_total_requests_delta=min_present(row.observed_total_requests_delta for row in rows), + worst_pipeline_elapsed_sec_delta_pct=max_present(row.pipeline_elapsed_sec_delta_pct for row in rows), + worst_observed_total_tokens_delta=max_present(row.observed_total_tokens_delta for row in rows), + worst_observed_total_requests_delta=max_present(row.observed_total_requests_delta for row in rows), + min_shared_stable_final_entity_signature_count=min_present_int( + row.shared_stable_final_entity_signature_count for row in rows + ), + sum_baseline_replacement_missing_final_entity_count=sum_present( + row.baseline_replacement_missing_final_entity_count for row in rows + ), + sum_candidate_replacement_missing_final_entity_count=sum_present( + row.candidate_replacement_missing_final_entity_count for row in rows + ), + sum_baseline_original_value_leak_count=sum_present(row.baseline_original_value_leak_count for row in rows), + sum_candidate_original_value_leak_count=sum_present(row.candidate_original_value_leak_count for row in rows), + sum_baseline_original_value_leak_record_count=sum_present( + row.baseline_original_value_leak_record_count for row in rows + ), + sum_candidate_original_value_leak_record_count=sum_present( + row.candidate_original_value_leak_record_count for row in rows + ), + sum_baseline_replacement_synthetic_original_collision_count=sum_present( + row.baseline_replacement_synthetic_original_collision_count for row in rows + ), + sum_candidate_replacement_synthetic_original_collision_count=sum_present( + row.candidate_replacement_synthetic_original_collision_count for row in rows + ), + sum_baseline_replacement_synthetic_original_collision_value_count=sum_present( + row.baseline_replacement_synthetic_original_collision_value_count for row in rows + ), + sum_candidate_replacement_synthetic_original_collision_value_count=sum_present( + row.candidate_replacement_synthetic_original_collision_value_count for row in rows + ), + sum_baseline_duplicate_synthetic_replacement_count=sum_present( + row.baseline_duplicate_synthetic_replacement_count for row in rows + ), + sum_candidate_duplicate_synthetic_replacement_count=sum_present( + row.candidate_duplicate_synthetic_replacement_count for row in rows + ), + baseline_defect_improvement_count=sum(1 for row in rows if row_has_baseline_defect_improvement(row)), + flag_counts=sum_string_counts(row.flags for row in rows), + baseline_only_label_counts=sum_dict_counts(row.baseline_only_label_counts for row in rows), + label_mismatch_label_counts=sum_dict_counts(row.label_mismatch_label_counts for row in rows), + stable_lost_label_counts=sum_dict_counts(row.stable_lost_label_counts for row in rows), + baseline_replacement_missing_final_entity_label_counts=sum_dict_counts( + row.baseline_replacement_missing_final_entity_label_counts for row in rows + ), + candidate_replacement_missing_final_entity_label_counts=sum_dict_counts( + row.candidate_replacement_missing_final_entity_label_counts for row in rows + ), + baseline_original_value_leak_label_counts=sum_dict_counts( + row.baseline_original_value_leak_label_counts for row in rows + ), + candidate_original_value_leak_label_counts=sum_dict_counts( + row.candidate_original_value_leak_label_counts for row in rows + ), + baseline_replacement_synthetic_original_collision_label_counts=sum_dict_counts( + row.baseline_replacement_synthetic_original_collision_label_counts for row in rows + ), + candidate_replacement_synthetic_original_collision_label_counts=sum_dict_counts( + row.candidate_replacement_synthetic_original_collision_label_counts for row in rows + ), + ) + group.recommendation = group_recommendation(group) + return group + + +def group_recommendation(group: ScreenGroup) -> str: + if group.viable_count and group.reject_count: + return "conflicting_evidence" + if group_needs_split_verdict_rerun(group): + return "needs_split_verdict_rerun" + if group_needs_viable_split_verdict(group): + return "needs_viable_split_verdict" + if group.viable_count and group.review_count and group_has_baseline_defect_improvement_reviews(group): + return "promising_with_baseline_defect_improvements" + if group.viable_count and group.review_count: + return "promising_needs_review" + if group.viable_count: + if group.row_count == 1: + return "single_slice_viable" + return "candidate_family_viable" + if group.reject_count: + return "reject" + if group.review_count: + if is_baseline_defect_improvement_group(group): + return "candidate_covers_baseline_defects" + if is_replacement_replay_review_group(group): + return "replacement_replay_review" + if is_reliability_review_group(group): + return "reliability_review" + if is_label_policy_review_group(group): + return "label_policy_review" + if is_fast_lane_review_group(group): + return "fast_lane_review" + if group.performance_verdict_counts.get("improved", 0) == group.review_count: + return "review_only" + if group.performance_verdict_counts.get("improved", 0) or group.performance_verdict_counts.get("mixed", 0): + return "review_mixed_performance" + return "no_performance_win" + return "unknown" + + +def group_needs_split_verdict_rerun(group: ScreenGroup) -> bool: + if not group.evidence_level_counts: + return False + if group.reject_count: + return False + split_verdict_count = group.evidence_level_counts.get("split_verdicts", 0) + if split_verdict_count == group.row_count: + return False + if group.viable_count and group.review_count: + return split_verdict_count == 0 + if group.review_count == group.row_count and split_verdict_count: + return True + return False + + +def group_needs_viable_split_verdict(group: ScreenGroup) -> bool: + if not (group.viable_count and group.review_count): + return False + if not group.evidence_level_counts.get("split_verdicts", 0): + return False + return not bool(group.split_verdict_candidate_verdict_counts.get("candidate_viable", 0)) + + +def is_replacement_replay_review_group(group: ScreenGroup) -> bool: + if not group.candidate_replacement_strategy or group.candidate_replacement_strategy == "default": + return False + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + return bool(group.flag_counts.get("replacement_only_detection_instability", 0)) + + +_BASELINE_DEFECT_IMPROVEMENT_FLAGS = { + "candidate_covers_baseline_original_value_leak", + "candidate_covers_baseline_replacement_missing_final_entity", + "candidate_covers_baseline_replacement_synthetic_original_collision", +} + + +def row_has_baseline_defect_improvement(row: ScreenRow) -> bool: + return bool(set(row.flags) & _BASELINE_DEFECT_IMPROVEMENT_FLAGS) + + +def group_has_baseline_defect_improvement_reviews(group: ScreenGroup) -> bool: + if not group.review_count: + return False + return group.baseline_defect_improvement_count == group.review_count + + +def is_baseline_defect_improvement_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: + return False + if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: + return False + return group.baseline_defect_improvement_count == group.review_count + + +_RELIABILITY_REVIEW_FLAGS = {"failed_request_increase", "bridge_fallback_increase"} + + +def is_reliability_review_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + return any(group.flag_counts.get(flag, 0) > 0 for flag in _RELIABILITY_REVIEW_FLAGS) + + +def is_label_policy_review_group(group: ScreenGroup) -> bool: + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: + return False + if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: + return False + return bool(group.label_mismatch_label_counts or group.flag_counts.get("covered_label_mismatch")) + + +_FAST_LANE_REVIEW_STRATEGIES = {"rules_only", "rules_covered_or_default"} +_FAST_LANE_REVIEW_FLAGS = { + "candidate_skips_llm_validation", + "candidate_uses_rule_entities", + "entity_count_loss", + "no_candidate_detector_entities", + "span_boundary_mismatch", +} + + +def is_fast_lane_review_group(group: ScreenGroup) -> bool: + if group.candidate_strategy not in _FAST_LANE_REVIEW_STRATEGIES: + return False + if group.review_count != group.row_count: + return False + if group.performance_verdict_counts.get("improved", 0) != group.review_count: + return False + leak_count = group.sum_candidate_original_value_leak_count + if leak_count is None or leak_count != 0: + return False + if ( + group.baseline_only_label_counts + or group.stable_lost_label_counts + or group.candidate_original_value_leak_label_counts + ): + return False + flags = set(group.flag_counts) + return bool(flags) and flags.issubset(_FAST_LANE_REVIEW_FLAGS) + + +def group_performance_summary(group: ScreenGroup) -> str: + if group.performance_verdict_counts: + return label_summary(group.performance_verdict_counts) + return "unknown" + + +def single_optional_value(values: object) -> str | None: + unique = sorted({str(value) for value in values if value is not None}) + return unique[0] if len(unique) == 1 else None + + +def min_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return min(present) if present else None + + +def min_present_int(values: object) -> int | None: + value = min_present(values) + return int(value) if value is not None else None + + +def max_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return max(present) if present else None + + +def sum_present(values: object) -> float | None: + present = [float(value) for value in values if value is not None] + return sum(present) if present else None + + +def sum_string_counts(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for items in values: + for item in items: + counts[str(item)] = counts.get(str(item), 0) + 1 + return dict(sorted(counts.items())) + + +def sum_dict_counts(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for mapping in values: + for key, value in mapping.items(): + counts[str(key)] = counts.get(str(key), 0) + int(value) + return dict(sorted(counts.items())) + + +def group_sort_key(group: ScreenGroup) -> tuple[int, int, float, float, str]: + conflict_rank = 1 if group.has_conflicting_verdicts else 0 + verdict_rank = 0 if group.viable_count else 1 if group.review_count else 2 + elapsed = group.best_pipeline_elapsed_sec_delta_pct + tokens = group.best_observed_total_tokens_delta + return ( + conflict_rank, + verdict_rank, + elapsed if elapsed is not None else float("inf"), + tokens if tokens is not None else float("inf"), + group.group_key, + ) + + +def dedupe_rows(rows: list[ScreenRow]) -> list[ScreenRow]: + deduped: dict[str, ScreenRow] = {} + for row in rows: + payload = row.model_dump() + payload.pop("source_path", None) + key = json.dumps(payload, ensure_ascii=True, sort_keys=True) + deduped.setdefault(key, row) + return list(deduped.values()) + + +def count_values(values: object) -> dict[str, int]: + counts: dict[str, int] = {} + for value in values: + key = str(value) + counts[key] = counts.get(key, 0) + 1 + return dict(sorted(counts.items())) + + +def count_present_values(values: object) -> dict[str, int]: + return count_values(value for value in values if value is not None) + + +def screen_sort_key(row: ScreenRow) -> tuple[int, float, float, str, str]: + verdict_rank = {"candidate_viable": 0, "review": 1, "reject": 2}.get(row.candidate_verdict, 3) + elapsed_delta = row.pipeline_elapsed_sec_delta_pct + token_delta = row.observed_total_tokens_delta + return ( + verdict_rank, + elapsed_delta if elapsed_delta is not None else float("inf"), + token_delta if token_delta is not None else float("inf"), + row.workload_id, + row.candidate_config_id, + ) + + +def required_string(row: pd.Series, column: str) -> str: + value = optional_string(row.get(column)) + if value is None: + raise ValueError(f"comparison row missing required value: {column}") + return value + + +def optional_string(value: object) -> str | None: + if is_missing(value): + return None + return str(value) + + +def optional_float(value: object) -> float | None: + if is_missing(value): + return None + return float(value) + + +def optional_int(value: object) -> int | None: + value = optional_float(value) + return int(value) if value is not None else None + + +def is_missing(value: object) -> bool: + return value is None or (isinstance(value, float) and pd.isna(value)) + + +def parse_flags(value: object) -> list[str]: + if is_missing(value): + return [] + if isinstance(value, list): + return [str(item) for item in value] + if not isinstance(value, str): + return [str(value)] + parsed = parse_nested(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + return [value] if value else [] + + +def parse_nested(value: str) -> object: + try: + return json.loads(value) + except json.JSONDecodeError: + try: + return ast.literal_eval(value) + except (SyntaxError, ValueError): + return value + + +def count_columns(row: pd.Series, prefix: str) -> dict[str, int]: + label_counts: dict[str, int] = {} + column_prefix = f"{prefix}." + for column, raw_value in row.items(): + if not str(column).startswith(column_prefix): + continue + value = optional_float(raw_value) + if value is not None and value > 0: + label_counts[str(column)[len(column_prefix) :]] = int(value) + return dict(sorted(label_counts.items())) + + +def preferred_count_columns( + row: pd.Series, + *, + preferred_prefix: str, + fallback_prefix: str, + preferred_count_column: str, +) -> dict[str, int]: + if has_comparison_field(row, preferred_prefix) or not is_missing(row.get(preferred_count_column)): + return count_columns(row, preferred_prefix) + return count_columns(row, fallback_prefix) + + +def has_comparison_field(row: pd.Series, prefix: str) -> bool: + column_prefix = f"{prefix}." + return any(str(column).startswith(column_prefix) for column in row.index) + + +def write_rows(rows: list[ScreenRow], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.json_normalize([row.model_dump() for row in rows], sep=".") + table = normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def write_groups(groups: list[ScreenGroup], output_path: Path, export_format: ExportFormat) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + table = pd.DataFrame([group.model_dump() for group in groups]) + table = normalize_table_cells(table) + if export_format == ExportFormat.parquet: + table.to_parquet(output_path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(output_path, index=False) + else: + table.to_json(output_path, orient="records", lines=True) + + +def normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: + normalized = table.copy() + for column in normalized.columns: + if normalized[column].map(is_nested_cell).any(): + normalized[column] = normalized[column].map(json_cell) + return normalized + + +def is_nested_cell(value: object) -> bool: + return isinstance(value, dict | list) + + +def json_cell(value: object) -> object: + if not is_nested_cell(value): + return value + return json.dumps(value, ensure_ascii=True, sort_keys=True) + + +def render_result(result: ScreenResult, *, json_output: bool, limit: int) -> str: + if json_output: + return result.model_dump_json(indent=2) + lines = [ + f"Screened {result.row_count} comparison row(s) from " + f"{result.comparison_file_count}/{result.scanned_file_count} CSV file(s): " + f"viable={result.summary.viable_count}, review={result.summary.review_count}, " + f"reject={result.summary.reject_count}, duplicates_skipped={result.duplicate_row_count}", + ] + for row in result.rows[:limit]: + lines.append(render_row(row)) + if len(result.rows) > limit: + lines.append(f"... {len(result.rows) - limit} more row(s)") + lines.append("Candidate groups:") + for group in result.groups[:limit]: + lines.append(render_group(group)) + if len(result.groups) > limit: + lines.append(f"... {len(result.groups) - limit} more group(s)") + return "\n".join(lines) + + +def render_row(row: ScreenRow) -> str: + return ( + f"- {row.workload_id}: {row.baseline_config_id}->{row.candidate_config_id} " + f"verdict={row.candidate_verdict} safety={row.safety_verdict} " + f"value_protection={row.value_protection_verdict or 'unknown'} " + f"signature_parity={row.signature_parity_verdict or 'unknown'} " + f"evidence={row.evidence_level} " + f"perf={row.performance_verdict} elapsed_delta={format_number(row.pipeline_elapsed_sec_delta_pct, '%')} " + f"replacement={row.baseline_replacement_strategy or 'unknown'}" + f"->{row.candidate_replacement_strategy or 'unknown'} " + f"tokens_delta={format_number(row.observed_total_tokens_delta)} " + f"cases={format_count(row.baseline_case_count)}/{format_count(row.candidate_case_count)} " + f"shared_stable={format_count(row.shared_stable_final_entity_signature_count)} " + f"aug_new_final_delta={format_number(row.augmented_new_final_value_count_delta)} " + f"baseline_missing_replacements={format_number(row.baseline_replacement_missing_final_entity_count)} " + f"candidate_missing_replacements={format_number(row.candidate_replacement_missing_final_entity_count)} " + f"baseline_original_value_leaks={format_number(row.baseline_original_value_leak_count)} " + f"candidate_original_value_leaks={format_number(row.candidate_original_value_leak_count)} " + f"baseline_replacement_collisions=" + f"{format_number(row.baseline_replacement_synthetic_original_collision_count)} " + f"candidate_replacement_collisions=" + f"{format_number(row.candidate_replacement_synthetic_original_collision_count)} " + f"baseline_duplicate_synthetics={format_number(row.baseline_duplicate_synthetic_replacement_count)} " + f"candidate_duplicate_synthetics={format_number(row.candidate_duplicate_synthetic_replacement_count)} " + f"lost={label_summary(row.baseline_only_label_counts)} " + f"label_mismatch={label_summary(row.label_mismatch_label_counts)} " + f"stable_lost={label_summary(row.stable_lost_label_counts)} " + f"baseline_missing_labels={label_summary(row.baseline_replacement_missing_final_entity_label_counts)} " + f"candidate_missing_labels={label_summary(row.candidate_replacement_missing_final_entity_label_counts)} " + f"baseline_leak_labels={label_summary(row.baseline_original_value_leak_label_counts)} " + f"leak_labels={label_summary(row.candidate_original_value_leak_label_counts)} " + f"baseline_collision_labels=" + f"{label_summary(row.baseline_replacement_synthetic_original_collision_label_counts)} " + f"collision_labels=" + f"{label_summary(row.candidate_replacement_synthetic_original_collision_label_counts)} " + f"flags={','.join(row.flags) if row.flags else 'none'}" + ) + + +def render_group(group: ScreenGroup) -> str: + return ( + f"- {group.group_key}: rows={group.row_count} viable={group.viable_count} " + f"review={group.review_count} reject={group.reject_count} " + f"conflict={str(group.has_conflicting_verdicts).lower()} recommendation={group.recommendation} " + f"replacement={group.candidate_replacement_strategy or 'unknown'} " + f"value_protection_counts={label_summary(group.value_protection_verdict_counts)} " + f"signature_parity_counts={label_summary(group.signature_parity_verdict_counts)} " + f"evidence_counts={label_summary(group.evidence_level_counts)} " + f"split_verdict_candidate_counts={label_summary(group.split_verdict_candidate_verdict_counts)} " + f"perf_counts={group_performance_summary(group)} " + f"best_elapsed_delta={format_number(group.best_pipeline_elapsed_sec_delta_pct, '%')} " + f"worst_elapsed_delta={format_number(group.worst_pipeline_elapsed_sec_delta_pct, '%')} " + f"best_tokens_delta={format_number(group.best_observed_total_tokens_delta)} " + f"worst_tokens_delta={format_number(group.worst_observed_total_tokens_delta)} " + f"best_requests_delta={format_number(group.best_observed_total_requests_delta)} " + f"worst_requests_delta={format_number(group.worst_observed_total_requests_delta)} " + f"min_cases={format_count(group.min_baseline_case_count)}/{format_count(group.min_candidate_case_count)} " + f"min_shared_stable={format_count(group.min_shared_stable_final_entity_signature_count)} " + f"baseline_defect_improvements={group.baseline_defect_improvement_count} " + f"baseline_missing_replacements=" + f"{format_number(group.sum_baseline_replacement_missing_final_entity_count)} " + f"candidate_missing_replacements=" + f"{format_number(group.sum_candidate_replacement_missing_final_entity_count)} " + f"baseline_original_value_leaks={format_number(group.sum_baseline_original_value_leak_count)} " + f"candidate_original_value_leaks={format_number(group.sum_candidate_original_value_leak_count)} " + f"baseline_replacement_collisions=" + f"{format_number(group.sum_baseline_replacement_synthetic_original_collision_count)} " + f"candidate_replacement_collisions=" + f"{format_number(group.sum_candidate_replacement_synthetic_original_collision_count)} " + f"baseline_duplicate_synthetics={format_number(group.sum_baseline_duplicate_synthetic_replacement_count)} " + f"candidate_duplicate_synthetics={format_number(group.sum_candidate_duplicate_synthetic_replacement_count)} " + f"lost={label_summary(group.baseline_only_label_counts)} " + f"label_mismatch={label_summary(group.label_mismatch_label_counts)} " + f"stable_lost={label_summary(group.stable_lost_label_counts)} " + f"baseline_missing_labels={label_summary(group.baseline_replacement_missing_final_entity_label_counts)} " + f"candidate_missing_labels={label_summary(group.candidate_replacement_missing_final_entity_label_counts)} " + f"baseline_leak_labels={label_summary(group.baseline_original_value_leak_label_counts)} " + f"leak_labels={label_summary(group.candidate_original_value_leak_label_counts)} " + f"baseline_collision_labels=" + f"{label_summary(group.baseline_replacement_synthetic_original_collision_label_counts)} " + f"collision_labels=" + f"{label_summary(group.candidate_replacement_synthetic_original_collision_label_counts)} " + f"flags={label_summary(group.flag_counts)}" + ) + + +def format_number(value: float | None, suffix: str = "") -> str: + if value is None: + return "unknown" + return f"{value:.1f}{suffix}" + + +def format_count(value: int | None) -> str: + return str(value) if value is not None else "unknown" + + +def label_summary(counts: dict[str, int]) -> str: + if not counts: + return "none" + return ",".join(f"{label}:{count}" for label, count in counts.items()) + + +@app.default +def main( + comparison_paths: list[Path], + *, + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + group_output: Annotated[Path | None, cyclopts.Parameter("--group-output")] = None, + group_by: Annotated[GroupBy, cyclopts.Parameter("--group-by")] = GroupBy.strategy, + config_aliases: Annotated[Path | None, cyclopts.Parameter("--config-aliases")] = None, + source_include: Annotated[list[str] | None, cyclopts.Parameter("--source-include")] = None, + source_exclude: Annotated[list[str] | None, cyclopts.Parameter("--source-exclude")] = None, + format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + limit: Annotated[int, cyclopts.Parameter("--limit")] = 20, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + configure_logging(log_format) + try: + result = screen_comparison_paths( + comparison_paths, + group_by=group_by, + config_aliases=read_config_aliases(config_aliases), + source_includes=source_include or [], + source_excludes=source_exclude or [], + ) + except ValueError as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + if output is not None: + write_rows(result.rows, output, format) + if group_output is not None: + write_groups(result.groups, group_output, format) + sys.stdout.write(render_result(result, json_output=json_output, limit=limit) + "\n") + + +if __name__ == "__main__": + app() diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py new file mode 100644 index 00000000..5751e3f1 --- /dev/null +++ b/tools/measurement/staged_detection_probe.py @@ -0,0 +1,1475 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Run benchmark-only DD-free staged entity detection probes. + +Usage: + uv run python tools/measurement/staged_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ + --text-column biography --labels age,city,first_name,last_name,occupation \ + --output /tmp/staged-probe --overwrite +""" + +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +import sys +import time +from collections import Counter +from dataclasses import dataclass, replace +from enum import StrEnum +from pathlib import Path +from typing import Annotated, Any, Protocol + +import cyclopts +import httpx +import pandas as pd +from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities +from dd_parser_compat import _load_embedded_json +from direct_detection_probe import ( + CaseStatus, + DirectCompletion, + DirectDetectionClient, + DirectDetectionRequest, + DirectGenerationRequest, + HttpxDirectDetectionClient, + LogFormat, + PromptMode, + SignatureComparison, + build_direct_prompt, + compare_signature_sets, + parse_labels, +) +from pydantic import BaseModel, Field, ValidationError, model_validator + +from anonymizer.engine.constants import ( + COL_AUGMENTED_ENTITIES, + COL_DETECTED_ENTITIES, + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_SEED_VALIDATION_CANDIDATES, + COL_TAG_NOTATION, + COL_TAGGED_TEXT, + COL_TEXT, + COL_VALIDATED_ENTITIES, + COL_VALIDATION_CANDIDATES, + COL_VALIDATION_DECISIONS, +) +from anonymizer.engine.detection.chunked_validation import ( + build_chunk_excerpt, + chunk_candidates, + merge_chunk_decisions, + order_candidates_by_position, +) +from anonymizer.engine.detection.custom_columns import ( + apply_validation_and_finalize, + apply_validation_to_seed_entities, + enrich_validation_decisions, + merge_and_build_candidates, + prepare_validation_inputs, +) +from anonymizer.engine.detection.postprocess import ( + VALIDATION_CONTEXT_WINDOW, + EntitySpan, + TagNotation, + apply_augmented_entities, + build_tagged_text, + build_validation_candidates, + get_tag_notation, + parse_raw_entities, + resolve_overlaps, +) +from anonymizer.engine.detection.rules import ( + STRUCTURED_RULE_FAST_LANE_LABELS, + detect_high_confidence_entities, +) +from anonymizer.engine.schemas import ( + EntitiesSchema, + EntitySchema, + RawValidationDecisionsSchema, + ValidatedDecisionSchema, + ValidatedDecisionsSchema, + ValidationCandidatesSchema, +) + +app = cyclopts.App(help=__doc__) +logger = logging.getLogger("measurement.staged_detection_probe") +_log_format = LogFormat.plain +_DATE_OF_BIRTH_CONTEXT_RE = re.compile(r"\b(born|birth|date of birth|dob)\b", re.IGNORECASE) + + +class SeedSource(StrEnum): + direct_llm = "direct_llm" + gliner = "gliner" + rules = "rules" + rules_trusted = "rules_trusted" + rules_plus_direct_llm = "rules_plus_direct_llm" + rules_router = "rules_router" + + +class ValidationPromptMode(StrEnum): + full_text = "full_text" + chunked_excerpt = "chunked_excerpt" + + +class GlinerDetectionRequest(BaseModel): + endpoint: str + model: str + text: str + labels: list[str] = Field(min_length=1) + threshold: float = 0.3 + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=120.0, gt=0) + api_key_env: str = "NVIDIA_API_KEY" + + +class GlinerSeedClient(Protocol): + def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: ... + + +class StagedExecutionConfig(BaseModel): + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1" + model: str = "nvidia/nemotron-3-super" + seed_source: SeedSource = SeedSource.direct_llm + gliner_endpoint: str = "https://integrate.api.nvidia.com/v1" + gliner_model: str = "nvidia/gliner-pii" + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = 0.3 + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=180.0, gt=0) + skip_augmentation: bool = False + skip_augmentation_when_rule_covered: bool = False + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text + validation_max_entities_per_call: int = Field(default=10, gt=0) + validation_excerpt_window_chars: int = Field(default=160, gt=0) + + +class HttpxGlinerSeedClient: + def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: + api_key = os.environ.get(request.api_key_env) + if not api_key: + raise ValueError(f"{request.api_key_env} is not set") + response = httpx.post( + f"{request.endpoint.rstrip('/')}/chat/completions", + headers={"Authorization": f"Bearer {api_key}"}, + json=_gliner_payload(request), + timeout=request.timeout_sec, + ) + response.raise_for_status() + data = response.json() + return DirectCompletion( + content=_completion_content(data), + elapsed_sec=float(response.elapsed.total_seconds()), + usage=data.get("usage") or {}, + ) + + +def _gliner_payload(request: GlinerDetectionRequest) -> dict[str, Any]: + return { + "model": request.model, + "messages": [{"role": "user", "content": request.text}], + "temperature": 0, + "max_tokens": request.max_tokens, + "labels": request.labels, + "threshold": request.threshold, + "chunk_length": 384, + "overlap": 128, + "flat_ner": False, + } + + +def _completion_content(data: dict[str, Any]) -> str: + choice = (data.get("choices") or [{}])[0] + message = choice.get("message") if isinstance(choice, dict) else {} + if isinstance(message, dict): + return str(message.get("content") or "") + return str(choice.get("text") or "") if isinstance(choice, dict) else "" + + +class StagedDetectionRequest(BaseModel): + case_id: str + text: str + labels: list[str] = Field(min_length=1) + row_index: int = 0 + data_summary: str | None = None + + @model_validator(mode="after") + def normalize_labels(self) -> StagedDetectionRequest: + self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) + if not self.labels: + raise ValueError("labels must contain at least one non-empty label") + return self + + +class PhaseUsage(BaseModel): + seed: dict[str, Any] = Field(default_factory=dict) + validation: dict[str, Any] = Field(default_factory=dict) + augmentation: dict[str, Any] = Field(default_factory=dict) + + +class PhaseModelWork(BaseModel): + seed: bool = False + validation: bool = False + augmentation: bool = False + + +class PhaseSkipReasons(BaseModel): + seed: str | None = None + validation: str | None = None + augmentation: str | None = None + + +class PhaseModelRequests(BaseModel): + seed: int = 0 + validation: int = 0 + augmentation: int = 0 + + +class StagedDetectionCase(BaseModel): + case_id: str + row_index: int + seed_source: SeedSource = SeedSource.direct_llm + status: CaseStatus + elapsed_sec: float | None = None + model_elapsed_sec: float | None = None + phase_usage: PhaseUsage = Field(default_factory=PhaseUsage) + phase_model_work: PhaseModelWork = Field(default_factory=PhaseModelWork) + phase_skip_reasons: PhaseSkipReasons = Field(default_factory=PhaseSkipReasons) + phase_model_requests: PhaseModelRequests = Field(default_factory=PhaseModelRequests) + total_usage: dict[str, int] = Field(default_factory=dict) + model_phase_count: int = 0 + model_request_count: int = 0 + rule_covered_label_set: bool = False + seed_suggestion_count: int = 0 + seed_entity_count: int = 0 + validation_candidate_count: int = 0 + validation_decision_count: int = 0 + augmented_suggestion_count: int = 0 + final_entity_count: int = 0 + final_entity_signature_count: int = 0 + final_entity_signature_hashes: list[str] = Field(default_factory=list) + final_label_counts: dict[str, int] = Field(default_factory=dict) + comparison: SignatureComparison | None = None + artifact: DetectionArtifactRow | None = None + error: str | None = None + + +class StagedDetectionRun(BaseModel): + input_path: str + text_column: str + endpoint: str + model: str + labels: list[str] + rows: list[StagedDetectionCase] = Field(default_factory=list) + + @property + def error_count(self) -> int: + return sum(1 for row in self.rows if row.status == CaseStatus.error) + + +@dataclass(frozen=True) +class StagedDetectionExecution: + case: StagedDetectionCase + row: dict[str, Any] + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(error: str) -> None: + if _log_format == LogFormat.json: + sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") + return + logger.error("bad_input error=%s", error) + + +def run_staged_detection_case( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None = None, + seed_source: SeedSource = SeedSource.direct_llm, + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", + model: str = "nvidia/nemotron-3-super", + gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", + gliner_model: str = "nvidia/gliner-pii", + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + max_tokens: int = 4096, + timeout_sec: float = 180.0, + skip_augmentation: bool = False, + skip_augmentation_when_rule_covered: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, +) -> StagedDetectionCase: + return execute_staged_detection_case( + request, + client=client, + seed_client=seed_client, + seed_source=seed_source, + endpoint=endpoint, + model=model, + gliner_endpoint=gliner_endpoint, + gliner_model=gliner_model, + gliner_api_key_env=gliner_api_key_env, + gliner_threshold=gliner_threshold, + max_tokens=max_tokens, + timeout_sec=timeout_sec, + skip_augmentation=skip_augmentation, + skip_augmentation_when_rule_covered=skip_augmentation_when_rule_covered, + validation_prompt_mode=validation_prompt_mode, + validation_max_entities_per_call=validation_max_entities_per_call, + validation_excerpt_window_chars=validation_excerpt_window_chars, + ).case + + +def execute_staged_detection_case( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None = None, + seed_source: SeedSource = SeedSource.direct_llm, + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", + model: str = "nvidia/nemotron-3-super", + gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", + gliner_model: str = "nvidia/gliner-pii", + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + max_tokens: int = 4096, + timeout_sec: float = 180.0, + skip_augmentation: bool = False, + skip_augmentation_when_rule_covered: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, +) -> StagedDetectionExecution: + config = _execution_config_from_params(locals()) + try: + return _run_staged_detection_execution( + request, + client=client, + seed_client=seed_client, + config=config, + ) + except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures + return StagedDetectionExecution(case=_errored_case(request, config.seed_source, exc), row={}) + + +def _execution_config_from_params(params: dict[str, Any]) -> StagedExecutionConfig: + return StagedExecutionConfig( + **{key: value for key, value in params.items() if key in StagedExecutionConfig.model_fields} + ) + + +def _errored_case(request: StagedDetectionRequest, seed_source: SeedSource, exc: Exception) -> StagedDetectionCase: + return StagedDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + seed_source=seed_source, + status=CaseStatus.error, + error=f"{type(exc).__name__}: {exc}", + ) + + +def _run_staged_detection_execution( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None, + config: StagedExecutionConfig, +) -> StagedDetectionExecution: + started = time.perf_counter() + row, seed_suggestion_count, seed_completion = _run_seed_phase( + request, + client=client, + seed_client=seed_client, + config=config, + ) + validation_completion = _run_validation_phase(row, request, client, config) + augmentation_completion = _run_augmentation_phase(row, request, client, config) + artifact = _finalize_row(row, request) + return StagedDetectionExecution( + case=_completed_case( + request, + seed_completion, + validation_completion, + augmentation_completion, + artifact, + row, + config=config, + seed_suggestion_count=seed_suggestion_count, + elapsed_sec=time.perf_counter() - started, + ), + row=row, + ) + + +def _run_seed_phase( + request: StagedDetectionRequest, + *, + client: DirectDetectionClient, + seed_client: GlinerSeedClient | None, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + if _uses_rule_short_circuit(request, config): + return _run_rules_seed_phase(request) + if config.seed_source == SeedSource.gliner: + return _run_gliner_seed_phase(request, seed_client or HttpxGlinerSeedClient(), config) + if config.seed_source in {SeedSource.rules, SeedSource.rules_trusted}: + return _run_rules_seed_phase(request) + if config.seed_source in {SeedSource.rules_plus_direct_llm, SeedSource.rules_router}: + return _run_rules_plus_direct_llm_seed_phase(request, client, config) + return _run_direct_llm_seed_phase(request, client, config) + + +def _run_gliner_seed_phase( + request: StagedDetectionRequest, + detector: GlinerSeedClient, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + completion = detector.detect( + GlinerDetectionRequest( + endpoint=config.gliner_endpoint, + model=config.gliner_model, + text=request.text, + labels=request.labels, + threshold=config.gliner_threshold, + max_tokens=config.max_tokens, + timeout_sec=config.timeout_sec, + api_key_env=config.gliner_api_key_env, + ) + ) + seed_spans = parse_raw_entities(raw_response=completion.content, text=request.text) + return _seed_row_from_spans(request, seed_spans), _raw_detector_entity_count(completion.content), completion + + +def _run_direct_llm_seed_phase( + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + completion = _complete( + client, + prompt=_seed_prompt(request), + config=config, + ) + row, seed_suggestions = _seed_row_from_llm(request, completion.content) + return row, len(seed_suggestions), completion + + +def _run_rules_plus_direct_llm_seed_phase( + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> tuple[dict[str, Any], int, DirectCompletion]: + completion = _complete(client, prompt=_seed_prompt(request), config=config) + direct_spans, seed_suggestions = _direct_seed_spans(request, completion.content) + rule_spans = detect_high_confidence_entities(request.text, labels=request.labels) + row = _seed_row_from_spans(request, resolve_overlaps([*rule_spans, *direct_spans])) + _limit_validation_candidates_to_sources(row, sources={"direct_seed"}) + return row, len(seed_suggestions) + len(rule_spans), completion + + +def _run_rules_seed_phase(request: StagedDetectionRequest) -> tuple[dict[str, Any], int, DirectCompletion]: + seed_spans = detect_high_confidence_entities(request.text, labels=request.labels) + completion = DirectCompletion(content="", elapsed_sec=0.0, usage={}) + return _seed_row_from_spans(request, seed_spans), len(seed_spans), completion + + +def _complete( + client: DirectDetectionClient, + *, + prompt: str, + config: StagedExecutionConfig, +) -> DirectCompletion: + return client.complete( + DirectGenerationRequest( + endpoint=config.endpoint, + model=config.model, + prompt=prompt, + max_tokens=config.max_tokens, + timeout_sec=config.timeout_sec, + ) + ) + + +def _seed_prompt(request: StagedDetectionRequest) -> str: + return build_direct_prompt(_direct_seed_request(request)) + + +def _direct_seed_request(request: StagedDetectionRequest) -> DirectDetectionRequest: + return DirectDetectionRequest( + case_id=request.case_id, + text=request.text, + labels=request.labels, + row_index=request.row_index, + prompt_mode=PromptMode.compact, + data_summary=request.data_summary, + ) + + +def _seed_row_from_llm(request: StagedDetectionRequest, content: str) -> tuple[dict[str, Any], list[dict[str, Any]]]: + seed_spans, suggestions = _direct_seed_spans(request, content) + return _seed_row_from_spans(request, seed_spans), suggestions + + +def _direct_seed_spans(request: StagedDetectionRequest, content: str) -> tuple[list[EntitySpan], list[dict[str, Any]]]: + suggestions = _extract_entity_suggestions(content) + seed_spans = _spans_from_suggestions(request.text, suggestions, labels=request.labels, source="direct_seed") + return seed_spans, suggestions + + +def _seed_row_from_spans(request: StagedDetectionRequest, seed_spans: list[EntitySpan]) -> dict[str, Any]: + allowed = set(request.labels) + seed_spans = [_normalize_seed_span(request.text, span, allowed_labels=allowed) for span in seed_spans] + seed_spans = [span for span in seed_spans if span.label in allowed] + row = {COL_TEXT: request.text, COL_TAG_NOTATION: get_tag_notation(text=request.text)} + row[COL_RAW_DETECTED] = "" + row[COL_SEED_ENTITIES] = EntitiesSchema(entities=[_span_to_entity_schema(span) for span in seed_spans]).model_dump( + mode="json" + ) + prepare_validation_inputs(row) + return row + + +def _normalize_seed_span(text: str, span: EntitySpan, *, allowed_labels: set[str]) -> EntitySpan: + if _should_promote_birth_context_date_seed(text, span, allowed_labels=allowed_labels): + return replace(span, entity_id=_seed_entity_id("date_of_birth", span), label="date_of_birth") + return span + + +def _should_promote_birth_context_date_seed( + text: str, + span: EntitySpan, + *, + allowed_labels: set[str], +) -> bool: + return span.label == "date" and "date_of_birth" in allowed_labels and _span_has_date_of_birth_context(text, span) + + +def _span_has_date_of_birth_context(text: str, span: EntitySpan) -> bool: + before_start = max(0, span.start_position - VALIDATION_CONTEXT_WINDOW) + after_end = min(len(text), span.end_position + VALIDATION_CONTEXT_WINDOW) + return _contains_date_of_birth_context( + text[before_start : span.start_position], span.value, text[span.end_position : after_end] + ) + + +def _seed_entity_id(label: str, span: EntitySpan) -> str: + return f"{label}_{span.start_position}_{span.end_position}" + + +def _limit_validation_candidates_to_sources(row: dict[str, Any], *, sources: set[str]) -> None: + text = str(row.get(COL_TEXT, "")) + seed_spans = [span for span in _seed_entity_spans(row) if span.source in sources] + row[COL_SEED_VALIDATION_CANDIDATES] = ValidationCandidatesSchema( + candidates=build_validation_candidates(text=text, entities=seed_spans) + ).model_dump(mode="json") + + +def _run_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + if config.seed_source == SeedSource.rules_trusted or _uses_rule_short_circuit(request, config): + _trust_seed_entities(row) + return DirectCompletion(content="", elapsed_sec=0.0, usage={}) + candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) + if not candidates.candidates: + row[COL_VALIDATION_DECISIONS] = {"decisions": []} + row[COL_VALIDATED_ENTITIES] = {"decisions": []} + apply_validation_to_seed_entities(row) + return DirectCompletion(content='{"decisions": []}', elapsed_sec=0.0, usage={}) + if config.validation_prompt_mode == ValidationPromptMode.chunked_excerpt: + completion = _run_chunked_validation_phase(row, request, candidates, client, config) + else: + completion = _run_full_text_validation_phase(row, request, candidates, client, config) + row[COL_VALIDATION_DECISIONS] = _extract_validation_decisions( + completion.content, + candidates=candidates, + labels=request.labels, + ) + enrich_validation_decisions(row) + apply_validation_to_seed_entities(row) + return completion + + +def _run_full_text_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + completion = _complete( + client, + prompt=_validation_prompt(request, candidates), + config=config, + ) + return completion + + +def _run_chunked_validation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + completions: list[DirectCompletion] = [] + chunk_results: list[RawValidationDecisionsSchema] = [] + for chunk in _validation_chunks(row, candidates, config): + completion = _complete(client, prompt=_chunk_validation_prompt(request, chunk), config=config) + completions.append(completion) + chunk_results.append( + RawValidationDecisionsSchema.from_raw( + _extract_validation_decisions(completion.content, candidates=chunk[1], labels=request.labels) + ) + ) + decisions = merge_chunk_decisions(chunk_results, candidates) + return _combine_completions(completions, content=json.dumps(decisions, ensure_ascii=True, sort_keys=True)) + + +def _validation_chunks( + row: dict[str, Any], + candidates: ValidationCandidatesSchema, + config: StagedExecutionConfig, +) -> list[tuple[str, ValidationCandidatesSchema]]: + seed_spans = _seed_entity_spans(row) + ordered = order_candidates_by_position(candidates, seed_spans) + return [ + (_validation_excerpt(row, chunk, seed_spans, config), _chunk_candidate_schema(chunk)) + for chunk in chunk_candidates(ordered, config.validation_max_entities_per_call) + ] + + +def _validation_excerpt( + row: dict[str, Any], + chunk: list[tuple[Any, EntitySpan]], + seed_spans: list[EntitySpan], + config: StagedExecutionConfig, +) -> str: + notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) + return build_chunk_excerpt( + text=str(row.get(COL_TEXT, "")), + chunk_spans=[span for _candidate, span in chunk], + all_spans=seed_spans, + window_chars=config.validation_excerpt_window_chars, + notation=notation, + ) + + +def _chunk_candidate_schema(chunk: list[tuple[Any, EntitySpan]]) -> ValidationCandidatesSchema: + return ValidationCandidatesSchema(candidates=[candidate for candidate, _span in chunk]) + + +def _chunk_validation_prompt( + request: StagedDetectionRequest, + chunk: tuple[str, ValidationCandidatesSchema], +) -> str: + excerpt, candidates = chunk + return _validation_prompt(request, candidates, input_text=excerpt) + + +def _seed_entity_spans(row: dict[str, Any]) -> list[EntitySpan]: + return [ + EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + for entity in EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities + ] + + +def _combine_completions(completions: list[DirectCompletion], *, content: str) -> DirectCompletion: + return DirectCompletion( + content=content, + elapsed_sec=sum(completion.elapsed_sec for completion in completions), + usage=_sum_completion_usage(completions), + ) + + +def _sum_completion_usage(completions: list[DirectCompletion]) -> dict[str, int]: + totals: Counter[str] = Counter() + for completion in completions: + for key, value in completion.usage.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def _trust_seed_entities(row: dict[str, Any]) -> None: + candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) + row[COL_VALIDATED_ENTITIES] = ValidatedDecisionsSchema( + decisions=[ + ValidatedDecisionSchema( + id=candidate.id, + decision="keep", + value=candidate.value, + label=candidate.label, + reason="trusted deterministic rule", + ) + for candidate in candidates.candidates + ] + ).model_dump(mode="json") + apply_validation_to_seed_entities(row) + + +def _validation_prompt( + request: StagedDetectionRequest, + candidates: ValidationCandidatesSchema, + *, + input_text: str | None = None, +) -> str: + text_for_prompt = input_text if input_text is not None else request.text + label_guidance = _validation_label_guidance(request.labels) + return f"""Validate candidate privacy-sensitive entities. +Use only these decisions: keep, reclass, drop. +Use only these labels when reclassifying: {", ".join(request.labels)}. +Prefer keep when the candidate already has the right specific label. For example, reclass date_of_birth to date only when the surrounding context is not birth-related. +{label_guidance} +Return JSON only with this shape: +{{"decisions": [{{"id": "candidate id", "decision": "keep|reclass|drop", "proposed_label": "", "reason": "short reason"}}]}} + +Context text: +--- +{text_for_prompt} +--- + +Candidates: +{candidates.model_dump_json()} +""" + + +def _validation_label_guidance(labels: list[str]) -> str: + allowed = set(labels) + lines: list[str] = [] + if "degree" in allowed: + lines.append( + "Prefer degree for credential names and degree abbreviations such as Bachelor of Science, BSc, BA, MA, MSc, PhD, or JD; use education_level for broad levels such as undergraduate, graduate, or high school." + ) + if not lines: + return "" + return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) + + +def _run_augmentation_phase( + row: dict[str, Any], + request: StagedDetectionRequest, + client: DirectDetectionClient, + config: StagedExecutionConfig, +) -> DirectCompletion: + if _should_skip_augmentation(request, config): + row[COL_AUGMENTED_ENTITIES] = {"entities": []} + return DirectCompletion(content="", elapsed_sec=0.0, usage={}) + completion = _complete( + client, + prompt=_augmentation_prompt(request, row), + config=config, + ) + row[COL_AUGMENTED_ENTITIES] = { + "entities": _allowed_entity_suggestions(_extract_entity_suggestions(completion.content), labels=request.labels) + } + return completion + + +def _should_skip_augmentation(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: + if config.skip_augmentation: + return True + if _uses_rule_short_circuit(request, config): + return True + if not config.skip_augmentation_when_rule_covered: + return False + if config.seed_source not in {SeedSource.rules, SeedSource.rules_trusted, SeedSource.rules_plus_direct_llm}: + return False + return set(request.labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) + + +def _uses_rule_short_circuit(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: + return config.seed_source == SeedSource.rules_router and _is_rule_covered_label_set(request) + + +def _augmentation_prompt(request: StagedDetectionRequest, row: dict[str, Any]) -> str: + seed_entities = row.get(COL_SEED_ENTITIES_JSON, "[]") + label_guidance = _augmentation_label_guidance(request.labels) + return f"""Find additional sensitive entities not already covered by the validated seed entities. +Use only these labels: {", ".join(request.labels)}. +Return exact substrings only. Do not invent values. +{label_guidance} +Return JSON only with this shape: +{{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} + +Input text: +--- +{request.text} +--- + +Validated seed entities: +{seed_entities} +""" + + +def _augmentation_label_guidance(labels: list[str]) -> str: + allowed = set(labels) + lines: list[str] = [] + if "first_name" in allowed: + lines.append( + "For family/member lists, split personal names connected by 'and' into separate first_name or last_name entities. Do not label a list of people as organization_name." + ) + if "last_name" in allowed and ({"place_name", "company_name", "organization_name"} & allowed): + lines.append( + "If a surname appears inside a street, place, organization, company, or possessive business phrase, also return the surname substring as last_name when it is not already covered." + ) + if not lines: + return "" + return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) + + +def _finalize_row(row: dict[str, Any], request: StagedDetectionRequest) -> DetectionArtifactRow: + merge_and_build_candidates(row) + apply_validation_and_finalize(row) + normalize_birth_context_final_entities(row, allowed_labels=request.labels) + seed_entities = EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities + final_entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities + augmented = _augmented_entity_schemas(row, request) + return build_detection_artifact_row_from_entities( + workflow_name="staged-direct-detection", + batch_file="staged-direct-detection", + row_index=request.row_index, + seed_entities=list(seed_entities), + seed_validation_candidate_count=_candidate_count(row.get(COL_SEED_VALIDATION_CANDIDATES)), + merged_validation_candidate_count=_candidate_count(row.get(COL_VALIDATION_CANDIDATES)), + augmented_entities=augmented, + final_entities=list(final_entities), + ) + + +def _completed_case( + request: StagedDetectionRequest, + seed: DirectCompletion, + validation: DirectCompletion, + augmentation: DirectCompletion, + artifact: DetectionArtifactRow, + row: dict[str, Any], + config: StagedExecutionConfig, + seed_suggestion_count: int, + elapsed_sec: float, +) -> StagedDetectionCase: + phase_usage = PhaseUsage(seed=seed.usage, validation=validation.usage, augmentation=augmentation.usage) + phase_model_work = _phase_model_work(request, artifact, config) + phase_model_requests = _phase_model_requests(artifact, phase_model_work, config) + model_elapsed_sec = seed.elapsed_sec + validation.elapsed_sec + augmentation.elapsed_sec + return StagedDetectionCase( + case_id=request.case_id, + row_index=request.row_index, + seed_source=config.seed_source, + status=CaseStatus.completed, + elapsed_sec=elapsed_sec, + model_elapsed_sec=model_elapsed_sec, + phase_usage=phase_usage, + phase_model_work=phase_model_work, + phase_skip_reasons=_phase_skip_reasons(request, artifact, config), + phase_model_requests=phase_model_requests, + total_usage=_sum_usage(phase_usage), + model_phase_count=_model_phase_count(phase_model_work), + model_request_count=_model_request_count(phase_model_requests), + rule_covered_label_set=_is_rule_covered_label_set(request), + seed_suggestion_count=seed_suggestion_count, + seed_entity_count=artifact.seed_entity_count, + validation_candidate_count=artifact.seed_validation_candidate_count, + validation_decision_count=_validated_decision_count(row.get(COL_VALIDATED_ENTITIES)), + augmented_suggestion_count=artifact.augmented_entity_count, + final_entity_count=artifact.final_entity_count, + final_entity_signature_count=artifact.final_entity_signature_count, + final_entity_signature_hashes=artifact.final_entity_signature_hashes, + final_label_counts=artifact.final_label_counts, + artifact=artifact, + ) + + +def _phase_model_work( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> PhaseModelWork: + return PhaseModelWork( + seed=_uses_seed_model(request, config), + validation=_uses_validation_model(request, artifact, config), + augmentation=not _should_skip_augmentation(request, config), + ) + + +def _uses_seed_model(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: + if _uses_rule_short_circuit(request, config): + return False + return config.seed_source in { + SeedSource.direct_llm, + SeedSource.gliner, + SeedSource.rules_plus_direct_llm, + SeedSource.rules_router, + } + + +def _uses_validation_model( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> bool: + if _uses_rule_short_circuit(request, config): + return False + if ( + config.seed_source in {SeedSource.rules_trusted, SeedSource.rules_router} + and artifact.seed_validation_candidate_count == 0 + ): + return False + if config.seed_source == SeedSource.rules_trusted: + return False + return artifact.seed_validation_candidate_count > 0 + + +def _phase_skip_reasons( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> PhaseSkipReasons: + return PhaseSkipReasons( + seed=_seed_skip_reason(request, config), + validation=_validation_skip_reason(request, artifact, config), + augmentation=_augmentation_skip_reason(request, config), + ) + + +def _seed_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: + if config.seed_source in {SeedSource.rules, SeedSource.rules_trusted} or _uses_rule_short_circuit(request, config): + return "deterministic_rules" + return None + + +def _validation_skip_reason( + request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig +) -> str | None: + if config.seed_source == SeedSource.rules_trusted or _uses_rule_short_circuit(request, config): + return "trusted_rules" + if artifact.seed_validation_candidate_count == 0: + return "no_seed_candidates" + return None + + +def _augmentation_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: + if config.skip_augmentation: + return "disabled" + if _should_skip_augmentation(request, config): + return "rule_covered_labels" + return None + + +def _model_phase_count(phase_model_work: PhaseModelWork) -> int: + return sum((phase_model_work.seed, phase_model_work.validation, phase_model_work.augmentation)) + + +def _phase_model_requests( + artifact: DetectionArtifactRow, + phase_model_work: PhaseModelWork, + config: StagedExecutionConfig, +) -> PhaseModelRequests: + return PhaseModelRequests( + seed=1 if phase_model_work.seed else 0, + validation=_validation_model_request_count(artifact, phase_model_work, config), + augmentation=1 if phase_model_work.augmentation else 0, + ) + + +def _validation_model_request_count( + artifact: DetectionArtifactRow, + phase_model_work: PhaseModelWork, + config: StagedExecutionConfig, +) -> int: + if not phase_model_work.validation: + return 0 + if config.validation_prompt_mode == ValidationPromptMode.full_text: + return 1 + return _ceil_div(artifact.seed_validation_candidate_count, config.validation_max_entities_per_call) + + +def _ceil_div(numerator: int, denominator: int) -> int: + return (numerator + denominator - 1) // denominator + + +def _model_request_count(phase_model_requests: PhaseModelRequests) -> int: + return phase_model_requests.seed + phase_model_requests.validation + phase_model_requests.augmentation + + +def _is_rule_covered_label_set(request: StagedDetectionRequest) -> bool: + return set(request.labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) + + +def _extract_entity_suggestions(content: str) -> list[dict[str, str]]: + payload = _load_embedded_json(content) + raw_entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in raw_entities + if isinstance(item, dict) and item.get("value") and item.get("label") + ] + + +def _allowed_entity_suggestions(suggestions: list[dict[str, str]], *, labels: list[str]) -> list[dict[str, str]]: + allowed = set(labels) + return [suggestion for suggestion in suggestions if suggestion["label"] in allowed] + + +def _extract_validation_decisions( + content: str, + *, + candidates: ValidationCandidatesSchema | None = None, + labels: list[str] | None = None, +) -> dict[str, Any]: + payload = _load_embedded_json(content) + decisions = payload.get("decisions", []) if isinstance(payload, dict) else [] + if not isinstance(decisions, list): + decisions = [] + if candidates is not None: + decisions = _preserve_specific_seed_labels(decisions, candidates) + if labels is not None: + decisions = _keep_invalid_reclass_labels(decisions, allowed_labels=set(labels)) + return {"decisions": decisions} + + +def _preserve_specific_seed_labels( + decisions: list[object], + candidates: ValidationCandidatesSchema, +) -> list[object]: + candidate_by_id = {candidate.id: candidate for candidate in candidates.candidates} + normalized: list[object] = [] + for decision in decisions: + if not isinstance(decision, dict): + normalized.append(decision) + continue + candidate = candidate_by_id.get(str(decision.get("id") or "")) + if candidate is None or not _should_preserve_specific_seed_label(decision, candidate): + normalized.append(decision) + continue + normalized.append( + { + **decision, + "decision": "keep", + "proposed_label": "", + "reason": _preserved_label_reason(decision.get("reason")), + } + ) + return normalized + + +def _should_preserve_specific_seed_label(decision: dict[str, object], candidate: Any) -> bool: + if str(decision.get("decision") or "") != "reclass": + return False + proposed_label = str(decision.get("proposed_label") or "") + if candidate.label == "date_of_birth" and proposed_label == "date": + return _has_date_of_birth_context(candidate) + return False + + +def _has_date_of_birth_context(candidate: Any) -> bool: + return _contains_date_of_birth_context(candidate.context_before, candidate.value, candidate.context_after) + + +def _keep_invalid_reclass_labels(decisions: list[object], *, allowed_labels: set[str]) -> list[object]: + normalized: list[object] = [] + for decision in decisions: + if not isinstance(decision, dict): + normalized.append(decision) + continue + if not _has_invalid_reclass_label(decision, allowed_labels=allowed_labels): + normalized.append(decision) + continue + normalized.append( + { + **decision, + "decision": "keep", + "proposed_label": "", + "reason": _invalid_reclass_label_reason(decision.get("reason"), decision.get("proposed_label")), + } + ) + return normalized + + +def _has_invalid_reclass_label(decision: dict[str, object], *, allowed_labels: set[str]) -> bool: + if str(decision.get("decision") or "") != "reclass": + return False + proposed_label = str(decision.get("proposed_label") or "").strip() + return proposed_label not in allowed_labels + + +def _invalid_reclass_label_reason(reason: object, proposed_label: object) -> str: + text = str(reason or "").strip() + suffix = f"ignored invalid reclass label {str(proposed_label or '').strip()!r}" + return f"{text}; {suffix}" if text else suffix + + +def normalize_birth_context_final_entities( + row: dict[str, Any], *, allowed_labels: list[str] | set[str] +) -> dict[str, Any]: + """Demote native date_of_birth spans without birth context back to date.""" + allowed = set(allowed_labels) + if not {"date", "date_of_birth"}.issubset(allowed): + return row + text = str(row.get(COL_TEXT, "")) + entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities + normalized: list[EntitySchema] = [] + changed = False + for entity in entities: + if entity.label == "date_of_birth" and not _entity_has_date_of_birth_context(text, entity): + normalized.append(entity.model_copy(update={"id": _entity_id("date", entity), "label": "date"})) + changed = True + else: + normalized.append(entity) + if changed: + row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=normalized).model_dump(mode="json") + row[COL_TAGGED_TEXT] = build_tagged_text( + text=text, + entities=[_entity_schema_to_span(entity) for entity in normalized], + ) + return row + + +def _entity_has_date_of_birth_context(text: str, entity: EntitySchema) -> bool: + before_start = max(0, entity.start_position - VALIDATION_CONTEXT_WINDOW) + after_end = min(len(text), entity.end_position + VALIDATION_CONTEXT_WINDOW) + return _contains_date_of_birth_context( + text[before_start : entity.start_position], + entity.value, + text[entity.end_position : after_end], + ) + + +def _entity_id(label: str, entity: EntitySchema) -> str: + return f"{label}_{entity.start_position}_{entity.end_position}" + + +def _entity_schema_to_span(entity: EntitySchema) -> EntitySpan: + return EntitySpan( + entity_id=entity.id, + value=entity.value, + label=entity.label, + start_position=entity.start_position, + end_position=entity.end_position, + score=entity.score, + source=entity.source, + ) + + +def _contains_date_of_birth_context(context_before: object, value: object, context_after: object) -> bool: + context = f"{context_before} {value} {context_after}" + return bool(_DATE_OF_BIRTH_CONTEXT_RE.search(context)) + + +def _preserved_label_reason(reason: object) -> str: + text = str(reason or "").strip() + suffix = "preserved specific seed label from birth-related context" + return f"{text}; {suffix}" if text else suffix + + +def _raw_detector_entity_count(content: str) -> int: + payload = _load_embedded_json(content) + entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return len(entities) if isinstance(entities, list) else 0 + + +def _spans_from_suggestions( + text: str, suggestions: list[dict[str, str]], *, labels: list[str], source: str +) -> list[EntitySpan]: + allowed = set(labels) + cleaned = [item for item in suggestions if item["value"] and item["label"] in allowed] + spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) + return [ + EntitySpan( + entity_id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=source, + ) + for span in spans + ] + + +def _span_to_entity_schema(span: EntitySpan) -> EntitySchema: + return EntitySchema( + id=span.entity_id, + value=span.value, + label=span.label, + start_position=span.start_position, + end_position=span.end_position, + score=span.score, + source=span.source, + ) + + +def _augmented_entity_schemas(row: dict[str, Any], request: StagedDetectionRequest) -> list[EntitySchema]: + spans = _spans_from_suggestions( + request.text, + _extract_entities_from_payload(row.get(COL_AUGMENTED_ENTITIES)), + labels=request.labels, + source="augmenter", + ) + return [_span_to_entity_schema(span) for span in spans] + + +def _extract_entities_from_payload(payload: object) -> list[dict[str, str]]: + entities = payload.get("entities", []) if isinstance(payload, dict) else [] + return [ + {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} + for item in entities + if isinstance(item, dict) + ] + + +def _candidate_count(raw: object) -> int: + return len(ValidationCandidatesSchema.from_raw(raw).candidates) + + +def _validated_decision_count(raw: object) -> int: + return len(ValidatedDecisionsSchema.from_raw(raw).decisions) + + +def _sum_usage(usage: PhaseUsage) -> dict[str, int]: + totals: Counter[str] = Counter() + for phase in (usage.seed, usage.validation, usage.augmentation): + for key, value in phase.items(): + if isinstance(value, int): + totals[key] += value + return dict(sorted(totals.items())) + + +def apply_baseline_comparisons( + cases: list[StagedDetectionCase], + baseline_artifacts: Path, +) -> list[StagedDetectionCase]: + baseline = _read_baseline_artifacts(baseline_artifacts) + return [_case_with_comparison(case, baseline.get(case.row_index)) for case in cases] + + +def _case_with_comparison(case: StagedDetectionCase, baseline_row: dict[str, Any] | None) -> StagedDetectionCase: + if baseline_row is None or case.status != CaseStatus.completed or case.artifact is None: + return case + baseline_hashes = _baseline_signature_hashes(baseline_row) + if baseline_hashes is None: + return case + comparison = compare_signature_sets( + baseline_hashes=baseline_hashes, + baseline_labels=_signature_labels(baseline_row), + direct_hashes=set(case.artifact.final_entity_signature_hashes), + direct_labels=case.artifact.final_entity_signature_labels, + ) + return case.model_copy(update={"comparison": comparison}) + + +def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: + hashes = row.get("final_entity_signature_hashes") + if not isinstance(hashes, list): + return None + return {str(item) for item in hashes} + + +def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: + baseline: dict[int, dict[str, Any]] = {} + with path.open(encoding="utf-8") as source: + for line in source: + if not line.strip(): + continue + row = json.loads(line) + row_index = int(row.get("row_index", 0)) + if row_index in baseline: + raise ValueError(f"baseline artifacts has multiple rows for row_index={row_index}") + baseline[row_index] = row + return baseline + + +def _signature_labels(row: dict[str, Any]) -> dict[str, str]: + return { + key.removeprefix("final_entity_signature_labels."): str(value) + for key, value in row.items() + if key.startswith("final_entity_signature_labels.") and value is not None + } + + +def run_probe( + input_path: Path, + *, + text_column: str, + labels: list[str], + output: Path | None = None, + overwrite: bool = False, + endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", + model: str = "nvidia/nemotron-3-super", + seed_source: SeedSource = SeedSource.direct_llm, + gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", + gliner_model: str = "nvidia/gliner-pii", + gliner_api_key_env: str = "NVIDIA_API_KEY", + gliner_threshold: float = 0.3, + skip_augmentation: bool = False, + skip_augmentation_when_rule_covered: bool = False, + validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, + validation_max_entities_per_call: int = 10, + validation_excerpt_window_chars: int = 160, + limit: int = 1, + offset: int = 0, + baseline_artifacts: Path | None = None, +) -> StagedDetectionRun: + requests = _load_requests(input_path, text_column=text_column, labels=labels, limit=limit, offset=offset) + config = _execution_config_from_params(locals()) + cases = _run_probe_cases(requests, config) + if baseline_artifacts is not None: + cases = apply_baseline_comparisons(cases, baseline_artifacts) + result = StagedDetectionRun( + input_path=str(input_path), text_column=text_column, endpoint=endpoint, model=model, labels=labels, rows=cases + ) + if output is not None: + write_outputs(result, output, overwrite=overwrite) + return result + + +def _run_probe_cases( + requests: list[StagedDetectionRequest], + config: StagedExecutionConfig, +) -> list[StagedDetectionCase]: + client = HttpxDirectDetectionClient() + seed_client = HttpxGlinerSeedClient() if config.seed_source == SeedSource.gliner else None + return [ + run_staged_detection_case( + request, + client=client, + seed_client=seed_client, + seed_source=config.seed_source, + endpoint=config.endpoint, + model=config.model, + gliner_endpoint=config.gliner_endpoint, + gliner_model=config.gliner_model, + gliner_api_key_env=config.gliner_api_key_env, + gliner_threshold=config.gliner_threshold, + skip_augmentation=config.skip_augmentation, + skip_augmentation_when_rule_covered=config.skip_augmentation_when_rule_covered, + validation_prompt_mode=config.validation_prompt_mode, + validation_max_entities_per_call=config.validation_max_entities_per_call, + validation_excerpt_window_chars=config.validation_excerpt_window_chars, + ) + for request in requests + ] + + +def _load_requests( + input_path: Path, *, text_column: str, labels: list[str], limit: int, offset: int +) -> list[StagedDetectionRequest]: + dataframe = pd.read_csv(input_path) + if text_column not in dataframe.columns: + raise ValueError(f"text column {text_column!r} not found in {input_path}") + selected = dataframe.iloc[offset : offset + limit] + return [ + StagedDetectionRequest( + case_id=f"{input_path.stem}-row-{int(index)}", + text=str(row[text_column]), + labels=labels, + row_index=int(index), + ) + for index, row in selected.iterrows() + ] + + +def write_outputs(result: StagedDetectionRun, output_dir: Path, *, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise ValueError(f"output directory already exists: {output_dir}") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True) + _write_jsonl(output_dir / "staged-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) + _write_jsonl(output_dir / "staged-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) + (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") + + +def _case_payload(case: StagedDetectionCase) -> dict[str, Any]: + payload = case.model_dump(exclude={"artifact"}) + payload["record_type"] = "staged_detection_case" + return payload + + +def _artifact_payload(case: StagedDetectionCase) -> dict[str, Any]: + payload = case.artifact.model_dump() if case.artifact is not None else {} + payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "staged_detection_artifact"}) + return payload + + +def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: + with path.open("w", encoding="utf-8") as target: + for row in rows: + target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") + + +def render_result(result: StagedDetectionRun, *, json_output: bool) -> str: + if json_output: + return result.model_dump_json(indent=2) + completed = len(result.rows) - result.error_count + return f"Ran {completed}/{len(result.rows)} staged detection case(s); errors={result.error_count}" + + +@app.default +def main( + input_path: Path, + *, + text_column: Annotated[str, cyclopts.Parameter("--text-column")], + labels: Annotated[str, cyclopts.Parameter("--labels")], + output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, + overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, + endpoint: Annotated[str, cyclopts.Parameter("--endpoint")] = "http://gpu-dev-pod-serve-svc:8000/v1", + model: Annotated[str, cyclopts.Parameter("--model")] = "nvidia/nemotron-3-super", + seed_source: Annotated[SeedSource, cyclopts.Parameter("--seed-source")] = SeedSource.direct_llm, + gliner_endpoint: Annotated[str, cyclopts.Parameter("--gliner-endpoint")] = "https://integrate.api.nvidia.com/v1", + gliner_model: Annotated[str, cyclopts.Parameter("--gliner-model")] = "nvidia/gliner-pii", + gliner_api_key_env: Annotated[str, cyclopts.Parameter("--gliner-api-key-env")] = "NVIDIA_API_KEY", + gliner_threshold: Annotated[float, cyclopts.Parameter("--gliner-threshold")] = 0.3, + skip_augmentation: Annotated[bool, cyclopts.Parameter("--skip-augmentation")] = False, + skip_augmentation_when_rule_covered: Annotated[ + bool, cyclopts.Parameter("--skip-augmentation-when-rule-covered") + ] = False, + validation_prompt_mode: Annotated[ + ValidationPromptMode, cyclopts.Parameter("--validation-prompt-mode") + ] = ValidationPromptMode.full_text, + validation_max_entities_per_call: Annotated[int, cyclopts.Parameter("--validation-max-entities-per-call")] = 10, + validation_excerpt_window_chars: Annotated[int, cyclopts.Parameter("--validation-excerpt-window-chars")] = 160, + limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, + offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, + baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, + json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, + log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, +) -> None: + params = locals() + json_output = bool(params.pop("json_output")) + result = _run_main_probe(params) + sys.stdout.write(render_result(result, json_output=json_output) + "\n") + if result.error_count: + raise SystemExit(1) + + +def _run_main_probe(params: dict[str, Any]) -> StagedDetectionRun: + configure_logging(params.pop("log_format")) + params["labels"] = parse_labels(params["labels"]) + try: + return run_probe(**params) + except (ValueError, ValidationError, httpx.HTTPError) as exc: + log_bad_input(str(exc)) + raise SystemExit(125) from exc + + +if __name__ == "__main__": + app() From 756c41d960391c89180370302248a638a3d71f48 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 20:56:45 +0000 Subject: [PATCH 05/26] Harden benchmark runtime portability Signed-off-by: Aaron Gonzales --- tests/tools/test_benchmark_output_analysis.py | 1 - .../tools/test_detection_artifact_analysis.py | 2 +- tests/tools/test_detection_strategies.py | 16 +- tests/tools/test_measurement_tools.py | 219 +++++++++++++++ tools/measurement/README.md | 126 ++++++--- tools/measurement/analyze_benchmark_output.py | 14 +- .../analyze_detection_artifacts.py | 6 - tools/measurement/compare_strategy_pairs.py | 15 +- tools/measurement/detection_strategies.py | 251 ++++++++++++++---- tools/measurement/direct_detection_probe.py | 29 +- .../measurement/examples/repo-data-smoke.yaml | 2 + .../run-repo-data-smoke-with-dd-traces.sh | 7 +- tools/measurement/extract_signature_deltas.py | 45 ++-- tools/measurement/run_benchmarks.py | 158 ++++++++++- tools/measurement/staged_detection_probe.py | 72 +++-- 15 files changed, 805 insertions(+), 158 deletions(-) diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index 44ba8358..3ebae5eb 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -302,7 +302,6 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp "row_index": 0, "start_position": 0, "end_position": 5, - "value_hash": "hash-person", "value_length": 5, } } diff --git a/tests/tools/test_detection_artifact_analysis.py b/tests/tools/test_detection_artifact_analysis.py index 3bc25c51..f99fffbd 100644 --- a/tests/tools/test_detection_artifact_analysis.py +++ b/tests/tools/test_detection_artifact_analysis.py @@ -113,7 +113,7 @@ def test_detection_artifact_analysis_reports_augmentation_contribution(tmp_path: assert first_name_detail["start_position"] == 0 assert first_name_detail["end_position"] == 5 assert first_name_detail["value_length"] == 5 - assert len(first_name_detail["value_hash"]) == 16 + assert "value_hash" not in first_name_detail serialized = row.model_dump_json() assert "Alice" not in serialized diff --git a/tests/tools/test_detection_strategies.py b/tests/tools/test_detection_strategies.py index c5bda974..984bb025 100644 --- a/tests/tools/test_detection_strategies.py +++ b/tests/tools/test_detection_strategies.py @@ -299,6 +299,7 @@ def complete(self, _request): # type: ignore[no-untyped-def] with tool.experimental_detection_strategy_context( tool.ExperimentalDetectionStrategy.native_rules_router, native_client=SequencedClient(), + native_runtime=tool.NativeDetectionRuntime(model="test/native", provider="test-provider"), ): workflow = EntityDetectionWorkflow(adapter=Mock()) workflow.detect_and_validate_entities( @@ -320,7 +321,8 @@ def complete(self, _request): # type: ignore[no-untyped-def] assert record["observed_input_tokens"] == 30 assert record["observed_output_tokens"] == 12 assert record["observed_total_tokens"] == 42 - assert record["model_usage"]["native-direct"]["model_name"] == "nvidia/nemotron-3-super" + assert record["model_usage"]["native-direct"]["model_name"] == "test/native" + assert record["model_usage"]["native-direct"]["model_provider_name"] == "test-provider" def test_native_candidate_validate_no_augment_strategy_skips_data_designer_and_augmentation() -> None: @@ -738,6 +740,12 @@ def complete(self, request): # type: ignore[no-untyped-def] tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, native_client=validation_client, gliner_seed_client=seed_client, + native_runtime=tool.NativeDetectionRuntime( + model="test/native", + provider="test-native-provider", + gliner_model="test/gliner", + gliner_provider="test-gliner-provider", + ), ): workflow = EntityDetectionWorkflow(adapter=adapter) result = workflow.detect_and_validate_entities( @@ -763,9 +771,11 @@ def complete(self, request): # type: ignore[no-untyped-def] assert record["observed_total_requests"] == 2 assert record["observed_total_tokens"] == 73 assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] - assert record["model_usage"]["gliner-direct"]["model_name"] == "nvidia/gliner-pii" + assert record["model_usage"]["gliner-direct"]["model_name"] == "test/gliner" + assert record["model_usage"]["gliner-direct"]["model_provider_name"] == "test-gliner-provider" assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 - assert record["model_usage"]["native-direct"]["model_name"] == "nvidia/nemotron-3-super" + assert record["model_usage"]["native-direct"]["model_name"] == "test/native" + assert record["model_usage"]["native-direct"]["model_provider_name"] == "test-native-provider" assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 48 diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 80abba54..5d1f5a79 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -14,6 +14,7 @@ import pandas as pd import pytest +import yaml from pydantic import ValidationError from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT @@ -847,6 +848,7 @@ def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None """, encoding="utf-8", ) + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "biographies.csv", index=False) output_dir = tmp_path / "dry-run-output" result = tool.run_or_plan( @@ -1037,6 +1039,43 @@ def test_benchmark_preflight_accepts_local_structured_substitute_supported_label tool.preflight_suite(spec, spec_path=spec_path) +def test_benchmark_example_suites_are_portable() -> None: + example_paths = sorted((REPO_ROOT / "tools/measurement/examples").glob("*.yaml")) + assert example_paths + + machine_specific_fragments = ( + "/root/", + "/Users/", + "/stable-cache/", + "gpu-dev-pod", + "serve-svc", + ) + path_fields = {"source", "model_configs", "model_providers", "artifact_path"} + + def walk(value: Any) -> Iterator[tuple[str, Any]]: + if isinstance(value, dict): + for key, item in value.items(): + yield str(key), item + yield from walk(item) + elif isinstance(value, list): + for item in value: + yield from walk(item) + + for example_path in example_paths: + payload = yaml.safe_load(example_path.read_text(encoding="utf-8")) + assert isinstance(payload, dict) + + for key, value in walk(payload): + if isinstance(value, str): + assert not any(fragment in value for fragment in machine_specific_fragments), ( + f"{example_path} contains machine-specific value for {key}: {value}" + ) + if key in path_fields: + assert not Path(value).is_absolute(), f"{example_path} uses absolute path for {key}: {value}" + if key in {"endpoint", "gliner_endpoint"}: + raise AssertionError(f"{example_path} should use endpoint_env for {key}, not a literal endpoint") + + def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: tool = load_tool( "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" @@ -1101,6 +1140,186 @@ def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> Non tool.preflight_suite(spec, spec_path=spec_path) +def test_benchmark_preflight_rejects_native_strategy_without_runtime( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_required", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="native_runtime.runtime_id"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_native_runtime_config_override_merges_suite_defaults() -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_merge", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + config = tool.ConfigSpec( + id="native-single-pass", + experimental_detection_strategy="native_single_pass", + native_runtime=tool.NativeRuntimeSpec(model="config-model", max_workers=2), + replace="redact", + ) + spec = tool.BenchmarkSpec( + suite_id="native-runtime-suite", + native_runtime=tool.NativeRuntimeSpec( + runtime_id="suite-runtime", + endpoint="http://suite-endpoint/v1", + model="suite-model", + provider="suite-provider", + max_workers=4, + ), + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[config], + ) + + runtime = tool._native_detection_runtime(spec, config) + + assert runtime.endpoint == "http://suite-endpoint/v1" + assert runtime.model == "config-model" + assert runtime.provider == "suite-provider" + assert runtime.max_workers == 2 + + +def test_benchmark_preflight_rejects_native_strategy_without_endpoint_or_model( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_endpoint_required", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) + monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +native_runtime: + runtime_id: native-runtime +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="native_runtime.endpoint"): + tool.preflight_suite(spec, spec_path=spec_path) + + +def test_benchmark_native_runtime_resolves_endpoint_and_model_from_env( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool( + "measurement_benchmark_tool_native_runtime_env", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", "http://runtime-from-env/v1") + monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_MODEL", "env-model") + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: native-runtime-suite +native_runtime: + runtime_id: env-runtime +workloads: + - id: input + source: input.csv +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + runtime = tool._native_detection_runtime(spec, spec.configs[0]) + tags = tool._run_tags( + tool.BenchmarkCase( + suite_id="native-runtime-suite", + workload_id="input", + config_id="native-single-pass", + repetition=0, + case_id="input__native-single-pass__r000", + ), + spec, + ) + + assert runtime.endpoint == "http://runtime-from-env/v1" + assert runtime.model == "env-model" + assert tags["native_runtime_id"] == "env-runtime" + assert "native_endpoint" not in tags + assert tags["native_endpoint_env"] == "ANONYMIZER_BENCH_NATIVE_ENDPOINT" + assert tags["native_model"] == "env-model" + + +def test_benchmark_preflight_skips_inactive_native_configs(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_inactive_native_runtime", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: inactive-native-suite +workloads: + - id: input + source: input.csv +configs: + - id: redact + replace: redact + - id: inactive-native + experimental_detection_strategy: native_single_pass + replace: redact +matrix: + - workload: input + config: redact +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + tool.preflight_suite(spec, spec_path=spec_path) + + def test_benchmark_case_passes_dd_trace_config_to_measurement_session( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 8a6a1d4f..f101d54d 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -108,6 +108,8 @@ even when the matrix has multiple configs or repetitions. Slicing is rejected for URL-like sources because the runner cannot safely materialize a local subset without downloading the whole dataset first. +Relative paths in suite files are resolved from the suite file's directory. + Use `case_retries` and `case_retry_backoff_sec` for long-running suites on shared model endpoints. Retries are disabled by default. When enabled, a failed case is retried with the same `case_id` and output paths; the final case still @@ -130,6 +132,30 @@ configs: digest_length: 12 ``` +Native benchmark strategies require an explicit runtime. Set top-level +`native_runtime.endpoint` and `native_runtime.model`, set the standard +`ANONYMIZER_BENCH_NATIVE_ENDPOINT` and `ANONYMIZER_BENCH_NATIVE_MODEL` +environment variables, or override runtime fields per config with +`configs[].native_runtime`. GLiNER-seeded native strategies also require +`native_runtime.gliner_endpoint` and `native_runtime.gliner_model`, or the +standard `ANONYMIZER_BENCH_GLINER_ENDPOINT` and +`ANONYMIZER_BENCH_GLINER_MODEL` environment variables. The runner records +runtime id, alias, provider, model, and env-variable names as run tags; raw +endpoint URLs are not emitted into measurement tables. + +```yaml +native_runtime: + runtime_id: local-vllm-json + endpoint_env: ANONYMIZER_BENCH_NATIVE_ENDPOINT + model_env: ANONYMIZER_BENCH_NATIVE_MODEL + provider: local-vllm + alias: native-direct +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +``` + Supported values: - `default`: run the normal Anonymizer detection pipeline. @@ -231,12 +257,11 @@ short-circuit for the former and falls back to the default pipeline for prose or legal labels. Use `native_rules_router` when you want the same routing shape without -DataDesigner orchestration. It defaults to the local OpenAI-compatible endpoint -used by the staged probe (`http://gpu-dev-pod-serve-svc:8000/v1`) and model -`nvidia/nemotron-3-super`. Treat it as a native-executor prototype: it can prove -that DataDesigner overhead is avoidable, but it must be compared against -baseline signatures and original-value leak metrics before any workload-specific -promotion decision. +DataDesigner orchestration. It uses the resolved native runtime endpoint/model +from `native_runtime` or the standard benchmark runtime environment variables. +Treat it as a native-executor prototype: it can prove that DataDesigner overhead +is avoidable, but it must be compared against baseline signatures and +original-value leak metrics before any workload-specific promotion decision. Use `native_candidate_validate_no_augment` when you want a narrower native executor diagnostic: direct seed candidates plus direct validation, with no @@ -260,15 +285,15 @@ validation and augmentation. Use `gliner_native_validate_no_augment` or `gliner_native_validate_native_augment` when the question is specifically "what if GLiNER did not run through DataDesigner?" These strategies use the -staged direct executor's GLiNER seed client, which defaults to -`https://integrate.api.nvidia.com/v1`, model `nvidia/gliner-pii`, and the -`NVIDIA_API_KEY` environment variable. The no-augmentation arm is a lower-cost -boundary; the native-augmentation arm is the quality-oriented no-DataDesigner -candidate. The integrated benchmark strategies execute staged direct rows with -bounded parallelism so hosted GLiNER and native validation/augmentation latency -is not serialized across records. These arms also normalize direct GLiNER -`date` seeds to `date_of_birth` only when the local seed context contains -birth/DOB language. +staged direct executor's GLiNER seed client using +`native_runtime.gliner_endpoint`, `native_runtime.gliner_model`, or the standard +GLiNER runtime environment variables; the API key env var defaults to +`NVIDIA_API_KEY`. The no-augmentation arm is a lower-cost boundary; the +native-augmentation arm is the quality-oriented no-DataDesigner candidate. The +integrated benchmark strategies execute staged direct rows with bounded +parallelism so GLiNER and native validation/augmentation latency is not +serialized across records. These arms also normalize direct GLiNER `date` seeds +to `date_of_birth` only when the local seed context contains birth/DOB language. Generic filing or event dates remain `date`. Both arms still need repeated signature, leak, label-mismatch, and reliability gates before any workload-specific promotion. @@ -320,8 +345,9 @@ The runner refuses to write into a non-empty output directory unless `tables/`; pass `--no-export` when you only want the raw measurement JSONL. Before starting a real run, the benchmark runner performs cheap preflight checks: suite/config parsing, local dataset existence, CSV/Parquet text-column -metadata, provider YAML shape, and active model-alias references. `--dry-run` -only expands the planned matrix and skips these file/config checks. +metadata, provider YAML shape, native runtime requirements, and active +model-alias references. `--dry-run` runs those same checks, expands the planned +matrix, and skips output-dir writes and model work. For debugging DataDesigner calls, pass `--dd-trace last-message` or `--dd-trace all-messages`. Trace records are written separately from sanitized @@ -380,6 +406,9 @@ input row, then reuses Anonymizer's existing span postprocessing, occurrence expansion, overlap resolution, and entity signature logic so results can be compared against normal detection artifacts. +Pass `--endpoint` and `--model`, or set `ANONYMIZER_BENCH_NATIVE_ENDPOINT` and +`ANONYMIZER_BENCH_NATIVE_MODEL`. + Example biography probe: ```bash @@ -387,8 +416,10 @@ uv run python tools/measurement/direct_detection_probe.py \ docs/data/NVIDIA_synthetic_biographies.csv \ --text-column biography \ --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ - --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/biographies-slice-1__biography-prose-default__r000.detection-artifacts.jsonl \ - --output /tmp/anonymizer-perf-explore/out-direct-detection-probe-biography \ + --endpoint http://your-openai-compatible-endpoint/v1 \ + --model your-model-id \ + --baseline-artifacts "$BASELINE_ARTIFACTS" \ + --output /tmp/direct-detection-probe-biography \ --overwrite \ --json ``` @@ -400,8 +431,10 @@ uv run python tools/measurement/direct_detection_probe.py \ docs/data/TAB_legal_sample25.csv \ --text-column text \ --labels application_number,city,country,date,date_of_birth,nationality,person \ - --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/legal-slice-2__legal-prose-default__r000.detection-artifacts.jsonl \ - --output /tmp/anonymizer-perf-explore/out-direct-detection-probe-legal \ + --endpoint http://your-openai-compatible-endpoint/v1 \ + --model your-model-id \ + --baseline-artifacts "$BASELINE_ARTIFACTS" \ + --output /tmp/direct-detection-probe-legal \ --overwrite \ --json ``` @@ -413,6 +446,9 @@ signature hashes, and optional baseline comparison counts. Artifact rows use the same opaque signature fields as `analyze_detection_artifacts.py` and omit raw entity values. For baseline comparison, pass a per-case sidecar or another artifact file with one row per `row_index`; duplicate row indexes are rejected +to avoid ambiguous comparisons. Treat the probe `summary.json` as a sensitive +debug artifact because it records the resolved endpoint/model runtime used for +the probe. so a combined multi-case artifact cannot silently select the wrong baseline. When this probe shape is promising, move it into a normal benchmark suite with @@ -455,8 +491,10 @@ uv run python tools/measurement/staged_detection_probe.py \ docs/data/NVIDIA_synthetic_biographies.csv \ --text-column biography \ --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ - --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/biographies-slice-1__biography-prose-default__r000.detection-artifacts.jsonl \ - --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography \ + --endpoint http://your-openai-compatible-endpoint/v1 \ + --model your-model-id \ + --baseline-artifacts "$BASELINE_ARTIFACTS" \ + --output /tmp/staged-detection-probe-biography \ --overwrite \ --json ``` @@ -468,16 +506,19 @@ uv run python tools/measurement/staged_detection_probe.py \ docs/data/TAB_legal_sample25.csv \ --text-column text \ --labels application_number,city,country,date,date_of_birth,nationality,person \ - --baseline-artifacts /tmp/anonymizer-perf-explore/out-repo-data-sliced-local-vllm-json-strategies-labels/raw/legal-slice-2__legal-prose-default__r000.detection-artifacts.jsonl \ - --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-legal \ + --endpoint http://your-openai-compatible-endpoint/v1 \ + --model your-model-id \ + --baseline-artifacts "$BASELINE_ARTIFACTS" \ + --output /tmp/staged-detection-probe-legal \ --overwrite \ --json ``` To replace the LLM seed phase with a direct GLiNER call, add -`--seed-source gliner`. The default GLiNER endpoint is NVIDIA-hosted -`https://integrate.api.nvidia.com/v1` with model `nvidia/gliner-pii`; it reads -the API key from `NVIDIA_API_KEY`. +`--seed-source gliner` plus `--gliner-endpoint` and `--gliner-model`, or set +`ANONYMIZER_BENCH_GLINER_ENDPOINT` and `ANONYMIZER_BENCH_GLINER_MODEL`. The +probe reads the GLiNER API key from `--gliner-api-key-env`, which defaults to +`NVIDIA_API_KEY`. To replace the LLM seed phase with deterministic local rules, add `--seed-source rules`. This still sends rule candidates through the validator. @@ -516,6 +557,8 @@ wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, plus `model_phase_count`, `model_request_count`, total usage, and optional baseline signature deltas. Use these fields to distinguish local work, provider latency, and a provider that returned no token accounting. +Treat the staged probe `summary.json` as a sensitive debug artifact because it +records the resolved endpoint/model runtime used for the probe. For example, a fully local rule-covered run should show `model_phase_count: 0`, `model_request_count: 0`, `rule_covered_label_set: true`, and `phase_skip_reasons.augmentation: "rule_covered_labels"`; `elapsed_sec` should @@ -527,8 +570,8 @@ To summarize those staged probe outputs without hand-written `jq`, run: ```bash uv run python tools/measurement/analyze_staged_detection_output.py \ - /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography \ - --output /tmp/anonymizer-perf-explore/out-staged-detection-probe-biography/analysis \ + /tmp/staged-detection-probe-biography \ + --output /tmp/staged-detection-probe-biography/analysis \ --format csv ``` @@ -660,9 +703,13 @@ A benchmark run writes one raw measurement file per case, then combines them: ```text benchmark-runs/suite-id/ raw/ + inputs/ + biographies__redact-default__r000.csv biographies__redact-default__r000.jsonl biographies__redact-default__r000.detection-artifacts.jsonl support__hash-agent-labels__r000.jsonl + artifacts/ + biographies__redact-default__r000/ traces/ biographies__redact-default__r000.jsonl measurements.jsonl @@ -689,6 +736,13 @@ Use `measurements.jsonl` when you need the original structured records. Use Use `traces/` only when `--dd-trace` was enabled and you need raw DataDesigner message-level debugging. +Treat `summary.json`, `raw/inputs/`, `artifacts/`, +`raw/*.detection-artifacts.jsonl`, and `traces/` as sensitive outputs. They can +contain source text, entity values, replacement values, prompts, model +responses, exception messages, or other PII-bearing debug data. The exported +measurement tables and detection signature ids are designed for analysis +without raw values, but debug sidecars are not sanitized bundles. + Detection workflow artifacts can be analyzed separately when you need to know whether augmentation helped or only added cost. `run_benchmarks.py` writes `detection-artifacts.jsonl` automatically when export is enabled and detection @@ -793,8 +847,9 @@ detection artifacts, including `artifact_final_detector_entity_count`, strategy is still relying on contextual detector/validator spans, or whether it has shifted a workload entirely onto deterministic rules. They also include `artifact_final_entity_signature_count` and -`artifact_final_entity_signature_hashes`, which are opaque per-row hashes of the -final entity label, normalized value, and offsets. The companion +`artifact_final_entity_signature_hashes`, which are opaque per-row identifiers +derived from the final entity label and offsets. They do not include raw or +normalized entity values. The companion `artifact_final_entity_signature_labels` field maps each opaque hash to its entity label. These fields do not expose raw entity values, but they let analysis tools detect when two configs report the same entity count while @@ -935,7 +990,7 @@ by source-path fragments: ```bash uv run python tools/measurement/screen_strategy_comparisons.py \ - /tmp/anonymizer-perf-explore \ + /tmp/anonymizer-benchmark-scratch \ --source-include analysis-current-csv \ --source-include analysis-failure-aware-csv \ --output current-strategy-screen.csv \ @@ -1306,8 +1361,9 @@ Use `extract_signature_deltas.py` when a fast candidate has fewer, more, or different final entity signatures than a higher-recall reference run. The tool compares two `detection-artifacts.jsonl` files and recovers local context from the DataDesigner artifact parquet files. Entity values are masked by default: -the output stores label, source, span offsets, value length, value hash, and a -small context window with the entity replaced by a placeholder. +the output stores label, source, span offsets, value length, signature id, and a +small context window with the entity replaced by a placeholder. It does not +emit a hash derived from the raw entity value. Example: review spans found by a text/raw-parser reference but missed by a hybrid candidate for one workload/config pair: diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py index b5844523..ec93b0cb 100644 --- a/tools/measurement/analyze_benchmark_output.py +++ b/tools/measurement/analyze_benchmark_output.py @@ -28,6 +28,14 @@ logger = logging.getLogger("measurement.benchmark_output") _SYNC_CLIENT_UNAVAILABLE_ERROR = "SyncClientUnavailableError" +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} class ExportFormat(StrEnum): @@ -771,6 +779,8 @@ def _artifact_signature_details(artifact_rows: pd.DataFrame) -> dict[str, dict[s signature_hash, _, field = remainder.partition(".") if not signature_hash or not field: continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue for value in artifact_rows[column].dropna(): details.setdefault(signature_hash, {})[field] = _json_scalar(value) return dict(sorted(details.items())) @@ -787,7 +797,9 @@ def _coerce_detail_map(raw: object) -> dict[str, dict[str, Any]]: details: dict[str, dict[str, Any]] = {} for signature_hash, value in raw.items(): if isinstance(value, dict): - details[str(signature_hash)] = {str(key): _json_scalar(item) for key, item in value.items()} + details[str(signature_hash)] = { + str(key): _json_scalar(item) for key, item in value.items() if str(key) in _SIGNATURE_DETAIL_FIELDS + } return details diff --git a/tools/measurement/analyze_detection_artifacts.py b/tools/measurement/analyze_detection_artifacts.py index 583a2218..0c16da75 100644 --- a/tools/measurement/analyze_detection_artifacts.py +++ b/tools/measurement/analyze_detection_artifacts.py @@ -279,7 +279,6 @@ def _entity_signature_details(entities: list[EntitySchema], *, row_index: int) - "row_index": int(row_index), "start_position": entity.start_position, "end_position": entity.end_position, - "value_hash": _entity_value_hash(entity.value), "value_length": len(entity.value), } for entity in entities @@ -292,7 +291,6 @@ def _entity_signature_hash(entity: EntitySchema, *, row_index: int) -> str: { "row": row_index, "label": entity.label, - "value": _value_key(entity.value), "start": entity.start_position, "end": entity.end_position, }, @@ -302,10 +300,6 @@ def _entity_signature_hash(entity: EntitySchema, *, row_index: int) -> str: return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] -def _entity_value_hash(value: str) -> str: - return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] - - def _weak_api_key_shape_counts(entities: list[EntitySchema]) -> Counter[str]: counts: Counter[str] = Counter() for entity in entities: diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py index 42faf61f..a4d9550d 100644 --- a/tools/measurement/compare_strategy_pairs.py +++ b/tools/measurement/compare_strategy_pairs.py @@ -30,6 +30,15 @@ app = cyclopts.App(help=__doc__) logger = logging.getLogger("measurement.strategy_pairs") +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} + class ExportFormat(StrEnum): parquet = "parquet" @@ -1104,6 +1113,8 @@ def _signature_details(rows: pd.DataFrame) -> dict[str, dict[str, object]]: signature_hash, _, field = remainder.partition(".") if not signature_hash or not field: continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue for value in rows[column].tolist(): if not _is_missing_cell(value): details.setdefault(signature_hash, {})[field] = _json_scalar(value) @@ -1221,7 +1232,9 @@ def _coerce_signature_details(value: object) -> dict[str, dict[str, object]]: if not isinstance(raw_detail, dict): continue details[str(signature_hash)] = { - str(key): _json_scalar(item) for key, item in raw_detail.items() if not _is_missing_cell(item) + str(key): _json_scalar(item) + for key, item in raw_detail.items() + if str(key) in _SIGNATURE_DETAIL_FIELDS and not _is_missing_cell(item) } return dict(sorted(details.items())) if not isinstance(value, str): diff --git a/tools/measurement/detection_strategies.py b/tools/measurement/detection_strategies.py index e55e677a..a8955cd9 100644 --- a/tools/measurement/detection_strategies.py +++ b/tools/measurement/detection_strategies.py @@ -85,15 +85,12 @@ from anonymizer.measurement import record_model_workflow _NATIVE_DIRECT_MODEL_ALIAS = "native-direct" -_NATIVE_DIRECT_MODEL_NAME = "nvidia/nemotron-3-super" -_NATIVE_DIRECT_MODEL_PROVIDER = "local-vllm" -_NATIVE_DIRECT_ENDPOINT = "http://gpu-dev-pod-serve-svc:8000/v1" -_NATIVE_DIRECT_MAX_TOKENS = 4096 -_NATIVE_DIRECT_TIMEOUT_SEC = 180.0 _GLINER_DIRECT_MODEL_ALIAS = "gliner-direct" -_GLINER_DIRECT_MODEL_NAME = "nvidia/gliner-pii" -_GLINER_DIRECT_MODEL_PROVIDER = "nvidia" -_NATIVE_STAGED_MAX_WORKERS = 4 +_DIRECT_MODEL_NAME = "configured-native-model" +_DIRECT_MODEL_PROVIDER = "configured-native-provider" +_DIRECT_MAX_TOKENS = 4096 +_DIRECT_TIMEOUT_SEC = 180.0 +_DIRECT_MAX_WORKERS = 4 _STRUCTURED_ASSIGNMENT_RE = re.compile( r"(?" @@ -147,6 +144,25 @@ class ExperimentalDetectionStrategy(StrEnum): """ +@dataclass(frozen=True) +class NativeDetectionRuntime: + """Runtime settings for benchmark-only native detection strategies.""" + + endpoint: str | None = None + model: str = _DIRECT_MODEL_NAME + provider: str = _DIRECT_MODEL_PROVIDER + alias: str = _NATIVE_DIRECT_MODEL_ALIAS + max_tokens: int = _DIRECT_MAX_TOKENS + timeout_sec: float = _DIRECT_TIMEOUT_SEC + gliner_endpoint: str | None = None + gliner_model: str = _DIRECT_MODEL_NAME + gliner_provider: str = _DIRECT_MODEL_PROVIDER + gliner_alias: str = _GLINER_DIRECT_MODEL_ALIAS + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = 0.3 + max_workers: int = _DIRECT_MAX_WORKERS + + @dataclass(frozen=True) class _NoAugmentOptions: include_rules: bool @@ -167,6 +183,7 @@ class _NativeStagedRunParams: labels: list[str] client: DirectDetectionClient gliner_seed_client: GlinerSeedClient | None + runtime: NativeDetectionRuntime data_summary: str | None validation_prompt_mode: ValidationPromptMode validation_max_entities_per_call: int @@ -189,6 +206,7 @@ class _NativeStagedRowResult: class _DetectorNativeValidationParams: labels: list[str] client: DirectDetectionClient + runtime: NativeDetectionRuntime data_summary: str | None validation_prompt_mode: ValidationPromptMode validation_max_entities_per_call: int @@ -202,6 +220,7 @@ class _DetectorNativeValidationRowResult: ordinal: int index: Any workflow_name: str + runtime: NativeDetectionRuntime output_row: dict[str, Any] | None failed_record: FailedRecord | None completion: DirectCompletion | None @@ -216,6 +235,7 @@ def experimental_detection_strategy_context( rule_labels: list[str] | None = None, native_client: DirectDetectionClient | None = None, gliner_seed_client: GlinerSeedClient | None = None, + native_runtime: NativeDetectionRuntime | None = None, ) -> Iterator[None]: """Temporarily apply a benchmark-only detection strategy.""" if strategy == ExperimentalDetectionStrategy.default: @@ -239,6 +259,7 @@ def experimental_detection_strategy_context( rule_labels=rule_labels, native_client=native_client, gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime or NativeDetectionRuntime(), ) try: yield @@ -276,7 +297,9 @@ def _method_for_strategy( rule_labels: list[str] | None = None, native_client: DirectDetectionClient | None = None, gliner_seed_client: GlinerSeedClient | None = None, + native_runtime: NativeDetectionRuntime | None = None, ) -> _DetectAndValidate: + runtime = native_runtime or NativeDetectionRuntime() if strategy == ExperimentalDetectionStrategy.compact_validation: if original is None: raise ValueError("compact_validation requires the original detection method") @@ -323,32 +346,39 @@ def _method_for_strategy( if strategy == ExperimentalDetectionStrategy.rules_only: return _detect_with_rules_only if strategy == ExperimentalDetectionStrategy.native_rules_router: - return _make_native_rules_router_method(native_client=native_client) + return _make_native_rules_router_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.native_candidate_validate_no_augment: - return _make_native_candidate_validate_no_augment_method(native_client=native_client) + return _make_native_candidate_validate_no_augment_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.detector_native_validate_no_augment: - return _make_detector_native_validate_no_augment_method(native_client=native_client) + return _make_detector_native_validate_no_augment_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.detector_native_validate_native_augment: - return _make_detector_native_validate_native_augment_method(native_client=native_client) + return _make_detector_native_validate_native_augment_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.gliner_native_validate_no_augment: return _make_gliner_native_validate_no_augment_method( native_client=native_client, gliner_seed_client=gliner_seed_client, + native_runtime=runtime, ) if strategy == ExperimentalDetectionStrategy.gliner_native_validate_native_augment: return _make_gliner_native_validate_native_augment_method( native_client=native_client, gliner_seed_client=gliner_seed_client, + native_runtime=runtime, ) if strategy == ExperimentalDetectionStrategy.native_single_pass: - return _make_native_single_pass_method(native_client=native_client) + return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.native_single_pass_recall: - return _make_native_single_pass_method(native_client=native_client, recall_prompt=True) + return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime, recall_prompt=True) if strategy == ExperimentalDetectionStrategy.native_single_pass_values: - return _make_native_single_pass_method(native_client=native_client, value_only_prompt=True) + return _make_native_single_pass_method( + native_client=native_client, + native_runtime=runtime, + value_only_prompt=True, + ) if strategy == ExperimentalDetectionStrategy.native_single_pass_values_recall: return _make_native_single_pass_method( native_client=native_client, + native_runtime=runtime, recall_prompt=True, value_only_prompt=True, ) @@ -1114,8 +1144,9 @@ def _entity_spans_from_payload(raw_payload: object) -> list[EntitySpan]: def _make_native_single_pass_method( - native_client: DirectDetectionClient | None, *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, recall_prompt: bool = False, value_only_prompt: bool = False, ) -> _DetectAndValidate: @@ -1134,11 +1165,12 @@ def detect_and_validate_entities( preview_num_records: int | None = None, ) -> dw.EntityDetectionResult: labels = dw._resolve_detection_labels(entity_labels) - client = native_client or HttpxDirectDetectionClient() + client = _native_client_or_default(native_client, native_runtime) return _run_native_single_pass_detection( dataframe, labels=labels, client=client, + runtime=native_runtime, data_summary=data_summary, preview_num_records=preview_num_records, recall_prompt=recall_prompt, @@ -1153,6 +1185,7 @@ def _run_native_single_pass_detection( *, labels: list[str], client: DirectDetectionClient, + runtime: NativeDetectionRuntime, data_summary: str | None, preview_num_records: int | None, recall_prompt: bool, @@ -1169,6 +1202,7 @@ def _run_native_single_pass_detection( index=index, labels=labels, client=client, + runtime=runtime, data_summary=data_summary, recall_prompt=recall_prompt, value_only_prompt=value_only_prompt, @@ -1194,6 +1228,7 @@ def _execute_native_single_pass_row( index: object, labels: list[str], client: DirectDetectionClient, + runtime: NativeDetectionRuntime, data_summary: str | None, recall_prompt: bool, value_only_prompt: bool, @@ -1205,19 +1240,32 @@ def _execute_native_single_pass_row( text=text, labels=labels, client=client, + runtime=runtime, data_summary=data_summary, recall_prompt=recall_prompt, value_only_prompt=value_only_prompt, ) except Exception as exc: # noqa: BLE001 - benchmark experiment records case-local failures - _record_native_single_pass_request_error(elapsed_sec=time.perf_counter() - started) + _record_native_single_pass_request_error(elapsed_sec=time.perf_counter() - started, runtime=runtime) return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") try: spans = _native_single_pass_spans(text, completion.content, labels=labels) except Exception as exc: # noqa: BLE001 - parser fragility is a benchmark failure mode - _record_native_single_pass_completion(completion, status="error", output_row_count=0, failed_record_count=1) + _record_native_single_pass_completion( + completion, + status="error", + output_row_count=0, + failed_record_count=1, + runtime=runtime, + ) return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") - _record_native_single_pass_completion(completion, status="completed", output_row_count=1, failed_record_count=0) + _record_native_single_pass_completion( + completion, + status="completed", + output_row_count=1, + failed_record_count=0, + runtime=runtime, + ) return _native_single_pass_result_row(row, spans=spans, labels=labels), None @@ -1226,14 +1274,15 @@ def _complete_native_single_pass( text: str, labels: list[str], client: DirectDetectionClient, + runtime: NativeDetectionRuntime, data_summary: str | None, recall_prompt: bool, value_only_prompt: bool, ) -> Any: return client.complete( DirectGenerationRequest( - endpoint=_NATIVE_DIRECT_ENDPOINT, - model=_NATIVE_DIRECT_MODEL_NAME, + endpoint=runtime.endpoint or "", + model=runtime.model, prompt=_native_single_pass_prompt( text=text, labels=labels, @@ -1241,8 +1290,8 @@ def _complete_native_single_pass( recall_prompt=recall_prompt, value_only_prompt=value_only_prompt, ), - max_tokens=_NATIVE_DIRECT_MAX_TOKENS, - timeout_sec=_NATIVE_DIRECT_TIMEOUT_SEC, + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, ) ) @@ -1395,10 +1444,11 @@ def _record_native_single_pass_completion( status: str, output_row_count: int, failed_record_count: int, + runtime: NativeDetectionRuntime, ) -> None: record_model_workflow( workflow_name="entity-detection-native-single-pass", - model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + model_aliases=[runtime.alias], input_row_count=1, output_row_count=output_row_count, failed_record_count=failed_record_count, @@ -1408,20 +1458,26 @@ def _record_native_single_pass_completion( successful_requests=1, failed_requests=0, usage=dict(getattr(completion, "usage", {}) or {}), + runtime=runtime, ), ) -def _record_native_single_pass_request_error(*, elapsed_sec: float) -> None: +def _record_native_single_pass_request_error(*, elapsed_sec: float, runtime: NativeDetectionRuntime) -> None: record_model_workflow( workflow_name="entity-detection-native-single-pass", - model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + model_aliases=[runtime.alias], input_row_count=1, output_row_count=0, failed_record_count=1, elapsed_sec=elapsed_sec, status="error", - model_usage=_native_single_pass_model_usage(successful_requests=0, failed_requests=1, usage={}), + model_usage=_native_single_pass_model_usage( + successful_requests=0, + failed_requests=1, + usage={}, + runtime=runtime, + ), ) @@ -1430,13 +1486,14 @@ def _native_single_pass_model_usage( successful_requests: int, failed_requests: int, usage: dict[str, int], + runtime: NativeDetectionRuntime, ) -> dict[str, dict[str, Any]]: total_requests = successful_requests + failed_requests return { - _NATIVE_DIRECT_MODEL_ALIAS: { - "model_alias": _NATIVE_DIRECT_MODEL_ALIAS, - "model_name": _NATIVE_DIRECT_MODEL_NAME, - "model_provider_name": _NATIVE_DIRECT_MODEL_PROVIDER, + runtime.alias: { + "model_alias": runtime.alias, + "model_name": runtime.model, + "model_provider_name": runtime.provider, "request_usage": { "successful_requests": successful_requests, "failed_requests": failed_requests, @@ -1447,10 +1504,33 @@ def _native_single_pass_model_usage( } -def _make_native_rules_router_method(native_client: DirectDetectionClient | None) -> _DetectAndValidate: +def _native_client_or_default( + native_client: DirectDetectionClient | None, + runtime: NativeDetectionRuntime, +) -> DirectDetectionClient: + if native_client is not None: + return native_client + _require_native_endpoint(runtime) + return HttpxDirectDetectionClient() + + +def _require_native_endpoint(runtime: NativeDetectionRuntime) -> None: + if not runtime.endpoint or not runtime.model: + raise ValueError( + "native detection strategies require configured native endpoint and model; " + "set native_runtime.endpoint and native_runtime.model in the benchmark suite" + ) + + +def _make_native_rules_router_method( + *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, +) -> _DetectAndValidate: return _make_native_staged_method( native_client=native_client, gliner_seed_client=None, + native_runtime=native_runtime, seed_source=SeedSource.rules_router, workflow_name="entity-detection-native-rules-router", skip_augmentation=False, @@ -1458,11 +1538,14 @@ def _make_native_rules_router_method(native_client: DirectDetectionClient | None def _make_native_candidate_validate_no_augment_method( + *, native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, ) -> _DetectAndValidate: return _make_native_staged_method( native_client=native_client, gliner_seed_client=None, + native_runtime=native_runtime, seed_source=SeedSource.rules_plus_direct_llm, workflow_name="entity-detection-native-candidate-validate-no-augment", skip_augmentation=True, @@ -1473,10 +1556,12 @@ def _make_gliner_native_validate_no_augment_method( *, native_client: DirectDetectionClient | None, gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, ) -> _DetectAndValidate: return _make_native_staged_method( native_client=native_client, gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime, seed_source=SeedSource.gliner, workflow_name="entity-detection-gliner-native-validate-no-augment", skip_augmentation=True, @@ -1487,10 +1572,12 @@ def _make_gliner_native_validate_native_augment_method( *, native_client: DirectDetectionClient | None, gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, ) -> _DetectAndValidate: return _make_native_staged_method( native_client=native_client, gliner_seed_client=gliner_seed_client, + native_runtime=native_runtime, seed_source=SeedSource.gliner, workflow_name="entity-detection-gliner-native-validate-native-augment", skip_augmentation=False, @@ -1498,10 +1585,13 @@ def _make_gliner_native_validate_native_augment_method( def _make_detector_native_validate_no_augment_method( + *, native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, ) -> _DetectAndValidate: return _make_detector_native_validate_method( - native_client, + native_client=native_client, + native_runtime=native_runtime, workflow_name="entity-detection-detector-native-validate-no-augment", seed_workflow_name="entity-detection-detector-native-validate-no-augment-seed", skip_augmentation=True, @@ -1509,10 +1599,13 @@ def _make_detector_native_validate_no_augment_method( def _make_detector_native_validate_native_augment_method( + *, native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, ) -> _DetectAndValidate: return _make_detector_native_validate_method( - native_client, + native_client=native_client, + native_runtime=native_runtime, workflow_name="entity-detection-detector-native-validate-native-augment", seed_workflow_name="entity-detection-detector-native-validate-native-augment-seed", skip_augmentation=False, @@ -1520,8 +1613,9 @@ def _make_detector_native_validate_native_augment_method( def _make_detector_native_validate_method( - native_client: DirectDetectionClient | None, *, + native_client: DirectDetectionClient | None, + native_runtime: NativeDetectionRuntime, workflow_name: str, seed_workflow_name: str, skip_augmentation: bool, @@ -1541,7 +1635,7 @@ def detect_and_validate_entities( preview_num_records: int | None = None, ) -> dw.EntityDetectionResult: labels = dw._resolve_detection_labels(entity_labels) - client = native_client or HttpxDirectDetectionClient() + client = _native_client_or_default(native_client, native_runtime) return _run_detector_native_validate_detection( self, dataframe, @@ -1554,6 +1648,7 @@ def detect_and_validate_entities( validation_excerpt_window_chars=validation_excerpt_window_chars, validation_single_chunk_full_text=validation_single_chunk_full_text, client=client, + runtime=native_runtime, data_summary=data_summary, workflow_name=workflow_name, seed_workflow_name=seed_workflow_name, @@ -1576,6 +1671,7 @@ def _run_detector_native_validate_detection( validation_excerpt_window_chars: int, validation_single_chunk_full_text: bool, client: DirectDetectionClient, + runtime: NativeDetectionRuntime, data_summary: str | None, workflow_name: str, seed_workflow_name: str, @@ -1605,6 +1701,7 @@ def _run_detector_native_validate_detection( params = _DetectorNativeValidationParams( labels=labels, client=client, + runtime=runtime, data_summary=data_summary, validation_prompt_mode=validation_prompt_mode, validation_max_entities_per_call=validation_max_entities_per_call, @@ -1645,7 +1742,7 @@ def _execute_detector_native_validate_tasks( ) -> list[_DetectorNativeValidationRowResult]: if not tasks: return [] - worker_count = _native_staged_worker_count(len(tasks)) + worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) if worker_count == 1: return [_execute_detector_native_validate_task(task, params=params) for task in tasks] with ThreadPoolExecutor(max_workers=worker_count) as executor: @@ -1668,6 +1765,7 @@ def _execute_detector_native_validate_task( ordinal=task.ordinal, labels=params.labels, client=params.client, + runtime=params.runtime, data_summary=params.data_summary, validation_prompt_mode=params.validation_prompt_mode, validation_max_entities_per_call=params.validation_max_entities_per_call, @@ -1698,6 +1796,7 @@ def _execute_detector_native_validate_row( ordinal: int, labels: list[str], client: DirectDetectionClient, + runtime: NativeDetectionRuntime, data_summary: str | None, validation_prompt_mode: ValidationPromptMode, validation_max_entities_per_call: int, @@ -1708,6 +1807,10 @@ def _execute_detector_native_validate_row( output_row = row.to_dict() request = _native_staged_request(row, index=index, ordinal=ordinal, labels=labels, data_summary=data_summary) config = StagedExecutionConfig( + endpoint=runtime.endpoint or "", + model=runtime.model, + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, validation_prompt_mode=validation_prompt_mode, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, @@ -1733,6 +1836,7 @@ def _execute_detector_native_validate_row( ordinal=ordinal, index=index, workflow_name=workflow_name, + runtime=runtime, output_row=None, failed_record=_native_failed_record( index, @@ -1747,6 +1851,7 @@ def _execute_detector_native_validate_row( ordinal=ordinal, index=index, workflow_name=workflow_name, + runtime=runtime, output_row=output_row, failed_record=None, completion=completion, @@ -1782,10 +1887,11 @@ def _record_detector_native_validation_completion( *, request_count: int, workflow_name: str, + runtime: NativeDetectionRuntime, ) -> None: record_model_workflow( workflow_name=workflow_name, - model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + model_aliases=[runtime.alias], input_row_count=1, output_row_count=1, failed_record_count=0, @@ -1794,6 +1900,7 @@ def _record_detector_native_validation_completion( successful_requests=request_count, failed_requests=0, usage=dict(completion.usage or {}), + runtime=runtime, ), ) @@ -1804,6 +1911,7 @@ def _record_detector_native_validation_result(result: _DetectorNativeValidationR result.completion, request_count=result.request_count, workflow_name=result.workflow_name, + runtime=result.runtime, ) return if result.failed_record is not None: @@ -1811,13 +1919,20 @@ def _record_detector_native_validation_result(result: _DetectorNativeValidationR elapsed_sec=result.elapsed_sec, request_count=result.request_count, workflow_name=result.workflow_name, + runtime=result.runtime, ) -def _record_detector_native_validation_error(*, elapsed_sec: float, request_count: int, workflow_name: str) -> None: +def _record_detector_native_validation_error( + *, + elapsed_sec: float, + request_count: int, + workflow_name: str, + runtime: NativeDetectionRuntime, +) -> None: record_model_workflow( workflow_name=workflow_name, - model_aliases=[_NATIVE_DIRECT_MODEL_ALIAS], + model_aliases=[runtime.alias], input_row_count=1, output_row_count=0, failed_record_count=1, @@ -1827,6 +1942,7 @@ def _record_detector_native_validation_error(*, elapsed_sec: float, request_coun successful_requests=0, failed_requests=max(request_count, 1), usage={}, + runtime=runtime, ), ) @@ -1835,6 +1951,7 @@ def _make_native_staged_method( *, native_client: DirectDetectionClient | None, gliner_seed_client: GlinerSeedClient | None, + native_runtime: NativeDetectionRuntime, seed_source: SeedSource, workflow_name: str, skip_augmentation: bool, @@ -1855,12 +1972,13 @@ def detect_and_validate_entities( ) -> dw.EntityDetectionResult: _ = self, model_configs, selected_models, gliner_detection_threshold labels = dw._resolve_detection_labels(entity_labels) - client = native_client or HttpxDirectDetectionClient() + client = _native_client_or_default(native_client, native_runtime) return _run_native_staged_detection( dataframe, labels=labels, client=client, gliner_seed_client=gliner_seed_client, + runtime=native_runtime, data_summary=data_summary, preview_num_records=preview_num_records, validation_max_entities_per_call=validation_max_entities_per_call, @@ -1880,6 +1998,7 @@ def _run_native_staged_detection( labels: list[str], client: DirectDetectionClient, gliner_seed_client: GlinerSeedClient | None, + runtime: NativeDetectionRuntime, data_summary: str | None, preview_num_records: int | None, validation_max_entities_per_call: int, @@ -1898,6 +2017,7 @@ def _run_native_staged_detection( labels=labels, client=client, gliner_seed_client=gliner_seed_client, + runtime=runtime, data_summary=data_summary, validation_prompt_mode=validation_prompt_mode, validation_max_entities_per_call=validation_max_entities_per_call, @@ -1925,7 +2045,7 @@ def _run_native_staged_detection( ) continue if result.case is not None: - _record_native_direct_usage(result.case, workflow_name=workflow_name) + _record_native_direct_usage(result.case, workflow_name=workflow_name, runtime=runtime) output_rows.append(result.output_row) output_indices.append(result.index) @@ -1942,15 +2062,15 @@ def _execute_native_staged_tasks( ) -> list[_NativeStagedRowResult]: if not tasks: return [] - worker_count = _native_staged_worker_count(len(tasks)) + worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) if worker_count == 1: return [_execute_native_staged_task(task, params=params) for task in tasks] with ThreadPoolExecutor(max_workers=worker_count) as executor: return list(executor.map(lambda task: _execute_native_staged_task(task, params=params), tasks)) -def _native_staged_worker_count(task_count: int) -> int: - return max(1, min(task_count, _NATIVE_STAGED_MAX_WORKERS)) +def _native_staged_worker_count(task_count: int, *, runtime: NativeDetectionRuntime) -> int: + return max(1, min(task_count, runtime.max_workers)) def _execute_native_staged_task( @@ -1971,6 +2091,14 @@ def _execute_native_staged_task( client=params.client, seed_client=params.gliner_seed_client, seed_source=params.seed_source, + endpoint=params.runtime.endpoint or "", + model=params.runtime.model, + gliner_endpoint=params.runtime.gliner_endpoint or "", + gliner_model=params.runtime.gliner_model, + gliner_api_key_env=params.runtime.gliner_api_key_env, + gliner_threshold=params.runtime.gliner_threshold, + max_tokens=params.runtime.max_tokens, + timeout_sec=params.runtime.timeout_sec, skip_augmentation=params.skip_augmentation, validation_prompt_mode=params.validation_prompt_mode, validation_max_entities_per_call=params.validation_max_entities_per_call, @@ -2021,8 +2149,13 @@ def _native_failed_record(index: object, *, workflow_name: str, error: str | Non ) -def _record_native_direct_usage(case: StagedDetectionCase, *, workflow_name: str) -> None: - model_usage = _native_staged_model_usage(case) +def _record_native_direct_usage( + case: StagedDetectionCase, + *, + workflow_name: str, + runtime: NativeDetectionRuntime, +) -> None: + model_usage = _native_staged_model_usage(case, runtime=runtime) if not model_usage: return record_model_workflow( @@ -2036,17 +2169,21 @@ def _record_native_direct_usage(case: StagedDetectionCase, *, workflow_name: str ) -def _native_staged_model_usage(case: StagedDetectionCase) -> dict[str, dict[str, Any]]: +def _native_staged_model_usage( + case: StagedDetectionCase, + *, + runtime: NativeDetectionRuntime, +) -> dict[str, dict[str, Any]]: usage: dict[str, dict[str, Any]] = {} native_requests = 0 native_usage: list[dict[str, int]] = [] if case.phase_model_work.seed: if case.seed_source == SeedSource.gliner: - usage[_GLINER_DIRECT_MODEL_ALIAS] = _direct_model_usage_entry( - alias=_GLINER_DIRECT_MODEL_ALIAS, - model_name=_GLINER_DIRECT_MODEL_NAME, - provider_name=_GLINER_DIRECT_MODEL_PROVIDER, + usage[runtime.gliner_alias] = _direct_model_usage_entry( + alias=runtime.gliner_alias, + model_name=runtime.gliner_model, + provider_name=runtime.gliner_provider, successful_requests=case.phase_model_requests.seed, usage=case.phase_usage.seed, ) @@ -2060,10 +2197,10 @@ def _native_staged_model_usage(case: StagedDetectionCase) -> dict[str, dict[str, native_requests += case.phase_model_requests.augmentation native_usage.append(case.phase_usage.augmentation) if native_requests: - usage[_NATIVE_DIRECT_MODEL_ALIAS] = _direct_model_usage_entry( - alias=_NATIVE_DIRECT_MODEL_ALIAS, - model_name=_NATIVE_DIRECT_MODEL_NAME, - provider_name=_NATIVE_DIRECT_MODEL_PROVIDER, + usage[runtime.alias] = _direct_model_usage_entry( + alias=runtime.alias, + model_name=runtime.model, + provider_name=runtime.provider, successful_requests=native_requests, usage=_sum_usage_dicts(native_usage), ) diff --git a/tools/measurement/direct_detection_probe.py b/tools/measurement/direct_detection_probe.py index aefb09e5..8659fec0 100644 --- a/tools/measurement/direct_detection_probe.py +++ b/tools/measurement/direct_detection_probe.py @@ -13,6 +13,7 @@ import json import logging +import os import shutil import sys from collections import Counter @@ -34,6 +35,11 @@ app = cyclopts.App(help=__doc__) logger = logging.getLogger("measurement.direct_detection_probe") +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" +_UNCONFIGURED_MODEL = "configured-native-model" + class CaseStatus(StrEnum): completed = "completed" @@ -217,11 +223,13 @@ def run_direct_detection_case( request: DirectDetectionRequest, *, client: DirectDetectionClient, - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", - model: str = "nvidia/nemotron-3-super", + endpoint: str | None = None, + model: str | None = None, max_tokens: int = 4096, timeout_sec: float = 180.0, ) -> DirectDetectionCase: + endpoint = endpoint or _UNCONFIGURED_ENDPOINT + model = model or _UNCONFIGURED_MODEL try: completion = client.complete( DirectGenerationRequest( @@ -422,13 +430,15 @@ def run_probe( labels: list[str], output: Path | None = None, overwrite: bool = False, - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", - model: str = "nvidia/nemotron-3-super", + endpoint: str | None = None, + model: str | None = None, limit: int = 1, offset: int = 0, prompt_mode: PromptMode = PromptMode.compact, baseline_artifacts: Path | None = None, ) -> DirectDetectionRun: + endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) + model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) requests = _load_requests( input_path, text_column=text_column, labels=labels, limit=limit, offset=offset, prompt_mode=prompt_mode ) @@ -450,6 +460,13 @@ def run_probe( return result +def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: + value = explicit or os.environ.get(env_var) + if not value: + raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") + return value + + def _load_requests( input_path: Path, *, @@ -523,8 +540,8 @@ def main( labels: Annotated[str, cyclopts.Parameter("--labels")], output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, - endpoint: Annotated[str, cyclopts.Parameter("--endpoint")] = "http://gpu-dev-pod-serve-svc:8000/v1", - model: Annotated[str, cyclopts.Parameter("--model")] = "nvidia/nemotron-3-super", + endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, + model: Annotated[str | None, cyclopts.Parameter("--model")] = None, limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, prompt_mode: Annotated[PromptMode, cyclopts.Parameter("--prompt-mode")] = PromptMode.compact, diff --git a/tools/measurement/examples/repo-data-smoke.yaml b/tools/measurement/examples/repo-data-smoke.yaml index 5d0e5623..5009f054 100644 --- a/tools/measurement/examples/repo-data-smoke.yaml +++ b/tools/measurement/examples/repo-data-smoke.yaml @@ -6,9 +6,11 @@ workloads: - id: biographies source: ../../../docs/data/NVIDIA_synthetic_biographies.csv text_column: biography + row_limit: 5 - id: legal source: ../../../docs/data/TAB_legal_sample25.csv text_column: text + row_limit: 5 configs: - id: biographies-redact-default replace: redact diff --git a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh index cc521b92..9000f03f 100644 --- a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh +++ b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh @@ -6,9 +6,14 @@ set -euo pipefail output_dir="${1:-/tmp/anonymizer-repo-data-smoke-dd-traces}" trace_mode="${DD_TRACE_MODE:-last-message}" +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +repo_root="$(cd "${script_dir}/../../.." && pwd)" +suite_file="${BENCHMARK_SUITE:-${script_dir}/repo-data-smoke.yaml}" + +cd "${repo_root}" uv run python tools/measurement/run_benchmarks.py \ - /stable-cache/anonymizer/repo-data-smoke.yaml \ + "${suite_file}" \ --output "${output_dir}" \ --overwrite \ --dd-trace "${trace_mode}" \ diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py index 7a4bbaa2..66a8e40c 100644 --- a/tools/measurement/extract_signature_deltas.py +++ b/tools/measurement/extract_signature_deltas.py @@ -13,7 +13,6 @@ from __future__ import annotations -import hashlib import json import logging import sys @@ -58,6 +57,14 @@ class ContextResolution(StrEnum): _log_format = LogFormat.plain +_SIGNATURE_DETAIL_FIELDS = { + "label", + "source", + "row_index", + "start_position", + "end_position", + "value_length", +} class SignatureDeltaRow(BaseModel): @@ -70,7 +77,6 @@ class SignatureDeltaRow(BaseModel): source: str | None = None start_position: int | None = None end_position: int | None = None - value_hash: str | None = None value_length: int | None = None masked_context: str | None = None resolution: ContextResolution = ContextResolution.metadata_only @@ -317,7 +323,7 @@ def _parquet_entity_context( text, row_index, row = record for entity in EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES)).entities: if _entity_signature_hash(entity, row_index=row_index) == signature: - return _entity_context(entity, text, context_window, ContextResolution.parquet) + return _entity_context(entity, text, signature, context_window, ContextResolution.parquet) return None @@ -335,7 +341,7 @@ def _rule_entity_context( for span in detect_high_confidence_entities(text, labels=[label]): entity = EntitySchema.model_validate(span.as_dict()) if _entity_signature_hash(entity, row_index=row_index) == signature: - return _entity_context(entity, text, context_window, ContextResolution.rule) + return _entity_context(entity, text, signature, context_window, ContextResolution.rule) return None @@ -351,9 +357,8 @@ def _artifact_detail_context( return None start_position = _optional_int(details.get("start_position")) end_position = _optional_int(details.get("end_position")) - value_hash = _optional_string(details.get("value_hash")) resolved_label = _optional_string(details.get("label")) or label - if start_position is None or end_position is None or value_hash is None or resolved_label is None: + if start_position is None or end_position is None or resolved_label is None: return _metadata_context_from_details(details) record = _artifact_record(artifact_row, artifact_root) masked_context = None @@ -362,7 +367,7 @@ def _artifact_detail_context( masked_context = _masked_context_from_details( text, label=resolved_label, - value_hash=value_hash, + signature_hash=signature, start_position=start_position, end_position=end_position, window=context_window, @@ -371,7 +376,6 @@ def _artifact_detail_context( "source": _optional_string(details.get("source")), "start_position": start_position, "end_position": end_position, - "value_hash": value_hash, "value_length": _optional_int(details.get("value_length")), "masked_context": masked_context, "resolution": ContextResolution.artifact_details @@ -390,6 +394,8 @@ def _signature_details(row: dict[str, object]) -> dict[str, dict[str, object]]: signature_hash, _, field = remainder.partition(".") if not signature_hash or not field: continue + if field not in _SIGNATURE_DETAIL_FIELDS: + continue details.setdefault(signature_hash, {})[field] = value return details @@ -402,7 +408,11 @@ def _coerce_detail_map(raw: object) -> dict[str, dict[str, object]]: return {} if not isinstance(raw, dict): return {} - return {str(key): value for key, value in raw.items() if isinstance(value, dict)} + return { + str(key): {str(field): item for field, item in value.items() if str(field) in _SIGNATURE_DETAIL_FIELDS} + for key, value in raw.items() + if isinstance(value, dict) + } def _metadata_context_from_details(details: dict[str, object]) -> dict[str, object]: @@ -410,7 +420,6 @@ def _metadata_context_from_details(details: dict[str, object]) -> dict[str, obje "source": _optional_string(details.get("source")), "start_position": _optional_int(details.get("start_position")), "end_position": _optional_int(details.get("end_position")), - "value_hash": _optional_string(details.get("value_hash")), "value_length": _optional_int(details.get("value_length")), "resolution": ContextResolution.metadata_only, } @@ -431,6 +440,7 @@ def _artifact_record(artifact_row: dict[str, object], artifact_root: Path) -> tu def _entity_context( entity: EntitySchema, text: str, + signature_hash: str, context_window: int, resolution: ContextResolution, ) -> dict[str, object]: @@ -438,17 +448,16 @@ def _entity_context( "source": entity.source, "start_position": entity.start_position, "end_position": entity.end_position, - "value_hash": _value_hash(entity.value), "value_length": len(entity.value), - "masked_context": _masked_context(text, entity, context_window), + "masked_context": _masked_context(text, entity, signature_hash, context_window), "resolution": resolution, } -def _masked_context(text: str, entity: EntitySchema, window: int) -> str: +def _masked_context(text: str, entity: EntitySchema, signature_hash: str, window: int) -> str: before = text[max(0, entity.start_position - window) : entity.start_position] after = text[entity.end_position : entity.end_position + window] - placeholder = f"[{entity.label.upper()}:{_value_hash(entity.value)}]" + placeholder = f"[{entity.label.upper()}:{signature_hash}]" return (before + placeholder + after).replace("\n", " ") @@ -456,21 +465,17 @@ def _masked_context_from_details( text: str, *, label: str, - value_hash: str, + signature_hash: str, start_position: int, end_position: int, window: int, ) -> str: before = text[max(0, start_position - window) : start_position] after = text[end_position : end_position + window] - placeholder = f"[{label.upper()}:{value_hash}]" + placeholder = f"[{label.upper()}:{signature_hash}]" return (before + placeholder + after).replace("\n", " ") -def _value_hash(value: str) -> str: - return hashlib.sha256(value.encode("utf-8")).hexdigest()[:16] - - def _optional_int(value: object) -> int | None: try: if pd.isna(value): diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index acb967d8..77b24493 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -10,6 +10,7 @@ import json import logging +import os import shutil import sys import time @@ -32,6 +33,7 @@ from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context from detection_strategies import ( ExperimentalDetectionStrategy, + NativeDetectionRuntime, experimental_detection_strategy_context, ) from export_measurements import ExportFormat, export_tables, read_measurements @@ -88,6 +90,35 @@ class DDTraceMode(StrEnum): all_messages = "all_messages" +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" +_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" + + +class NativeRuntimeSpec(BaseModel): + model_config = ConfigDict(extra="forbid") + + runtime_id: str | None = None + endpoint: str | None = None + endpoint_env: str | None = _NATIVE_ENDPOINT_ENV + model: str | None = None + model_env: str | None = _NATIVE_MODEL_ENV + provider: str = "native" + alias: str = "native-direct" + max_tokens: int = Field(default=4096, gt=0) + timeout_sec: float = Field(default=180.0, gt=0) + gliner_endpoint: str | None = None + gliner_endpoint_env: str | None = _GLINER_ENDPOINT_ENV + gliner_model: str | None = None + gliner_model_env: str | None = _GLINER_MODEL_ENV + gliner_provider: str = "gliner" + gliner_alias: str = "gliner-direct" + gliner_api_key_env: str = "NVIDIA_API_KEY" + gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) + max_workers: int = Field(default=4, ge=1) + + class ReplaceKind(StrEnum): redact = "redact" hash = "hash" @@ -140,6 +171,7 @@ class ConfigSpec(BaseModel): experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default experimental_replacement_strategy: ExperimentalReplacementStrategy = ExperimentalReplacementStrategy.default experimental_rule_labels: list[str] | None = None + native_runtime: NativeRuntimeSpec | None = None @model_validator(mode="after") def validate_mode(self) -> "ConfigSpec": @@ -176,6 +208,7 @@ class BenchmarkSpec(BaseModel): model_providers: str | None = None artifact_path: str | None = None dd_parser_compat: DDParserCompatMode = DDParserCompatMode.none + native_runtime: NativeRuntimeSpec | None = None case_retries: int = Field(default=0, ge=0) case_retry_backoff_sec: float = Field(default=0.0, ge=0.0) workloads: list[WorkloadSpec] = Field(min_length=1) @@ -270,6 +303,22 @@ class _CaseExecution: ExperimentalDetectionStrategy.rules_guardrail_detector_only, ExperimentalDetectionStrategy.rules_only, } +_NATIVE_RUNTIME_STRATEGIES = { + ExperimentalDetectionStrategy.native_rules_router, + ExperimentalDetectionStrategy.native_candidate_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_no_augment, + ExperimentalDetectionStrategy.detector_native_validate_native_augment, + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, + ExperimentalDetectionStrategy.native_single_pass, + ExperimentalDetectionStrategy.native_single_pass_recall, + ExperimentalDetectionStrategy.native_single_pass_values, + ExperimentalDetectionStrategy.native_single_pass_values_recall, +} +_GLINER_NATIVE_RUNTIME_STRATEGIES = { + ExperimentalDetectionStrategy.gliner_native_validate_no_augment, + ExperimentalDetectionStrategy.gliner_native_validate_native_augment, +} _FINAL_ARTIFACT_KEYS = { "final_entity_count", @@ -392,7 +441,10 @@ def _preflight_workload_errors(spec: BenchmarkSpec, *, base_dir: Path) -> list[s def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) -> list[str]: errors: list[str] = [] + active_config_ids = _active_config_ids(spec) for config in spec.configs: + if config.id not in active_config_ids: + continue try: anonymizer_config = build_anonymizer_config(config) except Exception as exc: @@ -402,6 +454,10 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) _preflight_experimental_detection_strategy(config, anonymizer_config) except Exception as exc: errors.append(f"config '{config.id}' experimental_detection_strategy invalid: {exc}") + try: + _preflight_native_runtime(config, spec=spec) + except Exception as exc: + errors.append(f"config '{config.id}' native_runtime invalid: {exc}") try: _preflight_experimental_replacement_strategy(config, anonymizer_config) except Exception as exc: @@ -421,6 +477,12 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) return errors +def _active_config_ids(spec: BenchmarkSpec) -> set[str]: + if spec.matrix is None: + return {config.id for config in spec.configs} + return {entry.config for entry in spec.matrix} + + def _preflight_experimental_detection_strategy(config: ConfigSpec, anonymizer_config: AnonymizerConfig) -> None: _preflight_experimental_rule_labels(config) if config.experimental_detection_strategy != ExperimentalDetectionStrategy.rules_only: @@ -454,6 +516,65 @@ def _preflight_experimental_rule_labels(config: ConfigSpec) -> None: ) +def _preflight_native_runtime(config: ConfigSpec, *, spec: BenchmarkSpec) -> None: + strategy = config.experimental_detection_strategy + if strategy not in _NATIVE_RUNTIME_STRATEGIES: + return + runtime = _resolve_native_runtime_spec(spec, config) + if not runtime.runtime_id: + raise ValueError("native strategies require native_runtime.runtime_id") + if not runtime.endpoint or not runtime.model: + raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") + if strategy in _GLINER_NATIVE_RUNTIME_STRATEGIES: + if not runtime.gliner_endpoint or not runtime.gliner_model: + raise ValueError("GLiNER-native strategies require native_runtime.gliner_endpoint and gliner_model") + if not os.environ.get(runtime.gliner_api_key_env): + raise ValueError(f"{runtime.gliner_api_key_env} is not set for GLiNER-native strategy") + + +def _resolve_native_runtime_spec(spec: BenchmarkSpec, config: ConfigSpec) -> NativeRuntimeSpec: + runtime = spec.native_runtime or NativeRuntimeSpec() + if config.native_runtime is not None: + runtime = runtime.model_copy(update=config.native_runtime.model_dump(exclude_unset=True)) + return runtime.model_copy( + update={ + "endpoint": _resolve_runtime_value(runtime.endpoint, runtime.endpoint_env), + "model": _resolve_runtime_value(runtime.model, runtime.model_env), + "gliner_endpoint": _resolve_runtime_value(runtime.gliner_endpoint, runtime.gliner_endpoint_env), + "gliner_model": _resolve_runtime_value(runtime.gliner_model, runtime.gliner_model_env), + } + ) + + +def _resolve_runtime_value(explicit: str | None, env_var: str | None) -> str | None: + if explicit: + return explicit + return os.environ.get(env_var) if env_var else None + + +def _native_detection_runtime(spec: BenchmarkSpec, config: ConfigSpec) -> NativeDetectionRuntime: + runtime = _resolve_native_runtime_spec(spec, config) + if not runtime.runtime_id: + raise ValueError("native strategies require native_runtime.runtime_id") + if not runtime.endpoint or not runtime.model: + raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") + return NativeDetectionRuntime( + endpoint=runtime.endpoint, + model=runtime.model, + provider=runtime.provider, + alias=runtime.alias, + max_tokens=runtime.max_tokens, + timeout_sec=runtime.timeout_sec, + gliner_endpoint=runtime.gliner_endpoint, + gliner_model=runtime.gliner_model or "", + gliner_provider=runtime.gliner_provider, + gliner_alias=runtime.gliner_alias, + gliner_api_key_env=runtime.gliner_api_key_env, + gliner_threshold=runtime.gliner_threshold, + max_workers=runtime.max_workers, + ) + + def _preflight_experimental_replacement_strategy( config: ConfigSpec, anonymizer_config: AnonymizerConfig, @@ -489,7 +610,7 @@ def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: if not candidate.is_file(): raise FileNotFoundError(f"Providers config file not found: {candidate}") config_source = candidate - config_dict = load_config_file(config_source) + config_dict = yaml.safe_load(raw) if isinstance(raw, str) and "\n" in raw else load_config_file(config_source) raw_providers = config_dict.get("providers") if isinstance(config_dict, dict) else None if not isinstance(raw_providers, list): raise ValueError("model_providers YAML must contain a top-level 'providers' list.") @@ -1027,9 +1148,12 @@ def _execute_case( ) with configured_measurement_session(measurement): with dd_parser_compat_context(dd_parser_compat): + detection_context_kwargs: dict[str, Any] = {"rule_labels": config.experimental_rule_labels} + if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: + detection_context_kwargs["native_runtime"] = _native_detection_runtime(spec, config) with experimental_detection_strategy_context( config.experimental_detection_strategy, - rule_labels=config.experimental_rule_labels, + **detection_context_kwargs, ): with experimental_replacement_strategy_context(config.experimental_replacement_strategy): result = anonymizer.run( @@ -1358,7 +1482,7 @@ def render_result(result: BenchmarkResult, *, json_output: bool) -> str: def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: config = next(item for item in spec.configs if item.id == case.config_id) - return { + tags = { "suite_id": spec.suite_id, "workload_id": case.workload_id, "config_id": case.config_id, @@ -1369,6 +1493,32 @@ def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: "experimental_rule_labels": config.experimental_rule_labels, "dd_parser_compat": spec.dd_parser_compat.value, } + if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: + tags.update(_native_runtime_tags(_resolve_native_runtime_spec(spec, config))) + return tags + + +def _native_runtime_tags(runtime: NativeRuntimeSpec) -> dict[str, Any]: + return _present( + { + "native_runtime_id": runtime.runtime_id, + "native_endpoint_env": runtime.endpoint_env, + "native_model": runtime.model, + "native_model_env": runtime.model_env, + "native_provider": runtime.provider, + "native_alias": runtime.alias, + "native_max_tokens": runtime.max_tokens, + "native_timeout_sec": runtime.timeout_sec, + "native_max_workers": runtime.max_workers, + "gliner_endpoint_env": runtime.gliner_endpoint_env, + "gliner_model": runtime.gliner_model, + "gliner_model_env": runtime.gliner_model_env, + "gliner_provider": runtime.gliner_provider, + "gliner_alias": runtime.gliner_alias, + "gliner_api_key_env": runtime.gliner_api_key_env, + "gliner_threshold": runtime.gliner_threshold, + } + ) def _present(values: dict[str, Any]) -> dict[str, Any]: @@ -1481,6 +1631,7 @@ def run_or_plan( output_dir = output or Path("benchmark-runs") / benchmark_spec.suite_id if trace_dir is not None and dd_trace == DDTraceMode.none: raise ValueError("--trace-dir requires --dd-trace") + preflight_suite(benchmark_spec, spec_path=spec_path) if dry_run: return dry_run_result( benchmark_spec, @@ -1489,7 +1640,6 @@ def run_or_plan( dd_trace=dd_trace, trace_dir=trace_dir, ) - preflight_suite(benchmark_spec, spec_path=spec_path) prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) return run_suite( benchmark_spec, diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py index 5751e3f1..6da611da 100644 --- a/tools/measurement/staged_detection_probe.py +++ b/tools/measurement/staged_detection_probe.py @@ -98,6 +98,15 @@ app = cyclopts.App(help=__doc__) logger = logging.getLogger("measurement.staged_detection_probe") + +_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" +_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" +_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" +_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" +_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" +_UNCONFIGURED_MODEL = "configured-native-model" +_UNCONFIGURED_GLINER_ENDPOINT = "configured-gliner-endpoint" +_UNCONFIGURED_GLINER_MODEL = "configured-gliner-model" _log_format = LogFormat.plain _DATE_OF_BIRTH_CONTEXT_RE = re.compile(r"\b(born|birth|date of birth|dob)\b", re.IGNORECASE) @@ -121,7 +130,7 @@ class GlinerDetectionRequest(BaseModel): model: str text: str labels: list[str] = Field(min_length=1) - threshold: float = 0.3 + threshold: float = Field(default=0.3, ge=0.0, le=1.0) max_tokens: int = Field(default=4096, gt=0) timeout_sec: float = Field(default=120.0, gt=0) api_key_env: str = "NVIDIA_API_KEY" @@ -132,13 +141,13 @@ def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: ... class StagedExecutionConfig(BaseModel): - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1" - model: str = "nvidia/nemotron-3-super" + endpoint: str = _UNCONFIGURED_ENDPOINT + model: str = _UNCONFIGURED_MODEL seed_source: SeedSource = SeedSource.direct_llm - gliner_endpoint: str = "https://integrate.api.nvidia.com/v1" - gliner_model: str = "nvidia/gliner-pii" + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT + gliner_model: str = _UNCONFIGURED_GLINER_MODEL gliner_api_key_env: str = "NVIDIA_API_KEY" - gliner_threshold: float = 0.3 + gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) max_tokens: int = Field(default=4096, gt=0) timeout_sec: float = Field(default=180.0, gt=0) skip_augmentation: bool = False @@ -297,10 +306,10 @@ def run_staged_detection_case( client: DirectDetectionClient, seed_client: GlinerSeedClient | None = None, seed_source: SeedSource = SeedSource.direct_llm, - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", - model: str = "nvidia/nemotron-3-super", - gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", - gliner_model: str = "nvidia/gliner-pii", + endpoint: str = _UNCONFIGURED_ENDPOINT, + model: str = _UNCONFIGURED_MODEL, + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, + gliner_model: str = _UNCONFIGURED_GLINER_MODEL, gliner_api_key_env: str = "NVIDIA_API_KEY", gliner_threshold: float = 0.3, max_tokens: int = 4096, @@ -338,10 +347,10 @@ def execute_staged_detection_case( client: DirectDetectionClient, seed_client: GlinerSeedClient | None = None, seed_source: SeedSource = SeedSource.direct_llm, - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", - model: str = "nvidia/nemotron-3-super", - gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", - gliner_model: str = "nvidia/gliner-pii", + endpoint: str = _UNCONFIGURED_ENDPOINT, + model: str = _UNCONFIGURED_MODEL, + gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, + gliner_model: str = _UNCONFIGURED_GLINER_MODEL, gliner_api_key_env: str = "NVIDIA_API_KEY", gliner_threshold: float = 0.3, max_tokens: int = 4096, @@ -1312,11 +1321,11 @@ def run_probe( labels: list[str], output: Path | None = None, overwrite: bool = False, - endpoint: str = "http://gpu-dev-pod-serve-svc:8000/v1", - model: str = "nvidia/nemotron-3-super", + endpoint: str | None = None, + model: str | None = None, seed_source: SeedSource = SeedSource.direct_llm, - gliner_endpoint: str = "https://integrate.api.nvidia.com/v1", - gliner_model: str = "nvidia/gliner-pii", + gliner_endpoint: str | None = None, + gliner_model: str | None = None, gliner_api_key_env: str = "NVIDIA_API_KEY", gliner_threshold: float = 0.3, skip_augmentation: bool = False, @@ -1328,6 +1337,18 @@ def run_probe( offset: int = 0, baseline_artifacts: Path | None = None, ) -> StagedDetectionRun: + endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) + model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) + if seed_source == SeedSource.gliner: + gliner_endpoint = _required_runtime_value( + "gliner_endpoint", explicit=gliner_endpoint, env_var=_GLINER_ENDPOINT_ENV + ) + gliner_model = _required_runtime_value("gliner_model", explicit=gliner_model, env_var=_GLINER_MODEL_ENV) + if not os.environ.get(gliner_api_key_env): + raise ValueError(f"{gliner_api_key_env} is not set") + else: + gliner_endpoint = gliner_endpoint or _UNCONFIGURED_GLINER_ENDPOINT + gliner_model = gliner_model or _UNCONFIGURED_GLINER_MODEL requests = _load_requests(input_path, text_column=text_column, labels=labels, limit=limit, offset=offset) config = _execution_config_from_params(locals()) cases = _run_probe_cases(requests, config) @@ -1341,6 +1362,13 @@ def run_probe( return result +def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: + value = explicit or os.environ.get(env_var) + if not value: + raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") + return value + + def _run_probe_cases( requests: list[StagedDetectionRequest], config: StagedExecutionConfig, @@ -1431,11 +1459,11 @@ def main( labels: Annotated[str, cyclopts.Parameter("--labels")], output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, - endpoint: Annotated[str, cyclopts.Parameter("--endpoint")] = "http://gpu-dev-pod-serve-svc:8000/v1", - model: Annotated[str, cyclopts.Parameter("--model")] = "nvidia/nemotron-3-super", + endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, + model: Annotated[str | None, cyclopts.Parameter("--model")] = None, seed_source: Annotated[SeedSource, cyclopts.Parameter("--seed-source")] = SeedSource.direct_llm, - gliner_endpoint: Annotated[str, cyclopts.Parameter("--gliner-endpoint")] = "https://integrate.api.nvidia.com/v1", - gliner_model: Annotated[str, cyclopts.Parameter("--gliner-model")] = "nvidia/gliner-pii", + gliner_endpoint: Annotated[str | None, cyclopts.Parameter("--gliner-endpoint")] = None, + gliner_model: Annotated[str | None, cyclopts.Parameter("--gliner-model")] = None, gliner_api_key_env: Annotated[str, cyclopts.Parameter("--gliner-api-key-env")] = "NVIDIA_API_KEY", gliner_threshold: Annotated[float, cyclopts.Parameter("--gliner-threshold")] = 0.3, skip_augmentation: Annotated[bool, cyclopts.Parameter("--skip-augmentation")] = False, From b157a1dd19d2f301478b1112a8e0a77134fde995 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 21:54:23 +0000 Subject: [PATCH 06/26] Add richer benchmark quality metrics Signed-off-by: Aaron Gonzales --- src/anonymizer/measurement.py | 140 ++++++- tests/test_measurement.py | 59 +++ tests/tools/test_benchmark_output_analysis.py | 144 +++++++ tools/measurement/README.md | 39 +- tools/measurement/analyze_benchmark_output.py | 394 +++++++++++++++++- 5 files changed, 756 insertions(+), 20 deletions(-) diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 510eb988..d6c77be3 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -37,6 +37,28 @@ default=None, ) _GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") +_ENTITY_LABEL_EQUIVALENCE_CLASSES = ( + frozenset( + { + "access_token", + "api_key", + "auth_token", + "bearer_token", + "password", + "secret_key", + "session_id", + "unique_id", + "user_id", + } + ), + frozenset({"full_name", "person_name", "user", "user_name", "username"}), + frozenset({"phone", "phone_number", "telephone"}), + frozenset({"email", "email_address"}), + frozenset({"cookie", "http_cookie", "session_cookie"}), +) +_ENTITY_LABEL_EQUIVALENCE: dict[str, str] = { + label: sorted(labels)[0] for labels in _ENTITY_LABEL_EQUIVALENCE_CLASSES for label in labels +} MEASUREMENT_SCHEMA_VERSION = 1 DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" DD_TRACE_MODES = {"none", "last_message", "all_messages"} @@ -1102,6 +1124,16 @@ def _entity_ground_truth_metrics( "entity_precision": None, "entity_recall": None, "entity_f1": None, + "entity_relaxed_gt_found_count": None, + "entity_relaxed_detected_tp_count": None, + "entity_relaxed_label_compatible_gt_found_count": None, + "entity_relaxed_label_compatible_detected_tp_count": None, + "entity_relaxed_precision": None, + "entity_relaxed_recall": None, + "entity_relaxed_f1": None, + "entity_relaxed_label_compatible_precision": None, + "entity_relaxed_label_compatible_recall": None, + "entity_relaxed_label_compatible_f1": None, } predicted = _entity_identity_set(final_entities) @@ -1111,11 +1143,6 @@ def _entity_ground_truth_metrics( false_negative = len(expected - predicted) precision = _safe_ratio(true_positive, true_positive + false_positive) recall = _safe_ratio(true_positive, true_positive + false_negative) - f1 = ( - None - if precision is None or recall is None or precision + recall == 0 - else 2 * precision * recall / (precision + recall) - ) return { "ground_truth_entity_count": len(ground_truth_entities), "ground_truth_entity_label_counts": dict( @@ -1126,7 +1153,8 @@ def _entity_ground_truth_metrics( "entity_false_negative_count": false_negative, "entity_precision": precision, "entity_recall": recall, - "entity_f1": f1, + "entity_f1": _f1(precision, recall), + **_entity_relaxed_ground_truth_metrics(final_entities, ground_truth_entities), } @@ -1141,6 +1169,100 @@ def _entity_identity_set(entities: list[dict[str, Any]]) -> set[tuple[str, str]] return identities +def _entity_relaxed_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]], +) -> dict[str, Any]: + gt_found = sum( + 1 for ground_truth_entity in ground_truth_entities if _has_relaxed_entity_match(final_entities, ground_truth_entity) + ) + detected_tp = sum( + 1 for final_entity in final_entities if _has_relaxed_entity_match(ground_truth_entities, final_entity) + ) + label_compatible_gt_found = sum( + 1 + for ground_truth_entity in ground_truth_entities + if _has_relaxed_entity_match(final_entities, ground_truth_entity, require_label_compatible=True) + ) + label_compatible_detected_tp = sum( + 1 + for final_entity in final_entities + if _has_relaxed_entity_match(ground_truth_entities, final_entity, require_label_compatible=True) + ) + precision = _safe_ratio(detected_tp, len(final_entities)) + recall = _safe_ratio(gt_found, len(ground_truth_entities)) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, len(final_entities)) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, len(ground_truth_entities)) + return { + "entity_relaxed_gt_found_count": gt_found, + "entity_relaxed_detected_tp_count": detected_tp, + "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, + "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, + "entity_relaxed_precision": precision, + "entity_relaxed_recall": recall, + "entity_relaxed_f1": _f1(precision, recall), + "entity_relaxed_label_compatible_precision": label_compatible_precision, + "entity_relaxed_label_compatible_recall": label_compatible_recall, + "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), + } + + +def _has_relaxed_entity_match( + candidates: list[dict[str, Any]], + target: dict[str, Any], + *, + require_label_compatible: bool = False, +) -> bool: + return any( + _entities_match_relaxed(candidate, target, require_label_compatible=require_label_compatible) + for candidate in candidates + ) + + +def _entities_match_relaxed( + left: dict[str, Any], + right: dict[str, Any], + *, + require_label_compatible: bool, +) -> bool: + if require_label_compatible and not _entity_labels_compatible(left.get("label"), right.get("label")): + return False + left_span = _entity_span(left) + right_span = _entity_span(right) + if left_span is not None and right_span is not None: + return left_span[0] < right_span[1] and right_span[0] < left_span[1] + left_value = left.get("value") + right_value = right.get("value") + return left_value is not None and right_value is not None and str(left_value) == str(right_value) + + +def _entity_span(entity: dict[str, Any]) -> tuple[int, int] | None: + start = _coerce_float(entity.get("start_position", entity.get("start"))) + end = _coerce_float(entity.get("end_position", entity.get("end"))) + if start is None or end is None: + return None + start_int = int(start) + end_int = int(end) + if start_int < 0 or end_int <= start_int: + return None + return start_int, end_int + + +def _entity_labels_compatible(left: object, right: object) -> bool: + left_key = _entity_label_key(left) + right_key = _entity_label_key(right) + return left_key is not None and right_key is not None and left_key == right_key + + +def _entity_label_key(label: object) -> str | None: + if label is None: + return None + normalized = str(label).strip().lower() + if not normalized: + return None + return _ENTITY_LABEL_EQUIVALENCE.get(normalized, normalized) + + def _replacement_map_metrics(raw: object) -> dict[str, Any]: replacement_maps = _replacement_maps_from_raw(raw) synthetic_values = [] @@ -1291,6 +1413,12 @@ def _safe_ratio(numerator: int | float | None, denominator: int | float | None) return float(numerator) / float(denominator) +def _f1(precision: float | None, recall: float | None) -> float | None: + if precision is None or recall is None or precision + recall == 0: + return None + return 2 * precision * recall / (precision + recall) + + def _size_bucket(value: int) -> str: if value == 0: return "0" diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 64f154d4..ebcbe95c 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -800,6 +800,16 @@ def test_record_metrics_capture_generic_counts_without_raw_values() -> None: assert record["entity_precision"] == 0.5 assert record["entity_recall"] == 0.5 assert record["entity_f1"] == 0.5 + assert record["entity_relaxed_gt_found_count"] == 2 + assert record["entity_relaxed_detected_tp_count"] == 2 + assert record["entity_relaxed_label_compatible_gt_found_count"] == 2 + assert record["entity_relaxed_label_compatible_detected_tp_count"] == 2 + assert record["entity_relaxed_precision"] == 1.0 + assert record["entity_relaxed_recall"] == 1.0 + assert record["entity_relaxed_f1"] == 1.0 + assert record["entity_relaxed_label_compatible_precision"] == 1.0 + assert record["entity_relaxed_label_compatible_recall"] == 1.0 + assert record["entity_relaxed_label_compatible_f1"] == 1.0 assert record["replacement_count"] == 2 assert record["replacement_label_counts"] == {"company_name": 1, "first_name": 1} assert record["replacement_duplicate_value_count"] == 1 @@ -824,6 +834,55 @@ def test_record_metrics_capture_generic_counts_without_raw_values() -> None: assert "Maya" not in serialized +def test_record_metrics_capture_relaxed_gt_label_equivalence_without_raw_values() -> None: + final_entities = { + "entities": [ + {"value": "builduser42", "label": "user_name", "start_position": 4, "end_position": 15}, + ] + } + ground_truth_entities = { + "entities": [ + {"value": "legacy-user", "label": "username", "start_position": 6, "end_position": 14}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["ssh builduser42@host"], + COL_FINAL_ENTITIES: [final_entities], + "ground_truth_entities": [ground_truth_entities], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=2, + ) + + record = collector.records[0] + assert record["entity_true_positive_count"] == 0 + assert record["entity_false_positive_count"] == 1 + assert record["entity_false_negative_count"] == 1 + assert record["entity_relaxed_gt_found_count"] == 1 + assert record["entity_relaxed_detected_tp_count"] == 1 + assert record["entity_relaxed_label_compatible_gt_found_count"] == 1 + assert record["entity_relaxed_label_compatible_detected_tp_count"] == 1 + assert record["entity_relaxed_precision"] == 1.0 + assert record["entity_relaxed_recall"] == 1.0 + assert record["entity_relaxed_f1"] == 1.0 + assert record["entity_relaxed_label_compatible_precision"] == 1.0 + assert record["entity_relaxed_label_compatible_recall"] == 1.0 + assert record["entity_relaxed_label_compatible_f1"] == 1.0 + + serialized = json.dumps(collector.records) + assert "builduser42" not in serialized + assert "legacy-user" not in serialized + + def test_record_metrics_counts_missing_replacement_map_entries_without_raw_values() -> None: final_entities = { "entities": [ diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index 3ebae5eb..9f309c73 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -84,10 +84,41 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp "run_tags": { "suite_id": "suite", "workload_id": "bio", + "workload_category": "synthetic_biography", "config_id": "default", "experimental_detection_strategy": "default", "experimental_replacement_strategy": "default", "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "topology_endpoint_count": 2, + "topology_gpu_count": 4, + "topology_tensor_parallelism": 2, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "stage", + "run_id": "bio__default__r000", + "stage": "Anonymizer._run_internal", + "elapsed_sec": 10.0, + "status": "completed", + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "topology_endpoint_count": 2, + "topology_gpu_count": 4, + "topology_tensor_parallelism": 2, "repetition": 0, "case_id": "bio__default__r000", }, @@ -95,7 +126,16 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp { "record_type": "record", "run_id": "bio__default__r000", + "text_length_tokens": 1200, "final_entity_count": 14, + "ground_truth_entity_count": 20, + "entity_true_positive_count": 10, + "entity_false_positive_count": 4, + "entity_false_negative_count": 10, + "entity_relaxed_gt_found_count": 15, + "entity_relaxed_detected_tp_count": 14, + "entity_relaxed_label_compatible_gt_found_count": 13, + "entity_relaxed_label_compatible_detected_tp_count": 12, "replacement_count": 12, "replacement_missing_final_entity_count": 2, "replacement_missing_final_entity_label_counts": {"date": 2}, @@ -108,10 +148,51 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp "run_tags": { "suite_id": "suite", "workload_id": "bio", + "workload_category": "synthetic_biography", + "config_id": "default", + "experimental_detection_strategy": "default", + "experimental_replacement_strategy": "default", + "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, + "repetition": 0, + "case_id": "bio__default__r000", + }, + }, + { + "record_type": "record", + "run_id": "bio__default__r000", + "text_length_tokens": 300, + "final_entity_count": 0, + "ground_truth_entity_count": 2, + "entity_true_positive_count": 0, + "entity_false_positive_count": 0, + "entity_false_negative_count": 2, + "entity_relaxed_gt_found_count": 0, + "entity_relaxed_detected_tp_count": 0, + "entity_relaxed_label_compatible_gt_found_count": 0, + "entity_relaxed_label_compatible_detected_tp_count": 0, + "replacement_count": 0, + "replacement_missing_final_entity_count": 0, + "replacement_missing_final_entity_label_counts": {}, + "replacement_missing_final_value_count": 0, + "replacement_synthetic_original_collision_count": 0, + "replacement_synthetic_original_collision_label_counts": {}, + "replacement_synthetic_original_collision_value_count": 0, + "original_value_leak_count": 0, + "original_value_leak_label_counts": {}, + "run_tags": { + "suite_id": "suite", + "workload_id": "bio", + "workload_category": "synthetic_biography", "config_id": "default", "experimental_detection_strategy": "default", "experimental_replacement_strategy": "default", "dd_parser_compat": "raw_json", + "entity_label_set_id": "agent", + "entity_label_count": 4, + "gliner_threshold": 0.3, "repetition": 0, "case_id": "bio__default__r000", }, @@ -119,6 +200,7 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp { "record_type": "record", "run_id": "shell__rules-only__r000", + "text_length_tokens": 750, "final_entity_count": 8, "replacement_count": 8, "replacement_missing_final_entity_count": 0, @@ -261,6 +343,10 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert result.model_usage_count == 2 assert result.model_usage_group_count == 2 cases = {row.case_id: row for row in result.cases} + assert cases["bio__default__r000"].workload_category == "synthetic_biography" + assert cases["bio__default__r000"].entity_label_set_id == "agent" + assert cases["bio__default__r000"].entity_label_count == 4 + assert cases["bio__default__r000"].gliner_threshold == pytest.approx(0.3) assert cases["bio__default__r000"].experimental_replacement_strategy == "default" assert cases["bio__default__r000"].observed_total_requests == 4 assert cases["bio__default__r000"].observed_successful_requests == 3 @@ -275,6 +361,36 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert cases["bio__default__r000"].observed_non_bridge_total_requests == 3 assert cases["bio__default__r000"].observed_non_bridge_failed_requests == 0 assert cases["bio__default__r000"].observed_non_bridge_failed_request_rate == 0 + assert cases["bio__default__r000"].record_count == 2 + assert cases["bio__default__r000"].input_text_tokens_total == 1500 + assert cases["bio__default__r000"].records_per_pipeline_sec == pytest.approx(0.2) + assert cases["bio__default__r000"].records_per_ndd_sec == pytest.approx(2 / 8.5) + assert cases["bio__default__r000"].input_text_tokens_per_pipeline_sec == 150 + assert cases["bio__default__r000"].input_text_tokens_per_ndd_sec == pytest.approx(1500 / 8.5) + assert cases["bio__default__r000"].topology_endpoint_count == 2 + assert cases["bio__default__r000"].topology_gpu_count == 4 + assert cases["bio__default__r000"].topology_tensor_parallelism == 2 + assert cases["bio__default__r000"].input_text_tokens_per_endpoint_sec == 75 + assert cases["bio__default__r000"].input_text_tokens_per_gpu_sec == 37.5 + assert cases["bio__default__r000"].empty_detection_count == 1 + assert cases["bio__default__r000"].empty_detection_rate == 0.5 + assert cases["bio__default__r000"].empty_detection_with_ground_truth_count == 1 + assert cases["bio__default__r000"].empty_detection_with_ground_truth_rate == 0.5 + assert cases["bio__default__r000"].ground_truth_record_count == 2 + assert cases["bio__default__r000"].ground_truth_entity_count == 22 + assert cases["bio__default__r000"].entity_true_positive_count == 10 + assert cases["bio__default__r000"].entity_false_positive_count == 4 + assert cases["bio__default__r000"].entity_false_negative_count == 12 + assert cases["bio__default__r000"].entity_precision == pytest.approx(10 / 14) + assert cases["bio__default__r000"].entity_recall == pytest.approx(10 / 22) + assert cases["bio__default__r000"].entity_relaxed_gt_found_count == 15 + assert cases["bio__default__r000"].entity_relaxed_detected_tp_count == 14 + assert cases["bio__default__r000"].entity_relaxed_label_compatible_gt_found_count == 13 + assert cases["bio__default__r000"].entity_relaxed_label_compatible_detected_tp_count == 12 + assert cases["bio__default__r000"].entity_relaxed_precision == 1.0 + assert cases["bio__default__r000"].entity_relaxed_recall == pytest.approx(15 / 22) + assert cases["bio__default__r000"].entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) + assert cases["bio__default__r000"].entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) assert cases["bio__default__r000"].validation_max_entities_per_call == 10 assert cases["bio__default__r000"].original_value_leak_count == 0 assert cases["bio__default__r000"].original_value_leak_record_count == 0 @@ -342,6 +458,10 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert nemotron_group.observed_failed_request_rate == pytest.approx(1 / 3) assert nemotron_group.median_observed_total_requests == 3 bio_group = next(group for group in result.groups if group.workload_id == "bio") + assert bio_group.workload_category == "synthetic_biography" + assert bio_group.entity_label_set_id == "agent" + assert bio_group.entity_label_count == 4 + assert bio_group.gliner_threshold == pytest.approx(0.3) assert bio_group.experimental_replacement_strategy == "default" assert bio_group.median_observed_bridge_fallback_requests == 1 assert bio_group.median_observed_non_bridge_total_requests == 3 @@ -353,6 +473,30 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert bio_group.median_replacement_synthetic_original_collision_count == 1 assert bio_group.median_replacement_synthetic_original_collision_value_count == 1 assert bio_group.replacement_synthetic_original_collision_label_counts == {"date": 1} + assert bio_group.total_record_count == 2 + assert bio_group.total_input_text_tokens == 1500 + assert bio_group.median_input_text_tokens_per_pipeline_sec == 150 + assert bio_group.median_input_text_tokens_per_endpoint_sec == 75 + assert bio_group.median_input_text_tokens_per_gpu_sec == 37.5 + assert bio_group.total_empty_detection_count == 1 + assert bio_group.empty_detection_rate == 0.5 + assert bio_group.total_empty_detection_with_ground_truth_count == 1 + assert bio_group.empty_detection_with_ground_truth_rate == 0.5 + assert bio_group.total_ground_truth_record_count == 2 + assert bio_group.sum_ground_truth_entity_count == 22 + assert bio_group.sum_entity_true_positive_count == 10 + assert bio_group.sum_entity_false_positive_count == 4 + assert bio_group.sum_entity_false_negative_count == 12 + assert bio_group.micro_entity_precision == pytest.approx(10 / 14) + assert bio_group.micro_entity_recall == pytest.approx(10 / 22) + assert bio_group.sum_entity_relaxed_gt_found_count == 15 + assert bio_group.sum_entity_relaxed_detected_tp_count == 14 + assert bio_group.sum_entity_relaxed_label_compatible_gt_found_count == 13 + assert bio_group.sum_entity_relaxed_label_compatible_detected_tp_count == 12 + assert bio_group.micro_entity_relaxed_precision == 1.0 + assert bio_group.micro_entity_relaxed_recall == pytest.approx(15 / 22) + assert bio_group.micro_entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) + assert bio_group.micro_entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) shell_group = next(group for group in result.groups if group.workload_id == "shell") assert shell_group.experimental_replacement_strategy == "local_structured_substitute" assert shell_group.sum_original_value_leak_count == 1 diff --git a/tools/measurement/README.md b/tools/measurement/README.md index f101d54d..f0cc8ffe 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -828,13 +828,22 @@ and token counts still contribute to case, group, and model summaries. When the benchmark was run with current sidecar export, `rules_only` also has artifact-derived signatures and source counts; older runs may only have record-level entity counts. The joined case/group tables include -successful/failed request counts, input/output token splits, -`seed_validation_candidate_count`, `estimated_seed_validation_chunk_count`, and -`observed_failed_request_rate`; use these when testing +successful/failed request counts, input/output token splits, record counts, +dataset input-token throughput, `seed_validation_candidate_count`, +`estimated_seed_validation_chunk_count`, and `observed_failed_request_rate`; +use these when testing `detect.validation_max_entities_per_call` so you can distinguish a real chunk count change from provider retry variance. The model tables split the same usage by `workflow_name` and `model_name`, which is useful for separating local detector cost from validator, augmenter, substitute, or rewrite model cost. +When record-level measurements include ground-truth entities, the joined tables +also expose exact and relaxed entity-quality metrics. The relaxed metrics count +span overlap, with small label-equivalence groups for common aliases such as +`user_name` / `username` and `api_key` / `auth_token`. Case and group tables +also count empty detections, including empty records that had ground-truth +entities. If your suite adds portable topology tags such as `endpoint_count`, +`gpu_count`, or `tensor_parallelism`, the analysis computes per-endpoint and +per-GPU input-token throughput; otherwise those normalized fields remain null. The case/group tables also surface incomplete benchmark cases with `case_failed`, `error_stage_count`, `error_ndd_workflow_count`, `error_model_workflow_count`, `failed_case_count`, and `failed_case_rate`. @@ -1842,6 +1851,16 @@ Latency and throughput: - `tokens_per_sec`: observed total tokens per second when token usage exists. - `text_length_tokens_bucket`: a coarse text-size bucket for comparing similar inputs without storing text. +- `record_count` and `input_text_tokens_total`: case-level workload size + derived from record measurements. These are independent of provider-reported + token usage. +- `records_per_pipeline_sec` and `input_text_tokens_per_pipeline_sec`: dataset + throughput normalized by the measured Anonymizer pipeline stage. The matching + `*_per_ndd_sec` fields use summed DataDesigner workflow wall time instead. +- `input_text_tokens_per_endpoint_sec` and + `input_text_tokens_per_gpu_sec`: optional topology-normalized dataset + throughput. These are populated only when benchmark run tags provide portable + topology counts such as `endpoint_count` or `gpu_count`. LLM usage: @@ -1942,8 +1961,18 @@ Entity and quality metrics: it can replace a DataDesigner-backed baseline. - `final_entity_label_counts`: per-label entity counts serialized as JSON in exported tabular files. -- `ground_truth_*`: precision, recall, F1, false positives, and false negatives - when the input includes one of the supported ground-truth entity columns. +- `ground_truth_*` and `entity_*`: exact value+label precision, recall, F1, + false positives, and false negatives when the input includes one of the + supported ground-truth entity columns. +- `entity_relaxed_*`: span-overlap precision, recall, and F1. The + label-compatible variants require both span overlap and equivalent labels, + while the non-label-compatible relaxed metrics only ask whether a + ground-truth span was protected by any detected span. +- `empty_detection_count`, `empty_detection_rate`, + `empty_detection_with_ground_truth_count`, and + `empty_detection_with_ground_truth_rate`: diagnostics for records where the + detector returned no final entities. The ground-truth-specific fields are the + important safety signal when a benchmark includes labels. - `utility_score`, `leakage_mass`, `weighted_leakage_rate`, `needs_repair`, and `needs_human_review`: rewrite-mode evaluation fields. These are null for replace-mode runs. diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py index ec93b0cb..f4768de6 100644 --- a/tools/measurement/analyze_benchmark_output.py +++ b/tools/measurement/analyze_benchmark_output.py @@ -18,7 +18,7 @@ import sys from enum import StrEnum from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast import cyclopts import pandas as pd @@ -55,10 +55,14 @@ class LogFormat(StrEnum): class CaseAnalysisRow(BaseModel): suite_id: str | None = None workload_id: str | None = None + workload_category: str | None = None config_id: str | None = None experimental_detection_strategy: str | None = None experimental_replacement_strategy: str | None = None dd_parser_compat: str | None = None + entity_label_set_id: str | None = None + entity_label_count: int | None = None + gliner_threshold: float | None = None repetition: int | None = None case_id: str run_id: str @@ -83,10 +87,44 @@ class CaseAnalysisRow(BaseModel): observed_non_bridge_total_requests: int | None = None observed_non_bridge_failed_requests: int | None = None observed_non_bridge_failed_request_rate: float | None = None + record_count: int = 0 + input_text_tokens_total: int | None = None + records_per_pipeline_sec: float | None = None + records_per_ndd_sec: float | None = None + input_text_tokens_per_pipeline_sec: float | None = None + input_text_tokens_per_ndd_sec: float | None = None + topology_endpoint_count: float | None = None + topology_gpu_count: float | None = None + topology_tensor_parallelism: float | None = None + topology_shard_count: float | None = None + input_text_tokens_per_endpoint_sec: float | None = None + input_text_tokens_per_gpu_sec: float | None = None route_total_row_count: float | None = None route_rule_row_count: float | None = None route_fallback_row_count: float | None = None final_entity_count: float | None = None + empty_detection_count: int = 0 + empty_detection_rate: float | None = None + empty_detection_with_ground_truth_count: int = 0 + empty_detection_with_ground_truth_rate: float | None = None + ground_truth_record_count: int = 0 + ground_truth_entity_count: float | None = None + entity_true_positive_count: float | None = None + entity_false_positive_count: float | None = None + entity_false_negative_count: float | None = None + entity_precision: float | None = None + entity_recall: float | None = None + entity_f1: float | None = None + entity_relaxed_gt_found_count: float | None = None + entity_relaxed_detected_tp_count: float | None = None + entity_relaxed_label_compatible_gt_found_count: float | None = None + entity_relaxed_label_compatible_detected_tp_count: float | None = None + entity_relaxed_precision: float | None = None + entity_relaxed_recall: float | None = None + entity_relaxed_f1: float | None = None + entity_relaxed_label_compatible_precision: float | None = None + entity_relaxed_label_compatible_recall: float | None = None + entity_relaxed_label_compatible_f1: float | None = None replacement_count: float | None = None replacement_missing_final_entity_count: float | None = None replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) @@ -116,9 +154,13 @@ class CaseAnalysisRow(BaseModel): class GroupAnalysisRow(BaseModel): workload_id: str | None = None + workload_category: str | None = None config_id: str | None = None experimental_detection_strategy: str | None = None experimental_replacement_strategy: str | None = None + entity_label_set_id: str | None = None + entity_label_count: int | None = None + gliner_threshold: float | None = None case_count: int failed_case_count: int = 0 failed_case_rate: float | None = None @@ -138,10 +180,48 @@ class GroupAnalysisRow(BaseModel): median_observed_non_bridge_total_requests: float | None = None median_observed_non_bridge_failed_requests: float | None = None median_observed_non_bridge_failed_request_rate: float | None = None + total_record_count: int = 0 + median_record_count: float | None = None + total_input_text_tokens: int | None = None + median_input_text_tokens_total: float | None = None + median_records_per_pipeline_sec: float | None = None + median_records_per_ndd_sec: float | None = None + median_input_text_tokens_per_pipeline_sec: float | None = None + median_input_text_tokens_per_ndd_sec: float | None = None + median_topology_endpoint_count: float | None = None + median_topology_gpu_count: float | None = None + median_topology_tensor_parallelism: float | None = None + median_topology_shard_count: float | None = None + median_input_text_tokens_per_endpoint_sec: float | None = None + median_input_text_tokens_per_gpu_sec: float | None = None median_route_total_row_count: float | None = None median_route_rule_row_count: float | None = None median_route_fallback_row_count: float | None = None median_final_entity_count: float | None = None + total_empty_detection_count: int = 0 + empty_detection_rate: float | None = None + total_empty_detection_with_ground_truth_count: int = 0 + empty_detection_with_ground_truth_rate: float | None = None + total_ground_truth_record_count: int = 0 + sum_ground_truth_entity_count: float | None = None + sum_entity_true_positive_count: float | None = None + sum_entity_false_positive_count: float | None = None + sum_entity_false_negative_count: float | None = None + micro_entity_precision: float | None = None + micro_entity_recall: float | None = None + micro_entity_f1: float | None = None + sum_entity_relaxed_gt_found_count: float | None = None + sum_entity_relaxed_detected_tp_count: float | None = None + sum_entity_relaxed_label_compatible_gt_found_count: float | None = None + sum_entity_relaxed_label_compatible_detected_tp_count: float | None = None + micro_entity_relaxed_precision: float | None = None + micro_entity_relaxed_recall: float | None = None + micro_entity_relaxed_f1: float | None = None + micro_entity_relaxed_label_compatible_precision: float | None = None + micro_entity_relaxed_label_compatible_recall: float | None = None + micro_entity_relaxed_label_compatible_f1: float | None = None + median_entity_relaxed_f1: float | None = None + median_entity_relaxed_label_compatible_f1: float | None = None median_replacement_missing_final_entity_count: float | None = None median_replacement_missing_final_value_count: float | None = None replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) @@ -373,11 +453,27 @@ def _build_case_row( pipeline_rows = _pipeline_stage_rows(measurement_rows) validation_max_entities_per_call = _first_int([measurement_rows], ["detect.validation_max_entities_per_call"]) request_metrics = _case_request_metrics(model_rows) + pipeline_elapsed_sec = _sum_or_none(pipeline_rows, "elapsed_sec") + ndd_elapsed_sec_total = _sum_or_zero(ndd_rows, "elapsed_sec") + record_count = len(record_rows) + input_text_tokens_total = _sum_int_or_none(record_rows, "text_length_tokens") + records_per_pipeline_sec = _safe_rate(record_count, pipeline_elapsed_sec) + records_per_ndd_sec = _safe_rate(record_count, ndd_elapsed_sec_total) + input_text_tokens_per_pipeline_sec = _safe_rate(input_text_tokens_total, pipeline_elapsed_sec) + input_text_tokens_per_ndd_sec = _safe_rate(input_text_tokens_total, ndd_elapsed_sec_total) + final_entity_count = _coalesce_number( + _sum_or_none(record_rows, "final_entity_count"), + _sum_or_none(artifact_rows, "final_entity_count"), + ) return CaseAnalysisRow( suite_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.suite_id", "suite_id"]), workload_id=_first_value( [measurement_rows, artifact_rows, trace_rows], ["run_tags.workload_id", "workload_id"] ), + workload_category=_first_value( + [measurement_rows, artifact_rows, trace_rows], + ["run_tags.workload_category", "run_tags.dataset_category", "workload_category", "dataset_category"], + ), config_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_tags.config_id", "config_id"]), experimental_detection_strategy=_first_value([measurement_rows], ["run_tags.experimental_detection_strategy"]), experimental_replacement_strategy=_first_value( @@ -385,22 +481,42 @@ def _build_case_row( ["run_tags.experimental_replacement_strategy"], ), dd_parser_compat=_first_value([measurement_rows], ["run_tags.dd_parser_compat"]), + entity_label_set_id=_first_value( + [measurement_rows], + [ + "run_tags.entity_label_set_id", + "run_tags.entity_label_set", + "run_tags.label_set", + "detect.entity_label_source", + ], + ), + entity_label_count=_first_int([measurement_rows], ["run_tags.entity_label_count", "detect.entity_label_count"]), + gliner_threshold=_first_float([measurement_rows], ["run_tags.gliner_threshold", "detect.gliner_threshold"]), repetition=_first_int([measurement_rows, artifact_rows, trace_rows], ["run_tags.repetition", "repetition"]), case_id=case_id, run_id=_first_value([measurement_rows, artifact_rows, trace_rows], ["run_id"]) or case_id, **_case_failure_metrics(stage_rows=stage_rows, ndd_rows=ndd_rows, model_rows=model_rows), - pipeline_elapsed_sec=_sum_or_none(pipeline_rows, "elapsed_sec"), + pipeline_elapsed_sec=pipeline_elapsed_sec, ndd_workflow_count=len(ndd_rows), - ndd_elapsed_sec_total=_sum_or_zero(ndd_rows, "elapsed_sec"), + ndd_elapsed_sec_total=ndd_elapsed_sec_total, **request_metrics, **_case_trace_metrics(trace_rows, request_metrics=request_metrics), + record_count=record_count, + input_text_tokens_total=input_text_tokens_total, + records_per_pipeline_sec=records_per_pipeline_sec, + records_per_ndd_sec=records_per_ndd_sec, + input_text_tokens_per_pipeline_sec=input_text_tokens_per_pipeline_sec, + input_text_tokens_per_ndd_sec=input_text_tokens_per_ndd_sec, + **_case_topology_metrics( + measurement_rows, + input_text_tokens_per_pipeline_sec=input_text_tokens_per_pipeline_sec, + ), route_total_row_count=_sum_or_none(model_rows, "route_total_row_count"), route_rule_row_count=_sum_or_none(model_rows, "route_rule_row_count"), route_fallback_row_count=_sum_or_none(model_rows, "route_fallback_row_count"), - final_entity_count=_coalesce_number( - _sum_or_none(record_rows, "final_entity_count"), - _sum_or_none(artifact_rows, "final_entity_count"), - ), + final_entity_count=final_entity_count, + **_case_empty_detection_metrics(record_rows, record_count=record_count), + **_case_ground_truth_metrics(record_rows, final_entity_count=final_entity_count), replacement_count=_sum_or_none(record_rows, "replacement_count"), replacement_missing_final_entity_count=_sum_or_none(record_rows, "replacement_missing_final_entity_count"), replacement_missing_final_entity_label_counts=_sum_prefixed_ints( @@ -504,6 +620,115 @@ def _case_failure_metrics( } +def _case_topology_metrics( + measurement_rows: pd.DataFrame, + *, + input_text_tokens_per_pipeline_sec: float | None, +) -> dict[str, float | None]: + endpoint_count = _first_float( + [measurement_rows], + [ + "run_tags.topology_endpoint_count", + "run_tags.endpoint_count", + "run_tags.n_endpoints", + "run_tags.n_llm_endpoints", + ], + ) + gpu_count = _first_float( + [measurement_rows], + [ + "run_tags.topology_gpu_count", + "run_tags.gpu_count", + "run_tags.n_gpus", + "run_tags.n_llm_gpus", + ], + ) + tensor_parallelism = _first_float( + [measurement_rows], + [ + "run_tags.topology_tensor_parallelism", + "run_tags.tensor_parallelism", + "run_tags.gpus_per_endpoint", + "run_tags.tp", + ], + ) + shard_count = _first_float( + [measurement_rows], + ["run_tags.topology_shard_count", "run_tags.shard_count", "run_tags.n_shards"], + ) + return { + "topology_endpoint_count": endpoint_count, + "topology_gpu_count": gpu_count, + "topology_tensor_parallelism": tensor_parallelism, + "topology_shard_count": shard_count, + "input_text_tokens_per_endpoint_sec": _safe_ratio(input_text_tokens_per_pipeline_sec, endpoint_count), + "input_text_tokens_per_gpu_sec": _safe_ratio(input_text_tokens_per_pipeline_sec, gpu_count), + } + + +def _case_empty_detection_metrics(record_rows: pd.DataFrame, *, record_count: int) -> dict[str, int | float | None]: + empty_detection_count = _zero_count(record_rows, "final_entity_count") + ground_truth_record_count = _non_null_count(record_rows, "ground_truth_entity_count") + empty_detection_with_gt_count = _zero_with_positive_count( + record_rows, + zero_column="final_entity_count", + positive_column="ground_truth_entity_count", + ) + return { + "empty_detection_count": empty_detection_count, + "empty_detection_rate": _safe_ratio(empty_detection_count, record_count), + "empty_detection_with_ground_truth_count": empty_detection_with_gt_count, + "empty_detection_with_ground_truth_rate": _safe_ratio( + empty_detection_with_gt_count, + ground_truth_record_count, + ), + "ground_truth_record_count": ground_truth_record_count, + } + + +def _case_ground_truth_metrics( + record_rows: pd.DataFrame, + *, + final_entity_count: float | None, +) -> dict[str, float | None]: + ground_truth_entity_count = _sum_or_none(record_rows, "ground_truth_entity_count") + true_positive = _sum_or_none(record_rows, "entity_true_positive_count") + false_positive = _sum_or_none(record_rows, "entity_false_positive_count") + false_negative = _sum_or_none(record_rows, "entity_false_negative_count") + relaxed_gt_found = _sum_or_none(record_rows, "entity_relaxed_gt_found_count") + relaxed_detected_tp = _sum_or_none(record_rows, "entity_relaxed_detected_tp_count") + label_compatible_gt_found = _sum_or_none(record_rows, "entity_relaxed_label_compatible_gt_found_count") + label_compatible_detected_tp = _sum_or_none( + record_rows, + "entity_relaxed_label_compatible_detected_tp_count", + ) + strict_precision = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_positive)) + strict_recall = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_negative)) + relaxed_precision = _safe_ratio(relaxed_detected_tp, final_entity_count) + relaxed_recall = _safe_ratio(relaxed_gt_found, ground_truth_entity_count) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, final_entity_count) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, ground_truth_entity_count) + return { + "ground_truth_entity_count": ground_truth_entity_count, + "entity_true_positive_count": true_positive, + "entity_false_positive_count": false_positive, + "entity_false_negative_count": false_negative, + "entity_precision": strict_precision, + "entity_recall": strict_recall, + "entity_f1": _f1(strict_precision, strict_recall), + "entity_relaxed_gt_found_count": relaxed_gt_found, + "entity_relaxed_detected_tp_count": relaxed_detected_tp, + "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, + "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, + "entity_relaxed_precision": relaxed_precision, + "entity_relaxed_recall": relaxed_recall, + "entity_relaxed_f1": _f1(relaxed_precision, relaxed_recall), + "entity_relaxed_label_compatible_precision": label_compatible_precision, + "entity_relaxed_label_compatible_recall": label_compatible_recall, + "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), + } + + def _error_status_count(rows: pd.DataFrame) -> int: if "status" not in rows.columns: return 0 @@ -735,6 +960,11 @@ def _first_int(frames: list[pd.DataFrame], columns: list[str]) -> int | None: return int(float(value)) if value is not None else None +def _first_float(frames: list[pd.DataFrame], columns: list[str]) -> float | None: + value = _first_value(frames, columns) + return float(value) if value is not None else None + + def _coalesce_number(*values: float | None) -> float | None: for value in values: if value is not None: @@ -852,6 +1082,27 @@ def _positive_count(dataframe: pd.DataFrame, column: str) -> int: return int((values > 0).sum()) +def _zero_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + return int((values == 0).sum()) + + +def _non_null_count(dataframe: pd.DataFrame, column: str) -> int: + if column not in dataframe.columns: + return 0 + return int(pd.to_numeric(dataframe[column], errors="coerce").notna().sum()) + + +def _zero_with_positive_count(dataframe: pd.DataFrame, *, zero_column: str, positive_column: str) -> int: + if zero_column not in dataframe.columns or positive_column not in dataframe.columns: + return 0 + zero_values = pd.to_numeric(dataframe[zero_column], errors="coerce") + positive_values = pd.to_numeric(dataframe[positive_column], errors="coerce") + return int(((zero_values == 0) & (positive_values > 0)).sum()) + + def _sum_prefixed_ints(dataframe: pd.DataFrame, prefix: str) -> dict[str, int]: totals: dict[str, int] = {} base_column = prefix.removesuffix(".") @@ -890,9 +1141,13 @@ def build_group_rows(cases: list[CaseAnalysisRow]) -> list[GroupAnalysisRow]: rows: list[GroupAnalysisRow] = [] group_columns = [ "workload_id", + "workload_category", "config_id", "experimental_detection_strategy", "experimental_replacement_strategy", + "entity_label_set_id", + "entity_label_count", + "gliner_threshold", ] for keys, group in table.groupby(group_columns, dropna=False): rows.append(_build_group_row(keys, group)) @@ -961,15 +1216,48 @@ def _build_model_usage_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> ) -def _build_group_row(keys: tuple[Any, Any, Any, Any], group: pd.DataFrame) -> GroupAnalysisRow: - workload_id, config_id, detection_strategy, replacement_strategy = keys +def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysisRow: + ( + workload_id, + workload_category, + config_id, + detection_strategy, + replacement_strategy, + entity_label_set_id, + entity_label_count, + gliner_threshold, + ) = keys case_count = int(group["case_id"].nunique()) failed_case_count = _sum_bool_or_zero(group, "case_failed") + total_record_count = _sum_int_or_zero(group, "record_count") + total_input_text_tokens = _sum_int_or_none(group, "input_text_tokens_total") + total_empty_detection_count = _sum_int_or_zero(group, "empty_detection_count") + total_ground_truth_record_count = _sum_int_or_zero(group, "ground_truth_record_count") + total_empty_detection_with_gt_count = _sum_int_or_zero(group, "empty_detection_with_ground_truth_count") + final_entity_count = _sum_or_none(group, "final_entity_count") + ground_truth_entity_count = _sum_or_none(group, "ground_truth_entity_count") + true_positive = _sum_or_none(group, "entity_true_positive_count") + false_positive = _sum_or_none(group, "entity_false_positive_count") + false_negative = _sum_or_none(group, "entity_false_negative_count") + strict_precision = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_positive)) + strict_recall = _safe_ratio(true_positive, _sum_optional_numbers(true_positive, false_negative)) + relaxed_gt_found = _sum_or_none(group, "entity_relaxed_gt_found_count") + relaxed_detected_tp = _sum_or_none(group, "entity_relaxed_detected_tp_count") + label_compatible_gt_found = _sum_or_none(group, "entity_relaxed_label_compatible_gt_found_count") + label_compatible_detected_tp = _sum_or_none(group, "entity_relaxed_label_compatible_detected_tp_count") + relaxed_precision = _safe_ratio(relaxed_detected_tp, final_entity_count) + relaxed_recall = _safe_ratio(relaxed_gt_found, ground_truth_entity_count) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, final_entity_count) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, ground_truth_entity_count) return GroupAnalysisRow( workload_id=_none_if_nan(workload_id), + workload_category=_none_if_nan(workload_category), config_id=_none_if_nan(config_id), experimental_detection_strategy=_none_if_nan(detection_strategy), experimental_replacement_strategy=_none_if_nan(replacement_strategy), + entity_label_set_id=_none_if_nan(entity_label_set_id), + entity_label_count=_int_if_not_nan(entity_label_count), + gliner_threshold=_float_if_not_nan(gliner_threshold), case_count=case_count, failed_case_count=failed_case_count, failed_case_rate=_request_failure_rate(failed=failed_case_count, total=case_count), @@ -992,10 +1280,54 @@ def _build_group_row(keys: tuple[Any, Any, Any, Any], group: pd.DataFrame) -> Gr group, "observed_non_bridge_failed_request_rate", ), + total_record_count=total_record_count, + median_record_count=_median_or_none(group, "record_count"), + total_input_text_tokens=total_input_text_tokens, + median_input_text_tokens_total=_median_or_none(group, "input_text_tokens_total"), + median_records_per_pipeline_sec=_median_or_none(group, "records_per_pipeline_sec"), + median_records_per_ndd_sec=_median_or_none(group, "records_per_ndd_sec"), + median_input_text_tokens_per_pipeline_sec=_median_or_none(group, "input_text_tokens_per_pipeline_sec"), + median_input_text_tokens_per_ndd_sec=_median_or_none(group, "input_text_tokens_per_ndd_sec"), + median_topology_endpoint_count=_median_or_none(group, "topology_endpoint_count"), + median_topology_gpu_count=_median_or_none(group, "topology_gpu_count"), + median_topology_tensor_parallelism=_median_or_none(group, "topology_tensor_parallelism"), + median_topology_shard_count=_median_or_none(group, "topology_shard_count"), + median_input_text_tokens_per_endpoint_sec=_median_or_none(group, "input_text_tokens_per_endpoint_sec"), + median_input_text_tokens_per_gpu_sec=_median_or_none(group, "input_text_tokens_per_gpu_sec"), median_route_total_row_count=_median_or_none(group, "route_total_row_count"), median_route_rule_row_count=_median_or_none(group, "route_rule_row_count"), median_route_fallback_row_count=_median_or_none(group, "route_fallback_row_count"), median_final_entity_count=_median_or_none(group, "final_entity_count"), + total_empty_detection_count=total_empty_detection_count, + empty_detection_rate=_safe_ratio(total_empty_detection_count, total_record_count), + total_empty_detection_with_ground_truth_count=total_empty_detection_with_gt_count, + empty_detection_with_ground_truth_rate=_safe_ratio( + total_empty_detection_with_gt_count, + total_ground_truth_record_count, + ), + total_ground_truth_record_count=total_ground_truth_record_count, + sum_ground_truth_entity_count=ground_truth_entity_count, + sum_entity_true_positive_count=true_positive, + sum_entity_false_positive_count=false_positive, + sum_entity_false_negative_count=false_negative, + micro_entity_precision=strict_precision, + micro_entity_recall=strict_recall, + micro_entity_f1=_f1(strict_precision, strict_recall), + sum_entity_relaxed_gt_found_count=relaxed_gt_found, + sum_entity_relaxed_detected_tp_count=relaxed_detected_tp, + sum_entity_relaxed_label_compatible_gt_found_count=label_compatible_gt_found, + sum_entity_relaxed_label_compatible_detected_tp_count=label_compatible_detected_tp, + micro_entity_relaxed_precision=relaxed_precision, + micro_entity_relaxed_recall=relaxed_recall, + micro_entity_relaxed_f1=_f1(relaxed_precision, relaxed_recall), + micro_entity_relaxed_label_compatible_precision=label_compatible_precision, + micro_entity_relaxed_label_compatible_recall=label_compatible_recall, + micro_entity_relaxed_label_compatible_f1=_f1(label_compatible_precision, label_compatible_recall), + median_entity_relaxed_f1=_median_or_none(group, "entity_relaxed_f1"), + median_entity_relaxed_label_compatible_f1=_median_or_none( + group, + "entity_relaxed_label_compatible_f1", + ), median_replacement_missing_final_entity_count=_median_or_none( group, "replacement_missing_final_entity_count", @@ -1039,6 +1371,18 @@ def _none_if_nan(value: object) -> str | None: return str(value) +def _int_if_not_nan(value: object) -> int | None: + if pd.isna(value): + return None + return int(float(cast(Any, value))) + + +def _float_if_not_nan(value: object) -> float | None: + if pd.isna(value): + return None + return float(cast(Any, value)) + + def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: values = pd.to_numeric(dataframe[column], errors="coerce").dropna() if values.empty: @@ -1069,6 +1413,35 @@ def _request_failure_rate(*, failed: object, total: object) -> float | None: return failed_value / total_value +def _safe_rate(numerator: object, elapsed_sec: object) -> float | None: + numerator_value = _optional_number(numerator) + elapsed_value = _optional_number(elapsed_sec) + if numerator_value is None or elapsed_value is None or elapsed_value <= 0: + return None + return numerator_value / elapsed_value + + +def _safe_ratio(numerator: object, denominator: object) -> float | None: + numerator_value = _optional_number(numerator) + denominator_value = _optional_number(denominator) + if numerator_value is None or denominator_value is None or denominator_value <= 0: + return None + return numerator_value / denominator_value + + +def _sum_optional_numbers(*values: object) -> float | None: + numeric_values = [_optional_number(value) for value in values] + if any(value is None for value in numeric_values): + return None + return sum(cast(float, value) for value in numeric_values) + + +def _f1(precision: float | None, recall: float | None) -> float | None: + if precision is None or recall is None or precision + recall == 0: + return None + return 2 * precision * recall / (precision + recall) + + def _optional_number(value: object) -> float | None: if value is None or pd.isna(value): return None @@ -1161,6 +1534,9 @@ def render_result(result: BenchmarkOutputAnalysis, *, json_output: bool) -> str: f"- {label}: cases={group.case_count}, median_entities={group.median_final_entity_count}, " f"failed_cases={group.failed_case_count}/{group.case_count}, " f"median_requests={group.median_observed_total_requests}, median_tokens={group.median_observed_total_tokens}, " + f"median_input_tok_s={group.median_input_text_tokens_per_pipeline_sec}, " + f"micro_relaxed_f1={group.micro_entity_relaxed_f1}, " + f"empty_with_gt={group.total_empty_detection_with_ground_truth_count}, " f"median_failed_request_rate={group.median_observed_failed_request_rate}, " f"median_aug_new_final={group.median_augmented_new_final_value_count}" ) From b3485829e4bd6346640af30d7366ae80aaeb9ba6 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 22:26:27 +0000 Subject: [PATCH 07/26] Split regex detection into stacked branch Signed-off-by: Aaron Gonzales --- .../engine/detection/detection_workflow.py | 75 +- src/anonymizer/engine/detection/rules.py | 274 --- tests/engine/test_detection_rules.py | 318 --- tests/tools/test_benchmark_output_analysis.py | 154 +- tests/tools/test_compare_strategy_pairs.py | 107 +- tests/tools/test_detection_strategies.py | 1198 +--------- tests/tools/test_extract_signature_deltas.py | 16 +- tests/tools/test_measurement_tools.py | 510 +---- .../tools/test_screen_strategy_comparisons.py | 130 +- .../test_staged_detection_output_analysis.py | 108 +- tests/tools/test_staged_detection_probe.py | 233 -- tools/measurement/README.md | 1946 ++--------------- tools/measurement/analyze_benchmark_output.py | 16 - .../analyze_staged_detection_output.py | 59 - tools/measurement/compare_strategy_pairs.py | 15 +- tools/measurement/detection_strategies.py | 908 +------- tools/measurement/extract_signature_deltas.py | 23 - tools/measurement/run_benchmarks.py | 139 +- .../screen_strategy_comparisons.py | 32 - tools/measurement/staged_detection_probe.py | 117 +- 20 files changed, 507 insertions(+), 5871 deletions(-) delete mode 100644 src/anonymizer/engine/detection/rules.py delete mode 100644 tests/engine/test_detection_rules.py diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index 153eb50d..c0a34f83 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -49,12 +49,7 @@ parse_detected_entities, prepare_validation_inputs, ) -from anonymizer.engine.detection.postprocess import EntitySpan, build_tagged_text, group_entities_by_value -from anonymizer.engine.detection.rules import ( - STRUCTURED_RULE_FAST_LANE_LABELS, - SUPPORTED_RULE_LABELS, - detect_high_confidence_entities, -) +from anonymizer.engine.detection.postprocess import EntitySpan, group_entities_by_value from anonymizer.engine.ndd.adapter import FailedRecord, NddAdapter from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases from anonymizer.engine.prompt_utils import substitute_placeholders @@ -91,33 +86,6 @@ class EntityDetectionWorkflow: def __init__(self, adapter: NddAdapter) -> None: self._adapter = adapter - def detect_with_high_confidence_rules( - self, - dataframe: pd.DataFrame, - *, - entity_labels: list[str] | None = None, - ) -> EntityDetectionResult: - """Detect only deterministic high-confidence rule spans without DataDesigner. - - This is an internal fast-lane primitive for benchmark probes and - future routing work. It is intentionally limited to labels with narrow - deterministic coverage and does not attempt contextual PII detection. - """ - labels = _resolve_detection_labels(entity_labels) - _ensure_high_confidence_rule_labels(labels) - output = dataframe.copy() - output[COL_DETECTED_ENTITIES] = output[COL_TEXT].apply( - lambda text: _high_confidence_rule_payload(text, labels=labels) - ) - output[COL_TAGGED_TEXT] = output.apply( - lambda row: _tagged_text_from_entities( - text=row.get(COL_TEXT, ""), - raw_entities=row.get(COL_DETECTED_ENTITIES, {}), - ), - axis=1, - ) - return EntityDetectionResult(dataframe=output, failed_records=[]) - def detect_and_validate_entities( self, dataframe: pd.DataFrame, @@ -389,47 +357,6 @@ def _resolve_detection_labels(entity_labels: list[str] | None) -> list[str]: return list(entity_labels) -def labels_are_supported_by_high_confidence_rules(labels: list[str]) -> bool: - """Return True when every label can be handled by deterministic rules.""" - return set(labels).issubset(SUPPORTED_RULE_LABELS) - - -def labels_are_supported_by_structured_rule_fast_lane(labels: list[str]) -> bool: - """Return True when every label is safe for the structured no-model fast lane.""" - return set(labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) - - -def _ensure_high_confidence_rule_labels(labels: list[str]) -> None: - unsupported = sorted(set(labels) - SUPPORTED_RULE_LABELS) - if unsupported: - supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) - raise ValueError( - f"unsupported high-confidence rule labels: {', '.join(unsupported)}; supported labels: {supported}" - ) - - -def _high_confidence_rule_payload(text: object, *, labels: list[str]) -> dict: - spans = detect_high_confidence_entities(str(text), labels=labels) - return EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump(mode="json") - - -def _tagged_text_from_entities(*, text: object, raw_entities: object) -> str: - parsed = EntitiesSchema.from_raw(raw_entities) - spans = [ - EntitySpan( - entity_id=e.id, - value=e.value, - label=e.label, - start_position=e.start_position, - end_position=e.end_position, - score=e.score, - source=e.source, - ) - for e in parsed.entities - ] - return build_tagged_text(text=str(text), entities=spans) - - def _materialize_final_entities(raw: object, *, allowed_labels: set[str] | None) -> dict: """Build COL_FINAL_ENTITIES, optionally filtering to *allowed_labels*.""" parsed = EntitiesSchema.from_raw(raw) diff --git a/src/anonymizer/engine/detection/rules.py b/src/anonymizer/engine/detection/rules.py deleted file mode 100644 index 1b96748b..00000000 --- a/src/anonymizer/engine/detection/rules.py +++ /dev/null @@ -1,274 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import re -from collections.abc import Iterable -from dataclasses import dataclass - -from anonymizer.engine.detection.postprocess import EntitySpan, resolve_overlaps - -_RULE_SCORE = 1.0 -_RULE_SOURCE = "rule" -_RELIGIOUS_BELIEF_TERMS = ( - "agnostic", - "atheist", - "baptist", - "buddhist", - "catholic", - "christian", - "hindu", - "jewish", - "mormon", - "muslim", - "protestant", - "secular", -) -_RELIGIOUS_BELIEF_RE = "|".join(re.escape(term) for term in _RELIGIOUS_BELIEF_TERMS) -_COOKIE_PAIR_RE = r"[A-Za-z][A-Za-z0-9_-]*=[^;'\s\"\r\n]+" -_COOKIE_VALUE_RE = rf"({_COOKIE_PAIR_RE}(?:;\s*{_COOKIE_PAIR_RE})*)" -_STRUCTURED_ID_VALUE_RE = ( - r"(?:[A-Za-z][A-Za-z0-9]{1,20}[-_][A-Za-z0-9][A-Za-z0-9_-]{5,}|" - r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})" -) - - -@dataclass(frozen=True) -class _RulePattern: - label: str - pattern: re.Pattern[str] - group: int = 0 - - -_RULES: tuple[_RulePattern, ...] = ( - _RulePattern( - label="api_key", - pattern=re.compile(r"sk-(?:test|ant-api03|proj|prod)-[A-Za-z0-9_-]{16,}"), - ), - _RulePattern(label="api_key", pattern=re.compile(r"ghp_[A-Za-z0-9_]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"hf_[A-Za-z0-9]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"pat-[A-Za-z0-9_-]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"xoxb-[A-Za-z0-9-]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"AIza[A-Za-z0-9_-]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"ya29\.[A-Za-z0-9_-]{20,}")), - _RulePattern(label="api_key", pattern=re.compile(r"AKIA[A-Z0-9]{16,}")), - _RulePattern( - label="api_key", - pattern=re.compile( - r"\b(?:api[_-]?key|token|auth[_-]?token|session[_-]?id|aws_access_key_id|access_key_id)=" - r"([^\s;'\"\\]{8,})", - flags=re.IGNORECASE, - ), - group=1, - ), - _RulePattern( - label="api_key", - pattern=re.compile(r"Authorization:\s*Bearer\s+([A-Za-z0-9._-]{16,})", flags=re.IGNORECASE), - group=1, - ), - _RulePattern( - label="http_cookie", - pattern=re.compile(rf"\bCookie:\s*{_COOKIE_VALUE_RE}", flags=re.IGNORECASE), - group=1, - ), - _RulePattern( - label="http_cookie", - pattern=re.compile(rf"\bcookie\s*=\s*{_COOKIE_VALUE_RE}", flags=re.IGNORECASE), - group=1, - ), - _RulePattern( - label="pin", - pattern=re.compile(r"(?]+", - flags=re.IGNORECASE, - ), - ), - _RulePattern( - label="email", - pattern=re.compile(r"(?]+")), - _RulePattern( - label="date_of_birth", - pattern=re.compile( - r"\b(?:born|date\s+of\s+birth|dob)\s*(?:[:=-]|\bin\b|\bon\b)?\s*" - r"(\d{4}|\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}-\d{2}-\d{2})\b", - flags=re.IGNORECASE, - ), - group=1, - ), - _RulePattern( - label="religious_belief", - pattern=re.compile( - rf"\b(?:describes?\s+(?:himself|herself|themself|themselves)\s+as|" - rf"identif(?:y|ies)\s+as|raised\s+in\s+the|practicing)\s+" - rf"(?:a|an|the)?\s*({_RELIGIOUS_BELIEF_RE})\b", - flags=re.IGNORECASE, - ), - group=1, - ), - _RulePattern( - label="street_address", - pattern=re.compile( - r"\b(?:lives?\s+at|living\s+at|house\s+on|home\s+on)\s+" - r"([A-Z0-9][A-Za-z0-9.\s-]{1,60}?\b" - r"(?:Street|St\.?|Avenue|Ave\.?|Road|Rd\.?|Drive|Dr\.?|Trail|Boulevard|Blvd\.?|Lane|Ln\.?|Court|Ct\.?))", - ), - group=1, - ), - _RulePattern( - label="organization_name", - pattern=re.compile( - r"\b(?:at|from|with|joining|joined)\s+" - r"([A-Z][A-Za-z0-9&.'\u2019 -]{2,90}?\b" - r"(?:Center|Hospital|Clinic|University|College|Institute|Bank|Builders|Construction|Woodworks|Health))" - r"\b", - ), - group=1, - ), -) - -SUPPORTED_RULE_LABELS = frozenset(rule.label for rule in _RULES) -STRUCTURED_RULE_FAST_LANE_LABELS = frozenset( - { - "api_key", - "email", - "http_cookie", - "password", - "pin", - "unique_id", - "url", - "user_name", - } -) - - -def detect_high_confidence_entities(text: str, labels: Iterable[str] | None = None) -> list[EntitySpan]: - """Detect deterministic high-confidence PII and secret spans in raw text. - - These rules intentionally cover narrow, high-signal command/log and prose - patterns. They are suitable as a local seed detector or benchmark probe, - not as a complete replacement for model-backed contextual detection. - """ - allowed_labels = set(labels) if labels is not None else None - spans: list[EntitySpan] = [] - - for rule in _RULES: - if allowed_labels is not None and rule.label not in allowed_labels: - continue - for match in rule.pattern.finditer(text): - start, end = match.span(rule.group) - if start < 0 or end <= start: - continue - value = text[start:end] - value, end = _trim_rule_value(label=rule.label, value=value, end=end) - if not value: - continue - spans.append( - EntitySpan( - entity_id=_build_rule_entity_id(label=rule.label, start=start, end=end), - value=value, - label=rule.label, - start_position=start, - end_position=end, - score=_RULE_SCORE, - source=_RULE_SOURCE, - ) - ) - - return resolve_overlaps(_deduplicate(spans)) - - -def _trim_rule_value(*, label: str, value: str, end: int) -> tuple[str, int]: - if label != "http_cookie": - return value, end - trimmed = value.rstrip(".,") - return trimmed, end - (len(value) - len(trimmed)) - - -def _deduplicate(entities: list[EntitySpan]) -> list[EntitySpan]: - seen: set[tuple[str, int, int]] = set() - deduplicated: list[EntitySpan] = [] - for entity in entities: - key = (entity.label, entity.start_position, entity.end_position) - if key in seen: - continue - seen.add(key) - deduplicated.append(entity) - return deduplicated - - -def _build_rule_entity_id(*, label: str, start: int, end: int) -> str: - return f"{label}_{start}_{end}" diff --git a/tests/engine/test_detection_rules.py b/tests/engine/test_detection_rules.py deleted file mode 100644 index d7640ed5..00000000 --- a/tests/engine/test_detection_rules.py +++ /dev/null @@ -1,318 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections import Counter -from unittest.mock import Mock - -import pandas as pd -import pytest - -from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TAGGED_TEXT, COL_TEXT -from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow -from anonymizer.engine.detection.rules import ( - STRUCTURED_RULE_FAST_LANE_LABELS, - SUPPORTED_RULE_LABELS, - detect_high_confidence_entities, -) -from anonymizer.engine.schemas import EntitiesSchema - -SHELL_TEXT = """$ curl -H 'Authorization: Bearer sk-test-AAAAAAAAAAAAAAAAAAAAAAAA' https://internal.example.test/api -$ export AWS_ACCESS_KEY_ID=AKIATEST1234567890FAKE -$ export AWS_SECRET_ACCESS_KEY=fakeSecretValue1234567890! -$ docker run -e DATABASE_URL='postgres://app_user:fakeDbPass123!@db.example.test:5432/app' -e API_KEY=ghp_FAKEtoken1234567890abcdef myapp:latest -$ ssh jane.doe@example.test@host-01.example.test -Password: fakeSshPass123! -""" - - -def test_detect_high_confidence_entities_extracts_shell_secret_values() -> None: - entities = detect_high_confidence_entities( - SHELL_TEXT, - labels=["api_key", "password", "email", "url"], - ) - - assert Counter(entity.label for entity in entities) == { - "api_key": 3, - "password": 2, - "email": 1, - "url": 2, - } - values_by_label = {(entity.label, entity.value) for entity in entities} - assert ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA") in values_by_label - assert ("api_key", "AKIATEST1234567890FAKE") in values_by_label - assert ("api_key", "ghp_FAKEtoken1234567890abcdef") in values_by_label - assert ("password", "fakeSecretValue1234567890!") in values_by_label - assert ("password", "fakeSshPass123!") in values_by_label - assert ("email", "jane.doe@example.test") in values_by_label - assert ("url", "https://internal.example.test/api") in values_by_label - assert ("url", "postgres://app_user:fakeDbPass123!@db.example.test:5432/app") in values_by_label - - values = [entity.value for entity in entities] - assert all(not value.startswith(("Authorization", "Bearer", "API_KEY=", "Password:")) for value in values) - - -def test_detect_high_confidence_entities_extracts_email_before_sentence_punctuation() -> None: - entities = detect_high_confidence_entities( - "Email alice@example.com. Then contact bob@example.co.uk, if needed.", - labels=["email"], - ) - - assert [entity.value for entity in entities] == ["alice@example.com", "bob@example.co.uk"] - - -def test_detect_high_confidence_entities_excludes_config_url_separators() -> None: - text = ( - "DATABASE_URL=postgres://svc_user:DbSecretPass2026!@db.example.test:5432/app; " - "endpoint: https://internal.example.test/admin;" - ) - - entities = detect_high_confidence_entities(text, labels=["url"]) - - assert [entity.value for entity in entities] == [ - "postgres://svc_user:DbSecretPass2026!@db.example.test:5432/app", - "https://internal.example.test/admin", - ] - - -def test_supported_rule_labels_match_detected_label_families() -> None: - assert SUPPORTED_RULE_LABELS == { - "api_key", - "date_of_birth", - "email", - "http_cookie", - "organization_name", - "password", - "pin", - "religious_belief", - "street_address", - "unique_id", - "url", - "user_name", - } - - -def test_structured_rule_fast_lane_excludes_narrow_prose_labels() -> None: - assert STRUCTURED_RULE_FAST_LANE_LABELS == { - "api_key", - "email", - "http_cookie", - "password", - "pin", - "unique_id", - "url", - "user_name", - } - assert {"date_of_birth", "organization_name", "religious_belief", "street_address"}.isdisjoint( - STRUCTURED_RULE_FAST_LANE_LABELS - ) - - -def test_detect_high_confidence_entities_respects_label_filter() -> None: - entities = detect_high_confidence_entities(SHELL_TEXT, labels=["password"]) - - assert Counter(entity.label for entity in entities) == {"password": 3} - assert {entity.value for entity in entities} == { - "fakeSecretValue1234567890!", - "fakeDbPass123!", - "fakeSshPass123!", - } - - -def test_detect_high_confidence_entities_extracts_sudo_stdin_password() -> None: - text = '$ echo "P@ssw0rd-local-2026!" | sudo -S systemctl restart nginx' - - entities = detect_high_confidence_entities(text, labels=["password"]) - - assert [(entity.label, entity.value) for entity in entities] == [("password", "P@ssw0rd-local-2026!")] - - -def test_detect_high_confidence_entities_does_not_treat_generic_echo_as_password() -> None: - text = '$ echo "P@ssw0rd-local-2026!" | grep local' - - assert detect_high_confidence_entities(text, labels=["password"]) == [] - - -def test_detect_high_confidence_entities_does_not_emit_secret_false_positives_for_prose() -> None: - prose = ( - "Alice Johnson filed Case No. 2025-CV-12345 in Superior Court. " - "The opinion cites Section 10(b), Exhibit A-17, and docket trace order_390974. " - "A biography says Jordan Patel joined NVIDIA in 2021 and later moved to Seattle." - ) - - entities = detect_high_confidence_entities(prose, labels=["api_key", "password", "email", "url"]) - - assert entities == [] - - -def test_detect_high_confidence_entities_extracts_contextual_date_of_birth() -> None: - text = "The applicant was born in 1978 and later moved to Berlin. Another report cites 2024." - - entities = detect_high_confidence_entities(text, labels=["date_of_birth"]) - - assert [(entity.label, entity.value) for entity in entities] == [("date_of_birth", "1978")] - - -def test_detect_high_confidence_entities_ignores_standalone_year_for_date_of_birth() -> None: - text = "The report cites filings from 1978, 2021, and 2024." - - assert detect_high_confidence_entities(text, labels=["date_of_birth"]) == [] - - -def test_detect_high_confidence_entities_extracts_narrow_prose_patterns() -> None: - text = ( - "After graduation he spent three years at NASA's Goddard Space Flight Center before joining a lab. " - "Idilio describes himself as secular and leans progressive on most political issues. " - "Outside the lab, Idilio shares a modest house on West Roberts Drive with his wife." - ) - - entities = detect_high_confidence_entities( - text, - labels=["organization_name", "religious_belief", "street_address"], - ) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("organization_name", "NASA's Goddard Space Flight Center"), - ("religious_belief", "secular"), - ("street_address", "West Roberts Drive"), - ] - - -def test_detect_high_confidence_entities_avoids_generic_prose_belief_false_positive() -> None: - text = "Jordan describes himself as careful and later worked at a local lab near Roberts Drive." - - assert ( - detect_high_confidence_entities( - text, - labels=["organization_name", "religious_belief", "street_address"], - ) - == [] - ) - - -def test_detect_high_confidence_entities_returns_sorted_non_overlapping_spans() -> None: - entities = detect_high_confidence_entities( - "token=sk-test-BBBBBBBBBBBBBBBBBBBBBBBB and Auth: ignored\nPassword: fakePass123!", - labels=["api_key", "password"], - ) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "sk-test-BBBBBBBBBBBBBBBBBBBBBBBB"), - ("password", "fakePass123!"), - ] - assert entities[0].end_position < entities[1].start_position - - -def test_detect_high_confidence_entities_extracts_session_id_assignments() -> None: - text = "Cookie: session_id=abc123xyz; auth_token=xoxb-STRUCTURED-Slack-Token-000000" - - entities = detect_high_confidence_entities(text, labels=["api_key"]) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "abc123xyz"), - ("api_key", "xoxb-STRUCTURED-Slack-Token-000000"), - ] - - -def test_detect_high_confidence_entities_extracts_structured_identifier_labels() -> None: - text = ( - "POST /audit HTTP/1.1\n" - "Cookie: session_id=abc123xyz; user_id=26762; auth_token=token-abcdef\n" - "trace-id: req_KA5k78XNwT0yUNZkPpwq\n" - "pin=97294\n" - "user_name=sloanenguy217\n" - ) - - entities = detect_high_confidence_entities( - text, - labels=["http_cookie", "pin", "unique_id", "user_name"], - ) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("http_cookie", "session_id=abc123xyz; user_id=26762; auth_token=token-abcdef"), - ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), - ("pin", "97294"), - ("user_name", "sloanenguy217"), - ] - - -def test_detect_high_confidence_entities_extracts_quoted_structured_identifier_keys() -> None: - text = '{"user": "avery_khan", "pin": "4921", "callback": "https://internal.example.test/admin"}' - - entities = detect_high_confidence_entities(text, labels=["pin", "url", "user_name"]) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("user_name", "avery_khan"), - ("pin", "4921"), - ("url", "https://internal.example.test/admin"), - ] - - -def test_detect_high_confidence_entities_excludes_cookie_sentence_punctuation() -> None: - text = "Cookie: session_id=abc123xyz; auth_token=token-abcdef. Recovery flow starts." - - entities = detect_high_confidence_entities(text, labels=["http_cookie"]) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("http_cookie", "session_id=abc123xyz; auth_token=token-abcdef"), - ] - - -def test_detect_high_confidence_entities_extracts_service_principal_user_and_tenant_id() -> None: - text = "$ az login --service-principal -u skylerlee985 -p fakePass123! --tenant trace-1b7278d77a73ef4e" - - entities = detect_high_confidence_entities(text, labels=["user_name", "unique_id"]) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("user_name", "skylerlee985"), - ("unique_id", "trace-1b7278d77a73ef4e"), - ] - - -def test_detect_high_confidence_entities_extracts_audit_user_and_trace_id() -> None: - text = "Audit record: user skylerlee985 opened session with trace-id req_KA5k78XNwT0yUNZkPpwq." - - entities = detect_high_confidence_entities(text, labels=["user_name", "unique_id"]) - - assert [(entity.label, entity.value) for entity in entities] == [ - ("user_name", "skylerlee985"), - ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), - ] - - -def test_detect_high_confidence_entities_does_not_extract_structured_identifiers_from_generic_prose() -> None: - text = "The order_390974 filing mentions user research, cookie recipes, and a five digit docket page." - - assert detect_high_confidence_entities(text, labels=["http_cookie", "pin", "unique_id", "user_name"]) == [] - - -def test_workflow_can_detect_with_high_confidence_rules_without_adapter_calls() -> None: - adapter = Mock() - workflow = EntityDetectionWorkflow(adapter=adapter) - - result = workflow.detect_with_high_confidence_rules( - pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), - entity_labels=["api_key", "password"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), - ("password", "fakePass123!"), - ] - tagged_text = result.dataframe[COL_TAGGED_TEXT].iloc[0] - assert "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" in tagged_text - assert "fakePass123!" in tagged_text - assert result.failed_records == [] - - -def test_workflow_rule_detection_rejects_unsupported_labels() -> None: - workflow = EntityDetectionWorkflow(adapter=Mock()) - - with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): - workflow.detect_with_high_confidence_rules( - pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), - entity_labels=["api_key", "person"], - ) diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index 9f309c73..76edb177 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -199,7 +199,7 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp }, { "record_type": "record", - "run_id": "shell__rules-only__r000", + "run_id": "shell__native-local__r000", "text_length_tokens": 750, "final_entity_count": 8, "replacement_count": 8, @@ -214,12 +214,12 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp "run_tags": { "suite_id": "suite", "workload_id": "shell", - "config_id": "rules-only", - "experimental_detection_strategy": "rules_only", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", "experimental_replacement_strategy": "local_structured_substitute", "dd_parser_compat": "raw_json", "repetition": 0, - "case_id": "shell__rules-only__r000", + "case_id": "shell__native-local__r000", }, }, ], @@ -259,23 +259,23 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp { "suite_id": "suite", "workload_id": "shell", - "config_id": "rules-only", + "config_id": "native-local", "repetition": 0, - "case_id": "shell__rules-only__r000", - "run_id": "shell__rules-only__r000", - "workflow_name": "rules-only", + "case_id": "shell__native-local__r000", + "run_id": "shell__native-local__r000", + "workflow_name": "native-single-pass", "seed_entity_count": 8, "seed_validation_candidate_count": 0, "augmented_entity_count": 0, "augmented_new_final_value_count": 0, "final_entity_count": 8, - "final_source_counts": {"rule": 8}, + "final_source_counts": {"augmenter": 8}, "final_entity_signature_hashes": ["shell-hash-a"], "final_entity_signature_labels": {"shell-hash-a": "api_key"}, "final_entity_signature_details": { "shell-hash-a": { "label": "api_key", - "source": "rule", + "source": "native", "row_index": 0, "start_position": 12, "end_position": 32, @@ -422,25 +422,25 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp } } assert cases["bio__default__r000"].artifact_final_entity_signature_count == 2 - assert cases["shell__rules-only__r000"].observed_total_requests == 0 - assert cases["shell__rules-only__r000"].experimental_replacement_strategy == "local_structured_substitute" - assert cases["shell__rules-only__r000"].observed_failed_request_rate is None - assert cases["shell__rules-only__r000"].observed_bridge_fallback_requests is None - assert cases["shell__rules-only__r000"].observed_non_bridge_failed_requests is None - assert cases["shell__rules-only__r000"].final_entity_count == 8 - assert cases["shell__rules-only__r000"].replacement_missing_final_entity_count == 0 - assert cases["shell__rules-only__r000"].replacement_missing_final_entity_label_counts == {} - assert cases["shell__rules-only__r000"].replacement_missing_final_value_count == 0 - assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_count == 0 - assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_label_counts == {} - assert cases["shell__rules-only__r000"].replacement_synthetic_original_collision_value_count == 0 - assert cases["shell__rules-only__r000"].original_value_leak_count == 1 - assert cases["shell__rules-only__r000"].original_value_leak_record_count == 1 - assert cases["shell__rules-only__r000"].original_value_leak_label_counts == {"api_key": 1} - assert cases["shell__rules-only__r000"].artifact_final_rule_entity_count == 8 - assert cases["shell__rules-only__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] - assert cases["shell__rules-only__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} - assert cases["shell__rules-only__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "rule" + assert cases["shell__native-local__r000"].observed_total_requests == 0 + assert cases["shell__native-local__r000"].experimental_replacement_strategy == "local_structured_substitute" + assert cases["shell__native-local__r000"].observed_failed_request_rate is None + assert cases["shell__native-local__r000"].observed_bridge_fallback_requests is None + assert cases["shell__native-local__r000"].observed_non_bridge_failed_requests is None + assert cases["shell__native-local__r000"].final_entity_count == 8 + assert cases["shell__native-local__r000"].replacement_missing_final_entity_count == 0 + assert cases["shell__native-local__r000"].replacement_missing_final_entity_label_counts == {} + assert cases["shell__native-local__r000"].replacement_missing_final_value_count == 0 + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_count == 0 + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_label_counts == {} + assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_value_count == 0 + assert cases["shell__native-local__r000"].original_value_leak_count == 1 + assert cases["shell__native-local__r000"].original_value_leak_record_count == 1 + assert cases["shell__native-local__r000"].original_value_leak_label_counts == {"api_key": 1} + assert cases["shell__native-local__r000"].artifact_final_augmenter_entity_count == 8 + assert cases["shell__native-local__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] + assert cases["shell__native-local__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} + assert cases["shell__native-local__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" model_rows = {row.model_name: row for row in result.model_usage} assert model_rows["nvidia/gliner-pii"].observed_failed_requests == 0 assert model_rows["nvidia/gliner-pii"].observed_failed_request_rate == 0 @@ -521,7 +521,7 @@ def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path { "record_type": "model_workflow", "run_id": "bio__native__r000", - "workflow_name": "entity-detection-native-rules-router", + "workflow_name": "entity-detection-native-single-pass", "elapsed_sec": 0.25, "observed_total_requests": 3, "observed_successful_requests": 3, @@ -550,7 +550,7 @@ def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path "suite_id": "suite", "workload_id": "bio", "config_id": "native", - "experimental_detection_strategy": "native_rules_router", + "experimental_detection_strategy": "native_single_pass", "experimental_replacement_strategy": "default", "dd_parser_compat": "raw_json", "repetition": 0, @@ -568,7 +568,7 @@ def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path "suite_id": "suite", "workload_id": "bio", "config_id": "native", - "experimental_detection_strategy": "native_rules_router", + "experimental_detection_strategy": "native_single_pass", "experimental_replacement_strategy": "default", "dd_parser_compat": "raw_json", "repetition": 0, @@ -587,7 +587,7 @@ def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path assert case.observed_failed_request_rate == 0 assert result.model_usage_count == 1 model_row = result.model_usage[0] - assert model_row.workflow_name == "entity-detection-native-rules-router" + assert model_row.workflow_name == "entity-detection-native-single-pass" assert model_row.model_alias == "native-direct" assert model_row.model_name == "nvidia/nemotron-3-super" assert model_row.observed_total_tokens == 42 @@ -690,21 +690,21 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> tool.CaseAnalysisRow( suite_id="suite", workload_id="shell", - config_id="rules", - experimental_detection_strategy="rules_only", + config_id="native", + experimental_detection_strategy="native_single_pass", experimental_replacement_strategy="local_structured_substitute", dd_parser_compat="raw_json", repetition=0, - case_id="shell__rules__r000", - run_id="shell__rules__r000", + case_id="shell__native__r000", + run_id="shell__native__r000", final_entity_count=8, ) ], groups=[ tool.GroupAnalysisRow( workload_id="shell", - config_id="rules", - experimental_detection_strategy="rules_only", + config_id="native", + experimental_detection_strategy="native_single_pass", experimental_replacement_strategy="local_structured_substitute", case_count=1, median_final_entity_count=8, @@ -713,18 +713,17 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> median_observed_output_tokens=0, median_observed_failed_request_rate=0, median_artifact_final_entity_count=8, - median_artifact_final_rule_entity_count=8, ) ], model_usage=[ tool.ModelUsageAnalysisRow( workload_id="shell", - config_id="rules", - experimental_detection_strategy="rules_only", + config_id="native", + experimental_detection_strategy="native_single_pass", experimental_replacement_strategy="local_structured_substitute", dd_parser_compat="raw_json", - case_id="shell__rules__r000", - run_id="shell__rules__r000", + case_id="shell__native__r000", + run_id="shell__native__r000", workflow_name="entity-detection", model_name="nvidia/gliner-pii", observed_total_requests=1, @@ -735,8 +734,8 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> model_usage_groups=[ tool.ModelUsageGroupAnalysisRow( workload_id="shell", - config_id="rules", - experimental_detection_strategy="rules_only", + config_id="native", + experimental_detection_strategy="native_single_pass", experimental_replacement_strategy="local_structured_substitute", dd_parser_compat="raw_json", workflow_name="entity-detection", @@ -755,7 +754,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> output_dir = tmp_path / "tables" tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) - assert pd.read_csv(output_dir / "case_analysis.csv")["case_id"].tolist() == ["shell__rules__r000"] + assert pd.read_csv(output_dir / "case_analysis.csv")["case_id"].tolist() == ["shell__native__r000"] assert pd.read_csv(output_dir / "case_analysis.csv")["experimental_replacement_strategy"].tolist() == [ "local_structured_substitute" ] @@ -813,7 +812,7 @@ def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_p "run_tags": { "workload_id": "secrets", "config_id": "candidate", - "experimental_detection_strategy": "rules_covered_or_default", + "experimental_detection_strategy": "native_single_pass", "experimental_replacement_strategy": "default", "case_id": "secrets__candidate__r000", }, @@ -825,7 +824,7 @@ def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_p "run_tags": { "workload_id": "secrets", "config_id": "candidate", - "experimental_detection_strategy": "rules_covered_or_default", + "experimental_detection_strategy": "native_single_pass", "experimental_replacement_strategy": "local_structured_substitute", "case_id": "secrets__candidate__r001", }, @@ -842,57 +841,6 @@ def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_p } -def test_analyze_benchmark_output_surfaces_route_counts(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_output_analysis_route_counts", - REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", - ) - benchmark_dir = tmp_path / "benchmark" - benchmark_dir.mkdir() - _write_jsonl( - benchmark_dir / "measurements.jsonl", - [ - { - "record_type": "model_workflow", - "run_id": "mixed__router__r000", - "workflow_name": "entity-detection-rules-covered-router", - "status": "completed", - "input_row_count": 2, - "output_row_count": 2, - "failed_record_count": 0, - "elapsed_sec": 0.01, - "observed_total_requests": 0, - "observed_successful_requests": 0, - "observed_failed_requests": 0, - "observed_input_tokens": 0, - "observed_output_tokens": 0, - "observed_total_tokens": 0, - "route_total_row_count": 2, - "route_rule_row_count": 1, - "route_fallback_row_count": 1, - "run_tags": { - "workload_id": "mixed", - "config_id": "router", - "experimental_detection_strategy": "rules_covered_or_default", - "experimental_replacement_strategy": "default", - "case_id": "mixed__router__r000", - }, - } - ], - ) - - result = tool.analyze_benchmark_output(benchmark_dir) - - case = result.cases[0] - assert case.route_total_row_count == 2 - assert case.route_rule_row_count == 1 - assert case.route_fallback_row_count == 1 - group = result.groups[0] - assert group.median_route_total_row_count == 2 - assert group.median_route_rule_row_count == 1 - assert group.median_route_fallback_row_count == 1 - - def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: tool = load_tool( "measurement_benchmark_output_analysis_failures", @@ -912,7 +860,7 @@ def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: "run_tags": { "workload_id": "shell", "config_id": "candidate", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "experimental_detection_strategy": "detector_only", "repetition": 0, "case_id": "shell__candidate__r000", }, @@ -926,7 +874,7 @@ def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: "run_tags": { "workload_id": "shell", "config_id": "candidate", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "experimental_detection_strategy": "detector_only", "repetition": 1, "case_id": "shell__candidate__r001", }, @@ -940,7 +888,7 @@ def test_analyze_benchmark_output_surfaces_failed_cases(tmp_path: Path) -> None: "run_tags": { "workload_id": "shell", "config_id": "candidate", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "experimental_detection_strategy": "detector_only", "repetition": 1, "case_id": "shell__candidate__r001", }, diff --git a/tests/tools/test_compare_strategy_pairs.py b/tests/tools/test_compare_strategy_pairs.py index be9f455a..ad0ab436 100644 --- a/tests/tools/test_compare_strategy_pairs.py +++ b/tests/tools/test_compare_strategy_pairs.py @@ -46,8 +46,8 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N }, { "workload_id": "shell-1", - "config_id": "shell-filter", - "experimental_detection_strategy": "rules_filter_guardrail_no_augment", + "config_id": "shell-native", + "experimental_detection_strategy": "native_candidate_validate_no_augment", "experimental_replacement_strategy": "local_structured_substitute", "case_id": "shell__candidate", "pipeline_elapsed_sec": 0.8, @@ -58,7 +58,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N "seed_validation_candidate_count": 0, "augmented_entity_count": 0, "augmented_new_final_value_count": 0, - "artifact_final_rule_entity_count": 8, + "artifact_final_augmenter_entity_count": 8, }, { "workload_id": "legal-1", @@ -76,8 +76,8 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N }, { "workload_id": "legal-1", - "config_id": "legal-filter", - "experimental_detection_strategy": "rules_filter_guardrail_no_augment", + "config_id": "legal-native", + "experimental_detection_strategy": "native_candidate_validate_no_augment", "experimental_replacement_strategy": "local_structured_substitute", "case_id": "legal__candidate", "pipeline_elapsed_sec": 20.9, @@ -94,7 +94,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N rows = tool.compare_case_analysis( table, baseline_strategy="no_augment", - candidate_strategy="rules_filter_guardrail_no_augment", + candidate_strategy="native_candidate_validate_no_augment", ) by_workload = {row.workload_id: row for row in rows} @@ -111,11 +111,11 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N assert shell.baseline_augmented_new_final_value_count == 1 assert shell.candidate_augmented_new_final_value_count == 0 assert shell.augmented_new_final_value_count_delta == -1 - assert shell.candidate_rule_entity_count == 8 + assert shell.candidate_augmenter_entity_count == 8 assert shell.safety_verdict == "review" assert shell.performance_verdict == "improved" assert shell.candidate_verdict == "review" - assert shell.flags == ["no_candidate_detector_entities", "candidate_uses_rule_entities"] + assert shell.flags == ["no_candidate_detector_entities"] legal = by_workload["legal-1"] assert legal.baseline_replacement_strategy == "default" @@ -151,14 +151,14 @@ def test_compare_case_analysis_rejects_ambiguous_strategy_selector() -> None: { "workload_id": "shell-1", "config_id": "candidate", - "experimental_detection_strategy": "rules_only", + "experimental_detection_strategy": "detector_only", "case_id": "c", }, ] ) with pytest.raises(ValueError, match="baseline selector matched multiple configs"): - tool.compare_case_analysis(table, baseline_strategy="no_augment", candidate_strategy="rules_only") + tool.compare_case_analysis(table, baseline_strategy="no_augment", candidate_strategy="detector_only") def test_compare_case_analysis_rejects_candidate_synthetic_original_collisions() -> None: @@ -279,8 +279,8 @@ def test_compare_case_tables_allows_candidate_from_separate_run() -> None: [ { "workload_id": "legal-5", - "config_id": "legal-rules-guardrail", - "experimental_detection_strategy": "rules_guardrail_no_augment", + "config_id": "legal-native-validate", + "experimental_detection_strategy": "detector_native_validate_no_augment", "case_id": "legal__candidate", "observed_total_tokens": 55805, "final_entity_count": 193, @@ -293,13 +293,13 @@ def test_compare_case_tables_allows_candidate_from_separate_run() -> None: baseline, candidate, baseline_strategy="no_augment", - candidate_strategy="rules_guardrail_no_augment", + candidate_strategy="detector_native_validate_no_augment", ) assert len(rows) == 1 assert rows[0].workload_id == "legal-5" assert rows[0].baseline_config_id == "legal-no-augment" - assert rows[0].candidate_config_id == "legal-rules-guardrail" + assert rows[0].candidate_config_id == "legal-native-validate" assert rows[0].observed_total_tokens_delta == 15 assert rows[0].flags == ["token_increase"] @@ -348,7 +348,7 @@ def test_compare_case_analysis_preserves_augmentation_contribution_deltas() -> N assert row.candidate_augmenter_entity_count == 0 -def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: +def test_compare_case_analysis_review_gates_detector_only_candidate_shell_case() -> None: tool = load_tool( "measurement_compare_strategy_pairs_detector_only", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", @@ -392,9 +392,9 @@ def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: assert row.flags == ["candidate_skips_llm_validation"] -def test_compare_case_analysis_review_gates_rule_detector_only_candidates() -> None: +def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: tool = load_tool( - "measurement_compare_strategy_pairs_rule_detector_only", + "measurement_compare_strategy_pairs_detector_only", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", ) table = pd.DataFrame( @@ -413,15 +413,14 @@ def test_compare_case_analysis_review_gates_rule_detector_only_candidates() -> N }, { "workload_id": "shell-1", - "config_id": "rule-detector-only", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "config_id": "detector-only", + "experimental_detection_strategy": "detector_only", "case_id": "shell__candidate", "pipeline_elapsed_sec": 1, "observed_total_requests": 1, "observed_total_tokens": 200, "final_entity_count": 2, - "artifact_final_detector_entity_count": 1, - "artifact_final_rule_entity_count": 1, + "artifact_final_detector_entity_count": 2, "artifact_final_entity_signature_hashes": ["a", "b"], }, ] @@ -430,7 +429,7 @@ def test_compare_case_analysis_review_gates_rule_detector_only_candidates() -> N rows = tool.compare_case_analysis( table, baseline_strategy="default", - candidate_strategy="rules_guardrail_detector_only", + candidate_strategy="detector_only", ) assert len(rows) == 1 @@ -438,19 +437,19 @@ def test_compare_case_analysis_review_gates_rule_detector_only_candidates() -> N assert row.safety_verdict == "review" assert row.performance_verdict == "improved" assert row.candidate_verdict == "review" - assert row.flags == ["candidate_uses_rule_entities", "candidate_skips_llm_validation"] + assert row.flags == ["candidate_skips_llm_validation"] -def test_compare_case_analysis_review_gates_rules_covered_or_default_when_signatures_match() -> None: +def test_compare_case_analysis_review_gates_non_detector_sources_when_signatures_match() -> None: tool = load_tool( - "measurement_compare_strategy_pairs_rules_covered_or_default", + "measurement_compare_strategy_pairs_non_detector_sources", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", ) table = pd.DataFrame( [ { "workload_id": "shell-1", - "config_id": "rule-labels-default", + "config_id": "native-source-default", "experimental_detection_strategy": "default", "case_id": "shell__default", "pipeline_elapsed_sec": 21.4, @@ -463,14 +462,14 @@ def test_compare_case_analysis_review_gates_rules_covered_or_default_when_signat }, { "workload_id": "shell-1", - "config_id": "rule-labels-covered-or-default", - "experimental_detection_strategy": "rules_covered_or_default", + "config_id": "native-source-candidate", + "experimental_detection_strategy": "native_single_pass", "case_id": "shell__candidate", "pipeline_elapsed_sec": 0.001, "observed_total_requests": 0, "observed_total_tokens": 0, "final_entity_count": 2, - "artifact_final_rule_entity_count": 2, + "artifact_final_augmenter_entity_count": 2, "artifact_final_entity_signature_hashes": ["a", "b"], "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, }, @@ -479,8 +478,8 @@ def test_compare_case_analysis_review_gates_rules_covered_or_default_when_signat rows = tool.compare_case_analysis( table, - baseline_config="rule-labels-default", - candidate_config="rule-labels-covered-or-default", + baseline_config="native-source-default", + candidate_config="native-source-candidate", ) assert len(rows) == 1 @@ -493,7 +492,7 @@ def test_compare_case_analysis_review_gates_rules_covered_or_default_when_signat assert row.safety_verdict == "review" assert row.performance_verdict == "improved" assert row.candidate_verdict == "review" - assert row.flags == ["no_candidate_detector_entities", "candidate_uses_rule_entities"] + assert row.flags == ["no_candidate_detector_entities"] def test_compare_case_analysis_flags_signature_loss_even_when_counts_match() -> None: @@ -587,20 +586,20 @@ def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> }, { "workload_id": "structured-identifiers", - "config_id": "rules-local", - "experimental_detection_strategy": "rules_covered_or_default", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", "case_id": "candidate-r0", "pipeline_elapsed_sec": 0.01, "observed_total_requests": 0, "observed_total_tokens": 0, "final_entity_count": 2, - "artifact_final_rule_entity_count": 2, + "artifact_final_augmenter_entity_count": 2, "artifact_final_entity_signature_hashes": ["cookie", "pin"], "artifact_final_entity_signature_labels": {"cookie": "http_cookie", "pin": "pin"}, "artifact_final_entity_signature_details": { "cookie": { "label": "http_cookie", - "source": "rule", + "source": "native", "row_index": 0, "start_position": 30, "end_position": 80, @@ -609,7 +608,7 @@ def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> }, "pin": { "label": "pin", - "source": "rule", + "source": "native", "row_index": 0, "start_position": 90, "end_position": 95, @@ -621,7 +620,7 @@ def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> ] ) - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") assert len(rows) == 1 row = rows[0] @@ -744,20 +743,20 @@ def test_compare_case_analysis_treats_high_overlap_candidate_span_as_covered() - }, { "workload_id": "structured-identifiers", - "config_id": "rules-local", - "experimental_detection_strategy": "rules_covered_or_default", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", "case_id": "candidate-r0", "pipeline_elapsed_sec": 0.01, "observed_total_requests": 0, "observed_total_tokens": 0, "final_entity_count": 1, - "artifact_final_rule_entity_count": 1, + "artifact_final_augmenter_entity_count": 1, "artifact_final_entity_signature_hashes": ["token-value"], "artifact_final_entity_signature_labels": {"token-value": "api_key"}, "artifact_final_entity_signature_details": { "token-value": { "label": "api_key", - "source": "rule", + "source": "native", "row_index": 0, "start_position": 26, "end_position": 69, @@ -769,7 +768,7 @@ def test_compare_case_analysis_treats_high_overlap_candidate_span_as_covered() - ] ) - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") assert len(rows) == 1 row = rows[0] @@ -820,20 +819,20 @@ def test_compare_case_analysis_treats_small_assignment_prefix_gap_as_boundary_ov }, { "workload_id": "structured-identifiers", - "config_id": "rules-local", - "experimental_detection_strategy": "rules_covered_or_default", + "config_id": "native-local", + "experimental_detection_strategy": "native_single_pass", "case_id": "candidate-r0", "pipeline_elapsed_sec": 0.01, "observed_total_requests": 0, "observed_total_tokens": 0, "final_entity_count": 1, - "artifact_final_rule_entity_count": 1, + "artifact_final_augmenter_entity_count": 1, "artifact_final_entity_signature_hashes": ["login-value"], "artifact_final_entity_signature_labels": {"login-value": "user_name"}, "artifact_final_entity_signature_details": { "login-value": { "label": "user_name", - "source": "rule", + "source": "native", "row_index": 0, "start_position": 26, "end_position": 38, @@ -845,7 +844,7 @@ def test_compare_case_analysis_treats_small_assignment_prefix_gap_as_boundary_ov ] ) - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="rules-local") + rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") assert len(rows) == 1 row = rows[0] @@ -957,7 +956,7 @@ def test_compare_case_analysis_rejects_candidate_original_value_leaks() -> None: { "workload_id": "structured-secrets", "config_id": "candidate", - "experimental_detection_strategy": "rules_covered_or_default", + "experimental_detection_strategy": "native_single_pass", "case_id": "structured__candidate", "pipeline_elapsed_sec": 1, "observed_total_tokens": 0, @@ -1155,7 +1154,7 @@ def test_compare_case_analysis_rejects_candidate_case_failures() -> None: { "workload_id": "shell-5", "config_id": "candidate", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "experimental_detection_strategy": "detector_only", "case_id": "candidate-r0", "pipeline_elapsed_sec": 2, "case_failed": False, @@ -1164,7 +1163,7 @@ def test_compare_case_analysis_rejects_candidate_case_failures() -> None: { "workload_id": "shell-5", "config_id": "candidate", - "experimental_detection_strategy": "rules_guardrail_detector_only", + "experimental_detection_strategy": "detector_only", "case_id": "candidate-r1", "pipeline_elapsed_sec": 0.2, "case_failed": True, @@ -1386,7 +1385,7 @@ def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: baseline_final_entity_count=4, candidate_final_entity_count=8, final_entity_count_delta=4, - flags=["candidate_uses_rule_entities"], + flags=["candidate_skips_llm_validation"], ) ] @@ -1397,7 +1396,7 @@ def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: assert exported["workload_id"].tolist() == ["shell-1"] assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] assert exported["final_entity_count_delta"].tolist() == [4] - assert exported["flags"].tolist() == ['["candidate_uses_rule_entities"]'] + assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] def test_compare_strategy_pairs_summarizes_candidate_verdicts() -> None: @@ -1433,7 +1432,7 @@ def test_compare_strategy_pairs_summarizes_candidate_verdicts() -> None: tool.ComparisonRow( workload_id="shell-1", baseline_config_id="shell-default", - candidate_config_id="shell-rules", + candidate_config_id="shell-detector-only", baseline_case_count=1, candidate_case_count=1, value_protection_verdict="review", diff --git a/tests/tools/test_detection_strategies.py b/tests/tools/test_detection_strategies.py index 984bb025..5ef9affd 100644 --- a/tests/tools/test_detection_strategies.py +++ b/tests/tools/test_detection_strategies.py @@ -13,24 +13,15 @@ from unittest.mock import Mock import pandas as pd -import pytest from anonymizer.engine.constants import ( - COL_AUGMENTED_ENTITIES, COL_DETECTED_ENTITIES, - COL_FINAL_ENTITIES, - COL_INITIAL_TAGGED_TEXT, - COL_MERGED_ENTITIES, COL_RAW_DETECTED, COL_SEED_ENTITIES, COL_SEED_ENTITIES_JSON, - COL_SEED_VALIDATION_CANDIDATES, COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, - COL_VALIDATED_ENTITIES, - COL_VALIDATED_SEED_ENTITIES, - COL_VALIDATION_DECISIONS, ) from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow from anonymizer.engine.ndd.model_loader import load_default_model_selection @@ -51,280 +42,6 @@ def load_tool(module_name: str, path: Path) -> ModuleType: return module -def test_rules_only_strategy_detects_rule_spans_and_restores_workflow_method() -> None: - tool = load_tool("measurement_detection_strategies", REPO_ROOT / "tools/measurement/detection_strategies.py") - original = EntityDetectionWorkflow.detect_and_validate_entities - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_only): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "password"], - ) - - assert EntityDetectionWorkflow.detect_and_validate_entities is original - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), - ("password", "fakePass123!"), - ] - - -def test_rules_only_strategy_rejects_unsupported_labels_at_runtime() -> None: - tool = load_tool( - "measurement_detection_strategies_runtime_guard", REPO_ROOT / "tools/measurement/detection_strategies.py" - ) - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_only): - workflow = EntityDetectionWorkflow(adapter=Mock()) - with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "person"], - ) - - -def test_rules_covered_or_default_short_circuits_structured_fast_lane_labels() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_covered_short_circuit", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "password"], - ) - - adapter.run_workflow.assert_not_called() - assert "_anonymizer_row_order" not in result.dataframe.columns - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), - ("password", "fakePass123!"), - ] - - -def test_rules_covered_or_default_falls_back_for_uncovered_structured_assignments() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_covered_row_fallback", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - calls: list[pd.DataFrame] = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(dataframe.copy()) - output = dataframe.copy() - output[COL_DETECTED_ENTITIES] = [ - EntitiesSchema(entities=[]).model_dump(mode="json") for _ in range(len(output)) - ] - output[COL_TAGGED_TEXT] = output[COL_TEXT] - return SimpleNamespace(dataframe=output, failed_records=[]) - - rows = [ - "token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", - '{"password": "SecretNoRule123!"}', - ] - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: rows}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "password"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert len(calls) == 1 - assert calls[0][COL_TEXT].tolist() == ['{"password": "SecretNoRule123!"}'] - assert "_anonymizer_row_order" not in result.dataframe.columns - assert result.dataframe[COL_TEXT].tolist() == rows - first_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - second_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[1]).entities - assert [(entity.label, entity.value) for entity in first_entities] == [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA") - ] - assert second_entities == [] - - -def test_rules_covered_or_default_records_route_counts() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_covered_route_counts", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - output = dataframe.copy() - output[COL_DETECTED_ENTITIES] = [ - EntitiesSchema(entities=[]).model_dump(mode="json") for _ in range(len(output)) - ] - output[COL_TAGGED_TEXT] = output[COL_TEXT] - return SimpleNamespace(dataframe=output, failed_records=[]) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - collector = MeasurementCollector(record_hash_key="test-key") - try: - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.rules_covered_or_default - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame( - { - COL_TEXT: [ - "token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", - '{"password": "SecretNoRule123!"}', - ] - } - ), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "password"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-rules-covered-router" - assert record["route_total_row_count"] == 2 - assert record["route_rule_row_count"] == 1 - assert record["route_fallback_row_count"] == 1 - assert record["observed_total_requests"] == 0 - assert record["observed_total_tokens"] == 0 - - -def test_native_rules_router_strategy_runs_staged_detection_without_data_designer() -> None: - tool = load_tool( - "measurement_detection_strategies_native_rules_router", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SequencedClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - self.outputs = [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace(content=self.outputs.pop(0), elapsed_sec=0.1, usage={}) - - adapter = Mock() - client = SequencedClient() - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_rules_router, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "direct_seed"), - ("organization_name", "NVIDIA", "augmenter"), - ] - assert len(client.prompts) == 3 - - -def test_native_rules_router_strategy_records_direct_model_usage() -> None: - tool = load_tool( - "measurement_detection_strategies_native_rules_router_usage", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SequencedClient: - def __init__(self) -> None: - self.outputs = [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ] - - def complete(self, _request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content=self.outputs.pop(0), - elapsed_sec=0.1, - usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, - ) - - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_rules_router, - native_client=SequencedClient(), - native_runtime=tool.NativeDetectionRuntime(model="test/native", provider="test-provider"), - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["first_name", "organization_name"], - ) - - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-native-rules-router" - assert record["observed_total_requests"] == 3 - assert record["observed_successful_requests"] == 3 - assert record["observed_failed_requests"] == 0 - assert record["observed_input_tokens"] == 30 - assert record["observed_output_tokens"] == 12 - assert record["observed_total_tokens"] == 42 - assert record["model_usage"]["native-direct"]["model_name"] == "test/native" - assert record["model_usage"]["native-direct"]["model_provider_name"] == "test-provider" - - def test_native_candidate_validate_no_augment_strategy_skips_data_designer_and_augmentation() -> None: tool = load_tool( "measurement_detection_strategies_native_candidate_validate", @@ -1181,9 +898,9 @@ def complete(self, _request): # type: ignore[no-untyped-def] assert record["observed_total_tokens"] == 27 -def test_native_single_pass_strategy_adds_non_overlapping_rule_spans() -> None: +def test_native_single_pass_strategy_uses_only_native_spans() -> None: tool = load_tool( - "measurement_detection_strategies_native_single_pass_rule_guardrail", + "measurement_detection_strategies_native_single_pass_native_spans", REPO_ROOT / "tools/measurement/detection_strategies.py", ) text = "Alice logged in.\nPassword: SuperSecret123!\n" @@ -1212,7 +929,6 @@ def complete(self, _request): # type: ignore[no-untyped-def] entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities assert [(entity.label, entity.value, entity.source) for entity in entities] == [ ("person", "Alice", "direct_single_pass"), - ("password", "SuperSecret123!", "rule"), ] @@ -1259,9 +975,66 @@ def complete(self, _request): # type: ignore[no-untyped-def] assert record["observed_failed_requests"] == 0 -def test_rules_covered_or_default_uses_default_pipeline_for_contextual_labels() -> None: +def test_detector_only_strategy_finalizes_gliner_spans_without_validation_or_augmentation() -> None: + tool = load_tool( + "measurement_detection_strategies_detector_only", + REPO_ROOT / "tools/measurement/detection_strategies.py", + ) + adapter = Mock() + text = "Alice met Alice at the lab." + start = text.index("Alice") + + def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: + assert [column.name for column in columns] == [ + COL_RAW_DETECTED, + COL_SEED_ENTITIES, + COL_SEED_ENTITIES_JSON, + COL_DETECTED_ENTITIES, + ] + assert kwargs["workflow_name"] == "entity-detection-detector-only" + row = { + COL_TEXT: dataframe[COL_TEXT].iloc[0], + COL_RAW_DETECTED: json.dumps( + { + "entities": [ + { + "text": "Alice", + "label": "person", + "start": start, + "end": start + len("Alice"), + "score": 0.99, + } + ] + } + ), + } + row = columns[1].generator_function(row) + row = columns[2].generator_function(row) + seed_entities = json.loads(row[COL_SEED_ENTITIES_JSON]) + assert [(entity["label"], entity["value"]) for entity in seed_entities] == [("person", "Alice")] + row = columns[3].generator_function(row) + return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) + + adapter.run_workflow.side_effect = fake_run_workflow + + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.detector_only): + workflow = EntityDetectionWorkflow(adapter=adapter) + result = workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: [text]}), + model_configs=[], + selected_models=load_default_model_selection().detection, + gliner_detection_threshold=0.3, + entity_labels=["person"], + ) + + entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities + assert [(entity.label, entity.value) for entity in entities] == [("person", "Alice"), ("person", "Alice")] + assert "Alice" in result.dataframe[COL_TAGGED_TEXT].iloc[0] + + +def test_compact_validation_strategy_disables_full_text_single_chunk_validation() -> None: tool = load_tool( - "measurement_detection_strategies_rules_covered_default_fallback", + "measurement_detection_strategies_compact_validation", REPO_ROOT / "tools/measurement/detection_strategies.py", ) original = EntityDetectionWorkflow.detect_and_validate_entities @@ -1287,447 +1060,28 @@ def fake_original( EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): + with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.compact_validation): workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}), + workflow.detect_and_validate_entities( + pd.DataFrame({COL_TEXT: ["Alice works at Acme."]}), model_configs=[], selected_models=load_default_model_selection().detection, gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["api_key", "person"], + entity_labels=["first_name"], ) finally: EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - assert len(calls) == 1 - assert calls[0]["entity_labels"] == ["api_key", "person"] - assert calls[0]["validation_single_chunk_full_text"] is False assert EntityDetectionWorkflow.detect_and_validate_entities is original - assert EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities == [] + assert calls[0]["validation_single_chunk_full_text"] is False -def test_rules_covered_or_default_uses_default_pipeline_for_narrow_prose_rule_labels() -> None: +def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: tool = load_tool( - "measurement_detection_strategies_rules_covered_prose_rule_fallback", + "measurement_detection_strategies_prose_augment_focus", REPO_ROOT / "tools/measurement/detection_strategies.py", ) - original = EntityDetectionWorkflow.detect_and_validate_entities - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(kwargs) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_covered_or_default): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Jordan worked at Acme Research Center and lived on Maple Street."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["organization_name", "street_address"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert len(calls) == 1 - assert calls[0]["entity_labels"] == ["organization_name", "street_address"] - assert EntityDetectionWorkflow.detect_and_validate_entities is original - assert EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities == [] - - -def test_detector_only_strategy_finalizes_gliner_spans_without_validation_or_augmentation() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_only", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - text = "Alice met Alice at the lab." - start = text.index("Alice") - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: - assert [column.name for column in columns] == [ - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_DETECTED_ENTITIES, - ] - assert kwargs["workflow_name"] == "entity-detection-detector-only" - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "person", - "start": start, - "end": start + len("Alice"), - "score": 0.99, - } - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = json.loads(row[COL_SEED_ENTITIES_JSON]) - assert [(entity["label"], entity["value"]) for entity in seed_entities] == [("person", "Alice")] - row = columns[3].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.detector_only): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("person", "Alice"), ("person", "Alice")] - assert "Alice" in result.dataframe[COL_TAGGED_TEXT].iloc[0] - - -def test_rules_guardrail_detector_only_adds_rule_spans_without_validation_or_augmentation() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_guardrail_detector_only", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - text = f"Alice exported token={token}" - alice_start = text.index("Alice") - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: - assert [column.name for column in columns] == [ - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_DETECTED_ENTITIES, - ] - assert kwargs["workflow_name"] == "entity-detection-rules-guardrail-detector-only" - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "person", - "start": alice_start, - "end": alice_start + len("Alice"), - "score": 0.99, - } - ] - } - ), - } - for column in columns[1:]: - row = column.generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail_detector_only): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["api_key", "person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("person", "Alice", "detector"), - ("api_key", token, "rule"), - ] - - -def test_rules_guardrail_keeps_default_pipeline_and_adds_rule_spans() -> None: - tool = load_tool( - "measurement_detection_strategies_default_rules_guardrail", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - text = "The applicant was born in 1978 and later moved to Berlin." - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **_: object, - ) -> object: - calls.append(self) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["date_of_birth"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert len(calls) == 1 - assert EntityDetectionWorkflow.detect_and_validate_entities is original - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [("date_of_birth", "1978", "rule")] - - -def test_compact_validation_strategy_disables_full_text_single_chunk_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_compact_validation", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(kwargs) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.compact_validation): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at Acme."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert EntityDetectionWorkflow.detect_and_validate_entities is original - assert calls[0]["validation_single_chunk_full_text"] is False - - -def test_rules_guardrail_compact_validation_combines_rule_guardrail_and_compact_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_guardrail_compact_validation", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - text = "The applicant was born in 1978 and later moved to Berlin." - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(kwargs) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.rules_guardrail_compact_validation - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["date_of_birth"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert calls[0]["validation_single_chunk_full_text"] is False - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [("date_of_birth", "1978", "rule")] - - -def test_rules_guardrail_prefers_rule_label_for_exact_span_overlap() -> None: - tool = load_tool( - "measurement_detection_strategies_default_rules_guardrail_exact_overlap", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - text = "Idilio describes himself as secular and leans progressive on most political issues." - start = text.index("secular") - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **_: object, - ) -> object: - calls.append(self) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema( - entities=[ - { - "id": "political_view_0", - "value": "secular", - "label": "political_view", - "start_position": start, - "end_position": start + len("secular"), - "score": 1.0, - "source": "detector", - } - ] - ).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["political_view", "religious_belief"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert len(calls) == 1 - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("religious_belief", "secular", "rule") - ] - - -def test_rules_guardrail_can_apply_explicit_rule_labels_outside_model_labels() -> None: - tool = load_tool( - "measurement_detection_strategies_default_rules_guardrail_rule_labels", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - text = "Outside the lab, Idilio shares a modest house on West Roberts Drive with his wife." - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(kwargs["entity_labels"]) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.rules_guardrail, - rule_labels=["street_address"], - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.run( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name"], - privacy_goal=None, - tag_latent_entities=False, - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert calls == [["first_name"]] - detected_entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - final_entities = EntitiesSchema.from_raw(result.dataframe[COL_FINAL_ENTITIES].iloc[0]).entities - expected = [("street_address", "West Roberts Drive", "rule")] - assert [(entity.label, entity.value, entity.source) for entity in detected_entities] == expected - assert [(entity.label, entity.value, entity.source) for entity in final_entities] == expected - - -def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: - tool = load_tool( - "measurement_detection_strategies_prose_augment_focus", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - before = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) + before = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.prose_augment_focus): inside = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) @@ -1737,417 +1091,3 @@ def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: assert "Contextual prose recall focus" in inside assert "organization and institution names" in inside assert after == before - - -def test_rules_guardrail_no_augment_adds_rule_spans_after_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_guardrail", REPO_ROOT / "tools/measurement/detection_strategies.py" - ) - adapter = Mock() - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - assert [column.name for column in columns][-3:] == [ - COL_AUGMENTED_ENTITIES, - COL_MERGED_ENTITIES, - COL_DETECTED_ENTITIES, - ] - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_MERGED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - COL_VALIDATED_ENTITIES: {"decisions": []}, - } - row = columns[-1].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_guardrail_no_augment): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA\nPassword: fakePass123!"]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["api_key", "password"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), - ("password", "fakePass123!"), - ] - assert adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-guardrail-no-augment" - - -def test_rules_filter_guardrail_no_augment_filters_rule_spans_before_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_filter_guardrail", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - text = f"Alice used token={token}" - alice_start = text.index("Alice") - token_start = text.index(token) - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "person", - "start": alice_start, - "end": alice_start + len("Alice"), - "score": 0.99, - }, - { - "text": token, - "label": "api_key", - "start": token_start, - "end": token_start + len(token), - "score": 0.99, - }, - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities - assert [(entity.label, entity.value) for entity in seed_entities] == [("person", "Alice")] - assert [candidate["label"] for candidate in row[COL_SEED_VALIDATION_CANDIDATES]["candidates"]] == ["person"] - row[COL_MERGED_ENTITIES] = EntitiesSchema(entities=[]).model_dump(mode="json") - row[COL_VALIDATED_ENTITIES] = {"decisions": []} - row = columns[-1].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["api_key", "person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("api_key", token)] - assert ( - adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-filter-guardrail-no-augment" - ) - - -def test_rules_filter_guardrail_keeps_augmentation_but_skips_rule_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_filter_guardrail_with_augmentation", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - token = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - text = f"Alice used token={token}" - alice_start = text.index("Alice") - token_start = text.index(token) - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - assert [column.name for column in columns] == [ - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_VALIDATION_CANDIDATES, - COL_VALIDATION_DECISIONS, - COL_VALIDATED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_AUGMENTED_ENTITIES, - COL_MERGED_ENTITIES, - COL_DETECTED_ENTITIES, - ] - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "person", - "start": alice_start, - "end": alice_start + len("Alice"), - "score": 0.99, - }, - { - "text": token, - "label": "api_key", - "start": token_start, - "end": token_start + len(token), - "score": 0.99, - }, - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities - assert [(entity.label, entity.value) for entity in seed_entities] == [("person", "Alice")] - assert [candidate["label"] for candidate in row[COL_SEED_VALIDATION_CANDIDATES]["candidates"]] == ["person"] - - row[COL_VALIDATION_DECISIONS] = { - "decisions": [ - { - "id": "person_0_5", - "decision": "keep", - "proposed_label": "", - "reason": "test keep", - } - ] - } - row = columns[4].generator_function(row) - row = columns[5].generator_function(row) - validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities - assert [(entity.label, entity.value, entity.source) for entity in validated_seed] == [ - ("person", "Alice", "detector"), - ("api_key", token, "rule"), - ] - assert f"{token}" in row[COL_INITIAL_TAGGED_TEXT] - - row[COL_AUGMENTED_ENTITIES] = {"entities": []} - row = columns[7].generator_function(row) - row = columns[8].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["api_key", "person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("person", "Alice", "detector"), - ("api_key", token, "rule"), - ] - assert adapter.run_workflow.call_args.kwargs["workflow_name"] == "entity-detection-rules-filter-guardrail" - - -def test_rules_filter_guardrail_preserves_different_label_rule_overlap() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_filter_guardrail_preserve_contextual_overlap", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - phrase = "Christian Democrat" - rule_value = "Christian" - text = f"He identifies as a {phrase}." - phrase_start = text.index(phrase) - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": phrase, - "label": "political_view", - "start": phrase_start, - "end": phrase_start + len(phrase), - "score": 0.99, - }, - { - "text": rule_value, - "label": "religious_belief", - "start": phrase_start, - "end": phrase_start + len(rule_value), - "score": 0.99, - }, - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities - assert [(entity.label, entity.value) for entity in seed_entities] == [("political_view", phrase)] - assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["label"] == "political_view" - - row[COL_VALIDATION_DECISIONS] = { - "decisions": [ - { - "id": f"political_view_{phrase_start}_{phrase_start + len(phrase)}", - "decision": "keep", - "proposed_label": "", - "reason": "test keep", - } - ] - } - row = columns[4].generator_function(row) - row = columns[5].generator_function(row) - validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities - assert any(entity.label == "political_view" and entity.value == phrase for entity in validated_seed) - assert f"{phrase}" in row[COL_INITIAL_TAGGED_TEXT] - - row[COL_AUGMENTED_ENTITIES] = {"entities": []} - row = columns[7].generator_function(row) - row = columns[8].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["political_view", "religious_belief"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert any(entity.label == "political_view" and entity.value == phrase for entity in entities) - - -def test_rules_filter_guardrail_preserves_longer_same_label_detector_overlap() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_filter_guardrail_preserve_longer_same_label", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - rule_value = "Great Health" - detector_value = f"{rule_value} and Mountain Timber" - text = f"After apprenticeships with {detector_value}, Darwin started his own shop." - detector_start = text.index(detector_value) - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": detector_value, - "label": "organization_name", - "start": detector_start, - "end": detector_start + len(detector_value), - "score": 0.99, - } - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = EntitiesSchema.from_raw(row[COL_SEED_ENTITIES]).entities - assert [(entity.label, entity.value) for entity in seed_entities] == [("organization_name", detector_value)] - assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["value"] == detector_value - - row[COL_VALIDATION_DECISIONS] = { - "decisions": [ - { - "id": f"organization_name_{detector_start}_{detector_start + len(detector_value)}", - "decision": "keep", - "proposed_label": "", - "reason": "test keep", - } - ] - } - row = columns[4].generator_function(row) - row = columns[5].generator_function(row) - row[COL_AUGMENTED_ENTITIES] = {"entities": []} - row = columns[7].generator_function(row) - row = columns[8].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["organization_name"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("organization_name", detector_value)] - - -def test_rules_filter_guardrail_does_not_shadow_validated_different_label_exact_span() -> None: - tool = load_tool( - "measurement_detection_strategies_rules_filter_guardrail_no_shadow_exact_span", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - value = "Bowdoin College" - text = f"He completed his MLS at {value}, and his early career followed." - start = text.index(value) - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **_: object) -> SimpleNamespace: - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": value, - "label": "university", - "start": start, - "end": start + len(value), - "score": 0.99, - } - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - assert row[COL_SEED_VALIDATION_CANDIDATES]["candidates"][0]["label"] == "university" - - row[COL_VALIDATION_DECISIONS] = { - "decisions": [ - { - "id": f"university_{start}_{start + len(value)}", - "decision": "keep", - "proposed_label": "", - "reason": "test keep", - } - ] - } - row = columns[4].generator_function(row) - row = columns[5].generator_function(row) - validated_seed = EntitiesSchema.from_raw(row[COL_VALIDATED_SEED_ENTITIES]).entities - assert [(entity.label, entity.value) for entity in validated_seed] == [("university", value)] - - row[COL_AUGMENTED_ENTITIES] = {"entities": []} - row = columns[7].generator_function(row) - row = columns[8].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.rules_filter_guardrail): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["organization_name", "university"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("university", value)] diff --git a/tests/tools/test_extract_signature_deltas.py b/tests/tools/test_extract_signature_deltas.py index 5e442932..922d73eb 100644 --- a/tests/tools/test_extract_signature_deltas.py +++ b/tests/tools/test_extract_signature_deltas.py @@ -113,21 +113,21 @@ def test_extract_signature_deltas_masks_candidate_only_context(tmp_path: Path) - assert "[ORGANIZATION_NAME:" in row.masked_context -def test_extract_signature_deltas_recovers_guardrail_rule_context(tmp_path: Path) -> None: +def test_extract_signature_deltas_recovers_artifact_detail_context(tmp_path: Path) -> None: analyzer = load_tool( - "measurement_detection_artifact_rule_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" + "measurement_detection_artifact_context_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" ) tool = load_tool( - "measurement_extract_signature_deltas_rule", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" + "measurement_extract_signature_deltas_context", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" ) baseline_root = tmp_path / "baseline" candidate_root = tmp_path / "candidate" baseline = _write_artifact_case(baseline_root, analyzer, [], "The applicant was born in 1990.") candidate = _write_artifact_case(candidate_root, analyzer, [], "The applicant was born in 1990.") - rule_entity = analyzer.EntitySchema( + detail_entity = analyzer.EntitySchema( value="1990", label="date_of_birth", start_position=26, end_position=30, source="rule" ) - rule_row = analyzer.build_detection_artifact_row_from_entities( + detail_row = analyzer.build_detection_artifact_row_from_entities( workflow_name="entity-detection", batch_file="entity-detection/parquet-files/batch_00000.parquet", row_index=0, @@ -135,9 +135,9 @@ def test_extract_signature_deltas_recovers_guardrail_rule_context(tmp_path: Path seed_validation_candidate_count=0, merged_validation_candidate_count=0, augmented_entities=[], - final_entities=[rule_entity], + final_entities=[detail_entity], ).model_dump() - pd.json_normalize([{**_case_metadata(), **rule_row}], sep=".").to_json(candidate, orient="records", lines=True) + pd.json_normalize([{**_case_metadata(), **detail_row}], sep=".").to_json(candidate, orient="records", lines=True) result = tool.extract_signature_deltas( baseline, @@ -150,7 +150,7 @@ def test_extract_signature_deltas_recovers_guardrail_rule_context(tmp_path: Path row = result.rows[0] assert row.label == "date_of_birth" assert row.source == "rule" - assert row.resolution == "rule" + assert row.resolution == "artifact_details" assert "1990" not in row.masked_context assert "[DATE_OF_BIRTH:" in row.masked_context diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 5d1f5a79..99642535 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -98,152 +98,6 @@ def test_benchmark_exports_detection_artifact_analysis(tmp_path: Path) -> None: assert "Alice" not in output_path.read_text(encoding="utf-8") -def test_benchmark_exports_rules_only_synthetic_detection_artifacts(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_only_synthetic_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - pd.DataFrame({"text": [f"export API_KEY={secret}"]}).to_csv(input_path, index=False) - config = tool.ConfigSpec( - id="rules-only-redact", - replace="redact", - detect={"entity_labels": ["api_key", "email", "password", "url"]}, - experimental_detection_strategy="rules_only", - ) - case = tool.BenchmarkCase( - suite_id="rules-suite", - workload_id="input", - config_id="rules-only-redact", - repetition=0, - case_id="input__rules-only-redact__r000", - ) - output_path = tmp_path / "raw" / "input__rules-only-redact__r000.detection-artifacts.jsonl" - - result = tool.export_rules_only_case_detection_artifacts( - config, - tool.AnonymizerInput(source=str(input_path), text_column="text"), - output_path, - case=case, - ) - - assert result == output_path - text = output_path.read_text(encoding="utf-8") - assert secret not in text - row = json.loads(text) - assert row["workflow_name"] == "entity-detection-rules-only" - assert row["final_entity_count"] == 1 - assert row["final_entity_signature_count"] == 1 - assert row["final_label_counts.api_key"] == 1 - assert row["final_source_counts.rule"] == 1 - assert any(key.startswith("final_entity_signature_labels.") for key in row) - - -def test_benchmark_exports_rules_covered_or_default_synthetic_artifacts_for_structured_fast_lane_labels( - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_covered_synthetic_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - pd.DataFrame({"text": [f"export API_KEY={secret}"]}).to_csv(input_path, index=False) - config = tool.ConfigSpec( - id="rules-covered-redact", - replace="redact", - detect={"entity_labels": ["api_key", "email", "password", "url"]}, - experimental_detection_strategy="rules_covered_or_default", - ) - case = tool.BenchmarkCase( - suite_id="rules-suite", - workload_id="input", - config_id="rules-covered-redact", - repetition=0, - case_id="input__rules-covered-redact__r000", - ) - output_path = tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl" - - result = tool.export_rules_only_case_detection_artifacts( - config, - tool.AnonymizerInput(source=str(input_path), text_column="text"), - output_path, - case=case, - ) - - assert result == output_path - row = json.loads(output_path.read_text(encoding="utf-8")) - assert row["workflow_name"] == "entity-detection-rules-only" - assert row["final_entity_count"] == 1 - assert row["final_label_counts.api_key"] == 1 - - -def test_benchmark_does_not_export_rules_covered_or_default_artifacts_for_contextual_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_covered_contextual_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - config = tool.ConfigSpec( - id="rules-covered-redact", - replace="redact", - detect={"entity_labels": ["api_key", "person"]}, - experimental_detection_strategy="rules_covered_or_default", - ) - case = tool.BenchmarkCase( - suite_id="rules-suite", - workload_id="input", - config_id="rules-covered-redact", - repetition=0, - case_id="input__rules-covered-redact__r000", - ) - - result = tool.export_rules_only_case_detection_artifacts( - config, - tool.AnonymizerInput(source=str(input_path), text_column="text"), - tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl", - case=case, - ) - - assert result is None - - -def test_benchmark_does_not_export_rules_covered_artifacts_for_narrow_prose_rule_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_covered_prose_rule_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Jordan worked at Acme Research Center and lived on Maple Street."]}).to_csv( - input_path, - index=False, - ) - config = tool.ConfigSpec( - id="rules-covered-redact", - replace="redact", - detect={"entity_labels": ["organization_name", "street_address"]}, - experimental_detection_strategy="rules_covered_or_default", - ) - case = tool.BenchmarkCase( - suite_id="rules-suite", - workload_id="input", - config_id="rules-covered-redact", - repetition=0, - case_id="input__rules-covered-redact__r000", - ) - - result = tool.export_rules_only_case_detection_artifacts( - config, - tool.AnonymizerInput(source=str(input_path), text_column="text"), - tmp_path / "raw" / "input__rules-covered-redact__r000.detection-artifacts.jsonl", - case=case, - ) - - assert result is None - - def test_benchmark_detection_artifact_analysis_ignores_stale_artifacts(tmp_path: Path) -> None: tool = load_tool("measurement_benchmark_tool_artifact_delta", REPO_ROOT / "tools/measurement/run_benchmarks.py") artifact_root = tmp_path / "artifacts" @@ -410,89 +264,6 @@ def test_benchmark_patches_detection_artifacts_from_final_trace_dataframe(tmp_pa assert "final_entity_signature_labels.stale" not in row -def test_rules_covered_or_default_detection_artifacts_use_final_trace_dataframe(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_covered_trace_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - case = tool.BenchmarkCase( - suite_id="suite-a", - workload_id="input", - config_id="rules-covered", - repetition=0, - case_id="input__rules-covered__r000", - ) - config = tool.ConfigSpec( - id="rules-covered", - replace="redact", - detect={"entity_labels": ["api_key", "password"]}, - experimental_detection_strategy="rules_covered_or_default", - ) - trace_dataframe = pd.DataFrame( - { - COL_FINAL_ENTITIES: [ - { - "entities": [ - { - "value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", - "label": "api_key", - "start_position": 6, - "end_position": 38, - "source": "rule", - } - ] - }, - { - "entities": [ - { - "value": "SecretNoRule123!", - "label": "password", - "start_position": 14, - "end_position": 30, - "source": "detector", - } - ] - }, - ] - } - ) - paths = tool._CaseRunPaths( - raw_path=tmp_path / "raw" / "case.jsonl", - artifact_output_path=tmp_path / "raw" / "case.detection-artifacts.jsonl", - trace_path=None, - artifact_snapshot={}, - ) - tool.write_detection_artifact_payloads([_stale_detection_artifact_payload()], paths.artifact_output_path) - contexts = {"artifact_path": tmp_path / "artifacts"} - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - execution = tool._CaseExecution( - input_data=tool.AnonymizerInput(source=str(input_path), text_column="text"), - trace_dataframe=trace_dataframe, - ) - - result = tool._case_detection_artifact_path( - contexts, - paths, - case=case, - config=config, - execution=execution, - ) - - assert result == paths.artifact_output_path - rows = [json.loads(line) for line in paths.artifact_output_path.read_text(encoding="utf-8").splitlines()] - assert len(rows) == 2 - assert [row["workflow_name"] for row in rows] == [ - "entity-detection-final-trace", - "entity-detection-final-trace", - ] - assert [row["row_index"] for row in rows] == [0, 1] - assert [row["final_source_counts.rule"] for row in rows] == [1.0, None] - assert [row["final_source_counts.detector"] for row in rows] == [None, 1.0] - assert "sk-test" not in paths.artifact_output_path.read_text(encoding="utf-8") - assert "SecretNoRule123!" not in paths.artifact_output_path.read_text(encoding="utf-8") - - def test_run_suite_records_detection_artifact_analysis_path( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -1379,16 +1150,6 @@ def test_benchmark_config_accepts_experimental_detection_strategy() -> None: "measurement_benchmark_tool_detection_strategy_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" ) - config = tool.ConfigSpec( - id="rules-only", - replace="redact", - experimental_detection_strategy="rules_only", - ) - - assert config.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.rules_only - anonymizer_config = tool.build_anonymizer_config(config) - assert not hasattr(anonymizer_config.detect, "experimental_detection_strategy") - detector_only = tool.ConfigSpec( id="detector-only", replace="redact", @@ -1396,22 +1157,8 @@ def test_benchmark_config_accepts_experimental_detection_strategy() -> None: ) assert detector_only.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.detector_only - - rules_covered = tool.ConfigSpec( - id="rules-covered", - replace="redact", - experimental_detection_strategy="rules_covered_or_default", - ) - - assert rules_covered.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.rules_covered_or_default - - native_rules_router = tool.ConfigSpec( - id="native-rules-router", - replace="redact", - experimental_detection_strategy="native_rules_router", - ) - - assert native_rules_router.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.native_rules_router + anonymizer_config = tool.build_anonymizer_config(detector_only) + assert not hasattr(anonymizer_config.detect, "experimental_detection_strategy") native_candidate_validate = tool.ConfigSpec( id="native-candidate-validate", @@ -1510,30 +1257,6 @@ def test_benchmark_config_accepts_experimental_detection_strategy() -> None: ) -def test_benchmark_config_accepts_experimental_rule_labels() -> None: - tool = load_tool("measurement_benchmark_tool_rule_labels_config", REPO_ROOT / "tools/measurement/run_benchmarks.py") - - config = tool.ConfigSpec( - id="rules-guardrail", - replace="redact", - experimental_detection_strategy="rules_guardrail", - experimental_rule_labels=["street_address"], - ) - - assert config.experimental_rule_labels == ["street_address"] - anonymizer_config = tool.build_anonymizer_config(config) - assert not hasattr(anonymizer_config.detect, "experimental_rule_labels") - - detector_only = tool.ConfigSpec( - id="rules-guardrail-detector-only", - replace="redact", - experimental_detection_strategy="rules_guardrail_detector_only", - experimental_rule_labels=["api_key"], - ) - - assert detector_only.experimental_rule_labels == ["api_key"] - - def test_benchmark_spec_accepts_dd_parser_compat() -> None: tool = load_tool( "measurement_benchmark_tool_dd_parser_compat_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" @@ -1549,200 +1272,6 @@ def test_benchmark_spec_accepts_dd_parser_compat() -> None: assert spec.dd_parser_compat == tool.DDParserCompatMode.raw_json -def test_benchmark_preflight_rejects_rules_only_without_explicit_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_only_without_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rules-only-no-labels -workloads: - - id: input - source: input.csv -configs: - - id: rules-only-redact - experimental_detection_strategy: rules_only - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="requires explicit detect.entity_labels"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_rejects_rules_only_unsupported_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_only_unsupported_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rules-only-unsupported-labels -workloads: - - id: input - source: input.csv -configs: - - id: rules-only-redact - experimental_detection_strategy: rules_only - detect: - entity_labels: [api_key, person] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="unsupported high-confidence rule labels.*person"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_accepts_rules_only_supported_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_only_supported_labels", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rules-only-supported-labels -workloads: - - id: input - source: input.csv -configs: - - id: rules-only-redact - experimental_detection_strategy: rules_only - detect: - entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_accepts_rules_covered_or_default_contextual_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rules_covered_contextual_labels", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rules-covered-contextual-labels -workloads: - - id: input - source: input.csv -configs: - - id: rules-covered-redact - experimental_detection_strategy: rules_covered_or_default - detect: - entity_labels: [api_key, person] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_rejects_experimental_rule_labels_for_non_rule_strategy(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rule_labels_non_rule_strategy", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rule-labels-non-rule-strategy -workloads: - - id: input - source: input.csv -configs: - - id: redact - experimental_detection_strategy: prose_augment_focus - experimental_rule_labels: [street_address] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="experimental_rule_labels requires a rule-backed strategy"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_accepts_experimental_rule_labels_for_compact_rule_guardrail( - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_rule_labels_compact_rule_guardrail", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice lives on West Roberts Drive."]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rule-labels-compact-rule-guardrail -workloads: - - id: input - source: input.csv -configs: - - id: redact - experimental_detection_strategy: rules_guardrail_compact_validation - experimental_rule_labels: [street_address] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_rejects_unsupported_experimental_rule_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_rule_labels_unsupported", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: rule-labels-unsupported -workloads: - - id: input - source: input.csv -configs: - - id: redact - experimental_detection_strategy: rules_guardrail - experimental_rule_labels: [person] - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="unsupported experimental_rule_labels.*person"): - tool.preflight_suite(spec, spec_path=spec_path) - - def test_benchmark_case_enters_experimental_detection_strategy_context( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -1753,7 +1282,7 @@ def test_benchmark_case_enters_experimental_detection_strategy_context( captured_measurements: list[Any] = [] captured_parser_compat: list[Any] = [] captured_strategies: list[Any] = [] - captured_rule_labels: list[Any] = [] + captured_context_kwargs: list[dict[str, Any]] = [] @contextmanager def fake_measurement_session(config: Any) -> Iterator[None]: @@ -1761,9 +1290,9 @@ def fake_measurement_session(config: Any) -> Iterator[None]: yield None @contextmanager - def fake_detection_strategy_context(strategy: Any, *, rule_labels: list[str] | None = None) -> Iterator[None]: + def fake_detection_strategy_context(strategy: Any, **kwargs: Any) -> Iterator[None]: captured_strategies.append(strategy) - captured_rule_labels.append(rule_labels) + captured_context_kwargs.append(kwargs) yield None @contextmanager @@ -1781,32 +1310,36 @@ def run(self, *, config: Any, data: Any) -> None: monkeypatch.setattr(tool, "experimental_detection_strategy_context", fake_detection_strategy_context) spec = tool.BenchmarkSpec( - suite_id="rules-suite", + suite_id="native-suite", dd_parser_compat="raw_json", + native_runtime=tool.NativeRuntimeSpec( + runtime_id="native-test", + endpoint="http://runtime.example/v1", + model="test-model", + ), workloads=[tool.WorkloadSpec(id="input", source="input.csv")], configs=[ tool.ConfigSpec( - id="rules-only-redact", + id="native-single-pass-redact", replace="redact", - experimental_detection_strategy="rules_only", - experimental_rule_labels=["api_key"], + experimental_detection_strategy="native_single_pass", ) ], ) pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(tmp_path / "input.csv", index=False) case = tool.BenchmarkCase( - suite_id="rules-suite", + suite_id="native-suite", workload_id="input", - config_id="rules-only-redact", + config_id="native-single-pass-redact", repetition=0, - case_id="input__rules-only-redact__r000", + case_id="input__native-single-pass-redact__r000", ) tool._execute_case( FakeAnonymizer(), spec.workloads[0], spec.configs[0], - raw_path=tmp_path / "raw" / "input__rules-only-redact__r000.jsonl", + raw_path=tmp_path / "raw" / "input__native-single-pass-redact__r000.jsonl", trace_path=None, case=case, spec=spec, @@ -1816,8 +1349,9 @@ def run(self, *, config: Any, data: Any) -> None: ) assert captured_parser_compat == [tool.DDParserCompatMode.raw_json] - assert captured_strategies == [tool.ExperimentalDetectionStrategy.rules_only] - assert captured_rule_labels == [["api_key"]] - assert captured_measurements[0].run_tags["experimental_detection_strategy"] == "rules_only" - assert captured_measurements[0].run_tags["experimental_rule_labels"] == ["api_key"] + assert captured_strategies == [tool.ExperimentalDetectionStrategy.native_single_pass] + assert captured_context_kwargs[0]["native_runtime"].endpoint == "http://runtime.example/v1" + assert captured_context_kwargs[0]["native_runtime"].model == "test-model" + assert captured_measurements[0].run_tags["experimental_detection_strategy"] == "native_single_pass" + assert captured_measurements[0].run_tags["native_runtime_id"] == "native-test" assert captured_measurements[0].run_tags["dd_parser_compat"] == "raw_json" diff --git a/tests/tools/test_screen_strategy_comparisons.py b/tests/tools/test_screen_strategy_comparisons.py index 95d1c3f2..8b2fa237 100644 --- a/tests/tools/test_screen_strategy_comparisons.py +++ b/tests/tools/test_screen_strategy_comparisons.py @@ -35,9 +35,9 @@ def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) { "workload_id": "shell-3", "baseline_config_id": "default", - "candidate_config_id": "rules-only", + "candidate_config_id": "detector-only", "baseline_strategy": "default", - "candidate_strategy": "rules_only", + "candidate_strategy": "detector_only", "baseline_replacement_strategy": "default", "candidate_replacement_strategy": "local_structured_substitute", "baseline_case_count": 3, @@ -91,7 +91,7 @@ def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) "source_path": "analysis/default-vs-candidates.csv", "workload_id": "shell-3", "baseline_config_id": "default", - "candidate_config_id": "rules-only", + "candidate_config_id": "detector-only", "safety_verdict": "review", "performance_verdict": "improved", "candidate_verdict": "review", @@ -121,13 +121,13 @@ def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) assert shell.baseline_case_count == 3 assert shell.candidate_case_count == 3 assert shell.shared_stable_final_entity_signature_count == 12 - rules_local = next( + detector_local = next( group for group in result.groups - if group.group_key == "strategy:rules_only|replacement:local_structured_substitute" + if group.group_key == "strategy:detector_only|replacement:local_structured_substitute" ) - assert rules_local.candidate_replacement_strategy == "local_structured_substitute" - assert rules_local.row_count == 1 + assert detector_local.candidate_replacement_strategy == "local_structured_substitute" + assert detector_local.row_count == 1 no_augment = next(group for group in result.groups if group.group_key == "strategy:no_augment") assert no_augment.row_count == 1 assert no_augment.reject_count == 1 @@ -141,10 +141,10 @@ def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: ) rows = [ tool.ScreenRow( - source_path="analysis/default-vs-rules.csv", + source_path="analysis/default-vs-detector-only.csv", workload_id="shell", baseline_config_id="default", - candidate_config_id="rules", + candidate_config_id="detector-only", baseline_replacement_strategy="default", candidate_replacement_strategy="local_structured_substitute", safety_verdict="review", @@ -248,8 +248,8 @@ def test_screen_strategy_comparisons_surfaces_candidate_original_value_leaks(tmp { "workload_id": "structured-secrets", "baseline_config_id": "default", - "candidate_config_id": "rules-covered", - "candidate_strategy": "rules_covered_or_default", + "candidate_config_id": "native-single-pass", + "candidate_strategy": "native_single_pass", "safety_verdict": "fail", "performance_verdict": "improved", "candidate_verdict": "reject", @@ -291,78 +291,6 @@ def test_screen_strategy_comparisons_surfaces_candidate_original_value_leaks(tmp assert "collision_labels=date:1" in rendered -def test_screen_strategy_comparisons_marks_rule_fast_lane_review_when_only_provenance_flags_remain( - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_fast_lane_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-secrets", - "baseline_config_id": "default", - "candidate_config_id": "rules-covered", - "candidate_strategy": "rules_covered_or_default", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "candidate_original_value_leak_count": 0, - "candidate_original_value_leak_record_count": 0, - "flags": '["no_candidate_detector_entities", "candidate_uses_rule_entities"]', - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - assert result.groups[0].recommendation == "fast_lane_review" - - -def test_screen_strategy_comparisons_treats_covered_boundary_deltas_as_fast_lane_review( - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_fast_lane_boundary_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "baseline_config_id": "default", - "candidate_config_id": "rules-local", - "candidate_strategy": "rules_covered_or_default", - "candidate_replacement_strategy": "local_structured_substitute", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "candidate_original_value_leak_count": 0, - "candidate_original_value_leak_record_count": 0, - "flags": ( - '["entity_count_loss", "span_boundary_mismatch", ' - '"no_candidate_detector_entities", "candidate_uses_rule_entities"]' - ), - "baseline_only_final_entity_signature_label_counts.api_key": 2, - "baseline_only_final_entity_signature_label_counts.http_cookie": 3, - "baseline_only_candidate_covered_signature_label_counts.api_key": 2, - "baseline_only_candidate_covered_signature_label_counts.http_cookie": 3, - "baseline_only_candidate_overlapping_signature_label_counts.http_cookie": 1, - "baseline_only_candidate_uncovered_signature_count": 0, - "baseline_stable_candidate_uncovered_signature_count": 0, - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - row = result.rows[0] - assert row.baseline_only_label_counts == {} - assert result.groups[0].baseline_only_label_counts == {} - assert result.groups[0].recommendation == "fast_lane_review" - - def test_screen_strategy_comparisons_surfaces_label_policy_review(tmp_path: Path) -> None: tool = load_tool( "measurement_screen_strategy_comparisons_label_policy_review", @@ -568,18 +496,18 @@ def test_screen_strategy_comparisons_groups_default_detection_by_replacement_str assert tool.group_base_for_row(row, config_aliases={}) == "replacement:local_structured_substitute" -def test_screen_strategy_comparisons_keeps_stale_rule_review_generic_without_leak_metrics() -> None: +def test_screen_strategy_comparisons_keeps_generic_review_without_leak_metrics() -> None: tool = load_tool( - "measurement_screen_strategy_comparisons_fast_lane_review_stale", + "measurement_screen_strategy_comparisons_generic_review", REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", ) group = tool.ScreenGroup( - group_key="strategy:rules_covered_or_default", - candidate_strategy="rules_covered_or_default", + group_key="strategy:detector_only", + candidate_strategy="detector_only", row_count=1, review_count=1, performance_verdict_counts={"improved": 1}, - flag_counts={"candidate_uses_rule_entities": 1, "no_candidate_detector_entities": 1}, + flag_counts={"candidate_skips_llm_validation": 1}, ) assert tool.group_recommendation(group) == "review_only" @@ -680,8 +608,8 @@ def test_screen_strategy_comparisons_groups_candidate_strategy_conflicts(tmp_pat { "workload_id": "shell", "baseline_config_id": "default", - "candidate_config_id": "rules", - "candidate_strategy": "rules_only", + "candidate_config_id": "detector-only", + "candidate_strategy": "detector_only", "safety_verdict": "review", "performance_verdict": "improved", "candidate_verdict": "review", @@ -696,7 +624,7 @@ def test_screen_strategy_comparisons_groups_candidate_strategy_conflicts(tmp_pat result = tool.screen_comparison_paths([tmp_path]) groups = {group.group_key: group for group in result.groups} - assert list(groups) == ["strategy:rules_only", "strategy:no_augment"] + assert list(groups) == ["strategy:detector_only", "strategy:no_augment"] no_augment = groups["strategy:no_augment"] assert no_augment.row_count == 2 assert no_augment.viable_count == 1 @@ -836,20 +764,20 @@ def test_screen_strategy_comparisons_can_group_by_strategy_and_workload_family(t { "workload_id": "shell-secrets-3", "baseline_config_id": "default", - "candidate_config_id": "rules-only-shell", - "candidate_strategy": "rules_only", + "candidate_config_id": "detector-only-shell", + "candidate_strategy": "detector_only", "safety_verdict": "review", "performance_verdict": "improved", "candidate_verdict": "review", "pipeline_elapsed_sec_delta_pct": -99.9, "observed_total_tokens_delta": -11000, - "flags": '["candidate_uses_rule_entities"]', + "flags": '["candidate_skips_llm_validation"]', }, { "workload_id": "biographies-r5-offset5", "baseline_config_id": "default", - "candidate_config_id": "rules-only-bio", - "candidate_strategy": "rules_only", + "candidate_config_id": "detector-only-bio", + "candidate_strategy": "detector_only", "safety_verdict": "fail", "performance_verdict": "improved", "candidate_verdict": "reject", @@ -865,11 +793,11 @@ def test_screen_strategy_comparisons_can_group_by_strategy_and_workload_family(t result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) groups = {group.group_key: group for group in result.groups} - assert list(groups) == ["strategy:rules_only|family:shell-secrets", "strategy:rules_only|family:biographies"] - assert groups["strategy:rules_only|family:shell-secrets"].recommendation == "review_only" - assert groups["strategy:rules_only|family:shell-secrets"].workload_families == ["shell-secrets"] - assert groups["strategy:rules_only|family:biographies"].recommendation == "reject" - assert groups["strategy:rules_only|family:biographies"].baseline_only_label_counts == {"first_name": 4} + assert list(groups) == ["strategy:detector_only|family:shell-secrets", "strategy:detector_only|family:biographies"] + assert groups["strategy:detector_only|family:shell-secrets"].recommendation == "review_only" + assert groups["strategy:detector_only|family:shell-secrets"].workload_families == ["shell-secrets"] + assert groups["strategy:detector_only|family:biographies"].recommendation == "reject" + assert groups["strategy:detector_only|family:biographies"].baseline_only_label_counts == {"first_name": 4} def test_workload_family_normalizes_slice_and_offset_suffixes() -> None: diff --git a/tests/tools/test_staged_detection_output_analysis.py b/tests/tools/test_staged_detection_output_analysis.py index f2ccb7f9..4c791501 100644 --- a/tests/tools/test_staged_detection_output_analysis.py +++ b/tests/tools/test_staged_detection_output_analysis.py @@ -44,13 +44,12 @@ def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_p "record_type": "staged_detection_case", "case_id": "shell-row-0", "row_index": 0, - "seed_source": "rules_router", + "seed_source": "gliner", "status": "completed", "elapsed_sec": 0.002, "model_elapsed_sec": 0.0, "model_phase_count": 0, "model_request_count": 0, - "rule_covered_label_set": True, "final_entity_count": 5, "final_entity_signature_count": 5, "final_label_counts": {"api_key": 2, "email": 1, "password": 1, "url": 1}, @@ -68,13 +67,12 @@ def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_p "record_type": "staged_detection_case", "case_id": "bio-row-0", "row_index": 0, - "seed_source": "rules_plus_direct_llm", + "seed_source": "direct_llm", "status": "completed", "elapsed_sec": 10.0, "model_elapsed_sec": 9.5, "model_phase_count": 3, "model_request_count": 3, - "rule_covered_label_set": False, "final_entity_count": 3, "final_entity_signature_count": 3, "final_label_counts": {"person": 2, "api_key": 1}, @@ -92,13 +90,12 @@ def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_p "record_type": "staged_detection_case", "case_id": "bio-row-1", "row_index": 1, - "seed_source": "rules_plus_direct_llm", + "seed_source": "direct_llm", "status": "error", "elapsed_sec": 1.0, "model_elapsed_sec": 0.8, "model_phase_count": 1, "model_request_count": 1, - "rule_covered_label_set": False, "final_entity_count": 0, "final_entity_signature_count": 0, "total_usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}, @@ -113,89 +110,32 @@ def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_p assert result.case_count == 3 assert result.group_count == 2 groups = {row.seed_source: row for row in result.groups} - assert groups["rules_router"].case_count == 1 - assert groups["rules_router"].completed_case_count == 1 - assert groups["rules_router"].model_elapsed_sec_sum == 0.0 - assert groups["rules_router"].model_request_count_sum == 0 - assert groups["rules_router"].rule_covered_case_count == 1 - assert groups["rules_router"].baseline_shared_signature_rate == 1.0 - assert groups["rules_router"].fast_lane_verdict == "review" - assert groups["rules_router"].flags == ["too_few_cases"] - assert groups["rules_plus_direct_llm"].case_count == 2 - assert groups["rules_plus_direct_llm"].completed_case_count == 1 - assert groups["rules_plus_direct_llm"].error_case_count == 1 - assert groups["rules_plus_direct_llm"].elapsed_sec_sum == pytest.approx(11.0) - assert groups["rules_plus_direct_llm"].model_elapsed_sec_sum == pytest.approx(10.3) - assert groups["rules_plus_direct_llm"].model_request_count_sum == 4 - assert groups["rules_plus_direct_llm"].total_tokens_sum == 132 - assert groups["rules_plus_direct_llm"].baseline_final_entity_signature_count_sum == 4 - assert groups["rules_plus_direct_llm"].shared_final_entity_signature_count_sum == 2 - assert groups["rules_plus_direct_llm"].baseline_only_final_entity_signature_count_sum == 2 - assert groups["rules_plus_direct_llm"].direct_only_final_entity_signature_count_sum == 1 - assert groups["rules_plus_direct_llm"].baseline_shared_signature_rate == pytest.approx(0.5) - assert groups["rules_plus_direct_llm"].fast_lane_verdict == "reject" - assert groups["rules_plus_direct_llm"].flags == [ - "too_few_cases", - "case_errors", - "baseline_signature_loss", - "uses_model", - "not_fully_rule_covered", - ] + assert groups["gliner"].case_count == 1 + assert groups["gliner"].completed_case_count == 1 + assert groups["gliner"].model_elapsed_sec_sum == 0.0 + assert groups["gliner"].model_request_count_sum == 0 + assert groups["gliner"].baseline_shared_signature_rate == 1.0 + assert groups["direct_llm"].case_count == 2 + assert groups["direct_llm"].completed_case_count == 1 + assert groups["direct_llm"].error_case_count == 1 + assert groups["direct_llm"].elapsed_sec_sum == pytest.approx(11.0) + assert groups["direct_llm"].model_elapsed_sec_sum == pytest.approx(10.3) + assert groups["direct_llm"].model_request_count_sum == 4 + assert groups["direct_llm"].total_tokens_sum == 132 + assert groups["direct_llm"].baseline_final_entity_signature_count_sum == 4 + assert groups["direct_llm"].shared_final_entity_signature_count_sum == 2 + assert groups["direct_llm"].baseline_only_final_entity_signature_count_sum == 2 + assert groups["direct_llm"].direct_only_final_entity_signature_count_sum == 1 + assert groups["direct_llm"].baseline_shared_signature_rate == pytest.approx(0.5) label_deltas = {(row.seed_source, row.delta_type, row.label): row.count for row in result.label_deltas} assert label_deltas == { - ("rules_plus_direct_llm", "baseline_only", "city"): 1, - ("rules_plus_direct_llm", "baseline_only", "person"): 1, - ("rules_plus_direct_llm", "direct_only", "api_key"): 1, + ("direct_llm", "baseline_only", "city"): 1, + ("direct_llm", "baseline_only", "person"): 1, + ("direct_llm", "direct_only", "api_key"): 1, } -def test_staged_detection_output_analysis_requires_repeated_cases_for_fast_lane(tmp_path: Path) -> None: - tool = load_tool( - "measurement_staged_detection_output_analysis_repeated_gate", - REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", - ) - output_dir = tmp_path / "staged" - output_dir.mkdir() - _write_jsonl( - output_dir / "staged-detection-cases.jsonl", - [ - { - "record_type": "staged_detection_case", - "case_id": f"shell-row-{index}", - "row_index": index, - "seed_source": "rules_router", - "status": "completed", - "elapsed_sec": 0.002, - "model_elapsed_sec": 0.0, - "model_phase_count": 0, - "model_request_count": 0, - "rule_covered_label_set": True, - "final_entity_count": 5, - "final_entity_signature_count": 5, - "total_usage": {}, - "comparison": { - "baseline_final_entity_signature_count": 5, - "shared_final_entity_signature_count": 5, - "baseline_only_final_entity_signature_count": 0, - "direct_only_final_entity_signature_count": 0, - "baseline_only_label_counts": {}, - "direct_only_label_counts": {}, - }, - } - for index in range(3) - ], - ) - - result = tool.analyze_staged_detection_output(output_dir) - - group = result.groups[0] - assert group.seed_source == "rules_router" - assert group.case_count == 3 - assert group.fast_lane_verdict == "fast_lane_candidate" - assert group.flags == [] - - def test_staged_detection_output_analysis_writes_csv_tables(tmp_path: Path) -> None: tool = load_tool( "measurement_staged_detection_output_analysis_export", @@ -208,7 +148,7 @@ def test_staged_detection_output_analysis_writes_csv_tables(tmp_path: Path) -> N { "case_id": "case-0", "row_index": 0, - "seed_source": "rules_router", + "seed_source": "gliner", "status": "completed", "elapsed_sec": 0.01, "model_elapsed_sec": 0.0, diff --git a/tests/tools/test_staged_detection_probe.py b/tests/tools/test_staged_detection_probe.py index c380cba2..29987f90 100644 --- a/tests/tools/test_staged_detection_probe.py +++ b/tests/tools/test_staged_detection_probe.py @@ -507,86 +507,6 @@ def test_staged_detection_augmentation_prompt_discourages_grouped_person_and_sur assert "also return the surname substring as last_name" in prompt -def test_staged_detection_can_seed_from_rules_without_llm_seed_prompt() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient( - tool, - [ - '{"decisions": [{"id": "email_6_23", "decision": "keep", "reason": "email address"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Email alice@example.com at NVIDIA.", - labels=["email", "organization_name"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.rules - assert result.phase_usage.seed == {} - assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=True, augmentation=True) - assert result.phase_skip_reasons.seed == "deterministic_rules" - assert result.phase_skip_reasons.validation is None - assert result.model_phase_count == 2 - assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=1, augmentation=1) - assert result.model_request_count == 2 - assert result.seed_suggestion_count == 1 - assert result.seed_entity_count == 1 - assert result.final_label_counts == {"email": 1, "organization_name": 1} - assert result.artifact.final_source_counts == {"augmenter": 1, "rule": 1} - assert result.total_usage == {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30} - assert len(llm_client.prompts) == 2 - - -def test_staged_detection_can_add_rules_to_direct_llm_seed_without_validating_rules() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_plus_direct_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient( - tool, - [ - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - '{"decisions": [{"id": "organization_name_27_33", "decision": "keep", "reason": "employer"}]}', - '{"entities": []}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Email alice@example.com at NVIDIA.", - labels=["email", "organization_name"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules_plus_direct_llm, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.rules_plus_direct_llm - assert result.seed_suggestion_count == 2 - assert result.seed_entity_count == 2 - assert result.validation_candidate_count == 1 - assert result.validation_decision_count == 1 - assert result.final_label_counts == {"email": 1, "organization_name": 1} - assert result.artifact.final_source_counts == {"direct_seed": 1, "rule": 1} - assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) - assert result.model_request_count == 3 - assert '"label":"email"' not in llm_client.prompts[1] - assert '"label":"organization_name"' in llm_client.prompts[1] - - def test_staged_detection_baseline_comparison_skips_rows_without_signature_hashes() -> None: tool = load_tool( "measurement_staged_detection_probe_missing_baseline_signatures", @@ -615,159 +535,6 @@ def test_staged_detection_baseline_comparison_skips_rows_without_signature_hashe assert compared.comparison is None -def test_staged_detection_can_trust_rules_without_validation_prompt() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_trusted_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient(tool, ['{"entities": []}']) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text=( - "$ docker run -e DATABASE_URL='postgres://app_user:fakeDbPass123!@db.example.test:5432/app' " - "-e API_KEY=ghp_FAKEtoken1234567890abcdef myapp:latest\nPassword: fakeLoginPass!" - ), - labels=["api_key", "password", "email", "url"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules_trusted, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.rules_trusted - assert result.phase_usage.seed == {} - assert result.phase_usage.validation == {} - assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=True) - assert result.phase_skip_reasons.seed == "deterministic_rules" - assert result.phase_skip_reasons.validation == "trusted_rules" - assert result.phase_skip_reasons.augmentation is None - assert result.model_phase_count == 1 - assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=1) - assert result.model_request_count == 1 - assert result.rule_covered_label_set is True - assert result.validation_decision_count == 3 - assert result.final_label_counts == {"api_key": 1, "password": 1, "url": 1} - assert result.artifact.final_source_counts == {"rule": 3} - assert result.total_usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} - assert len(llm_client.prompts) == 1 - - -def test_staged_detection_can_skip_augmentation_when_all_labels_are_rule_covered() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_trusted_no_augment", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient(tool, []) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Email alice@example.com", - labels=["email"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules_trusted, - skip_augmentation_when_rule_covered=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.phase_usage.augmentation == {} - assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=False) - assert result.phase_skip_reasons == tool.PhaseSkipReasons( - seed="deterministic_rules", - validation="trusted_rules", - augmentation="rule_covered_labels", - ) - assert result.model_phase_count == 0 - assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=0) - assert result.model_request_count == 0 - assert result.rule_covered_label_set is True - assert result.augmented_suggestion_count == 0 - assert result.final_label_counts == {"email": 1} - assert result.total_usage == {} - assert len(llm_client.prompts) == 0 - - -def test_staged_detection_rules_router_short_circuits_rule_covered_labels() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_router_short_circuit", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient(tool, []) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Email alice@example.com and token ghp_FAKEtoken1234567890abcdef", - labels=["email", "api_key"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules_router, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.rules_router - assert result.phase_model_work == tool.PhaseModelWork(seed=False, validation=False, augmentation=False) - assert result.phase_skip_reasons == tool.PhaseSkipReasons( - seed="deterministic_rules", - validation="trusted_rules", - augmentation="rule_covered_labels", - ) - assert result.model_phase_count == 0 - assert result.phase_model_requests == tool.PhaseModelRequests(seed=0, validation=0, augmentation=0) - assert result.model_request_count == 0 - assert result.elapsed_sec is not None and result.elapsed_sec > 0.0 - assert result.model_elapsed_sec == 0.0 - assert result.rule_covered_label_set is True - assert result.final_label_counts == {"api_key": 1, "email": 1} - assert result.artifact.final_source_counts == {"rule": 2} - assert result.total_usage == {} - assert len(llm_client.prompts) == 0 - - -def test_staged_detection_rules_router_uses_direct_seed_for_contextual_labels() -> None: - tool = load_tool( - "measurement_staged_detection_probe_rules_router_mixed_labels", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "person name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "person name"}]}', - '{"entities": []}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice emails alice@example.com.", - labels=["email", "first_name"], - row_index=0, - ), - client=llm_client, - seed_source=tool.SeedSource.rules_router, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.rules_router - assert result.rule_covered_label_set is False - assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=True) - assert result.phase_skip_reasons == tool.PhaseSkipReasons() - assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) - assert result.model_request_count == 3 - assert result.final_label_counts == {"email": 1, "first_name": 1} - assert result.artifact.final_source_counts == {"direct_seed": 1, "rule": 1} - assert '"label":"email"' not in llm_client.prompts[1] - assert '"label":"first_name"' in llm_client.prompts[1] - - def test_staged_detection_can_chunk_validation_into_local_excerpts() -> None: tool = load_tool( "measurement_staged_detection_probe_chunked_validation", diff --git a/tools/measurement/README.md b/tools/measurement/README.md index f0cc8ffe..69a017af 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -1,30 +1,31 @@ -# Measurement Tools +# Measurement tools -`export_measurements.py` converts Anonymizer measurement JSONL into one table -per `record_type`. - -Run these tools inside the project environment, either with an activated venv -or through `uv run`. +This directory contains developer tools for measuring Anonymizer runs, exporting +measurement JSONL to tables, and comparing benchmark strategies. Run the tools +inside the project environment, either with an activated venv or through +`uv run`. ```bash uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables ``` -By default it writes Parquet files plus `manifest.json`: +By default, `export_measurements.py` writes Parquet files plus +`manifest.json`: - `run.parquet` - `stage.parquet` - `record.parquet` -- `ndd_workflow.parquet` when adapter records are present -- `model_workflow.parquet` when non-DataDesigner model workflow records are - present +- `ndd_workflow.parquet` when DataDesigner adapter records are present +- `model_workflow.parquet` when direct model workflow records are present Use `--format csv` or `--format jsonl` for non-Parquet output, and `--overwrite` to replace existing output files. +## Benchmark runner + `run_benchmarks.py` runs repeatable Anonymizer workloads and writes the same measurement JSONL format, one raw file per benchmark case plus a combined `measurements.jsonl`. @@ -37,7 +38,7 @@ uv run python tools/measurement/run_benchmarks.py suite.yaml \ --dd-trace last-message ``` -To rerun the repo-data smoke suite with DataDesigner traces enabled: +The repo-data smoke suite can be run with DataDesigner traces enabled: ```bash bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh @@ -45,7 +46,7 @@ bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh The script writes to `/tmp/anonymizer-repo-data-smoke-dd-traces` by default. Pass a different output directory as the first argument, or set -`DD_TRACE_MODE=all-messages` when you need full chat history: +`DD_TRACE_MODE=all-messages` when full chat history is needed: ```bash DD_TRACE_MODE=all-messages \ @@ -63,7 +64,7 @@ Benchmark suites are YAML files with three parts: Example: ```yaml -suite_id: shell-and-biography-smoke +suite_id: biography-smoke model_configs: ./model-configs.yaml model_providers: ./providers.yaml dd_parser_compat: none @@ -109,47 +110,29 @@ for URL-like sources because the runner cannot safely materialize a local subset without downloading the whole dataset first. Relative paths in suite files are resolved from the suite file's directory. +The runner refuses to write into a non-empty output directory unless +`--overwrite` is set. By default it also exports Parquet tables into `tables/`; +pass `--no-export` when only raw measurement JSONL is needed. + +Before starting a real run, the benchmark runner performs cheap preflight +checks: suite/config parsing, local dataset existence, CSV/Parquet text-column +metadata, provider YAML shape, native runtime requirements, and active +model-alias references. `--dry-run` runs those same checks, expands the planned +matrix, and skips output-dir writes and model work. Use `case_retries` and `case_retry_backoff_sec` for long-running suites on shared model endpoints. Retries are disabled by default. When enabled, a failed case is retried with the same `case_id` and output paths; the final case still records `attempt_count` and `attempt_errors` in `summary.json`. `--fail-fast` -remains fail-fast and bypasses retries. Treat retried cases as reliability -signals during analysis, especially when failures come from provider health -checks or rate limits. +remains fail-fast and bypasses retries. -Configs may also set `experimental_detection_strategy` for benchmark-only -pipeline probes: +## Benchmark-only detection strategies -```yaml -configs: - - id: shell-rules-only - experimental_detection_strategy: rules_only - detect: - entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] - replace: - strategy: hash - digest_length: 12 -``` - -Native benchmark strategies require an explicit runtime. Set top-level -`native_runtime.endpoint` and `native_runtime.model`, set the standard -`ANONYMIZER_BENCH_NATIVE_ENDPOINT` and `ANONYMIZER_BENCH_NATIVE_MODEL` -environment variables, or override runtime fields per config with -`configs[].native_runtime`. GLiNER-seeded native strategies also require -`native_runtime.gliner_endpoint` and `native_runtime.gliner_model`, or the -standard `ANONYMIZER_BENCH_GLINER_ENDPOINT` and -`ANONYMIZER_BENCH_GLINER_MODEL` environment variables. The runner records -runtime id, alias, provider, model, and env-variable names as run tags; raw -endpoint URLs are not emitted into measurement tables. +Configs may set `experimental_detection_strategy` for benchmark-only pipeline +probes. These values are not public `Detect` config fields, and they should not +be treated as safe defaults across arbitrary data. ```yaml -native_runtime: - runtime_id: local-vllm-json - endpoint_env: ANONYMIZER_BENCH_NATIVE_ENDPOINT - model_env: ANONYMIZER_BENCH_NATIVE_MODEL - provider: local-vllm - alias: native-direct configs: - id: native-single-pass experimental_detection_strategy: native_single_pass @@ -159,71 +142,29 @@ configs: Supported values: - `default`: run the normal Anonymizer detection pipeline. -- `rules_guardrail`: run the normal Anonymizer detection pipeline, then union - deterministic high-confidence rule spans into the final entity set. -- `rules_filter_guardrail`: remove GLiNER candidates that are fully covered by - same-label deterministic high-confidence rule spans before validation, add - non-overlapping rule spans back before augmentation so the augmenter sees them - as already tagged, then add non-overlapping rule spans into the final entity - set. Different-label overlaps and longer detector spans remain validation - candidates so contextual spans such as a multi-token political view, - university, or organization name are not shadowed by a shorter or differently - labeled rule span. - `no_augment`: run GLiNER detection and validation, but skip LLM augmentation. -- `rules_seed_no_augment`: add deterministic high-confidence secret spans to - the GLiNER seed set, validate those seeds, and skip LLM augmentation. -- `rules_guardrail_no_augment`: run GLiNER detection and validation, skip LLM - augmentation, then union deterministic high-confidence rule spans into the - final entity set. -- `rules_filter_guardrail_no_augment`: remove GLiNER candidates that are fully - covered by same-label deterministic high-confidence rule spans before - validation, skip LLM augmentation, then add non-overlapping rule spans into - the final entity set. -- `rules_guardrail_detector_only`: run only GLiNER detection and local - finalization, then union deterministic high-confidence rule spans into the - final entity set. - `detector_only`: run only GLiNER detection and local finalization. This skips LLM validation and LLM augmentation. -- `rules_only`: use only deterministic high-confidence rules for the detection - stage. -- `rules_covered_or_default`: if explicit `detect.entity_labels` are entirely - inside the structured-secret fast lane (`api_key`, `email`, `http_cookie`, - `password`, `pin`, `unique_id`, `url`, `user_name`), use deterministic rules - for rows whose structured assignments are covered and route suspicious - uncovered rows through the normal Anonymizer detection pipeline. Label sets - outside the fast lane always use normal detection. -- `native_rules_router`: run a benchmark-only native staged detector without - DataDesigner. Rule-covered label sets short-circuit through deterministic - rules with no model calls; other label sets use direct OpenAI-compatible - provider calls for seed extraction, validation, and augmentation. - `native_candidate_validate_no_augment`: run a benchmark-only native staged detector without DataDesigner using direct OpenAI-compatible calls for seed - extraction and validation, then skip augmentation. This isolates the cost and - recall impact of removing the augmentation phase from the native executor. + extraction and validation, then skip augmentation. - `detector_native_validate_no_augment`: run the normal GLiNER detector seed through Anonymizer/DataDesigner, then bypass DataDesigner validation and - augmentation with direct OpenAI-compatible validation calls. This isolates - whether native validation can replace DataDesigner validation when candidate - quality is held closer to the default detector path. -- `detector_native_validate_native_augment`: run the normal GLiNER detector - seed through Anonymizer/DataDesigner, then bypass DataDesigner validation and + augmentation with direct OpenAI-compatible validation calls. +- `detector_native_validate_native_augment`: run the normal GLiNER detector seed + through Anonymizer/DataDesigner, then bypass DataDesigner validation and augmentation with direct OpenAI-compatible validation and augmentation calls. - This keeps the default detector candidate source while testing whether direct - provider calls can replace the two downstream DataDesigner LLM phases. - `gliner_native_validate_no_augment`: run a direct hosted-GLiNER seed without DataDesigner, validate those detector candidates with direct - OpenAI-compatible calls, and skip augmentation. This isolates DataDesigner - detector orchestration overhead while keeping a GLiNER-style candidate source. + OpenAI-compatible calls, and skip augmentation. - `gliner_native_validate_native_augment`: run a direct hosted-GLiNER seed without DataDesigner, validate those detector candidates with direct - OpenAI-compatible calls, then run direct native augmentation. This is the - fully staged no-DataDesigner detector/validator/augmenter lane for contextual - recall experiments. + OpenAI-compatible calls, then run direct native augmentation. - `native_single_pass`: run a benchmark-only native detector without DataDesigner using one direct OpenAI-compatible provider call per row. The - model must return exact values plus `start`/`end` offsets; local code validates - offsets, unions non-overlapping deterministic rule spans, resolves overlaps, - and records parser/runtime failures as `model_workflow` errors. + model must return exact values plus `start`/`end` offsets; local code + validates offsets, resolves overlaps, and records parser/runtime failures as + `model_workflow` errors. - `native_single_pass_recall`: the same one-call native detector with a recall-oriented prompt that includes Anonymizer's label examples and stronger high-recall guidance. @@ -234,88 +175,52 @@ Supported values: - `native_single_pass_values_recall`: the value-only one-call detector with the recall-oriented prompt from `direct_detection_probe.py`. -These strategies exist to compare performance options. They are not public -`Detect` config fields, and they should not be treated as safe defaults across -arbitrary data. The rule-backed strategies only cover deterministic -high-confidence spans for `api_key`, `date_of_birth`, `email`, -`http_cookie`, `organization_name`, `password`, `pin`, `religious_belief`, -`street_address`, `unique_id`, `url`, and `user_name`; they will not replace -contextual detection for prose identifiers such as names in biographies or -legal documents. The prose rules (`date_of_birth`, `organization_name`, -`religious_belief`, and `street_address`) are narrow contextual patterns and -are not enough to opt into `rules_covered_or_default`; those labels fall back -to default detection unless `rules_only` is explicitly selected. The structured -identifier rules require keyed or command-style syntax such as -`Cookie:`, `pin=`, `trace-id:`, `user_name=`, or service-principal flags. They -are not general entity recognizers. `detector_only` is also unsafe as a default -because it skips the LLM validation pass that drops false positives and -reclassifies ambiguous spans. `rules_only` requires explicit `entity_labels`, -and every label must be covered by those deterministic rules. Use -`rules_covered_or_default` when a benchmark suite may include both fully -structured-secret scans and contextual workloads; it keeps the no-DataDesigner -short-circuit for the former and falls back to the default pipeline for prose -or legal labels. - -Use `native_rules_router` when you want the same routing shape without -DataDesigner orchestration. It uses the resolved native runtime endpoint/model -from `native_runtime` or the standard benchmark runtime environment variables. -Treat it as a native-executor prototype: it can prove that DataDesigner overhead -is avoidable, but it must be compared against baseline signatures and -original-value leak metrics before any workload-specific promotion decision. - -Use `native_candidate_validate_no_augment` when you want a narrower native -executor diagnostic: direct seed candidates plus direct validation, with no -augmentation. It is useful for proving how much speed comes from removing a -phase, but a faster run that loses baseline signatures is still a rejection. - -Use `detector_native_validate_no_augment` when you want to keep the production -detector seed while testing a direct-provider validation path. It is not a -no-DataDesigner strategy because the detector seed still runs through the -adapter, but it tells you whether DataDesigner validation/augmentation is the -load-bearing part of a workload. The native validation shim preserves -`date_of_birth` over broader `date` reclassifications only when the local -candidate context contains birth/DOB language; generic filing or event dates can -still be reclassified to `date`. - -Use `detector_native_validate_native_augment` for the same detector-seed -question when augmentation recall is expected to be load-bearing. This arm still -uses DataDesigner for the detector seed, but direct provider calls own both -validation and augmentation. - -Use `gliner_native_validate_no_augment` or -`gliner_native_validate_native_augment` when the question is specifically -"what if GLiNER did not run through DataDesigner?" These strategies use the -staged direct executor's GLiNER seed client using -`native_runtime.gliner_endpoint`, `native_runtime.gliner_model`, or the standard -GLiNER runtime environment variables; the API key env var defaults to -`NVIDIA_API_KEY`. The no-augmentation arm is a lower-cost boundary; the -native-augmentation arm is the quality-oriented no-DataDesigner candidate. The -integrated benchmark strategies execute staged direct rows with bounded -parallelism so GLiNER and native validation/augmentation latency is not -serialized across records. These arms also normalize direct GLiNER `date` seeds -to `date_of_birth` only when the local seed context contains birth/DOB language. -Generic filing or event dates remain `date`. Both arms still need repeated -signature, leak, label-mismatch, and reliability gates before any -workload-specific promotion. - -Use `native_single_pass`, `native_single_pass_recall`, -`native_single_pass_values`, or `native_single_pass_values_recall` for the more -aggressive "collapse detection to one call" experiment. The first pair asks the -model for `start`/`end` offsets and validates them before falling back to exact -value matching. The value-only pair uses the standalone direct-probe prompt and -lets local code recover spans from exact returned values. Recall variants spend -more prompt tokens on label examples and high-recall guidance. All one-call -variants are expected to be faster than staged native detection when the prompt -works, but they are also more parser- and recall-sensitive. Any malformed JSON -response becomes a failed case in analysis, and any missed baseline signature -should be treated as a rejection rather than a latency win. +Native benchmark strategies require an explicit runtime. Set top-level +`native_runtime.endpoint` and `native_runtime.model`, set the standard +`ANONYMIZER_BENCH_NATIVE_ENDPOINT` and `ANONYMIZER_BENCH_NATIVE_MODEL` +environment variables, or override runtime fields per config with +`configs[].native_runtime`. GLiNER-seeded native strategies also require +`native_runtime.gliner_endpoint` and `native_runtime.gliner_model`, or the +standard `ANONYMIZER_BENCH_GLINER_ENDPOINT` and +`ANONYMIZER_BENCH_GLINER_MODEL` environment variables. The runner records +runtime id, alias, provider, model, and env-variable names as run tags; raw +endpoint URLs are not emitted into measurement tables. + +```yaml +native_runtime: + runtime_id: local-vllm-json + endpoint_env: ANONYMIZER_BENCH_NATIVE_ENDPOINT + model_env: ANONYMIZER_BENCH_NATIVE_MODEL + provider: local-vllm + alias: native-direct +configs: + - id: native-single-pass + experimental_detection_strategy: native_single_pass + replace: redact +``` + +Use `detector_only` only as a lower-bound ablation. It skips the LLM validation +pass that drops false positives and reclassifies ambiguous spans. A faster run +that loses baseline signatures is a rejection. + +Use staged native strategies when the question is "can direct provider calls +replace part of DataDesigner orchestration?" They still need repeated signature, +leak, label-mismatch, parser, and reliability gates before any workload-specific +promotion. + +Use one-call native strategies for the more aggressive "collapse detection to +one call" experiment. They are often faster when the prompt works, but they are +more parser- and recall-sensitive. Any malformed JSON response becomes a failed +case in analysis, and any missed baseline signature should be treated as a +rejection rather than a latency win. + +## Replacement strategy probes Replacement-map generation has a separate benchmark-only knob: ```yaml configs: - id: structured-local-substitute - experimental_detection_strategy: rules_covered_or_default experimental_replacement_strategy: local_structured_substitute detect: entity_labels: [api_key, email, password, url] @@ -340,14 +245,7 @@ is deliberate. The local substitute map generator does not understand names, social relations, cultural consistency, or prose semantics; use the default DataDesigner-backed `Substitute` path for those workloads. -The runner refuses to write into a non-empty output directory unless -`--overwrite` is set. By default it also exports Parquet tables into -`tables/`; pass `--no-export` when you only want the raw measurement JSONL. -Before starting a real run, the benchmark runner performs cheap preflight -checks: suite/config parsing, local dataset existence, CSV/Parquet text-column -metadata, provider YAML shape, native runtime requirements, and active -model-alias references. `--dry-run` runs those same checks, expands the planned -matrix, and skips output-dir writes and model work. +## DataDesigner traces For debugging DataDesigner calls, pass `--dd-trace last-message` or `--dd-trace all-messages`. Trace records are written separately from sanitized @@ -360,8 +258,8 @@ values, replacement values, secrets, and PII. Treat them as debug artifacts: keep them out of shared benchmark bundles unless they have been reviewed or redacted. -To summarize traced calls without copying raw prompts or responses into the -analysis output, run: +Summarize traced calls without copying raw prompts or responses into analysis +output: ```bash uv run python tools/measurement/analyze_dd_traces.py \ @@ -373,205 +271,65 @@ uv run python tools/measurement/analyze_dd_traces.py \ This writes `trace_analysis.*` and `trace_group_analysis.*`. The row table captures run tags, workflow/model metadata, status, elapsed time, prompt and response lengths, token counts, and response-shape flags such as `raw_json`, -`fenced_json`, `embedded_json`, `text`, and `none`. The grouped table rolls those -fields up by workload, config, workflow, model, provider, status, error type, -and response shape. Use this when diagnosing local provider behavior, parser -compatibility, unexpected thinking text, or retry-heavy workflows. +`fenced_json`, `embedded_json`, `text`, and `none`. -Some OpenAI-compatible local endpoints return raw JSON when their model config -uses `response_format: {type: json_object}`. DataDesigner structured recipes -currently prompt for markdown-fenced JSON, so those raw JSON responses can be -valid but still fail parsing. Set top-level `dd_parser_compat: raw_json` when a -benchmark suite needs this provider compatibility mode: +## Direct probes -```yaml -dd_parser_compat: raw_json -``` - -This is benchmark-only behavior. The runner patches DataDesigner structured -parser builders for the duration of a case, restores them afterward, and records -the mode in `run_tags.dd_parser_compat`. The fallback accepts either pure raw -JSON or a JSON object/array embedded after model reasoning text, then still -validates the extracted object against the requested schema. Keep the default -`none` unless a local provider or vLLM endpoint needs raw-JSON structured-output -compatibility. - -## DD-Free Direct Detection Probe - -Use `direct_detection_probe.py` to test a deliberately DD-free extraction path -against an OpenAI-compatible endpoint. This is a benchmark-only diagnostic: it -does not call DataDesigner, does not run GLiNER, and does not execute the -production detection graph. It sends one direct chat-completions request per -input row, then reuses Anonymizer's existing span postprocessing, occurrence -expansion, overlap resolution, and entity signature logic so results can be -compared against normal detection artifacts. - -Pass `--endpoint` and `--model`, or set `ANONYMIZER_BENCH_NATIVE_ENDPOINT` and -`ANONYMIZER_BENCH_NATIVE_MODEL`. - -Example biography probe: +`direct_detection_probe.py` calls a local OpenAI-compatible endpoint directly +for a small slice of records. It is useful for prompt experiments before adding +a benchmark strategy. ```bash uv run python tools/measurement/direct_detection_probe.py \ docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column biography \ - --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ - --endpoint http://your-openai-compatible-endpoint/v1 \ - --model your-model-id \ - --baseline-artifacts "$BASELINE_ARTIFACTS" \ - --output /tmp/direct-detection-probe-biography \ - --overwrite \ - --json -``` - -Example legal probe: - -```bash -uv run python tools/measurement/direct_detection_probe.py \ - docs/data/TAB_legal_sample25.csv \ --text-column text \ - --labels application_number,city,country,date,date_of_birth,nationality,person \ - --endpoint http://your-openai-compatible-endpoint/v1 \ - --model your-model-id \ - --baseline-artifacts "$BASELINE_ARTIFACTS" \ - --output /tmp/direct-detection-probe-legal \ - --overwrite \ - --json + --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ + --model nvidia/nemotron-3-super \ + --labels person,email,api_key,password \ + --row-limit 5 \ + --output /tmp/direct-detection-probe ``` -The tool writes `direct-detection-cases.jsonl`, -`direct-detection-artifacts.jsonl`, and `summary.json`. Case rows include model -usage, elapsed time, raw/allowed suggestion counts, final label counts, final -signature hashes, and optional baseline comparison counts. Artifact rows use -the same opaque signature fields as `analyze_detection_artifacts.py` and omit -raw entity values. For baseline comparison, pass a per-case sidecar or another -artifact file with one row per `row_index`; duplicate row indexes are rejected -to avoid ambiguous comparisons. Treat the probe `summary.json` as a sensitive -debug artifact because it records the resolved endpoint/model runtime used for -the probe. -so a combined multi-case artifact cannot silently select the wrong baseline. - -When this probe shape is promising, move it into a normal benchmark suite with -`experimental_detection_strategy: native_single_pass_values` or -`native_single_pass_values_recall`. Those strategies use the same value-only -prompt family but run through `run_benchmarks.py`, measurement collection, case -retries, artifact capture, and pairwise strategy comparison. - -Interpret this probe as a lower-friction model-call experiment, not a safe -replacement for detection. A local one-row smoke against -`nvidia/nemotron-3-super` with vLLM JSON mode and thinking disabled produced: - -- Biography: 4.1s, 906 total tokens, 19 final signatures, 18/22 baseline - signatures shared; misses included `field_of_study` and `place_name`. -- Legal: 4.9s, 1,308 total tokens, 21 final signatures, 19/22 baseline - signatures shared; misses included `date`, `date_of_birth`, and - `nationality`. - -That result makes a DD-free native executor worth exploring, but only if it -preserves the production safety decomposition (`GLiNER/rules -> validate -> -augment -> finalize`). The one-shot direct prompt is useful as a speed/quality -boundary, not as a production candidate. - -## DD-Free Staged Detection Probe - -Use `staged_detection_probe.py` to test a more conservative DD-free route. This -probe still avoids DataDesigner, but it does not collapse detection into one -model response. It can run direct LLM seed extraction, direct GLiNER seeding, -deterministic rule seeding, trusted deterministic rule seeding, or rule-routed -DD-free execution. It then runs direct validation and direct augmentation unless -trusted rules or the rule router short-circuit are selected, where rule spans -bypass validation. It reuses Anonymizer's existing row-level postprocessing -helpers for validation application, augmentation merge, occurrence expansion, -overlap resolution, and artifact signatures. - -Example biography probe: +`staged_detection_probe.py` runs a no-DataDesigner staged detector outside the +main benchmark harness. It can compare seed extraction, validation, and +augmentation boundaries before integrating a strategy into `run_benchmarks.py`. ```bash uv run python tools/measurement/staged_detection_probe.py \ docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column biography \ - --labels age,city,company_name,degree,education_level,field_of_study,first_name,language,last_name,occupation,organization_name,place_name,political_view,race_ethnicity,religious_belief,state,university \ - --endpoint http://your-openai-compatible-endpoint/v1 \ - --model your-model-id \ - --baseline-artifacts "$BASELINE_ARTIFACTS" \ - --output /tmp/staged-detection-probe-biography \ - --overwrite \ - --json + --text-column text \ + --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ + --model nvidia/nemotron-3-super \ + --labels person,email,api_key,password \ + --row-limit 5 \ + --output /tmp/staged-detection-probe ``` -Example legal probe: +Useful staged options: -```bash -uv run python tools/measurement/staged_detection_probe.py \ - docs/data/TAB_legal_sample25.csv \ - --text-column text \ - --labels application_number,city,country,date,date_of_birth,nationality,person \ - --endpoint http://your-openai-compatible-endpoint/v1 \ - --model your-model-id \ - --baseline-artifacts "$BASELINE_ARTIFACTS" \ - --output /tmp/staged-detection-probe-legal \ - --overwrite \ - --json -``` +- `--seed-source direct-llm`: use direct LLM seed extraction. +- `--seed-source gliner`: use direct hosted GLiNER seeding. +- `--skip-augmentation`: disable augmentation for any seed source. This is an + ablation for measuring how much recall the augmentation phase carries. +- `--validation-prompt-mode chunked-excerpt`: split seed validation candidates + into chunks of `--validation-max-entities-per-call` and send each chunk with a + tagged local excerpt bounded by `--validation-excerpt-window-chars`. -To replace the LLM seed phase with a direct GLiNER call, add -`--seed-source gliner` plus `--gliner-endpoint` and `--gliner-model`, or set -`ANONYMIZER_BENCH_GLINER_ENDPOINT` and `ANONYMIZER_BENCH_GLINER_MODEL`. The -probe reads the GLiNER API key from `--gliner-api-key-env`, which defaults to -`NVIDIA_API_KEY`. - -To replace the LLM seed phase with deterministic local rules, add -`--seed-source rules`. This still sends rule candidates through the validator. -Use `--seed-source rules-trusted` to bypass validation for high-confidence rule -spans and run only augmentation afterward. The trusted mode is a diagnostic for -rule-covered workloads; it is not a general prose/legal safety default. -Use `--seed-source rules-plus-direct-llm` to add deterministic rule spans to -direct LLM seed spans while validating only the direct LLM seed candidates. This -tests a mixed native path where obvious structured secrets are trusted locally -without giving up contextual model seeding for the rest of the record. -Use `--seed-source rules-router` to make that split explicit: if every requested -label is supported by deterministic rules, the probe runs trusted local rules -with no model calls; otherwise it falls back to `rules-plus-direct-llm`. -When the requested labels are all covered by deterministic rules, add -`--skip-augmentation-when-rule-covered` to measure a fully local short-circuit -with no model calls. -Use `--skip-augmentation` to disable augmentation for any seed source. This is -only a diagnostic for measuring how much recall the augmentation phase carries; -signature loss should reject the candidate even when latency improves. - -To test whether direct validation can preserve the phase boundary with less -prompt text, add `--validation-prompt-mode chunked-excerpt`. This splits seed -validation candidates into chunks of `--validation-max-entities-per-call` and -sends each chunk with a tagged local excerpt bounded by -`--validation-excerpt-window-chars`. The default remains `full-text`, which -keeps the prior one-call behavior. Treat this as a request-count/token tradeoff: -chunked excerpts can reduce prompt payload, but they also create more validator -requests and can remove context needed for labels such as legal roles, -education, demographics, or prose locations. - -The tool writes `staged-detection-cases.jsonl`, +The staged tool writes `staged-detection-cases.jsonl`, `staged-detection-artifacts.jsonl`, and `summary.json`. Case rows include per-phase usage for seed extraction, validation, and augmentation, true case -wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, plus +wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, `phase_model_work`, `phase_skip_reasons`, `phase_model_requests`, `model_phase_count`, `model_request_count`, total usage, and optional baseline -signature deltas. Use these fields to distinguish local work, provider latency, -and a provider that returned no token accounting. -Treat the staged probe `summary.json` as a sensitive debug artifact because it +signature deltas. Treat `summary.json` as a sensitive debug artifact because it records the resolved endpoint/model runtime used for the probe. -For example, a fully local rule-covered run should show `model_phase_count: 0`, -`model_request_count: 0`, `rule_covered_label_set: true`, and -`phase_skip_reasons.augmentation: "rule_covered_labels"`; `elapsed_sec` should -still capture the local rule/postprocess wall time while `model_elapsed_sec` -remains `0.0`. A chunked-excerpt validation run should usually keep -`model_phase_count` unchanged while raising `phase_model_requests.validation`. -To summarize those staged probe outputs without hand-written `jq`, run: +Summarize staged probe outputs: ```bash uv run python tools/measurement/analyze_staged_detection_output.py \ - /tmp/staged-detection-probe-biography \ - --output /tmp/staged-detection-probe-biography/analysis \ + /tmp/staged-detection-probe \ + --output /tmp/staged-detection-probe/analysis \ --format csv ``` @@ -580,219 +338,13 @@ The analyzer accepts either the staged output directory or the source group, and label-delta tables. Use `group_analysis.csv` for latency, token, request, and signature-overlap totals; use `label_delta_analysis.csv` to see which labels account for baseline-only misses or direct-only additions. The -analysis tables still omit raw text and raw entity values. - -The grouped table also includes a conservative `fast_lane_verdict`: - -- `fast_lane_candidate`: every case completed, every case was fully - rule-covered, the seed-source group has at least three cases, model requests - were zero, and baseline comparison found no missing signatures. -- `reject`: at least one case errored or the candidate lost any baseline - signature. -- `review`: baseline comparison is missing, fewer than three cases were - analyzed, the candidate still used model calls, or not every case was fully - rule-covered. - -Use `fast_lane_candidate` only as a workload-scoped promotion signal. It does -not prove that the same no-DataDesigner path is safe for prose/legal labels or -for data shapes outside the sampled suite. - -A refreshed local one-row smoke against `nvidia/nemotron-3-super` with vLLM JSON -mode and thinking disabled produced: - -- Biography: 13.7s, 4,550 total tokens, 24 final signatures, 20/22 baseline - signatures shared. The staged path recovered two signatures missed by the - one-shot direct probe, but still missed an `age` and a `place_name` signature - and added four direct-only signatures. -- Legal: 17.5s, 6,425 total tokens, 21 final signatures, 19/22 baseline - signatures shared. This did not improve signature overlap over the one-shot - direct probe and was materially slower. - -A direct hosted GLiNER seed smoke reached NVIDIA's endpoint but failed before -local validation with `DEGRADED function cannot be invoked` for -`nvidia/gliner-pii`. Keep the `--seed-source gliner` mode as a native executor -option, but do not treat hosted GLiNER availability as stable for local -performance conclusions. - -Rules seeding changed the tradeoff. On biography row 0, `rules` took 6.1s and -1,565 tokens but shared only 17/22 baseline signatures; `rules-trusted` took -5.2s and 1,019 tokens and shared 18/22. On legal row 0, `rules` took 7.1s and -2,213 tokens with 20/22 shared signatures; `rules-trusted` took 6.4s and 1,431 -tokens with the same 20/22 shared signatures. On the three-row shell-secrets -slice, `rules` exposed a validation regression: the validator reclassified a -database URL as a password, leaving row 1 with 2/3 shared baseline signatures. -`rules-trusted` preserved all shell baseline signatures and reduced each row to -one augmentation call, but that no-op augmentation still consumed 398-533 tokens -per row. With `--skip-augmentation-when-rule-covered`, the same trusted-rules -shell run preserved all 12 baseline signatures with zero model usage. Use this -as evidence for a native executor with rule-covered short circuiting, not as -evidence that trusted rules are safe for arbitrary text. - -Interpret this as evidence for native orchestration, not as a ready strategy. -The staged shape is closer to Anonymizer's safety model than one-shot -extraction, but the naive direct prompts spend too many tokens. The next useful -experiment is a native executor that preserves the same phase boundaries while -using compact production-equivalent prompts, direct provider clients, and a -cheap deterministic or detector-backed seed phase instead of LLM-seeded -extraction. - -## No-DataDesigner Strategy Pivot - -The strongest current performance signal comes from not invoking -DataDesigner at all for records whose requested labels and text shape are -covered by deterministic structured-secret extractors. On a local shell/structured-secret slice, -the staged `rules-router` path preserved every compared baseline signature with -zero model requests and millisecond-level elapsed time. In full Anonymizer -benchmarks, `rules_covered_or_default` plus `local_structured_substitute` -reduced structured substitute workloads by 38-99% wall time and removed most or -all observed model tokens, depending on whether the run still fell back to -default detection. - -The benchmark harness now has several integrated native strategies for that -next experiment. `native_rules_router` reuses the staged DD-free executor inside -Anonymizer's detection workflow, so benchmark cases still exercise the normal -replacement and measurement plumbing. `native_candidate_validate_no_augment` -removes augmentation to isolate the recall cost of that phase. -`detector_native_validate_no_augment` keeps the default detector seed and -switches only validation to direct provider calls. `native_single_pass` is the -more radical variant: it asks the local provider for all spans in one JSON -response and then lets Anonymizer validate offsets and finalize entities -locally. Use these arms to compare native provider calls against the -DataDesigner-backed `default` strategy on the same workloads. - -Treat that as a workload router, not a global replacement. The same DD-free -direct LLM approach on biography and legal prose still lost roughly a quarter -to a third of baseline signatures in repeated local probes, even though it -avoided DataDesigner. That is not an anonymization-safe trade by itself. The -current evidence points to three separate lanes: - -- **Structured fast lane:** if the explicit labels are all deterministic-rule - labels and rule extraction covers the workload, skip DataDesigner, skip model - calls, and use local redact/hash/substitute. This is the most promising path - for shell history, secrets, config files, audit logs, and similarly keyed - records. -- **Native model lane:** for prose or mixed records, preserve the production - detection decomposition but call providers directly: seed, validate, augment, - finalize. The prototype exists as `staged_detection_probe.py`, and the - benchmark harness includes detector-seeded and native-seeded variants, but - their current prompts are still research prompts and are too lossy/costly to - promote. -- **Single-pass model lane:** for a sharper boundary test, collapse prose or - mixed detection into one direct JSON span extraction call. This only becomes - interesting if it preserves baseline signatures; parser errors, invalid - offsets, or missed signatures should send the workload back to the default - pipeline. -- **Safety fallback:** route unsupported labels, uncertain text shapes, direct - parser failures, and signature-loss evidence back to the normal - DataDesigner-backed pipeline until a native executor proves equal or better - recall on repeated workload-specific comparisons. - -This changes the performance strategy from "make every DataDesigner phase -faster" to "avoid DataDesigner when the safety case is trivial, and use -DataDesigner as the fallback for hard cases." The benchmark interpretation -should therefore privilege signature coverage, original-value leak checks, -source provenance, and reliability flags over raw latency wins. A no-DD result -that is faster but loses baseline signatures remains a reject; a no-DD result -that is fully rule-covered, leak-free, and stable across repetitions is a -candidate for a production fast lane. - -## Output Layout - -A benchmark run writes one raw measurement file per case, then combines them: - -```text -benchmark-runs/suite-id/ - raw/ - inputs/ - biographies__redact-default__r000.csv - biographies__redact-default__r000.jsonl - biographies__redact-default__r000.detection-artifacts.jsonl - support__hash-agent-labels__r000.jsonl - artifacts/ - biographies__redact-default__r000/ - traces/ - biographies__redact-default__r000.jsonl - measurements.jsonl - summary.json - detection-artifacts.jsonl - tables/ - manifest.json - run.parquet - stage.parquet - record.parquet - ndd_workflow.parquet -``` - -Raw per-case JSONL files are streamed as measurement events are recorded, so a -long run leaves inspectable partial output before the case exits. The combined -`measurements.jsonl` is written after the completed and errored case files are -collected. - -Use `summary.json` to inspect case status, retry attempts, and errors. If a -case succeeds after retry, the combined `measurements.jsonl` contains the final -successful attempt while `summary.json` preserves the earlier failure messages. -Use `measurements.jsonl` when you need the original structured records. Use -`tables/` for analysis. -Use `traces/` only when `--dd-trace` was enabled and you need raw -DataDesigner message-level debugging. - -Treat `summary.json`, `raw/inputs/`, `artifacts/`, -`raw/*.detection-artifacts.jsonl`, and `traces/` as sensitive outputs. They can -contain source text, entity values, replacement values, prompts, model -responses, exception messages, or other PII-bearing debug data. The exported -measurement tables and detection signature ids are designed for analysis -without raw values, but debug sidecars are not sanitized bundles. - -Detection workflow artifacts can be analyzed separately when you need to know -whether augmentation helped or only added cost. `run_benchmarks.py` writes -`detection-artifacts.jsonl` automatically when export is enabled and detection -artifacts are present. The automatic export analyzes each case immediately after -it runs, then combines per-case sidecars from `raw/`; rows include `suite_id`, -`workload_id`, `config_id`, `repetition`, `case_id`, and `run_id` so they can be -joined to `measurements.jsonl` and exported tables. `rules_only` cases do not -produce DataDesigner parquet artifacts, so the runner writes a synthetic -rules-only sidecar from the same deterministic rules. That sidecar includes -counts, source=`rule`, and opaque entity signatures, but not raw entity values. -Routed strategies whose final entity set can differ from raw DataDesigner -artifacts, including row-aware `rules_covered_or_default`, write sidecars from -the final trace dataframe so rule-routed and fallback-routed rows are both -represented. - -Row-aware routed strategies also emit sanitized route telemetry into -`measurements.jsonl`, and `analyze_benchmark_output.py` surfaces it in -`case_analysis.*` and `group_analysis.*`. Use `route_total_row_count`, -`route_rule_row_count`, and `route_fallback_row_count` to confirm how many rows -used the zero-model rules lane versus the normal detection fallback before -interpreting request, token, or latency deltas. -You can also run the analyzer by hand against an artifact directory: - -```bash -uv run python tools/measurement/analyze_detection_artifacts.py \ - benchmark-runs/suite-id/artifacts \ - --output benchmark-runs/suite-id/detection-artifacts.jsonl -``` - -The analyzer reads `entity-detection*` parquet artifacts and emits one row per -artifact row. It reports seed, augmentation, and final entity counts; duplicate -augmentation suggestions; new augmented values that survived into final -entities; final label/source counts; and weak `api_key` shape warnings. The -output intentionally omits raw entity values. +analysis tables omit raw text and raw entity values. -Use this alongside the exported measurement tables when comparing -`default` against `no_augment`: +## Benchmark analysis -- High `augmented_duplicate_seed_value_count` with low - `augmented_new_final_value_count` means augmentation probably added cost - without improving that case. -- High `augmented_new_final_value_count` means augmentation found spans that - the detector+validator path missed. -- High `weak_api_key_shape_count` usually means the label set is mismatched to - the workload. For example, legal prose constrained to - `[person, email, api_key, password]` can force dates or case identifiers into - `api_key` because better prose labels are unavailable. - -For a ready-made case and grouped summary that joins `measurements.jsonl` with -`detection-artifacts.jsonl`, use: +`analyze_benchmark_output.py` joins `measurements.jsonl`, optional +DataDesigner traces, and detection artifact sidecars into richer case/group +tables: ```bash uv run python tools/measurement/analyze_benchmark_output.py \ @@ -801,1207 +353,159 @@ uv run python tools/measurement/analyze_benchmark_output.py \ --format csv ``` -By default this joins `benchmark-runs/suite-id/measurements.jsonl` with -`benchmark-runs/suite-id/detection-artifacts.jsonl`. To use a refreshed or -relocated sidecar that still contains benchmark case metadata, pass it -explicitly: - -```bash -uv run python tools/measurement/analyze_benchmark_output.py \ - benchmark-runs/suite-id \ - --detection-artifacts benchmark-runs/suite-id/current-analysis/detection-artifacts.jsonl \ - --output benchmark-runs/suite-id/current-analysis \ - --format csv -``` +Important outputs: -The override sidecar must include `case_id` or `run_id` values that match the -measurement rows. A raw artifact scan produced from only the DataDesigner -parquet directory can summarize detection artifacts, but it cannot be safely -joined to benchmark measurements unless benchmark case metadata is preserved. - -This writes `case_analysis.*`, `group_analysis.*`, `model_analysis.*`, and -`model_group_analysis.*`. It keeps fully local cases with no model workflow -rows, such as rule-covered `rules_only` or `native_rules_router` cases, in the -comparison with zero observed requests/tokens. Native direct-call strategies -that bypass DataDesigner write `model_workflow` rows, so their provider request -and token counts still contribute to case, group, and model summaries. When the -benchmark was run with current sidecar export, `rules_only` also has -artifact-derived signatures and source counts; older runs may only have -record-level entity counts. The joined case/group tables include -successful/failed request counts, input/output token splits, record counts, -dataset input-token throughput, `seed_validation_candidate_count`, -`estimated_seed_validation_chunk_count`, and `observed_failed_request_rate`; -use these when testing -`detect.validation_max_entities_per_call` so you can distinguish a real chunk -count change from provider retry variance. The model tables split the same -usage by `workflow_name` and `model_name`, which is useful for separating local -detector cost from validator, augmenter, substitute, or rewrite model cost. -When record-level measurements include ground-truth entities, the joined tables -also expose exact and relaxed entity-quality metrics. The relaxed metrics count -span overlap, with small label-equivalence groups for common aliases such as -`user_name` / `username` and `api_key` / `auth_token`. Case and group tables -also count empty detections, including empty records that had ground-truth -entities. If your suite adds portable topology tags such as `endpoint_count`, -`gpu_count`, or `tensor_parallelism`, the analysis computes per-endpoint and -per-GPU input-token throughput; otherwise those normalized fields remain null. -The case/group tables also surface incomplete benchmark cases with -`case_failed`, `error_stage_count`, `error_ndd_workflow_count`, -`error_model_workflow_count`, `failed_case_count`, and `failed_case_rate`. -Check these before interpreting a fast candidate as a safe improvement; a -failed repetition can otherwise look like entity instability or a latency win. -The joined case/group tables also expose final entity source counts from -detection artifacts, including `artifact_final_detector_entity_count`, -`artifact_final_rule_entity_count`, and -`artifact_final_augmenter_entity_count`. Use these to verify whether a faster -strategy is still relying on contextual detector/validator spans, or whether it -has shifted a workload entirely onto deterministic rules. -They also include `artifact_final_entity_signature_count` and -`artifact_final_entity_signature_hashes`, which are opaque per-row identifiers -derived from the final entity label and offsets. They do not include raw or -normalized entity values. The companion -`artifact_final_entity_signature_labels` field maps each opaque hash to its -entity label. These fields do not expose raw entity values, but they let -analysis tools detect when two configs report the same entity count while -protecting different spans. - -To compare a baseline and candidate strategy across common workloads, use: +- `case_analysis.*`: one row per benchmark case. +- `group_analysis.*`: median and aggregate metrics grouped by workload/config. +- `model_usage.*`: one row per measured model usage entry. +- `model_usage_group_analysis.*`: model usage rolled up by workflow/model. -```bash -uv run python tools/measurement/compare_strategy_pairs.py \ - benchmark-runs/suite-id/analysis/case_analysis.csv \ - --baseline-strategy no_augment \ - --candidate-strategy rules_filter_guardrail_no_augment \ - --output benchmark-runs/suite-id/analysis/strategy_comparison.csv -``` +Use `--detection-artifacts` to provide an explicit detection artifact JSONL +sidecar. Otherwise, the analyzer reads `detection-artifacts.jsonl` in the +benchmark directory when present. -If the candidate was run in a separate benchmark directory, pass a second case -analysis file: - -```bash -uv run python tools/measurement/compare_strategy_pairs.py \ - benchmark-runs/baseline-suite/analysis/case_analysis.csv \ - --candidate-case-analysis benchmark-runs/candidate-suite/analysis/case_analysis.csv \ - --baseline-strategy no_augment \ - --candidate-strategy rules_guardrail_no_augment -``` - -The comparison reports latency, request, token, entity-count, validation -candidate-count, augmentation-count, final source-count, and opaque -entity-signature deltas. It also reports original-value leak deltas from -`original_value_leak_count` and `original_value_leak_record_count`. The -`augmented_entity_count_delta` and -`augmented_new_final_value_count_delta` columns are especially useful for -no-augmentation and model-routing ablations: a faster candidate that removes -new final values from augmentation needs signature checks before promotion. -When signature labels are available, it also reports label counts for -baseline-only, candidate-only, and shared signatures. For repeated selector -runs, it also compares signatures that are stable across every repetition, -which catches cases where a candidate finds a sensitive span only -intermittently. It adds conservative flags such as -`baseline_case_failures`, `candidate_case_failures`, `entity_count_loss`, -`entity_signature_loss`, `span_boundary_mismatch`, -`covered_label_mismatch`, -`candidate_original_value_leak`, -`candidate_replacement_missing_final_entity`, -`candidate_duplicate_synthetic_replacement`, -`failed_request_increase`, `bridge_fallback_increase`, -`stable_entity_signature_loss`, `no_candidate_detector_entities`, -`candidate_uses_rule_entities`, `candidate_skips_llm_validation`, and -`replacement_only_detection_instability`, plus five verdict fields: - -- `value_protection_verdict`: `pass`, `review`, or `fail`. This axis focuses on - whether the candidate still protects the sensitive values. Candidate case - failures, candidate original-value leaks, missing replacement-map entries, - replacement collisions, and uncovered baseline signatures fail. Rule - provenance, validation skipping, - provider retry pressure, and covered boundary or label mismatches do not fail - this axis by themselves; they are represented in the semantic and overall - safety verdicts. -- `signature_parity_verdict`: `pass`, `review`, or `fail`. This axis focuses on - exact baseline signature semantics. Covered label or boundary mismatches stay - review-gated even when `value_protection_verdict` passes. -- `safety_verdict`: `pass`, `review`, or `fail`. Candidate case failures and - entity/signature loss fail. Candidate original-value leaks also fail, even - when entity signatures match. Baseline case failures, baseline - original-value leaks, rule-only, rule-heavy, or validation-skipping - candidates require review. Candidate provider failed-request increases or - bridge-fallback increases also require review: they are reliability signals, - not anonymization leaks. -- `performance_verdict`: `improved`, `mixed`, `regressed`, `unchanged`, or - `unknown`, based on available latency, request, and token deltas. -- `candidate_verdict`: `candidate_viable`, `review`, or `reject`. A candidate - is viable only when safety passes and measured performance improves. - -Use verdicts for triage, then inspect the underlying flags and label-count -deltas before promoting a strategy beyond benchmark experiments. -For replacement-only comparisons where the detection strategy is unchanged, -`replacement_only_detection_instability` means the candidate and baseline were -still run through independent detection passes and their detection artifacts -drifted. Treat that as a prompt to consult fixed-trace replacement replay before -blaming or promoting the replacement-map backend. -In fixed-trace replacement replay, -`candidate_duplicate_synthetic_replacement` means the local replacement backend -protected every original value but collapsed at least two replacements in the -same row to the same synthetic value. That is review-gated as a substitute -quality and relational-consistency concern rather than treated as an immediate -privacy leak. -When the replay CSV contains -`candidate_covers_baseline_replacement_missing_final_entity`, -`candidate_covers_baseline_original_value_leak`, or -`candidate_covers_baseline_replacement_synthetic_original_collision`, the -candidate removed a defect observed in the DataDesigner-backed substitute arm -on the same fixed detection trace. In that case `value_protection_verdict` can -pass while `signature_parity_verdict` remains review-gated, because the -candidate covered more of the final-entity set than the flawed baseline. - -For `rules_covered_or_default`, compare rule-covered configs by config ID so -the zero-model lane is checked against the same explicit label set: +`compare_strategy_pairs.py` compares baseline/candidate case rows: ```bash uv run python tools/measurement/compare_strategy_pairs.py \ benchmark-runs/suite-id/analysis/case_analysis.csv \ - --baseline-config rule-labels-default \ - --candidate-config rule-labels-covered-or-default \ - --output benchmark-runs/suite-id/analysis/rules-covered-comparison.csv + --baseline-config default \ + --candidate-config native-single-pass \ + --output benchmark-runs/suite-id/analysis/default-vs-native-single-pass.csv ``` -Promote the fast path only when -`baseline_only_candidate_uncovered_signature_count` is zero on the target -workload, `candidate_original_value_leak_count` is zero, `candidate_verdict` is -at least `review`, and the review flags are expected rule fast-lane flags such -as `candidate_uses_rule_entities`, `no_candidate_detector_entities`, -`entity_count_loss`, or `span_boundary_mismatch`. Exact -`baseline_only_final_entity_signature_count` can be nonzero when a candidate -protects the same sensitive value with a wider or slightly narrower keyed span; -use the covered/overlapping/uncovered columns to decide whether that is an -acceptable workload policy. A run that has uncovered signatures or leaks -original detected values should reject: -in the June 8, 2026 sudo-password smoke run, the pre-fix comparison rejected the -candidate with `lost_labels=password:1`; after the narrow sudo rule was added, -the same comparison had no baseline-only signatures and remained review-gated -only because the final spans were rule-sourced. - -The command output also includes a rollup summary with verdict counts and the -workloads in each candidate-verdict bucket, which is useful for repeated runs -over larger suites. - -To screen many comparison CSVs from one or more benchmark directories, use: +When one CSV does not contain both arms, pass `--candidate-case-analysis`: ```bash -uv run python tools/measurement/screen_strategy_comparisons.py \ - benchmark-runs/ \ - --output benchmark-runs/strategy-screen.csv \ - --group-output benchmark-runs/strategy-groups.csv -``` - -When screening a scratch directory that contains older analysis outputs, filter -by source-path fragments: - -```bash -uv run python tools/measurement/screen_strategy_comparisons.py \ - /tmp/anonymizer-benchmark-scratch \ - --source-include analysis-current-csv \ - --source-include analysis-failure-aware-csv \ - --output current-strategy-screen.csv \ - --group-output current-strategy-groups.csv -``` - -Use `--source-exclude` to omit known stale or exploratory subdirectories. -For example, if a scratch directory contains a pre-fix comparison and a rerun, -screen only current evidence by excluding the stale source-path fragment: - -```bash -uv run python tools/measurement/screen_strategy_comparisons.py \ - /tmp/anonymizer-perf-goal \ - --source-include comparison \ - --source-exclude before-sudo \ - --source-exclude structured-secrets-varied-comparison.csv \ - --output /tmp/anonymizer-perf-goal/strategy-screen-current.csv \ - --group-output /tmp/anonymizer-perf-goal/strategy-screen-current-groups.csv -``` - -The screen walks CSV files recursively, ignores non-comparison tables such as -`case_analysis.csv` and `group_analysis.csv`, and combines rows produced by -`compare_strategy_pairs.py`. It deduplicates exact repeated rows from copied -analysis directories, then sorts viable candidates first, then review and reject -rows, preserving latency/token deltas, flags, lost-label summaries, and -augmentation deltas. It also preserves baseline/candidate case counts, -baseline/candidate detection strategies, baseline/candidate replacement -strategies, stable-signature evidence counts, and candidate original-value leak -counts and labels. For DataDesigner-free experiments, it also preserves -`value_protection_verdict`, `signature_parity_verdict`, and label-mismatch -label counts, so one-off candidate rows are visible as weak evidence even before -opening the source comparison CSV. This is the quickest way to check whether a -benchmark directory contains any candidate worth rerunning on a larger workload -slice. - -Use the `evidence_level` column to separate current safety evidence from older -or weaker comparison rows. `split_verdicts` means the row has separate value -protection and signature-parity verdicts, `stable_signatures` means it has -stable-signature counts but not split verdicts, `signature_counts` means it only -has raw signature counts, and `legacy` means the screen can only use the older -aggregate verdict columns. The group output includes `evidence_level_counts` so -mixed scratch directories do not make a legacy row look as strong as a current -split-verdict rerun. - -The optional group output aggregates rows by candidate strategy when the -candidate used a non-default experimental strategy, or by candidate config -otherwise. This keeps ordinary config experiments, such as model routing or -prompt-parameter changes, from being collapsed under `strategy:default`. When -the same experiment used multiple config IDs, pass a JSON alias map: - -```json -{ - "biography-hybrid-augment-temp07": "biography-temp07-routing", - "biography-augment-temp07": "biography-temp07-routing" -} -``` - -```bash -uv run python tools/measurement/screen_strategy_comparisons.py \ - benchmark-runs/ \ - --group-by strategy_workload_family \ - --config-aliases config-aliases.json \ - --group-output benchmark-runs/strategy-family-groups.csv -``` - -Aliases only affect default-strategy, default-replacement config grouping. -Non-default experimental detection strategies still group by strategy; when a -candidate also uses a non-default replacement strategy, the group key appends -`replacement:`. If detection is default and only replacement changes, -the group key is `replacement:`. Use the group output to find -candidates with conflicting evidence, such as a no-augmentation candidate that -passes one slice and rejects on another. The -group table includes both best and worst latency, token, and request deltas so a -single fast slice does not hide a slower or unsafe repeat. It also includes -minimum baseline/candidate case counts and the minimum shared stable-signature -count observed in the group, plus summed candidate original-value leak counts -and leak labels. The -`recommendation` column is deliberately conservative: -`single_slice_viable` means one viable row exists but needs repeat evidence, -`candidate_family_viable` requires two or more viable rows and no review or -reject rows, `promising_needs_review` means viable rows exist but review-gated -rows remain and at least one split-verdict row is also viable, -`needs_split_verdict_rerun` means viable-looking and review-gated rows exist but -the group has only older signature-count or stable-signature evidence, or a -review-only group mixes current split-verdict rows with older comparison rows -that should be rerun under the current verdict schema, -`needs_viable_split_verdict` means older viable rows exist and split-verdict -evidence exists, but every split-verdict row is still review- or reject-gated, -`replacement_replay_review` means an improved replacement-strategy group is -review-gated by detection artifact drift even though the detection strategy did -not change; use fixed-trace replacement replay to isolate replacement-map -behavior, -`reliability_review` means every row improved performance but one or more rows -are review-gated by provider reliability signals such as failed-request or -sync-bridge fallback increases, -`fast_lane_review` means a `rules_only` or -`rules_covered_or_default` group improved performance, had explicit zero -candidate original-value leaks, had no uncovered baseline signatures, and is -review-gated only by expected rule fast-lane provenance or span-boundary flags, -`label_policy_review` means every row improved performance, passed -`value_protection_verdict`, and was review-gated on `signature_parity_verdict` -because the candidate protected a baseline value under a different label, -`review_only` means the family has no failures, still needs manual review, and -every review-gated row is `improved`, -`review_mixed_performance` -means the family has no failures but has mixed performance evidence, -`no_performance_win` means review-gated rows exist without an improvement -signal, `reject` means no viable rows survived, and `conflicting_evidence` means -at least one viable row and at least one rejected row exist for the same -candidate family. - -When a strategy's safety depends on workload shape, group by workload family: - -```bash -uv run python tools/measurement/screen_strategy_comparisons.py \ - benchmark-runs/ \ - --group-by strategy_workload_family \ - --output benchmark-runs/strategy-screen.csv \ - --group-output benchmark-runs/strategy-family-groups.csv +uv run python tools/measurement/compare_strategy_pairs.py \ + baseline/analysis/case_analysis.csv \ + --candidate-case-analysis candidate/analysis/case_analysis.csv \ + --baseline-strategy default \ + --candidate-strategy detector_native_validate_no_augment \ + --output comparison.csv ``` -This keeps evidence from families such as shell-secret command logs, legal -records, and biographies separate. Use this mode before claiming a broad -performance improvement from a strategy that may only be safe on rule-covered -secret workloads. Use `--group-by strategy_workload` for an even stricter -per-workload grouping. - -The exporter groups records by `record_type`: - -- `run`: one row per Anonymizer run, with sanitized config, workload, model, and - runtime metadata. -- `stage`: one row per measured pipeline stage, with elapsed time, row counts, - and throughput fields. -- `record`: one row per input row when record-level measurement is enabled, - with text-size buckets, entity counts, replacement counts, rewrite scores, - and estimated nominal LLM call counts. -- `ndd_workflow`: one row per DataDesigner adapter call, with model aliases, - elapsed time, row counts, failed-record counts, and observed token/request - usage when DataDesigner exposes it. -- `model_workflow`: one row per non-DataDesigner model-backed workflow, such as - `native_rules_router`, `native_candidate_validate_no_augment`, - `detector_native_validate_no_augment`, - `detector_native_validate_native_augment`, `native_single_pass`, and the - other `native_single_pass*` strategies, with the same sanitized usage fields - as `ndd_workflow`. - -The tables never store raw text, prompts, generated outputs, entity values, or -replacement maps. `record_hash` is a run-scoped HMAC, so it can join rows within -one run but should not be treated as a durable dataset identifier. - -## Analysis Patterns - -Start with these questions: - -- Which workload/config pair is fastest at the same quality target? -- Which stage dominates wall time: detection, replacement, rewrite, or a - DataDesigner sub-workflow? -- Does latency scale with text length, entity count, or rewrite repair work? -- Do token counts, request counts, and failed records explain latency outliers? -- Are quality metrics worse on one data shape, such as legal text, biographies, - support tickets, shell history, or mixed natural-language/code records? - -Most analyses join `stage`, `record`, `ndd_workflow`, and `model_workflow` back -to `run` through `run_id`, then group by run tags: - -- `run_tags.suite_id` -- `run_tags.workload_id` -- `run_tags.config_id` -- `run_tags.experimental_detection_strategy` -- `run_tags.experimental_replacement_strategy` -- `run_tags.dd_parser_compat` -- `run_tags.repetition` -- `run_tags.case_id` - -Prefer medians and percentiles over averages when comparing latency. LLM calls -usually have long tails, and one retry or provider stall can distort a mean. - -For staged DD-free detection probes, convert the probe output first: +`screen_strategy_comparisons.py` screens many comparison CSVs: ```bash -uv run python tools/measurement/analyze_staged_detection_output.py \ - /tmp/anonymizer-perf-goal/no-dd-rules-plus-direct-biography-r5-current \ - --output /tmp/anonymizer-perf-goal/no-dd-rules-plus-direct-biography-r5-current/analysis \ - --format csv +uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ \ + --output benchmark-runs/strategy-screen.csv ``` -Then read `analysis/group_analysis.csv` to compare `elapsed_sec_sum`, -`model_elapsed_sec_sum`, `model_request_count_sum`, `total_tokens_sum`, -`baseline_shared_signature_rate`, and -`baseline_only_final_entity_signature_count_sum`. Use `fast_lane_verdict` as -the first gate: `reject` means stop and inspect losses before running larger -slices; `fast_lane_candidate` means the sampled workload is a plausible -zero-model rule-covered lane with repeated evidence; `review` means the output -is incomplete, has too few cases, or still uses model work. The staged analyzer -requires at least three cases in a seed-source group before a clean zero-model -run can become `fast_lane_candidate`; one-row smokes remain `review` even when -they preserve all compared signatures. Read -`analysis/label_delta_analysis.csv` when the shared-signature rate is low; it -shows which labels drove the baseline-only losses or direct-only additions. +Use `--group-by strategy_workload_family` when the same candidate behaves +differently across workload families. Use `--config-aliases aliases.json` to +group related config IDs, such as temperature or validation-window variants of +the same strategy. -## Pandas Examples +## Pandas patterns -Load exported tables: +Analysis tables are regular CSV/Parquet files. A typical local workflow: ```python -from pathlib import Path - import pandas as pd -tables = Path("benchmark-runs/shell-and-biography-smoke/tables") -run = pd.read_parquet(tables / "run.parquet") -stage = pd.read_parquet(tables / "stage.parquet") -record = pd.read_parquet(tables / "record.parquet") -ndd = pd.read_parquet(tables / "ndd_workflow.parquet") -``` - -Compare end-to-end stage latency by workload and config: - -```python -stage_group_cols = ["run_tags.workload_id", "run_tags.config_id", "stage"] - -stage_summary = ( - stage - .groupby(stage_group_cols) - .agg( - runs=("run_id", "nunique"), - median_sec=("elapsed_sec", "median"), - p95_sec=("elapsed_sec", lambda s: s.quantile(0.95)), - rows_per_sec=("rows_per_sec", "median"), - ) - .reset_index() - .sort_values(["run_tags.workload_id", "stage", "median_sec"]) -) - -print(stage_summary) -``` - -Find slow records and relate them to text size and entity count: - -```python -record_view = record[ - [ - "run_tags.workload_id", - "run_tags.config_id", - "record_hash", - "text_length_tokens", - "text_length_tokens_bucket", - "final_entity_count", - "nominal_llm_call_count", - "utility_score", - "leakage_mass", - ] -].copy() - -shape_group_cols = [ - "run_tags.workload_id", - "run_tags.config_id", - "text_length_tokens_bucket", +cases = pd.read_parquet("benchmark-runs/suite/analysis/case_analysis.parquet") +groups = pd.read_parquet("benchmark-runs/suite/analysis/group_analysis.parquet") + +cols = [ + "workload_id", + "config_id", + "experimental_detection_strategy", + "median_pipeline_elapsed_sec", + "median_observed_total_requests", + "median_observed_total_tokens", + "median_artifact_final_entity_signature_count", ] +print(groups[cols].sort_values(["workload_id", "median_pipeline_elapsed_sec"])) -by_shape = ( - record_view - .groupby(shape_group_cols) - .agg( - records=("record_hash", "count"), - median_entities=("final_entity_count", "median"), - median_nominal_calls=("nominal_llm_call_count", "median"), - median_utility=("utility_score", "median"), - median_leakage=("leakage_mass", "median"), - ) - .reset_index() -) - -print(by_shape) -``` - -Summarize DataDesigner token and request usage: - -```python -workflow_group_cols = [ - "run_tags.workload_id", - "run_tags.config_id", - "run_tags.experimental_detection_strategy", - "run_tags.experimental_replacement_strategy", - "run_tags.dd_parser_compat", - "workflow_name", +failures = cases[ + (cases["case_failed"]) | + (cases["observed_failed_requests"] > 0) | + (cases["dd_trace_error_count"] > 0) ] - -token_summary = ( - ndd - .groupby(workflow_group_cols) - .agg( - calls=("workflow_name", "count"), - median_sec=("elapsed_sec", "median"), - total_input_tokens=("observed_input_tokens", "sum"), - total_output_tokens=("observed_output_tokens", "sum"), - total_requests=("observed_total_requests", "sum"), - failed_records=("failed_record_count", "sum"), - ) - .reset_index() - .sort_values(["run_tags.workload_id", "run_tags.config_id", "median_sec"]) -) - -print(token_summary) +print(failures[["case_id", "config_id", "observed_failed_requests", "dd_trace_error_count"]]) ``` -Summarize provider usage by workflow and model: +Compare a candidate against a baseline: ```python -model_usage = pd.read_csv("benchmark-runs/suite-id/analysis/model_group_analysis.csv") - -retry_sources = ( - model_usage - .sort_values( - ["sum_observed_failed_requests", "sum_observed_total_tokens"], - ascending=[False, False], - ) - [ - [ - "workload_id", - "config_id", - "workflow_name", - "model_name", - "sum_observed_total_requests", - "sum_observed_failed_requests", - "observed_failed_request_rate", - "sum_observed_total_tokens", - ] - ] -) - -print(retry_sources) +comparison = pd.read_csv("benchmark-runs/suite/analysis/default-vs-native.csv") +candidate_rows = comparison[ + ["workload_id", "candidate_verdict", "safety_verdict", "performance_verdict", "flags"] +] +print(candidate_rows) ``` -Join run metadata to stage timing: +Find candidate-specific misses: ```python -run_meta = run[ - [ - "run_id", - "mode", - "strategy", - "detect.entity_label_count", - "detect.validation_max_entities_per_call", - ] +loss_cols = [ + column for column in comparison.columns + if column.startswith("baseline_only_final_entity_signature_label_counts.") ] - -stage_with_config = stage.merge(run_meta, on="run_id", how="left") - -config_group_cols = ["mode", "strategy", "detect.entity_label_count", "stage"] - -print(stage_with_config.groupby(config_group_cols)["elapsed_sec"].median()) +print(comparison[["workload_id", *loss_cols]].fillna(0)) ``` -For quick interactive work, CSV can be easier than Parquet: - -```bash -uv run python tools/measurement/export_measurements.py \ - benchmark-runs/suite-id/measurements.jsonl \ - --output /tmp/suite-csv \ - --format csv \ - --overwrite -``` - -## Signature Delta Review - -Use `extract_signature_deltas.py` when a fast candidate has fewer, more, or -different final entity signatures than a higher-recall reference run. The tool -compares two `detection-artifacts.jsonl` files and recovers local context from -the DataDesigner artifact parquet files. Entity values are masked by default: -the output stores label, source, span offsets, value length, signature id, and a -small context window with the entity replaced by a placeholder. It does not -emit a hash derived from the raw entity value. - -Example: review spans found by a text/raw-parser reference but missed by a -hybrid candidate for one workload/config pair: - -```bash -uv run python tools/measurement/extract_signature_deltas.py \ - /tmp/reference/detection-artifacts.jsonl \ - /tmp/candidate/detection-artifacts.jsonl \ - --baseline-artifact-root /tmp/reference/artifacts \ - --candidate-artifact-root /tmp/candidate/artifacts \ - --baseline-config legal-default \ - --candidate-config legal-hybrid-rules-guardrail \ - --workload legal-r2 \ - --output /tmp/legal-signature-deltas.csv \ - --format csv -``` - -Interpretation: - -- `baseline_only` rows are spans the candidate missed relative to the - reference. -- `candidate_only` rows are spans the candidate found that the reference did - not. -- `resolution=parquet` means the span was recovered from DataDesigner's final - detection artifacts. -- `resolution=artifact_details` means the span was reconstructed from - sanitized final signature details plus the artifact row's source text. This - is common for benchmark-only strategies that patch final entities from an - in-memory dataframe after a seed-stage artifact is written. -- `resolution=rule` means the span was reconstructed from deterministic - rule-guardrail logic because it was added after DataDesigner wrote parquet. -- `resolution=metadata_only` means only the opaque signature metadata was - available; use this as a signal to rerun with trace/artifact capture if the - delta matters. - -## Current Local Findings - -These findings come from small local vLLM runs against -`nvidia/nemotron-3-super`; treat them as triage signals, not defaults. - -| Strategy | Latest local result | Status | Implication | -| --- | --- | --- | --- | -| `rules_only` on the three-row shell-secrets slice | Preserved all 12 stable signatures; median latency moved from 7.2s to 0.004s, requests from 12 to 0, and tokens from 11,019 to 0. | Review | Viable only for bounded secret scans where every requested label is covered by deterministic rules. | -| `rules_guardrail_detector_only` on the same shell-secrets slice | Preserved stable signatures and reduced model work, but one candidate repetition failed during GLiNER health checks. | Review | Useful as a structured-secret diagnostic, but less attractive than `rules_only` when labels are fully rule-covered. | -| `rules_filter_guardrail` on the same shell-secrets slice | Retry-enabled rerun completed all 6 cases. It preserved all 12 signatures, reduced seed validation candidates from 11 to 0, median pipeline latency from 8.0s to 3.9s, requests from 12.5 to 7.0, and tokens from 10,966 to 3,647. | Review | Useful as a mixed-workload probe; keep it review-gated because final entity provenance is rule-only for this slice. | -| `rules_filter_guardrail` on a mixed biography/legal/shell probe | After changing rule filtering to preserve different-label overlaps, repeated two-row biography, one-row legal, and three-row shell runs had no stable or unstable signature loss. Median pipeline latency moved from 28.5s to 20.0s on biography, 19.4s to 18.2s on legal, and 8.7s to 6.1s on shell. | Review | Historical positive probe only; the larger five-row non-shell repeat below did not preserve this signal. | -| `rules_filter_guardrail` on offset biography/legal slices | After hardening rule filtering so only fully covered same-label spans are skipped and rule reinsertion is additive, the five-row biography offset slice had no signature loss but moved into review because requests increased slightly while tokens decreased. The richer two-row legal offset slice rejected: latency, requests, and tokens regressed and one repetition missed three `court_name` signatures while adding one rule-backed `date_of_birth`. | Mixed | The hardened strategy is safer than the first version, but it still needs per-workload gates and is not a broad legal/prose default. | -| `rules_filter_guardrail` on current five-row biography/legal repeats | Biography preserved stable and unstable signatures but regressed latency from 37.8s to 45.9s and requests from 20.5 to 22.5. Legal improved latency from 60.6s to 51.2s and tokens from 63,072 to 61,568, but lost five stable `date` signatures and made two stable `person` signatures unstable. | Reject | Do not promote this as a prose/legal default; the safety and latency tradeoff is workload-dependent and fails the legal signature gate. | -| `rules_guardrail` on a five-row legal slice | Same-suite repeated comparison against default preserved stable and unstable signatures, but latency regressed from 39.6s to 47.1s, requests rose from 20.0 to 20.5, and tokens were roughly flat at 60,998 to 60,757. | Mixed | Deterministic date guardrails can improve coverage without signature loss, but they are not a legal-prose performance win on this slice. | -| `detector_only` and `rules_guardrail_detector_only` on prose/legal slices | Faster on one-row smoke checks, but lost baseline signatures on biography and legal samples. A current detector-only isolation rerun moved biography 27.3s → 0.9s and 8,416 → 526 tokens, but lost two `first_name` and one `organization_name` signatures. Legal moved 52.0s → 1.0s and 14,095 → 1,078 tokens, kept `date_of_birth`, but still lost one `date` and one `nationality` signature while adding many extra spans. | Reject | Local finalization alone is not a safe replacement for validation and augmentation on contextual text. The legal rerun is useful diagnostically because raw detector output kept `date_of_birth`, so a later native-validation miss likely came from validation behavior rather than detector seeding. | -| One-shot DD-free direct detection on biography/legal row 0 | Biography completed in 5.1s with 902 tokens but shared only 18/22 baseline signatures. Legal completed in 5.8s with 1,308 tokens but shared only 19/22 baseline signatures. | Reject as replacement | This is a useful speed boundary and prompt experiment, but a single extraction prompt drops core detections on non-shell workloads. | -| Standalone direct-detection five-row probe | A fresh local probe compared one direct extraction call per row against the current staged direct reference. On legal, compact direct detection moved from 62.3s and 15 requests to 17.1s and 5 requests, but shared only 75/147 reference signatures and missed 72. Recall prompting improved legal to 31.1s, 109 final entities, 102 shared, and 45 missed. On biographies, compact direct detection moved from 85.7s and 15 requests to 21.2s and 5 requests, with 91/102 shared signatures, 11 missed, and no extras; recall prompting regressed to 62 shared and 40 missed. Outputs: `/tmp/anonymizer-perf-goal/direct-detection-legal-r5-compact-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-legal-r5-recall-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-biography-r5-compact-after-guard`, `/tmp/anonymizer-perf-goal/direct-detection-biography-r5-recall-after-guard`. The benchmark harness can now run this value-only prompt shape through `native_single_pass_values` and `native_single_pass_values_recall`. | Mixed diagnostic | The one-call path is the clearest lower-bound latency test, but it is not a general anonymization-safe replacement. Compact one-call extraction may deserve workload-specific follow-up for biographies; legal still needs augmentation or a stronger candidate source. Recall prompting is not monotonic across domains. | -| Staged DD-free detection on biography/legal row 0 | Biography improved to 20/22 shared signatures but took 13.7s and 4,550 tokens. Legal stayed at 19/22 shared signatures while taking 17.5s and 6,425 tokens. Hosted GLiNER seeding was unavailable due a `DEGRADED function cannot be invoked` response for `nvidia/gliner-pii`. | Mixed diagnostic | A native no-DataDesigner executor is still plausible, but only if it preserves phase boundaries with much cheaper seed/validation prompts or deterministic code. Naive direct LLM phases are not enough. | -| Chunked-excerpt validation in staged DD-free detection | On current one-row reruns, biography preserved the same 20/22 shared signatures as full-text validation but moved from 10.8s, 4,527 tokens, and 3 model requests to 13.3s, 5,648 tokens, and 6 requests. Legal preserved the same 19/22 shared signatures but moved from 14.7s, 6,425 tokens, and 3 requests to 17.2s, 7,727 tokens, and 7 requests. | Reject | Splitting direct validation into local excerpts increases repeated instruction overhead and request count on these non-shell rows. Do not pursue validator excerpting as a standalone no-DD speedup unless longer records show a different request/token crossover. | -| Rules-seeded staged DD-free detection | `rules` improved biography/legal latency but still lost baseline signatures; legal row 0 reached 20/22 shared signatures at 7.1s and 2,213 tokens. On shell-secrets, validation reclassified a database URL as a password and lost one baseline URL signature. `rules-trusted` fixed that shell loss and preserved all 12 shell signatures with one augmentation call per row, but still missed biography/legal signatures. With `--skip-augmentation-when-rule-covered`, trusted rules preserved all 12 shell signatures with zero model usage. | Mixed diagnostic | Deterministic seed spans are useful, but rule-covered spans should not always go through LLM validation. A native executor needs workload gates and should short-circuit locally when every requested label is rule-covered. | -| Rules + direct LLM staged DD-free detection | `rules-plus-direct-llm` preserved all 12 shell-secrets signatures while avoiding validation, but still used two model calls per row and 726-938 tokens because the direct seed and augmentation phases still ran. On row-0 smokes it looked like the most plausible mixed no-DD path: biography shared 20/22 signatures at 10.8s and 4,465 tokens, and legal shared 20/22 at 11.5s and 5,929 tokens. The five-row gate rejected it: biography shared only 80/114 baseline signatures, lost 34 signatures, and took 85.7s versus the DD baseline's 32.9-47.8s; legal shared 108/145, lost 37 signatures, and took 62.3s versus the DD baseline's ~39.5s. | Reject for contextual workloads | Trusting deterministic structured spans locally is still useful, but direct LLM seed/validation/augmentation is not a safe or faster replacement for DataDesigner-backed contextual detection on prose/legal slices. Keep no-DD promotion limited to fully rule-covered structured-secret lanes unless a new native executor passes repeated signature gates. | -| Rules router staged DD-free detection | `rules-router` preserved all 12 shell-secrets signatures with no seed, validation, or augmentation model calls. The mixed/contextual fall-through did not generalize: on five biographies it shared 96/114 default signatures and lost 18 baseline signatures; on five legal rows it shared 86/145 default signatures and lost 59 baseline signatures. The benchmark-safe expression of this result is `rules_covered_or_default`, which short-circuits only fully rule-covered label sets and otherwise runs default detection. | Mixed | Keep the router only for the zero-model rule-covered structured-secret lane. Do not use the direct local LLM fall-through as a prose/legal replacement; use default Anonymizer or another signature-gated strategy for contextual rows. | -| Integrated `native_rules_router` benchmark with corrected direct-call metering | A five-row benchmark-harness run on biography/legal confirmed the staged finding. Biography moved from 32.9s to 85.6s, requests from 20 to 15, and tokens from 43,354 to 26,644, but entities fell from 114 to 102 and 34 baseline signatures were uncovered. Legal moved from 54.3s to 62.3s, requests from 21 to 15, and tokens from 60,649 to 31,894, but 37 baseline signatures were uncovered. Both workloads rejected. | Reject for contextual workloads | Direct native calls can reduce request and token counts while still losing safety and wall-time. Treat lower token counts as insufficient evidence; contextual promotion requires signature preservation and latency improvement together. | -| Integrated `native_candidate_validate_no_augment` smoke | One-row biography/legal benchmark-harness smoke proved the no-augmentation native executor is much cheaper but unsafe. Biography moved from 24.8s to 5.9s, requests from 4 to 2, and tokens from 8,092 to 2,000, but entities fell from 15 to 12 and lost `age`, `first_name`, and `organization_name` signatures. Legal moved from 49.8s to 10.9s, requests from 4 to 2, and tokens from 13,791 to 3,823, but entities fell from 23 to 21 and lost `date`, `date_of_birth`, and `nationality` signatures. Both rows had zero original-value leaks. | Reject for contextual workloads | Removing augmentation from the native executor gives the expected speed boundary, but augmentation or a stronger candidate source remains load-bearing for contextual recall. Keep this arm as a diagnostic, not a promotion candidate. | -| Integrated `detector_native_validate_no_augment` smoke | Keeping the detector seed and replacing DataDesigner validation/augmentation with direct validation is much cheaper, but quality remains workload-dependent. Biography still rejects: latest one-row rerun moved 26.6s -> 6.7s and 8,398 -> 2,347 tokens, but entities fell 15 -> 14 and two augmenter-sourced child `first_name` signatures were uncovered. A focused one-row legal repeat improved median latency from 15.0s to 11.0s, requests from 4 to 3, and tokens from 9,516 to 4,150 with zero leaks. After row-parallel direct validation plus deterministic DOB-context label normalization, a wider three-row, two-repeat legal gate moved median elapsed from 40.6s to 21.3s, requests from 12.5 to 6.5, and tokens from 37,972 to 17,902 with zero original-value leaks. The split verdicts were `value_protection=pass`, `signature_parity=review`, and `performance=improved`: a filing-date span that baseline labeled `date_of_birth` was protected as `date`, while separate birth-context years were added as `date_of_birth`. | Mixed: biography reject, legal label-policy review | The promising shape is not "remove DataDesigner everywhere"; it is "keep DD as fallback, use deterministic fast lanes where provably covered, and only replace validation when a native validator preserves both coverage and label semantics across repeated gates." The legal repeats now show a real latency win, but a DD-free candidate may protect values correctly while disagreeing with DataDesigner label semantics. That should stay review-gated until label policy says whether such covered reclassification is acceptable. | -| Integrated `detector_native_validate_no_augment` substitute gate | A one-row legal substitute smoke first showed the same review shape as the redact gate: latency moved from 21.1s to 15.2s, requests from 5 to 3, and tokens from 12,192 to 6,871 with zero original-value leaks. The wider three-row, two-repeat substitute gate still improved performance, but rejected on safety: median pipeline latency moved from 44.0s to 33.4s, requests from 15.0 to 9.0, tokens from 47,958 to 28,465.5, and failed requests from 3.0 to 0.0, while both baseline and candidate leaked two original `date` values across two row-runs. Replacement-map coverage was complete; local replay showed the leak was a substitute collision where one synthetic date reused another protected original date in the same record. The candidate added 11 stable signatures, but had covered label mismatches including a stable `date_of_birth` -> `date` mismatch. | Reject for substitute promotion | Native validation reduces detection cost even when substitute still uses normal replacement-map generation, but speed cannot promote a substitute strategy while original values survive in replaced output. The leak appears in the default substitute arm too, so this is a baseline substitute safety issue separate from the native validator. | -| Integrated `gliner_native_validate_*` no-DataDesigner gate | A one-row biography/legal smoke tested direct hosted GLiNER seeding outside DataDesigner plus direct native validation, with and without direct native augmentation. Biography no-augment rejected despite improving latency/tokens because it lost two `first_name` signatures. Biography with native augmentation passed the one-row gate: latency 13.7s -> 10.2s, requests 4 -> 3, tokens 8,033 -> 5,035, entities 22 -> 24, zero leaks, and only candidate-only additions. After bounded per-row parallelism and targeted label-boundary guidance in the integrated no-DD executor, a repeat-3 five-row biography gate improved median wall time 40.7s -> 25.5s, requests 21 -> 15, and tokens 43,371 -> 27,643 with zero original-value leaks and no case failures. The guidance removed the earlier `first_name` label mismatches, but repeat comparison rejected the candidate: four baseline signatures were only covered with mismatched labels (`degree`: 1, `last_name`: 2, `place_name`: 1), and six stable baseline signatures became unstable (`degree`: 1, `last_name`: 2, `organization_name`: 2, `place_name`: 1). Legal improved latency/tokens in both one-row arms, but stayed review-gated because a generic filing date that baseline labeled `date_of_birth` was protected as `date` by the candidate; the seed guardrail correctly does not promote dates without birth/DOB context. | Reject for contextual biographies | Direct GLiNER outside DataDesigner is a useful performance diagnostic, but repeated stable-signature gates block promotion on this biography slice. Lower requests/tokens plus faster wall time are insufficient if label semantics are unstable. | -| Integrated `native_single_pass` benchmark smoke | One-row benchmark-harness smoke on biography/legal showed the speed boundary for collapsing detection into one direct provider call. Biography improved latency 10.3s → 1.7s, requests 4 → 1, and tokens 5,059 → 597, but found 4 entities versus 7 and lost three `person` signatures, so it rejected. Legal improved latency 19.2s → 1.1s, requests 5 → 1, and tokens 7,107 → 838 while preserving both signatures, so that single row was viable. | Mixed diagnostic | The one-call native extractor is worth keeping as a benchmark arm, but it is not safe for broad contextual use. Promotion needs repeated workload-specific signature gates; a legal-row win does not cancel the biography miss. | -| Integrated `native_single_pass` five-row gate | After adding a local deterministic rule guardrail, the larger biography/legal run still rejected both contextual workloads. Biography moved from 24.1s to 8.3s, requests from 21 to 5, and tokens from 26,759 to 3,078, but entities fell from 36 to 21 and it lost 16 `person` signatures. Legal moved from 35.7s to 6.1s, requests from 21 to 5, and tokens from 38,569 to 5,781, but entities fell from 14 to 12 and it lost three `person` signatures. | Reject for contextual workloads | Local rules can cheaply protect deterministic secret shapes, but they do not fix contextual recall. Collapsing detection to one direct call remains a useful lower-bound latency experiment, not a safe contextual replacement. | -| Integrated `native_single_pass_recall` five-row gate | The recall prompt improved raw recall, especially on legal text, but still rejected both workloads. Biography moved from 23.0s to 10.2s, requests from 21 to 5, and tokens from 26,730 to 4,072, but entities fell from 36 to 26 and it still lost 16 `person` signatures. Legal moved from 32.2s to 8.7s, requests from 21 to 5, and tokens from 38,085 to 6,885; entity count rose from 14 to 20, but two baseline `person` signatures were still uncovered. | Reject for contextual workloads | Prompt recall can improve counts without satisfying anonymization safety. One-call contextual extraction remains below the signature gate even when it is much faster and cheaper than default. | -| Integrated `native_single_pass_values*` value-only five-row gate | Two repetitions on five NVIDIA biography rows and five TAB legal rows confirmed the value-only one-call prompt shape is only a speed boundary. Compact values mode improved latency by 55.6% on biographies and 68.9% on legal, with 15-15.5 fewer requests and 31,770-60,491 fewer tokens, but rejected both workloads after losing 45 biography and 123 legal baseline-only signatures. Recall values mode still rejected: it improved latency by 31.9% on biographies and 38.7% on legal, but lost 40 biography and 96 legal baseline-only signatures. Output: `/tmp/anonymizer-native-values-paired-r5`. | Reject for contextual workloads | Returning values instead of offsets makes parsing cheaper but does not solve contextual recall. Keep this arm in the harness as a lower-bound diagnostic; do not promote one-call extraction on biographies or legal text without a different seed source or repeated signature parity. | -| Structured fast-lane router tightening | `rules_covered_or_default` now short-circuits only the structured-secret labels `api_key`, `email`, `http_cookie`, `password`, `pin`, `unique_id`, `url`, and `user_name`. Narrow prose rule labels such as `date_of_birth`, `organization_name`, `religious_belief`, and `street_address` fall back to default detection unless `rules_only` is explicitly selected. A shell-secret smoke still found 12 entities across 3 records with 0 model rows, 0 requests, and 0 tokens. | Review | This preserves the no-DataDesigner win without assuming all inputs are shell logs. Local prose rules remain useful as explicit experiments or guardrails, but they are not complete enough for automatic contextual anonymization. | -| Narrow prose-label augmentation skip probe | On one synthetic `organization_name` + `street_address` row, `rules_covered_or_default` correctly fell back to model-backed detection instead of using the zero-model fast lane. A repeat-3 comparison then found `rules_guardrail_no_augment` preserved the same two signatures with zero leaks while moving median latency 3.0s → 2.6s, requests 4 → 3, and tokens 3,069 → 2,133. | Candidate | Skipping augmentation can be viable for tightly scoped prose-label slices when detector+validator already recover the needed spans. This is not a broad prose default; promote only through repeated signature gates, especially on biographies and legal text where augmentation may carry recall. | -| Real biography/legal no-augmentation check | On two NVIDIA biography rows, pure `no_augment` rejected: latency regressed 24.1s → 28.8s, entities fell 48 → 46, and two `first_name` signatures were lost. `rules_guardrail_no_augment` improved biography latency/tokens (24.1s → 18.3s, 17,992 → 11,905 tokens) but still rejected after losing the same two `first_name` signatures and using rule-sourced spans. On two TAB legal rows at offset 2, `no_augment` preserved signatures and reduced tokens but regressed latency (27.2s → 38.6s) and increased failed-request rate; `rules_guardrail_no_augment` preserved signatures with modest latency/token gains but remained review-gated because it introduced rule-sourced spans. | Mixed: biography reject, legal review | The synthetic augmentation-skip win does not generalize to biography prose. Augmentation remains load-bearing for contextual name recall, and legal gains need repeated runs plus failed-request scrutiny before promotion. | -| `rules_covered_or_default` mixed benchmark harness run | A two-row synthetic shell-secret run initially exposed a rule hole: default found one sudo-stdin password that the rule-only path missed. After adding a narrow `echo "..." | sudo -S` password rule, the rerun preserved all 9 shell signatures with detection latency 21.4s → 0.004s, requests 8 → 0, and tokens 9,854 → 0. One-row biography and legal contextual configs included `person`, so they fell back to model-backed detection and matched default entity counts. | Review | This is the safest implementation shape for the no-DataDesigner idea: use local rules only where labels and observed signatures prove coverage, and treat every missed signature as a rule-quality bug or a reason to fall back. | -| `rules_covered_or_default` current mixed fallback run | Current-code rerun completed all 6 cases. Shell secrets preserved all 9 signatures with pipeline latency 23.1s → 0.005s, requests 8 → 0, and tokens 10,173 → 0. The biography and legal configs included `person`, so both candidate cases fell back to model-backed detection and matched default entity counts and signatures: biography 7/7, legal 2/2. | Review | The router is behaving as designed after the rule-only `tagged_text` contract fix: structured secret configs can short-circuit locally, while contextual non-shell configs stay on default detection. | -| `rules_covered_or_default` repeated shell-secret run | A three-repetition shell-only suite completed all 6 cases. The candidate preserved all 9 final signatures in every repetition with median detection latency 29.4s → 0.004s, requests 9 → 0, and tokens 10,112 → 0. Default detection was unstable on this tiny slice: one repetition missed one `api_key`, so stable signatures were 8 for default and 9 for the rules path. The comparison remained review-gated, not viable, because all candidate spans were rule-sourced. | Review | Repeated evidence strengthens the structured-secret fast path but also shows why promotion should use stable-signature comparisons rather than treating default as perfectly deterministic on every repetition. | -| `rules_covered_or_default` on non-shell structured secrets | A four-row JSON/env/HTTP-header/YAML-style suite initially rejected after exposing two deterministic-rule gaps: URLs swallowed trailing semicolon separators and `session_id=...` cookie values were not protected. After tightening URL boundaries and adding a narrow `session_id` assignment rule, the rerun preserved all 17 default signatures while moving detection latency 25.8s → 0.010s, requests 16 → 0, and tokens 19,167 → 0. A repeat-3 run then kept all candidate signatures stable: default produced 15, 15, and 16 entities with median 18,822 tokens, while the rules path produced 17 entities every time with zero model requests and zero tokens. | Review | The no-DataDesigner fast lane is not shell-specific, but it must remain rule-coverage and signature-gated. Treat every structured-secret miss as either a narrow rule bug with tests or a reason to fall back to default detection. | -| `local_structured_substitute` on non-shell structured secrets | A four-row JSON/env/HTTP-header/YAML-style substitute suite preserved the same 17 final entities with zero original-value leaks. In a repeat-3 run, DataDesigner-backed substitute had median pipeline latency 38.1s, 4 requests, and 13,967 tokens for replacement-map generation; individual DD-backed runs ranged from 30.7s to 62.4s. `local_structured_substitute` had median latency 0.005s, 0 requests, and 0 tokens while preserving the same 17 replacements. | Review | Replacement-map generation is now another defensible no-DataDesigner lane for structured labels. Keep it benchmark-only until repeated gates and a policy decision define which structured labels deserve public API support. | -| `local_structured_substitute` with model-backed detection fallback | A one-row audit-style structured-identifier suite requested `api_key`, `http_cookie`, `pin`, `unique_id`, and `user_name`, so `rules_covered_or_default` fell back to normal model-backed detection in both arms. Both arms found the same 5 final entities with zero original-value leaks. Local replacement removed the replacement-map workflow, moving pipeline latency 53.6s → 33.0s, requests 5 → 4, and tokens 11,547 → 7,694. The pairwise comparison marked the candidate viable. | Candidate | Local replacement-map generation can help even when detection still needs DataDesigner. This is a cleaner promotion path than rule-only detection because contextual detection provenance is preserved; keep rejecting contextual replacement labels such as `person`. | -| `local_structured_substitute` with default detection on varied audit/config/HTTP identifiers | A four-row repeat-3 suite isolated replacement-map generation by keeping default model-backed detection in both arms. After adding a local synthetic-original collision guard, the guarded rerun kept value protection clean: zero original leaks, zero missing replacement-map entries, and zero synthetic-original collisions. Local substitute moved median pipeline latency 18.8s -> 12.7s, requests 21 -> 17, and tokens 24,324 -> 17,015. A current fixed-trace replay held detection constant at 21 entities and measured replacement only: DataDesigner substitute took 6.15s while local structured substitute took 0.003s, with 21/21 replacements and zero leaks/collisions in both arms. Regenerating the older repeat comparison with split verdicts moved the strategy-screen group out of `needs_split_verdict_rerun`; adding the fixed-trace replay comparison moved it to `promising_needs_review`. All three rows have `value_protection=pass`; the replay row has `signature_parity=pass` and `candidate_verdict=candidate_viable`, while one full-pipeline pairwise row has `signature_parity=review` because two covered signatures used different labels (`api_key`, `unique_id`). The comparison now tags this drift as `replacement_only_detection_instability` because detection strategy did not change. | Promising needs review | This is the cleanest structured-label promotion path because detector provenance stays model-backed in full-pipeline runs and the replacement backend passes fixed-trace replay. It is not fully promoted because normal pairwise runs still need monitoring for provider reliability and detection-run label drift. | -| `local_structured_substitute` fixed-trace replay on biography structured labels | A five-row NVIDIA synthetic biography replay used model-backed detection for `date_of_birth`, `organization_name`, `religious_belief`, and `street_address`, then replayed both substitute backends three times on the same 56 detected entities. After making local replacement-map generation avoid per-record duplicate synthetic values, both arms produced 159 replacements across 15 replacement attempts with zero duplicate synthetics, zero missing replacement-map entries, zero original-value leaks, and zero synthetic-original collisions. DataDesigner substitute took 23.59s for replacement-map generation and local structured substitute took 0.006s. The replay comparison marks `value_protection=pass`, `signature_parity=pass`, `safety=pass`, and `candidate_verdict=candidate_viable`. Output: `/tmp/anonymizer-perf-goal/biography-supported-structured-replacement-replay-repeat3.json`; comparison: `/tmp/anonymizer-perf-goal/biography-supported-structured-replacement-replay-repeat3-comparison.csv`; screen: `/tmp/anonymizer-perf-goal/strategy-screen-local-substitute-with-biography-replay-groups.csv`. | Candidate | This broadens the replacement-only result beyond shell or config logs without claiming DD-free contextual detection. The speed and leak profile are strong, the duplicate-collapse issue is fixed for this slice, and repeated replacement-only evidence shows the local backend can preserve replacement-map coverage when detection is held fixed. The remaining gate is policy: decide which structured labels and text shapes are eligible for deterministic substitute generation in production-facing configuration. | -| Expanded `rules_covered_or_default` + `local_structured_substitute` on an audit-style structured identifier record | After adding narrow keyed coverage for `http_cookie`, `pin`, `unique_id`, and `user_name`, the candidate protected all baseline signatures, found one additional `unique_id`, had zero original-value leaks, and moved pipeline latency 9.2s → 0.005s, requests 5 → 0, and tokens 6,075 → 0. | Review | This extends the no-DataDesigner fast lane beyond shell logs into keyed audit/config/HTTP-style structured records. It remains review-gated because every final span is rule-sourced and this run used one row. | -| Expanded `rules_covered_or_default` + `local_structured_substitute` on varied audit/config/HTTP identifiers | A four-row repeat-3 suite preserved every baseline-only signature through containing or overlapping candidate spans, with zero original-value leaks. Median pipeline latency moved 21.1s → 0.006s, requests 21 → 0, and tokens 24,332 → 0. The comparison records 8 exact baseline-only signatures, 8 candidate-covered signatures, 2 span-boundary mismatches, and 0 uncovered signatures. | Review | This is the strongest no-DataDesigner result so far for non-shell structured records. It is still not a default: all final spans are rule-sourced, and two protected values had different span boundaries such as `token=` versus ``, so promotion needs a workload policy gate. | -| Row-aware `rules_covered_or_default` + local substitute smoke | A four-row JSON/env/HTTP-header/YAML-style suite initially rejected because quoted JSON `user`/`pin` keys were not rule-covered. After adding quoted-key coverage and changing the router to fall back per row on suspicious uncovered structured assignments, the structured candidate moved pipeline latency 9.7s -> 0.0s, requests 20 -> 0, and tokens 20,080 -> 0 while matching entity count 10 -> 10 with zero original-value leaks. One-row biography and legal controls included `person`, used default detection in both arms, and passed comparison gates. | Review | The no-DataDesigner path is now safer: eligible labels are necessary but not sufficient, and rows with uncovered structured fields go through normal detection. The structured candidate still stays review-gated because one `HF_TOKEN` value was protected under a different label/boundary than the default `http_cookie` span. | -| Row-aware `rules_covered_or_default` + local substitute repeat gate | A focused repeat-3 split-verdict suite reran the same four structured rows after the row-aware router change. All 6 cases completed. Default substitute had median pipeline latency 12.4s, 21 requests, and 20,071 tokens; the row-aware rules/local candidate had median latency 0.006s, 0 requests, and 0 tokens. Both arms found 10 entities in every repetition and had zero original-value leaks or synthetic-original collisions. The split-verdict comparison has `value_protection=pass` but remained `safety=review` and `signature_parity=review`: one stable baseline `http_cookie` signature was protected by the candidate under an `api_key` label with a span/boundary mismatch. Output: `/tmp/anonymizer-perf-goal/structured-fastlane-split-r3`. | Needs viable split verdict | This is a large structured fast-lane performance win, but not promotion-ready. The next decision is whether the covered `http_cookie` -> `api_key` mismatch is acceptable value protection for this workload or whether the deterministic rules need to match baseline label semantics more closely. | -| `bio-vmax10-w80` validator window tuning | Rejected on biography rows 6-10: latency, requests, and tokens regressed, and stable `field_of_study` and `state` signatures were lost. | Reject | Smaller validation windows need per-workload proof; prompt-size savings can be outweighed by more calls and lost context. | -| Text augmenter routing at `temperature: 0.3` | A one-row biography smoke test passed, but repeated five-row slices did not: rows 0-4 preserved signatures while latency regressed from 40.4s to 45.9s and requests from 21.0 to 21.5; rows 5-9 rejected after latency regressed from 41.0s to 52.1s and two stable `state` signatures became unstable. | Reject | JSON-validator/text-augmenter routing at the default text temperature is not a reliable prose speedup on these slices. | -| Text augmenter routing at `temperature: 0.7` | Passed the first biography slice, then failed on rows 6-10 by losing a stable `university` signature and regressing latency. | Reject | Do not promote the routing pattern from a single positive slice. | -| `rules_guardrail_no_augment` on legal prose | Improved latency/tokens on legal rows 2-3, but lost two stable `first_name` signatures. | Reject | Augmentation remains load-bearing for contextual names, even when aggregate entity counts look acceptable. | - -No broad replacement for the default prose/legal detection path has passed the -current repeated signature checks. The only strong performance result so far is -workload-scoped: deterministic rules for tightly bounded, rule-covered secret -scans. - -When DataDesigner message traces are enabled, interpret failed request counts -through `observed_non_bridge_*` metrics before drawing provider-reliability -conclusions. Across 13 local trace files, the local-vLLM -`SyncClientUnavailableError` rows were 104 near-zero-latency sync-to-async -bridge fallbacks with zero token usage; they are adapter accounting, not model -work. GLiNER `ProviderError` rows are different: the same trace set had 20 real -detector failures, which can invalidate otherwise faster detector-heavy -candidates. - -Do not expand deterministic rules into contextual names merely to recover the -failed candidates above. The rejected prose and legal runs lost labels such as -`first_name`, `field_of_study`, `state`, and `university`; these require context -and separate precision evidence. The rule layer should stay narrow unless a new -label has high-confidence syntax and false-positive tests. - -## Validator Chunk Tuning - -The detector validator can dominate replace-mode latency on records with many -candidate entities. Tune `Detect.validation_max_entities_per_call` and -`Detect.validation_excerpt_window_chars` together: - -- `validation_max_entities_per_call` controls how many candidate entities go - into each validator call. Lower values create more calls, but Anonymizer can - overlap those calls through the validator pool. -- `validation_excerpt_window_chars` controls how much text surrounds each - validation chunk. Lower values reduce prompt size, but can hide context the - validator needs for labels such as `date_of_birth`, `race_ethnicity`, or - legal roles. - -Run these sweeps per workload. A window that is safe for short biographies may -drop legal identifiers, and a legal-safe window may erase the speedup on short -records. - -Example config fragment: - -```yaml -configs: - - id: legal-vmax10-w160 - detect: - validation_max_entities_per_call: 10 - validation_excerpt_window_chars: 160 - entity_labels: [first_name, last_name, court_name, date, date_of_birth] - replace: - strategy: hash - digest_length: 12 -``` - -Use the aggregate analysis first: - -```bash -uv run python tools/measurement/analyze_benchmark_output.py \ - benchmark-runs/legal-window-sweep \ - --json -``` - -Then compare every faster candidate against a higher-context reference: - -```bash -uv run python tools/measurement/extract_signature_deltas.py \ - /tmp/reference/legal__default-window__r000.detection-artifacts.jsonl \ - /tmp/candidate/legal__vmax10-w160__r000.detection-artifacts.jsonl \ - --baseline-artifact-root /tmp/reference/artifacts \ - --candidate-artifact-root /tmp/candidate/artifacts \ - --baseline-config default-window \ - --candidate-config vmax10-w160 \ - --workload legal \ - --output /tmp/legal-vmax10-w160-deltas.csv \ - --format csv -``` - -Treat a candidate as unsafe until signature deltas are clean on repeated runs. -In one local vLLM check with two repetitions, a biography sample went from -24.6s with the default window to 17.8s with `vmax10/w80`, with all 50 stable -entity signatures preserved. A one-row legal sample went from 21.2s with the -default window to 13.2s with `vmax10/w160`, with all 28 stable signatures -preserved. Both candidates increased request and token counts, so the comparison -tool marks them for review instead of as automatic wins. - -The biography `vmax10/w80` result did not hold on the next five biography rows. -With `row_offset: 5`, median latency regressed from 31.8s to 33.6s, requests -from 20.0 to 43.0, and tokens from 44,367.0 to 68,407.5. The comparison also -lost stable `field_of_study` and `state` signatures, with an additional -unstable `university` loss, so the tool rejected the candidate. Recheck this -tuning on the target dataset because smaller windows can miss sensitive -attributes and because the extra parallel validator calls can overwhelm any -prompt-size savings. - -## Augmentation Ablation - -Use `experimental_detection_strategy: rules_guardrail_no_augment` to measure -what happens when the detector keeps GLiNER, validation, and deterministic rule -guardrails, but skips LLM augmentation. Treat this as an ablation, not as a -replacement for the default pipeline. - -In a local vLLM check with two repetitions, removing augmentation from the -two-row biography sample reduced work but consistently lost two stable -`first_name` signatures. The comparison tool rejected both the default-window -and `vmax10/w80` no-augmentation candidates. This indicates augmentation is -load-bearing for prose records where contextual names and quasi-identifiers -matter. - -The same ablation preserved all 28 stable signatures on a one-row legal sample. -With the default validation window, latency moved from 21.2s to 18.3s, requests -from 5 to 4, and tokens from 11,327.5 to 7,654. With `vmax10/w160`, latency -moved from 13.2s to 9.5s, requests from 8 to 7, and tokens from 16,604 to -12,881. Compared directly against the default-window baseline, the combined -legal candidate is faster but still needs review because validator chunking -increases requests and tokens. - -That legal no-augmentation result also failed to generalize to the next two -legal records. On `row_offset: 1` with two rows and two repetitions, comparing -`legal-noaugment-vmax10-w160` against the same-window full augmentation baseline -improved latency from 23.9s to 21.5s, requests from 28.0 to 26.0, and tokens -from 61,780.5 to 50,905.0, but the candidate lost two stable `first_name` -signatures and one unstable `date` signature. The comparison rejected it despite -the performance improvement. - -Use this ablation when `augmented_new_final_value_count` is near zero for the -target workload and repeated signature deltas are clean. Do not generalize a -single legal row to the rest of a legal dataset, and do not generalize legal -results to biography, support-ticket, shell-history, or mixed prose data without -rerunning the comparison on that workload. - -## Augmenter Routing and Temperature - -The detection validator and augmenter do different jobs. Keep them separable in -model configs when testing local endpoints: - -- validators benefit from deterministic JSON-oriented settings; -- augmenters may work better through a text alias, because DataDesigner - structured parsing can be fragile on local OpenAI-compatible endpoints; -- augmenter temperature changes can alter retry pressure and output shape, so - evaluate them with repeated signature comparison, not only entity counts. - -In one local vLLM biography run with two repetitions, keeping the validator on -`local-nemotron-json` while routing the augmenter to a text alias with -`temperature: 0.7` was the first prose candidate that passed the current safety -gate and improved performance. Median latency moved from 24.2s to 21.6s, -requests from 8 to 6, and tokens from 17,938.5 to 11,921. The comparison had no -baseline-only or unstable-lost signatures across 48 stable signatures, so the -tool marked it `candidate_viable`. - -The same routing/temperature candidate also held on a five-row biography slice -with two repetitions, though the gain was smaller. Median latency moved from -40.4s to 38.0s, requests from 21.0 to 20.5, and tokens from 43,367.5 to -43,043.0. It preserved all 114 stable baseline signatures; one candidate-only -`place_name` appeared in one repetition, so the comparison still marked the -candidate `candidate_viable`. - -This result did not generalize cleanly to the next five biography rows. On a -second slice using `row_offset: 5`, the same candidate was rejected: median -latency moved from 41.0s to 47.5s, requests from 21.0 to 21.5, and tokens were -effectively unchanged at 44,708.0 to 44,670.0. More importantly, the comparison -lost one stable `university` signature and had unstable losses for -`field_of_study` and `university`. Treat this routing as an experiment to -retest on each workload, not as a default candidate yet. -When the two temp-0.7 config IDs are grouped with `--config-aliases`, the -biography family result is `conflicting_evidence`: three comparison rows, two -viable rows, one reject row, best latency -10.4%, worst latency +16.0%, and -stable losses for `field_of_study` and `university`. - -On a two-row legal slice with two repetitions, the same augmenter routing did -not materially improve latency or requests: median latency moved from 27.3s to -27.5s, requests stayed at 8, and tokens moved from 24,460.0 to 24,296.5. It -preserved stable signatures, but the rule-guardrail legal strategy remains -review-gated and this routing should be treated as neutral for that sample. -Also compare it against prompt-only changes such as -`prose_augment_focus`: in the same biography slice, prose-focused augmentation -preserved signatures and reduced requests/tokens, but wall time increased from -24.2s to 26.4s, so the tool kept it in review. - -Parser compatibility is a separate concern. A text-model suite without -`dd_parser_compat: raw_json` produced a failed biography case in local testing; -the raw-parser compatibility mode fixed that failure, but increased latency and -tokens on both biography and legal slices. Treat raw-parser compatibility as an -endpoint interoperability fix, not as a performance optimization. - -## Detector-Only Ablation - -Use `experimental_detection_strategy: detector_only` to measure the lower bound -of the detection phase when GLiNER output is trusted directly and only local -span finalization runs afterward. Use -`experimental_detection_strategy: rules_guardrail_detector_only` to measure the -same path with deterministic high-confidence rule spans unioned into the final -entity set. Both remove LLM validation and LLM augmentation from the detection -phase, so they are diagnostic ablations rather than deployable strategies. - -The comparison tool marks these candidates with -`candidate_skips_llm_validation`, which forces `safety_verdict: review` even -when entity signatures match on the sampled records. The rule-guardrail variant -also gets `candidate_uses_rule_entities` when rule spans survive. Promote either -path only if independent precision checks show false positives are acceptable -for the target workflow and repeated signature deltas remain clean. - -In a one-row cross-workload smoke check, detector-only improved latency and -token counts on biography, legal, and shell-secrets slices, but all three -candidates were rejected by signature comparison. Biography moved from 13.7s to -0.9s and lost two baseline `first_name` signatures; legal moved from 15.3s to -1.0s and lost one `nationality` signature while increasing final entity count -from 22 to 39; shell-secrets moved from 6.6s to 0.8s and still lost one -baseline `api_key` signature. This is a useful lower-bound measurement, but it -shows why validation/augmentation or deterministic rule coverage remain -load-bearing for anonymization. - -The `rules_guardrail_detector_only` variant did not fix prose/legal losses in -the same one-row check: biography still lost two `first_name` signatures and -legal still lost one `nationality` signature. It did preserve all shell-secret -baseline signatures while moving latency from 4.6s to 0.8s, requests from 4 to -1, and tokens from 3,969 to 85. Treat that as a narrow structured-secret -candidate. It remains review-gated because it skips LLM validation and relies -on deterministic rules. - -On the three-row shell-secrets slice with three successful candidate -repetitions, `rules_guardrail_detector_only` preserved all stable baseline -signatures while moving median latency from 7.2s to 3.2s, requests from 12 to 4, -and tokens from 11,019 to 198. The final entity set came from 9 detector spans -and 3 rule spans. It still had local GLiNER `ProviderError` health-check -failures and remains slower than `rules_only`, which used zero model calls and -zero tokens on the same fully rule-covered labels. - -## Deterministic Rules for Structured Secrets - -Use `experimental_detection_strategy: rules_only` only when the workload is a -bounded secret-scanning task and every requested label is covered by the -deterministic rules. Current rule coverage is intentionally narrow: -`api_key`, `date_of_birth`, `email`, `http_cookie`, `organization_name`, -`password`, `pin`, `religious_belief`, `street_address`, `unique_id`, `url`, -and `user_name`. The `http_cookie`, `pin`, `unique_id`, and `user_name` rules -cover keyed or command-style structured patterns only. They do not recognize -free-form names, narrative identifiers, or arbitrary prose mentions. - -The zero-model detector is implemented by -`EntityDetectionWorkflow.detect_with_high_confidence_rules()`. The benchmark -strategy delegates to that internal engine method, but no user-facing config -selects it outside the benchmark harness. - -Use `experimental_detection_strategy: rules_covered_or_default` for mixed -benchmark suites where some configs are structured-secret scans and others -include contextual labels such as `person`, `organization_name`, or -`street_address`. It runs the same zero-model path for structured fast-lane -cases, but does not attempt a DataDesigner-free replacement for contextual -prose or legal records. - -A mixed local-vLLM smoke run on June 8, 2026 used two synthetic shell-secret -rows plus one biography and one legal row. The first shell run found that -`rules_covered_or_default` missed a sudo stdin password that default -augmentation caught; after adding a narrow `echo "..." | sudo -S` rule, the -rerun preserved all 9 shell signatures with zero model requests and zero tokens. -The biography and legal configs requested `person`, so they correctly fell back -to model-backed detection and matched default entity counts. Keep this strategy -signature-gated: a missed default signature is a rule-quality bug or a fallback -signal, not acceptable drift. - -A follow-up three-repetition shell-only run kept all 9 candidate signatures -stable while default detection had 8 stable signatures because one `api_key` -was absent from one default repetition. The comparison still returned -`candidate_verdict=review` because the candidate had no detector-sourced final -spans. This is the intended behavior: repeated clean signatures can justify a -workload-scoped fast lane, but rule-only provenance should remain an explicit -review decision. - -For substitute workloads, use -`experimental_replacement_strategy: local_structured_substitute` to bypass the -DataDesigner replacement-generator call. The local substitute generator writes a -normal replacement map and stamps `_replacement_map_source=local_structured` so -measurement estimates do not count a replacement-map LLM call. It only supports -structured labels. Pair it with `rules_covered_or_default` when all requested -labels are also rule-covered; otherwise detection can still use the default -model-backed path while replacement-map generation stays local. If a config -includes `person` or another contextual label, preflight fails instead of -silently producing poor local substitutes. - -On the current four-row non-shell structured-secret suite, -DataDesigner-backed substitute preserved 17 entities with zero original-value -leaks but had repeat-3 median latency 38.1s, 4 requests, and 13,967 tokens in -replacement-map generation. The local structured substitute arm preserved the -same 17 entities, had zero original-value leaks, and had repeat-3 median latency -0.005s with 0 requests and 0 tokens. The repeat output used for this result is -`/tmp/anonymizer-perf-goal/structured-secrets-local-substitute-repeat3`. - -The local substitute backend can also combine with model-backed detection. In -the first one-row audit-style structured-identifier smoke, `api_key`, -`http_cookie`, `pin`, `unique_id`, and `user_name` were not all rule-covered, so -detection fell back to the default model path in both arms. The local substitute -arm still removed the replacement-map DataDesigner workflow, moving pipeline -latency from 53.6s to 33.0s, requests from 5 to 4, and tokens from 11,547 to -7,694 while preserving the same 5 final entities and zero original-value leaks. -The output used for that result is -`/tmp/anonymizer-perf-goal/structured-identifiers-local-substitute`. - -Use `replay_replacement_strategies.py` when you need to hold detection fixed and -isolate replacement-map generation: - -```bash -uv run python tools/measurement/replay_replacement_strategies.py \ - /tmp/anonymizer-perf-goal/structured_identifiers_varied.csv \ - --text-column text \ - --labels api_key,http_cookie,password,pin,unique_id,user_name \ - --nrows 5 \ - --replacement-repetitions 3 \ - --model-configs /stable-cache/anonymizer/local-vllm-json-models.yaml \ - --model-providers /stable-cache/anonymizer/local-vllm-providers.yaml \ - --dd-parser-compat raw_json \ - --comparison-output /tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay-comparison.csv \ - --json -``` - -The current fixed-trace replay detected 21 entities once, then ran both -substitute backends on that same trace. DataDesigner substitute took 6.04s for -replacement-map generation; local structured substitute took 0.003s. Both arms -produced 21 replacements, zero missing replacement-map entries, zero -original-value leaks, and zero synthetic-original collisions. The JSON output -used for this result is -`/tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay.json`. -A rerun after adding an LLM replacement-map collision guard produced the same -21/21 complete, leak-free, collision-free result. In that rerun, DataDesigner -substitute took 6.22s and local structured substitute took 0.003s; the updated -JSON output is -`/tmp/anonymizer-perf-goal/structured-identifiers-replacement-replay-after-llm-guard.json`. -When `--replacement-repetitions` is greater than one, detection still runs once -and only the substitute backends repeat. The summary rows aggregate replacement -latency, missing-map counts, leaks, collisions, duplicate synthetics, and source -counts across those repeated backend passes. When `--comparison-output` is set, -the replay tool also writes a one-row comparison CSV with -`value_protection_verdict`, `signature_parity_verdict`, `safety_verdict`, -`performance_verdict`, and `candidate_verdict`. This lets -`screen_strategy_comparisons.py` include fixed-trace replacement evidence -alongside normal pairwise benchmark comparisons. Missing local replacement-map -entries, original-value leaks, and synthetic-original collisions fail the replay -candidate even if the elapsed-time delta is large. -If the DD substitute baseline misses replacement-map entries or leaks original -values while the local backend covers them, the replay comparison emits -candidate-covers-baseline flags and the strategy screen recommends -`candidate_covers_baseline_defects` for all-review groups of that shape. Treat -that as a baseline-independent safety-rule prompt: inspect the candidate's -missing, leak, collision, duplicate-synthetic, and supported-label columns -rather than requiring exact parity with a known-flawed substitute baseline. - -After adding narrow keyed rules for `http_cookie`, `pin`, `unique_id`, and -`user_name`, the same audit-style label set can now short-circuit both -detection and local replacement for a structured record. In a one-row local -vLLM check, default detection plus DataDesigner substitute found 4 entities and -missed the `unique_id`; the rules/local arm found 5 entities, had zero -original-value leaks, and moved pipeline latency from 9.2s to 0.005s, requests -from 5 to 0, and tokens from 6,075 to 0. The pairwise comparison remains -`review`, not `candidate_viable`, because the candidate has rule-only -provenance and the evidence is a single row. The output used for this result is -`/tmp/anonymizer-perf-goal/structured-identifiers-expanded-rules`. - -On a three-row shell-secrets slice with labels `[api_key, password, email, url]`, -`rules_only` preserved all 12 stable signatures across three repetitions while -moving median latency from 7.2s to 0.004s, requests from 12 to 0, and tokens -from 11,019 to 0 in the refreshed failure-aware comparison. The comparison tool -still marks the candidate for review because it has no contextual detector spans -and skips LLM validation. That is the right gate: a pure rule strategy is -acceptable only when missing contextual spans is part of the test contract. - -`rules_seed_no_augment` preserved the same 12 signatures and reduced median -tokens from 11,017 to 7,732, but median latency moved from 8.0s to 8.5s on the -same slice. In this run, seeding rules into the validator path reduced token -work but did not improve end-to-end latency. Prefer `rules_only` for tightly -scoped secret scans; prefer rule guardrails plus contextual detection for prose, -legal text, support tickets, and mixed records. - -Use `rules_filter_guardrail` as the mixed-workload version of that idea. It -keeps LLM augmentation, but rule-covered spans are not sent to the seed -validator. The rule spans are reinserted before augmentation so the augmenter -does not waste work rediscovering them. This is a candidate for datasets that -combine structured secrets with contextual prose; it still needs repeated -signature comparison because filtered detector spans no longer receive the -LLM validator's reclassification/drop pass. In a local shell-secrets smoke run, -the completed candidate repetition reduced seed validation candidates to zero -and preserved all stable signatures, but the repeated comparison rejected it -because a later candidate case hit a GLiNER health-check rate limit. - -## Metric Interpretation +## Metric interpretation Use metrics as signals, not as a single score. Latency and throughput: -- `elapsed_sec`: wall time for a measured stage or DataDesigner workflow. - Staged DD-free detection cases report end-to-end case wall time here. -- `rows_per_sec`: completed output rows per second for the measured block. -- `tokens_per_sec`: observed total tokens per second when token usage exists. -- `text_length_tokens_bucket`: a coarse text-size bucket for comparing similar - inputs without storing text. -- `record_count` and `input_text_tokens_total`: case-level workload size - derived from record measurements. These are independent of provider-reported - token usage. -- `records_per_pipeline_sec` and `input_text_tokens_per_pipeline_sec`: dataset - throughput normalized by the measured Anonymizer pipeline stage. The matching - `*_per_ndd_sec` fields use summed DataDesigner workflow wall time instead. -- `input_text_tokens_per_endpoint_sec` and - `input_text_tokens_per_gpu_sec`: optional topology-normalized dataset - throughput. These are populated only when benchmark run tags provide portable - topology counts such as `endpoint_count` or `gpu_count`. - -LLM usage: - -- `observed_input_tokens`, `observed_output_tokens`, and - `observed_total_tokens`: provider-reported usage when available. Missing or - zero values mean the provider path did not expose usage, not necessarily that - no tokens were consumed. -- `observed_total_requests`, `observed_successful_requests`, and - `observed_failed_requests`: request counts when DataDesigner or a native - benchmark model workflow exposes them. -- `observed_failed_request_rate`: failed requests divided by total requests. - Case and group tables expose this as the end-to-end retry pressure for a - strategy; model usage tables expose it per workflow/model. Sort by this - together with total token count to find retry-heavy workflow/model pairs. -- `observed_bridge_fallback_requests`: DataDesigner sync-to-async bridge - fallbacks, derived from message traces when `--dd-trace` is enabled. Treat - these as adapter accounting, not provider/model failures. -- `model_elapsed_sec`: staged DD-free detection only; sum of direct model-call - durations for seed, validation, and augmentation. This stays `0.0` for fully - local rule-covered runs even when `elapsed_sec` records nonzero local work. -- `observed_non_bridge_total_requests`, - `observed_non_bridge_failed_requests`, and - `observed_non_bridge_failed_request_rate`: request metrics after subtracting - sync-to-async bridge fallbacks. Prefer these fields over raw failed-request - counts when diagnosing provider reliability from traced runs. -- `nominal_llm_call_count`: an internal estimate based on the Anonymizer - pipeline shape. Treat it as expected work, not observed provider traffic. -- `seed_validation_candidate_count`: number of detector candidates sent to the - seed validator, derived from detection artifacts without storing values. -- `estimated_seed_validation_chunk_count`: estimated validator chunk count after - applying `detect.validation_max_entities_per_call`. If this does not change - between benchmark configs, chunk-size experiments are not expected to reduce - successful validator calls. - -Entity and quality metrics: - -- `final_entity_count`: entities that survive detection and validation. -- `original_value_leak_count`: number of final entity original values that - still appear verbatim in the replaced or rewritten output text. This is a - conservative replace/rewrite safety signal and stores only counts, not raw - values. -- `original_value_leak_label_counts`: per-label counts for those surviving - original values. The analysis tables aggregate these as - `original_value_leak_record_count`, `sum_original_value_leak_count`, - `leaking_case_count`, and `median_original_value_leak_count`. -- `replacement_missing_final_entity_count`: number of final entity occurrences - whose original value has no entry in the replacement map. This is sanitized - replacement-map coverage, not raw leakage text. -- `replacement_missing_final_value_count`: number of unique final entity values - with no replacement-map entry. Compare it with - `original_value_leak_count` to distinguish omitted replacement-map entries - from replacement-application or metric issues. -- `replacement_missing_final_entity_label_counts`: per-label counts for missing - replacement-map coverage. -- `replacement_synthetic_original_collision_count`: number of final entity - occurrences whose original value was reused as a synthetic replacement value - elsewhere in the same record. This is a substitute safety signal; map - coverage can be complete while this is nonzero. -- `replacement_synthetic_original_collision_value_count`: number of unique - protected original values reused as synthetic replacement values. -- `replacement_synthetic_original_collision_label_counts`: per-label counts for - synthetic-original collisions. -- `artifact_final_detector_entity_count`, - `artifact_final_rule_entity_count`, and +- `elapsed_sec`: wall time for a measured stage or workflow. +- `pipeline_elapsed_sec`: end-to-end Anonymizer wall time for a case. +- `records_per_pipeline_sec`: completed input records per pipeline second. +- `input_text_tokens_per_pipeline_sec`: input text tokens processed per + pipeline second. + +Model work: + +- `observed_total_requests`: measured model requests from DataDesigner or direct + model workflow records. +- `observed_total_tokens`: measured input plus output tokens. +- `observed_failed_requests`: provider-level failed requests. +- `observed_bridge_fallback_requests`: sync-client fallback requests recorded + from DataDesigner traces. +- `observed_non_bridge_failed_requests`: failed requests after subtracting + sync-client bridge fallbacks. Prefer this field when judging endpoint + reliability from trace-enabled runs. + +Detection artifacts: + +- `seed_entity_count`: detector or direct-seed candidate count before + validation. +- `seed_validation_candidate_count`: candidates sent to validation. +- `estimated_seed_validation_chunk_count`: estimated validator chunks from the + active validation chunk size. +- `augmented_entity_count`: augmenter suggestions. +- `augmented_new_final_value_count`: augmenter suggestions that add values not + already present in the seed/final set. +- `artifact_final_detector_entity_count` and `artifact_final_augmenter_entity_count`: final entity source counts derived - from detection artifact sidecars. These are useful safety signals for - rule-backed benchmark strategies. + from detection artifact sidecars. - `artifact_final_entity_signature_count` and `artifact_final_entity_signature_hashes`: opaque final-span signatures derived - from detection artifacts. `artifact_final_entity_signature_labels` maps each - hash to a label, but still does not include raw entity values. Use these to - catch and triage safety regressions where total entity count is unchanged but - the candidate lost a baseline-protected span. + from detection artifacts. These do not include raw entity values. + +Safety and replacement: + +- `original_value_leak_count`: count of protected original values still present + in replaced output. +- `replacement_missing_final_entity_count`: final entity occurrences whose + original value has no replacement-map entry. +- `replacement_missing_final_value_count`: unique final entity values with no + replacement-map entry. +- `replacement_synthetic_original_collision_count`: final entity occurrences + whose original value was reused as a synthetic replacement value elsewhere in + the same record. - `baseline_only_candidate_covered_signature_count`, `baseline_only_candidate_overlapping_signature_count`, and `baseline_only_candidate_uncovered_signature_count`: comparison-only fields from `compare_strategy_pairs.py`. These split exact signature deltas into - baseline spans protected by a containing candidate span, protected by a - high-overlap or small keyed-boundary candidate span, or not protected by any - candidate span metadata. Overlapping coverage sets `span_boundary_mismatch` - and keeps the candidate in review; uncovered signatures set - `entity_signature_loss` and fail the safety verdict. -- `baseline_only_candidate_label_mismatch_signature_count`: comparison-only - field for baseline signatures whose raw span is covered by the candidate, but - under a different label. This sets `covered_label_mismatch` and keeps the - candidate in review because the value is protected but label semantics may no - longer match replacement/audit expectations. -- `value_protection_verdict`: comparison-only pass/review/fail verdict focused - on whether candidate output still protects baseline values. Covered - label-mismatch spans can still pass this axis because the sensitive value is - protected, while uncovered signatures, candidate leaks, and candidate case - failures fail it. -- `signature_parity_verdict`: comparison-only pass/review/fail verdict focused - on exact baseline signature semantics. Covered label mismatches and boundary - mismatches review-gate this axis even when `value_protection_verdict` passes. - This split is useful for DataDesigner-free experiments: a candidate can be a - plausible protection backend while still requiring label-policy review before - it can replace a DataDesigner-backed baseline. -- `final_entity_label_counts`: per-label entity counts serialized as JSON in - exported tabular files. -- `ground_truth_*` and `entity_*`: exact value+label precision, recall, F1, - false positives, and false negatives when the input includes one of the - supported ground-truth entity columns. -- `entity_relaxed_*`: span-overlap precision, recall, and F1. The - label-compatible variants require both span overlap and equivalent labels, - while the non-label-compatible relaxed metrics only ask whether a - ground-truth span was protected by any detected span. -- `empty_detection_count`, `empty_detection_rate`, - `empty_detection_with_ground_truth_count`, and - `empty_detection_with_ground_truth_rate`: diagnostics for records where the - detector returned no final entities. The ground-truth-specific fields are the - important safety signal when a benchmark includes labels. -- `utility_score`, `leakage_mass`, `weighted_leakage_rate`, - `needs_repair`, and `needs_human_review`: rewrite-mode evaluation fields. - These are null for replace-mode runs. - -Error and reliability metrics: - -- `failed_record_count`: records dropped by a DataDesigner workflow. -- `status`: completion state for a stage or workflow. -- `case_failed`: true when a benchmark case has any error-status stage or - DataDesigner workflow measurement. -- `error_stage_count`, `error_ndd_workflow_count`, and - `error_model_workflow_count`: error-status measurement rows counted per case. -- `failed_case_count` and `failed_case_rate`: group-level failed-case count and - rate for a workload/config/strategy. -- `summary.json` case errors: runner-level failures, such as invalid inputs or - model endpoint failures. - -## Reading Results Safely - -Compare like with like. A shell-history workload, a support-ticket workload, -and a legal-document workload stress different parts of Anonymizer. Group by -`workload_id` before drawing conclusions about model routing, speculative -decoding, validation chunk size, or rewrite repair settings. - -Record-level rows describe input shape and output quality, not per-record wall -time. Stage and workflow rows carry timing. To explain a slow run, first find -the slow stage, then inspect the records in that run for text length, entity -count, nominal call count, and rewrite repair signals. - -When token or request fields are missing, check `ndd_workflow.model_usage` and -`model_workflow.model_usage`. The measurement layer records deeper provider -usage only when the underlying executor returns it. + covered, boundary-overlap, and uncovered losses. +- `candidate_verdict`: `candidate_viable`, `review`, or `reject`. + +Treat `candidate_viable` as a promotion candidate, not as an automatic default. +It means the sampled comparison passed the current gates and improved at least +one performance metric without regressing another. Re-run candidates on the +target workload family, with repetitions, before changing production defaults. diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py index f4768de6..056c2287 100644 --- a/tools/measurement/analyze_benchmark_output.py +++ b/tools/measurement/analyze_benchmark_output.py @@ -99,9 +99,6 @@ class CaseAnalysisRow(BaseModel): topology_shard_count: float | None = None input_text_tokens_per_endpoint_sec: float | None = None input_text_tokens_per_gpu_sec: float | None = None - route_total_row_count: float | None = None - route_rule_row_count: float | None = None - route_fallback_row_count: float | None = None final_entity_count: float | None = None empty_detection_count: int = 0 empty_detection_rate: float | None = None @@ -144,7 +141,6 @@ class CaseAnalysisRow(BaseModel): augmented_new_final_value_count: float | None = None artifact_final_entity_count: float | None = None artifact_final_detector_entity_count: float | None = None - artifact_final_rule_entity_count: float | None = None artifact_final_augmenter_entity_count: float | None = None artifact_final_entity_signature_count: float | None = None artifact_final_entity_signature_hashes: list[str] = Field(default_factory=list) @@ -194,9 +190,6 @@ class GroupAnalysisRow(BaseModel): median_topology_shard_count: float | None = None median_input_text_tokens_per_endpoint_sec: float | None = None median_input_text_tokens_per_gpu_sec: float | None = None - median_route_total_row_count: float | None = None - median_route_rule_row_count: float | None = None - median_route_fallback_row_count: float | None = None median_final_entity_count: float | None = None total_empty_detection_count: int = 0 empty_detection_rate: float | None = None @@ -238,7 +231,6 @@ class GroupAnalysisRow(BaseModel): median_augmented_new_final_value_count: float | None = None median_artifact_final_entity_count: float | None = None median_artifact_final_detector_entity_count: float | None = None - median_artifact_final_rule_entity_count: float | None = None median_artifact_final_augmenter_entity_count: float | None = None median_artifact_final_entity_signature_count: float | None = None @@ -511,9 +503,6 @@ def _build_case_row( measurement_rows, input_text_tokens_per_pipeline_sec=input_text_tokens_per_pipeline_sec, ), - route_total_row_count=_sum_or_none(model_rows, "route_total_row_count"), - route_rule_row_count=_sum_or_none(model_rows, "route_rule_row_count"), - route_fallback_row_count=_sum_or_none(model_rows, "route_fallback_row_count"), final_entity_count=final_entity_count, **_case_empty_detection_metrics(record_rows, record_count=record_count), **_case_ground_truth_metrics(record_rows, final_entity_count=final_entity_count), @@ -754,7 +743,6 @@ def _case_artifact_metrics( "augmented_new_final_value_count": _sum_or_none(artifact_rows, "augmented_new_final_value_count"), "artifact_final_entity_count": _sum_or_none(artifact_rows, "final_entity_count"), "artifact_final_detector_entity_count": _sum_or_none(artifact_rows, "final_source_counts.detector"), - "artifact_final_rule_entity_count": _sum_or_none(artifact_rows, "final_source_counts.rule"), "artifact_final_augmenter_entity_count": _sum_or_none(artifact_rows, "final_source_counts.augmenter"), "artifact_final_entity_signature_count": _signature_count(artifact_rows, signature_hashes=signature_hashes), "artifact_final_entity_signature_hashes": signature_hashes, @@ -1294,9 +1282,6 @@ def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysi median_topology_shard_count=_median_or_none(group, "topology_shard_count"), median_input_text_tokens_per_endpoint_sec=_median_or_none(group, "input_text_tokens_per_endpoint_sec"), median_input_text_tokens_per_gpu_sec=_median_or_none(group, "input_text_tokens_per_gpu_sec"), - median_route_total_row_count=_median_or_none(group, "route_total_row_count"), - median_route_rule_row_count=_median_or_none(group, "route_rule_row_count"), - median_route_fallback_row_count=_median_or_none(group, "route_fallback_row_count"), median_final_entity_count=_median_or_none(group, "final_entity_count"), total_empty_detection_count=total_empty_detection_count, empty_detection_rate=_safe_ratio(total_empty_detection_count, total_record_count), @@ -1359,7 +1344,6 @@ def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysi median_augmented_new_final_value_count=_median_or_none(group, "augmented_new_final_value_count"), median_artifact_final_entity_count=_median_or_none(group, "artifact_final_entity_count"), median_artifact_final_detector_entity_count=_median_or_none(group, "artifact_final_detector_entity_count"), - median_artifact_final_rule_entity_count=_median_or_none(group, "artifact_final_rule_entity_count"), median_artifact_final_augmenter_entity_count=_median_or_none(group, "artifact_final_augmenter_entity_count"), median_artifact_final_entity_signature_count=_median_or_none(group, "artifact_final_entity_signature_count"), ) diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py index 225cc045..737d439a 100644 --- a/tools/measurement/analyze_staged_detection_output.py +++ b/tools/measurement/analyze_staged_detection_output.py @@ -40,9 +40,6 @@ class LogFormat(StrEnum): _log_format = LogFormat.plain -_FAST_LANE_MIN_CASES = 3 - - class StagedCaseAnalysisRow(BaseModel): source_path: str case_id: str @@ -54,7 +51,6 @@ class StagedCaseAnalysisRow(BaseModel): model_elapsed_sec: float | None = None model_phase_count: int = 0 model_request_count: int = 0 - rule_covered_label_set: bool = False prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 @@ -82,7 +78,6 @@ class StagedGroupAnalysisRow(BaseModel): completed_case_count: int = 0 error_case_count: int = 0 failed_case_rate: float | None = None - rule_covered_case_count: int = 0 elapsed_sec_sum: float | None = None elapsed_sec_mean: float | None = None model_elapsed_sec_sum: float | None = None @@ -100,8 +95,6 @@ class StagedGroupAnalysisRow(BaseModel): direct_only_final_entity_signature_count_sum: int = 0 baseline_shared_signature_rate: float | None = None baseline_loss_signature_rate: float | None = None - fast_lane_verdict: str = "review" - flags: list[str] = Field(default_factory=list) class LabelDeltaAnalysisRow(BaseModel): @@ -213,7 +206,6 @@ def _build_case_row(record: dict[str, Any], *, source_path: Path) -> StagedCaseA model_elapsed_sec=_optional_float(record.get("model_elapsed_sec")), model_phase_count=_int_value(record.get("model_phase_count")), model_request_count=_int_value(record.get("model_request_count")), - rule_covered_label_set=bool(record.get("rule_covered_label_set")), **_usage_fields(_dict_value(record.get("total_usage"))), **_entity_count_fields(record), baseline_final_entity_signature_count=baseline_count, @@ -264,22 +256,12 @@ def _build_group_row(seed_source: str | None, rows: list[StagedCaseAnalysisRow]) shared_total = _sum_optional_int(rows, "shared_final_entity_signature_count") baseline_only_total = _sum_optional_int(rows, "baseline_only_final_entity_signature_count") model_request_count = sum(row.model_request_count for row in rows) - rule_covered_count = sum(1 for row in rows if row.rule_covered_label_set) - flags = _fast_lane_flags( - case_count=case_count, - error_count=error_count, - baseline_total=baseline_total, - baseline_only_total=baseline_only_total, - model_request_count=model_request_count, - rule_covered_count=rule_covered_count, - ) return StagedGroupAnalysisRow( seed_source=seed_source, case_count=case_count, completed_case_count=case_count - error_count, error_case_count=error_count, failed_case_rate=_rate(error_count, case_count), - rule_covered_case_count=rule_covered_count, elapsed_sec_sum=_sum_optional_float(rows, "elapsed_sec"), elapsed_sec_mean=_mean_optional_float(rows, "elapsed_sec"), model_elapsed_sec_sum=_sum_optional_float(rows, "model_elapsed_sec"), @@ -299,8 +281,6 @@ def _build_group_row(seed_source: str | None, rows: list[StagedCaseAnalysisRow]) ), baseline_shared_signature_rate=_rate(shared_total, baseline_total), baseline_loss_signature_rate=_rate(baseline_only_total, baseline_total), - fast_lane_verdict=_fast_lane_verdict(flags), - flags=flags, ) @@ -308,39 +288,6 @@ def _group_sort_key(item: tuple[str | None, list[StagedCaseAnalysisRow]]) -> str return item[0] or "" -def _fast_lane_flags( - *, - case_count: int, - error_count: int, - baseline_total: int, - baseline_only_total: int, - model_request_count: int, - rule_covered_count: int, -) -> list[str]: - flags: list[str] = [] - if case_count < _FAST_LANE_MIN_CASES: - flags.append("too_few_cases") - if error_count: - flags.append("case_errors") - if baseline_total == 0: - flags.append("missing_baseline_comparison") - if baseline_only_total: - flags.append("baseline_signature_loss") - if model_request_count: - flags.append("uses_model") - if rule_covered_count != case_count: - flags.append("not_fully_rule_covered") - return flags - - -def _fast_lane_verdict(flags: list[str]) -> str: - if "case_errors" in flags or "baseline_signature_loss" in flags: - return "reject" - if not flags: - return "fast_lane_candidate" - return "review" - - def build_label_delta_rows(cases: list[StagedCaseAnalysisRow]) -> list[LabelDeltaAnalysisRow]: counts: Counter[tuple[str | None, str, str]] = Counter() for case in cases: @@ -477,7 +424,6 @@ def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDe lost = _top_labels(label_deltas, seed_source=group.seed_source, delta_type="baseline_only") return ( f"- {label}: cases={group.case_count}, errors={group.error_case_count}, " - f"verdict={group.fast_lane_verdict}, flags={_label_count_summary(group.flags)}, " f"elapsed_sum={_fmt_float(group.elapsed_sec_sum)}s, " f"model_elapsed_sum={_fmt_float(group.model_elapsed_sec_sum)}s, " f"requests={group.model_request_count_sum}, tokens={group.total_tokens_sum}, " @@ -487,11 +433,6 @@ def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDe f"direct_only={group.direct_only_final_entity_signature_count_sum}, lost_labels={lost}" ) - -def _label_count_summary(items: list[str]) -> str: - return "[]" if not items else "[" + ", ".join(items) + "]" - - def _top_labels(label_deltas: list[LabelDeltaAnalysisRow], *, seed_source: str | None, delta_type: str) -> str: matches = [delta for delta in label_deltas if delta.seed_source == seed_source and delta.delta_type == delta_type] if not matches: diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py index a4d9550d..bdbad0d5 100644 --- a/tools/measurement/compare_strategy_pairs.py +++ b/tools/measurement/compare_strategy_pairs.py @@ -5,12 +5,12 @@ Usage: uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.csv \ - --baseline-strategy no_augment --candidate-strategy rules_filter_guardrail_no_augment + --baseline-strategy default --candidate-strategy detector_native_validate_no_augment uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.parquet \ --baseline-config default --candidate-config no-augment --output comparisons.csv uv run python tools/measurement/compare_strategy_pairs.py baseline/case_analysis.csv \ --candidate-case-analysis candidate/case_analysis.csv \ - --baseline-strategy no_augment --candidate-strategy rules_guardrail_no_augment + --baseline-strategy default --candidate-strategy native_single_pass """ from __future__ import annotations @@ -165,8 +165,6 @@ class ComparisonRow(BaseModel): augmented_new_final_value_count_delta: float | None = None baseline_detector_entity_count: float | None = None candidate_detector_entity_count: float | None = None - baseline_rule_entity_count: float | None = None - candidate_rule_entity_count: float | None = None baseline_augmenter_entity_count: float | None = None candidate_augmenter_entity_count: float | None = None baseline_only_final_entity_signature_count: int | None = None @@ -444,7 +442,6 @@ def _single_string(rows: pd.DataFrame, column: str) -> str | None: "augmented_entity_count", "augmented_new_final_value_count", "artifact_final_detector_entity_count", - "artifact_final_rule_entity_count", "artifact_final_augmenter_entity_count", ] @@ -521,8 +518,6 @@ def _source_counts(baseline: dict[str, object], candidate: dict[str, object]) -> return { "baseline_detector_entity_count": _optional_float(baseline.get("artifact_final_detector_entity_count")), "candidate_detector_entity_count": _optional_float(candidate.get("artifact_final_detector_entity_count")), - "baseline_rule_entity_count": _optional_float(baseline.get("artifact_final_rule_entity_count")), - "candidate_rule_entity_count": _optional_float(candidate.get("artifact_final_rule_entity_count")), "baseline_augmenter_entity_count": _optional_float(baseline.get("artifact_final_augmenter_entity_count")), "candidate_augmenter_entity_count": _optional_float(candidate.get("artifact_final_augmenter_entity_count")), "baseline_original_value_leak_label_counts": _coerce_count_map( @@ -608,8 +603,6 @@ def _comparison_flags( _append_if_positive(flags, metrics, "observed_total_requests_delta", "request_increase") if _candidate_lacks_detector_entities(metrics): flags.append("no_candidate_detector_entities") - if _optional_float(metrics.get("candidate_rule_entity_count")): - flags.append("candidate_uses_rule_entities") if candidate_strategy in _SKIPS_LLM_VALIDATION_STRATEGIES: flags.append("candidate_skips_llm_validation") if _replacement_only_detection_instability( @@ -661,7 +654,7 @@ def _stable_signature_loss_metric(metrics: dict[str, object]) -> str: return "baseline_stable_candidate_unstable_final_entity_signature_count" -_SKIPS_LLM_VALIDATION_STRATEGIES = {"detector_only", "rules_guardrail_detector_only", "rules_only"} +_SKIPS_LLM_VALIDATION_STRATEGIES = {"detector_only"} def _has_metric_pair(metrics: dict[str, object], name: str) -> bool: @@ -739,7 +732,6 @@ def _safety_verdict(metrics: dict[str, object]) -> SafetyVerdict: return SafetyVerdict.fail if flags & { "no_candidate_detector_entities", - "candidate_uses_rule_entities", "candidate_skips_llm_validation", "failed_request_increase", "bridge_fallback_increase", @@ -812,7 +804,6 @@ def _candidate_lacks_detector_entities(metrics: dict[str, object]) -> bool: def _known_non_detector_candidate_count(metrics: dict[str, object]) -> float | None: known_counts = [ - _optional_float(metrics.get("candidate_rule_entity_count")), _optional_float(metrics.get("candidate_augmenter_entity_count")), ] if all(value is None for value in known_counts): diff --git a/tools/measurement/detection_strategies.py b/tools/measurement/detection_strategies.py index a8955cd9..8c45c6ab 100644 --- a/tools/measurement/detection_strategies.py +++ b/tools/measurement/detection_strategies.py @@ -6,7 +6,6 @@ from __future__ import annotations import json -import re import time from collections import Counter from collections.abc import Callable, Iterator @@ -18,7 +17,7 @@ import pandas as pd from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig from data_designer.config.models import ModelConfig from dd_parser_compat import _load_embedded_json from direct_detection_probe import DirectDetectionRequest, DirectGenerationRequest, PromptMode, build_direct_prompt @@ -45,17 +44,14 @@ from anonymizer.engine.constants import ( COL_AUGMENTED_ENTITIES, COL_DETECTED_ENTITIES, - COL_INITIAL_TAGGED_TEXT, COL_MERGED_ENTITIES, COL_RAW_DETECTED, COL_SEED_ENTITIES, COL_SEED_ENTITIES_JSON, COL_SEED_VALIDATION_CANDIDATES, - COL_TAG_NOTATION, COL_TAGGED_TEXT, COL_TEXT, COL_VALIDATED_ENTITIES, - COL_VALIDATED_SEED_ENTITIES, COL_VALIDATION_DECISIONS, _jinja, ) @@ -73,14 +69,10 @@ EntitySpan, build_tagged_text, expand_entity_occurrences, - get_tag_notation, - parse_raw_entities, resolve_overlaps, ) -from anonymizer.engine.detection.rules import detect_high_confidence_entities from anonymizer.engine.ndd.adapter import FailedRecord from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases -from anonymizer.engine.row_partitioning import merge_and_reorder, split_rows from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema, ValidationCandidatesSchema from anonymizer.measurement import record_model_workflow @@ -91,37 +83,14 @@ _DIRECT_MAX_TOKENS = 4096 _DIRECT_TIMEOUT_SEC = 180.0 _DIRECT_MAX_WORKERS = 4 -_STRUCTURED_ASSIGNMENT_RE = re.compile( - r"(?" - r"api[_-]?key|aws[_-]?access[_-]?key[_-]?id|access[_-]?key[_-]?id|hf[_-]?token|" - r"token|auth[_-]?token|session[_-]?id|authorization|" - r"password|pass|secret|aws[_-]?secret[_-]?access[_-]?key|django[_-]?secret|database[_-]?url|" - r"pin|user(?:_?name)?|username|login|account|cookie|" - r"trace[-_]?id|request[-_]?id|req[-_]?id|order[-_]?id|tenant[-_]?id|unique[-_]?id|" - r"url|uri|endpoint|callback|email" - r")['\"]?\s*[:=]\s*" - r"(?:['\"](?P[^'\"\r\n]+)['\"]|(?P[^\s'\",;]+))", - flags=re.IGNORECASE, -) class ExperimentalDetectionStrategy(StrEnum): default = "default" prose_augment_focus = "prose_augment_focus" compact_validation = "compact_validation" - rules_guardrail_compact_validation = "rules_guardrail_compact_validation" - rules_guardrail = "rules_guardrail" - rules_filter_guardrail = "rules_filter_guardrail" no_augment = "no_augment" - rules_seed_no_augment = "rules_seed_no_augment" - rules_guardrail_no_augment = "rules_guardrail_no_augment" - rules_filter_guardrail_no_augment = "rules_filter_guardrail_no_augment" - rules_guardrail_detector_only = "rules_guardrail_detector_only" detector_only = "detector_only" - rules_only = "rules_only" - rules_covered_or_default = "rules_covered_or_default" - native_rules_router = "native_rules_router" native_candidate_validate_no_augment = "native_candidate_validate_no_augment" detector_native_validate_no_augment = "detector_native_validate_no_augment" detector_native_validate_native_augment = "detector_native_validate_native_augment" @@ -135,7 +104,6 @@ class ExperimentalDetectionStrategy(StrEnum): _DetectAndValidate = Callable[..., dw.EntityDetectionResult] _AugmentPrompt = Callable[..., str] -_MaterializeFinalEntities = Callable[..., dict] PROSE_AUGMENT_FOCUS_TEXT = """\ Contextual prose recall focus: - Re-scan untagged narrative prose for organization and institution names, named facilities, labs, research centers, street or place names, self-described beliefs, occupations, titles, family member names, and other quasi-identifiers that combine with already tagged entities. @@ -163,14 +131,6 @@ class NativeDetectionRuntime: max_workers: int = _DIRECT_MAX_WORKERS -@dataclass(frozen=True) -class _NoAugmentOptions: - include_rules: bool - final_rule_guardrail: bool = False - filter_rule_overlaps: bool = False - rule_labels: tuple[str, ...] = () - - @dataclass(frozen=True) class _NativeStagedTask: ordinal: int @@ -232,7 +192,6 @@ class _DetectorNativeValidationRowResult: def experimental_detection_strategy_context( strategy: ExperimentalDetectionStrategy, *, - rule_labels: list[str] | None = None, native_client: DirectDetectionClient | None = None, gliner_seed_client: GlinerSeedClient | None = None, native_runtime: NativeDetectionRuntime | None = None, @@ -244,19 +203,12 @@ def experimental_detection_strategy_context( original_method = dw.EntityDetectionWorkflow.detect_and_validate_entities original_augment_prompt = dw._get_augment_prompt - original_materialize_final_entities = dw._materialize_final_entities - if rule_labels: - dw._materialize_final_entities = _make_rule_label_materializer( # type: ignore[assignment] - original_materialize_final_entities, - rule_labels=rule_labels, - ) if strategy == ExperimentalDetectionStrategy.prose_augment_focus: dw._get_augment_prompt = _make_prose_augment_prompt(original_augment_prompt) # type: ignore[assignment] else: dw.EntityDetectionWorkflow.detect_and_validate_entities = _method_for_strategy( # type: ignore[method-assign] strategy, original=original_method, - rule_labels=rule_labels, native_client=native_client, gliner_seed_client=gliner_seed_client, native_runtime=native_runtime or NativeDetectionRuntime(), @@ -266,20 +218,6 @@ def experimental_detection_strategy_context( finally: dw.EntityDetectionWorkflow.detect_and_validate_entities = original_method # type: ignore[method-assign] dw._get_augment_prompt = original_augment_prompt # type: ignore[assignment] - dw._materialize_final_entities = original_materialize_final_entities # type: ignore[assignment] - - -def _make_rule_label_materializer( - original: _MaterializeFinalEntities, - *, - rule_labels: list[str], -) -> _MaterializeFinalEntities: - def materialize_final_entities(raw: object, *, allowed_labels: set[str] | None) -> dict: - if allowed_labels is None: - return original(raw, allowed_labels=allowed_labels) - return original(raw, allowed_labels={*allowed_labels, *rule_labels}) - - return materialize_final_entities def _make_prose_augment_prompt(original: _AugmentPrompt) -> _AugmentPrompt: @@ -294,7 +232,6 @@ def _method_for_strategy( strategy: ExperimentalDetectionStrategy, *, original: _DetectAndValidate | None = None, - rule_labels: list[str] | None = None, native_client: DirectDetectionClient | None = None, gliner_seed_client: GlinerSeedClient | None = None, native_runtime: NativeDetectionRuntime | None = None, @@ -304,49 +241,10 @@ def _method_for_strategy( if original is None: raise ValueError("compact_validation requires the original detection method") return _make_default_compact_validation_method(original) - if strategy == ExperimentalDetectionStrategy.rules_guardrail: - if original is None: - raise ValueError("rules_guardrail requires the original detection method") - return _make_default_with_rule_guardrail_method(original, rule_labels=rule_labels) - if strategy == ExperimentalDetectionStrategy.rules_filter_guardrail: - return _make_validated_augmented_rule_filter_guardrail_method(rule_labels=rule_labels) - if strategy == ExperimentalDetectionStrategy.rules_guardrail_compact_validation: - if original is None: - raise ValueError("rules_guardrail_compact_validation requires the original detection method") - return _make_default_with_rule_guardrail_method( - original, - rule_labels=rule_labels, - compact_validation=True, - ) if strategy == ExperimentalDetectionStrategy.no_augment: - return _make_validated_no_augment_method(include_rules=False) - if strategy == ExperimentalDetectionStrategy.rules_seed_no_augment: - return _make_validated_no_augment_method(include_rules=True, rule_labels=rule_labels) - if strategy == ExperimentalDetectionStrategy.rules_guardrail_no_augment: - return _make_validated_no_augment_method( - include_rules=False, - final_rule_guardrail=True, - rule_labels=rule_labels, - ) - if strategy == ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment: - return _make_validated_no_augment_method( - include_rules=False, - final_rule_guardrail=True, - filter_rule_overlaps=True, - rule_labels=rule_labels, - ) - if strategy == ExperimentalDetectionStrategy.rules_guardrail_detector_only: - return _make_detector_only_with_rule_guardrail_method(rule_labels=rule_labels) + return _make_validated_no_augment_method() if strategy == ExperimentalDetectionStrategy.detector_only: return _detect_with_detector_only - if strategy == ExperimentalDetectionStrategy.rules_covered_or_default: - if original is None: - raise ValueError("rules_covered_or_default requires the original detection method") - return _make_rules_covered_or_default_method(original) - if strategy == ExperimentalDetectionStrategy.rules_only: - return _detect_with_rules_only - if strategy == ExperimentalDetectionStrategy.native_rules_router: - return _make_native_rules_router_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.native_candidate_validate_no_augment: return _make_native_candidate_validate_no_augment_method(native_client=native_client, native_runtime=runtime) if strategy == ExperimentalDetectionStrategy.detector_native_validate_no_augment: @@ -416,236 +314,7 @@ def detect_and_validate_entities( return detect_and_validate_entities -def _make_default_with_rule_guardrail_method( - original: _DetectAndValidate, - *, - rule_labels: list[str] | None = None, - compact_validation: bool = False, -) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - result = original( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=not compact_validation, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - output = _apply_rule_guardrail( - result.dataframe.copy(), - labels=_rule_labels_for_detection(entity_labels, extra_rule_labels=rule_labels), - ) - return dw.EntityDetectionResult(dataframe=output, failed_records=result.failed_records) - - return detect_and_validate_entities - - -def _make_validated_augmented_rule_filter_guardrail_method( - *, - rule_labels: list[str] | None = None, -) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - return _run_validated_augmented_rule_filter_guardrail_detection( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - preview_num_records=preview_num_records, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - entity_labels=entity_labels, - data_summary=data_summary, - rule_labels=rule_labels, - ) - - return detect_and_validate_entities - - -def _run_validated_augmented_rule_filter_guardrail_detection( - workflow: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - preview_num_records: int | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - entity_labels: list[str] | None, - data_summary: str | None, - rule_labels: list[str] | None, -) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - workflow_model_configs = workflow._inject_detector_params( - model_configs=model_configs, - selected_models=selected_models, - labels=labels, - gliner_detection_threshold=gliner_detection_threshold, - ) - detection_result = workflow._adapter.run_workflow( - dataframe, - model_configs=workflow_model_configs, - columns=_validated_augmented_rule_filter_guardrail_columns( - selected_models=selected_models, - labels=labels, - data_summary=data_summary, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - strict_labels=entity_labels is not None, - rule_labels=rule_labels, - ), - workflow_name="entity-detection-rules-filter-guardrail", - preview_num_records=preview_num_records, - ) - return dw.EntityDetectionResult( - dataframe=detection_result.dataframe.copy(), - failed_records=detection_result.failed_records, - ) - - -def _validated_augmented_rule_filter_guardrail_columns( - *, - selected_models: DetectionModelSelection, - labels: list[str], - data_summary: str | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - strict_labels: bool, - rule_labels: list[str] | None, -) -> list[LLMTextColumnConfig | LLMStructuredColumnConfig | CustomColumnConfig]: - validator_params = _validator_params( - selected_models=selected_models, - labels=labels, - data_summary=data_summary, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - ) - rule_detection_labels = _rule_labels_for_detection(labels, extra_rule_labels=rule_labels) - return [ - LLMTextColumnConfig( - name=COL_RAW_DETECTED, - prompt=_jinja(COL_TEXT), - model_alias=_detector_alias(selected_models), - ), - CustomColumnConfig( - name=COL_SEED_ENTITIES, - generator_function=_make_parse_detected_entities_filtering_rules(rule_detection_labels), - ), - CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), - _validation_decisions_column(selected_models, validator_params), - CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), - CustomColumnConfig( - name=COL_SEED_ENTITIES_JSON, - generator_function=_make_apply_validation_to_seed_entities_with_additive_rule_guardrail( - rule_detection_labels - ), - ), - LLMStructuredColumnConfig( - name=COL_AUGMENTED_ENTITIES, - prompt=dw._get_augment_prompt(data_summary=data_summary, labels=labels, strict_labels=strict_labels), - model_alias=resolve_model_alias("entity_augmenter", selected_models), - output_format=AugmentedEntitiesSchema, - ), - CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), - CustomColumnConfig( - name=COL_DETECTED_ENTITIES, - generator_function=_make_apply_validation_and_finalize_with_additive_rule_guardrail(rule_detection_labels), - ), - ] - - -def _rule_labels_for_detection( - entity_labels: list[str] | None, - *, - extra_rule_labels: list[str] | tuple[str, ...] | None = None, -) -> list[str]: - labels = set(dw._resolve_detection_labels(entity_labels)) - labels.update(extra_rule_labels or []) - return sorted(labels) - - -def _apply_rule_guardrail(dataframe: pd.DataFrame, *, labels: list[str]) -> pd.DataFrame: - if COL_TEXT not in dataframe.columns or COL_DETECTED_ENTITIES not in dataframe.columns: - return dataframe - dataframe[COL_DETECTED_ENTITIES] = dataframe[COL_DETECTED_ENTITIES].astype("object") - if COL_TAGGED_TEXT in dataframe.columns: - dataframe[COL_TAGGED_TEXT] = dataframe[COL_TAGGED_TEXT].astype("object") - for index, row in dataframe.iterrows(): - guarded = _guarded_entities( - text=str(row.get(COL_TEXT, "")), raw_entities=row.get(COL_DETECTED_ENTITIES), labels=labels - ) - dataframe.at[index, COL_DETECTED_ENTITIES] = EntitiesSchema( - entities=[entity.as_dict() for entity in guarded] - ).model_dump(mode="json") - if COL_TAGGED_TEXT in dataframe.columns: - dataframe.at[index, COL_TAGGED_TEXT] = build_tagged_text(text=str(row.get(COL_TEXT, "")), entities=guarded) - return dataframe - - -def _guarded_entities(*, text: str, raw_entities: object, labels: list[str]) -> list[EntitySpan]: - final_spans = _entity_spans_from_payload(raw_entities) - rule_spans = detect_high_confidence_entities(text, labels=labels) - return _merge_rule_guardrail_spans(final_spans, rule_spans) - - -def _merge_rule_guardrail_spans(final_spans: list[EntitySpan], rule_spans: list[EntitySpan]) -> list[EntitySpan]: - filtered_final = [ - entity - for entity in final_spans - if not any( - rule.start_position == entity.start_position - and rule.end_position == entity.end_position - and rule.label != entity.label - for rule in rule_spans - ) - ] - return resolve_overlaps([*filtered_final, *rule_spans]) - - -def _make_validated_no_augment_method( - *, - include_rules: bool, - final_rule_guardrail: bool = False, - filter_rule_overlaps: bool = False, - rule_labels: list[str] | None = None, -) -> _DetectAndValidate: - options = _NoAugmentOptions( - include_rules=include_rules, - final_rule_guardrail=final_rule_guardrail, - filter_rule_overlaps=filter_rule_overlaps, - rule_labels=tuple(rule_labels or ()), - ) - +def _make_validated_no_augment_method() -> _DetectAndValidate: def detect_and_validate_entities( self: dw.EntityDetectionWorkflow, dataframe: pd.DataFrame, @@ -670,7 +339,6 @@ def detect_and_validate_entities( validation_excerpt_window_chars=validation_excerpt_window_chars, entity_labels=entity_labels, data_summary=data_summary, - options=options, ) return detect_and_validate_entities @@ -688,7 +356,6 @@ def _run_validated_no_augment_detection( validation_excerpt_window_chars: int, entity_labels: list[str] | None, data_summary: str | None, - options: _NoAugmentOptions, ) -> dw.EntityDetectionResult: labels = dw._resolve_detection_labels(entity_labels) workflow_model_configs = workflow._inject_detector_params( @@ -706,9 +373,8 @@ def _run_validated_no_augment_detection( data_summary=data_summary, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, - options=options, ), - workflow_name=_workflow_name_for_no_augment(options), + workflow_name="entity-detection-no-augment", preview_num_records=preview_num_records, ) return dw.EntityDetectionResult( @@ -724,7 +390,6 @@ def _validated_no_augment_columns( data_summary: str | None, validation_max_entities_per_call: int, validation_excerpt_window_chars: int, - options: _NoAugmentOptions, ) -> list[LLMTextColumnConfig | CustomColumnConfig]: validator_params = _validator_params( selected_models=selected_models, @@ -733,28 +398,18 @@ def _validated_no_augment_columns( validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, ) - parse_generator = _parse_generator( - labels=_rule_labels_for_detection(labels, extra_rule_labels=options.rule_labels), - include_rules=options.include_rules, - filter_rule_overlaps=options.filter_rule_overlaps, - ) return [ LLMTextColumnConfig( name=COL_RAW_DETECTED, prompt=_jinja(COL_TEXT), model_alias=_detector_alias(selected_models) ), - CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_generator), + CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), _validation_decisions_column(selected_models, validator_params), CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=apply_validation_to_seed_entities), CustomColumnConfig(name=COL_AUGMENTED_ENTITIES, generator_function=_empty_augmentation), CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), - CustomColumnConfig( - name=COL_DETECTED_ENTITIES, - generator_function=_finalizer( - _rule_labels_for_detection(labels, extra_rule_labels=options.rule_labels), options - ), - ), + CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=apply_validation_and_finalize), ] @@ -772,24 +427,6 @@ def _validation_decisions_column( ) -def _finalizer(labels: list[str], options: _NoAugmentOptions) -> Callable[[dict[str, Any]], dict[str, Any]]: - if options.filter_rule_overlaps: - return _make_apply_validation_and_finalize_with_additive_rule_guardrail(labels) - if options.final_rule_guardrail: - return _make_apply_validation_and_finalize_with_rule_guardrail(labels) - return apply_validation_and_finalize - - -def _workflow_name_for_no_augment(options: _NoAugmentOptions) -> str: - if options.filter_rule_overlaps: - return "entity-detection-rules-filter-guardrail-no-augment" - if options.final_rule_guardrail: - return "entity-detection-rules-guardrail-no-augment" - if options.include_rules: - return "entity-detection-rules-no-augment" - return "entity-detection-no-augment" - - def _validator_params( *, selected_models: DetectionModelSelection, @@ -833,41 +470,10 @@ def _detect_with_detector_only( gliner_detection_threshold=gliner_detection_threshold, entity_labels=entity_labels, preview_num_records=preview_num_records, - rule_labels=None, workflow_name="entity-detection-detector-only", ) -def _make_detector_only_with_rule_guardrail_method(rule_labels: list[str] | None) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - return _run_detector_only_detection( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - entity_labels=entity_labels, - preview_num_records=preview_num_records, - rule_labels=_rule_labels_for_detection(entity_labels, extra_rule_labels=rule_labels), - workflow_name="entity-detection-rules-guardrail-detector-only", - ) - - return detect_and_validate_entities - - def _run_detector_only_detection( workflow: dw.EntityDetectionWorkflow, dataframe: pd.DataFrame, @@ -877,7 +483,6 @@ def _run_detector_only_detection( gliner_detection_threshold: float, entity_labels: list[str] | None, preview_num_records: int | None, - rule_labels: list[str] | None, workflow_name: str, ) -> dw.EntityDetectionResult: labels = dw._resolve_detection_labels(entity_labels) @@ -890,7 +495,7 @@ def _run_detector_only_detection( detection_result = workflow._adapter.run_workflow( dataframe, model_configs=workflow_model_configs, - columns=_detector_only_columns(selected_models, rule_labels=rule_labels), + columns=_detector_only_columns(selected_models), workflow_name=workflow_name, preview_num_records=preview_num_records, ) @@ -900,11 +505,7 @@ def _run_detector_only_detection( ) -def _detector_only_columns( - selected_models: DetectionModelSelection, - *, - rule_labels: list[str] | None, -) -> list[LLMTextColumnConfig | CustomColumnConfig]: +def _detector_only_columns(selected_models: DetectionModelSelection) -> list[LLMTextColumnConfig | CustomColumnConfig]: return [ LLMTextColumnConfig( name=COL_RAW_DETECTED, @@ -913,16 +514,10 @@ def _detector_only_columns( ), CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=_copy_seed_entities_json), - CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=_detector_only_finalizer(rule_labels)), + CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=_finalize_detector_only), ] -def _detector_only_finalizer(rule_labels: list[str] | None) -> Callable[[dict[str, Any]], dict[str, Any]]: - if rule_labels is None: - return _finalize_detector_only - return _make_finalize_detector_only_with_rule_guardrail(rule_labels) - - @custom_column_generator(required_columns=[COL_SEED_ENTITIES]) def _copy_seed_entities_json(row: dict[str, Any]) -> dict[str, Any]: row[COL_SEED_ENTITIES_JSON] = json.dumps( @@ -943,191 +538,12 @@ def _finalize_detector_only(row: dict[str, Any]) -> dict[str, Any]: return row -def _make_finalize_detector_only_with_rule_guardrail(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_SEED_ENTITIES], - side_effect_columns=[COL_TAGGED_TEXT], - ) - def finalize_detector_only_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: - row = _finalize_detector_only(row) - text = str(row.get(COL_TEXT, "")) - final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _merge_rule_guardrail_spans(final_spans, rule_spans) - row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in guarded]).model_dump( - mode="json" - ) - row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) - return row - - return finalize_detector_only_with_rule_guardrail - - @custom_column_generator(required_columns=[COL_TEXT]) def _empty_augmentation(row: dict[str, Any]) -> dict[str, Any]: row[COL_AUGMENTED_ENTITIES] = AugmentedEntitiesSchema().model_dump(mode="json") return row -def _parse_generator( - *, - labels: list[str], - include_rules: bool, - filter_rule_overlaps: bool, -) -> Callable[[dict[str, Any]], dict[str, Any]]: - if filter_rule_overlaps: - return _make_parse_detected_entities_filtering_rules(labels) - if include_rules: - return _make_parse_detected_entities_with_rules(labels) - return parse_detected_entities - - -def _make_parse_detected_entities_with_rules(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_RAW_DETECTED], - side_effect_columns=[COL_TAG_NOTATION], - ) - def parse_detected_entities_with_rules(row: dict[str, Any]) -> dict[str, Any]: - text = str(row.get(COL_TEXT, "")) - detected = parse_raw_entities(raw_response=str(row.get(COL_RAW_DETECTED, "")), text=text) - rule_spans = detect_high_confidence_entities(text, labels=labels) - row[COL_SEED_ENTITIES] = EntitiesSchema( - entities=[entity.as_dict() for entity in resolve_overlaps([*detected, *rule_spans])] - ).model_dump(mode="json") - row[COL_TAG_NOTATION] = get_tag_notation(text=text) - return row - - return parse_detected_entities_with_rules - - -def _make_parse_detected_entities_filtering_rules(labels: list[str]) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_RAW_DETECTED], - side_effect_columns=[COL_TAG_NOTATION], - ) - def parse_detected_entities_filtering_rules(row: dict[str, Any]) -> dict[str, Any]: - text = str(row.get(COL_TEXT, "")) - detected = parse_raw_entities(raw_response=str(row.get(COL_RAW_DETECTED, "")), text=text) - rule_spans = detect_high_confidence_entities(text, labels=labels) - filtered = [entity for entity in detected if not _is_rule_covered_detector_span(entity, rule_spans)] - row[COL_SEED_ENTITIES] = EntitiesSchema( - entities=[entity.as_dict() for entity in resolve_overlaps(filtered)] - ).model_dump(mode="json") - row[COL_TAG_NOTATION] = get_tag_notation(text=text) - return row - - return parse_detected_entities_filtering_rules - - -def _make_apply_validation_and_finalize_with_rule_guardrail( - labels: list[str], -) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_MERGED_ENTITIES, COL_VALIDATED_ENTITIES], - side_effect_columns=[COL_TAGGED_TEXT], - ) - def apply_validation_and_finalize_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: - row = apply_validation_and_finalize(row) - text = str(row.get(COL_TEXT, "")) - final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _merge_rule_guardrail_spans(final_spans, rule_spans) - row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[entity.as_dict() for entity in guarded]).model_dump( - mode="json" - ) - row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) - return row - - return apply_validation_and_finalize_with_rule_guardrail - - -def _make_apply_validation_and_finalize_with_additive_rule_guardrail( - labels: list[str], -) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_MERGED_ENTITIES, COL_VALIDATED_ENTITIES], - side_effect_columns=[COL_TAGGED_TEXT], - ) - def apply_validation_and_finalize_with_additive_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: - row = apply_validation_and_finalize(row) - text = str(row.get(COL_TEXT, "")) - final_spans = _entity_spans_from_payload(row.get(COL_DETECTED_ENTITIES, {})) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _add_non_overlapping_rule_spans(final_spans, rule_spans) - row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[entity.as_dict() for entity in guarded]).model_dump( - mode="json" - ) - row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) - return row - - return apply_validation_and_finalize_with_additive_rule_guardrail - - -def _make_apply_validation_to_seed_entities_with_rule_guardrail( - labels: list[str], -) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_SEED_ENTITIES, COL_VALIDATED_ENTITIES], - side_effect_columns=[COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, COL_VALIDATED_SEED_ENTITIES], - ) - def apply_validation_to_seed_entities_with_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: - row = apply_validation_to_seed_entities(row) - text = str(row.get(COL_TEXT, "")) - validated_seed = _entity_spans_from_payload(row.get(COL_VALIDATED_SEED_ENTITIES, {})) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _merge_rule_guardrail_spans(validated_seed, rule_spans) - seed_entities = [entity.as_dict() for entity in guarded] - row[COL_VALIDATED_SEED_ENTITIES] = EntitiesSchema(entities=seed_entities).model_dump(mode="json") - row[COL_SEED_ENTITIES_JSON] = json.dumps(seed_entities) - row[COL_INITIAL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) - return row - - return apply_validation_to_seed_entities_with_rule_guardrail - - -def _make_apply_validation_to_seed_entities_with_additive_rule_guardrail( - labels: list[str], -) -> Callable[[dict[str, Any]], dict[str, Any]]: - @custom_column_generator( - required_columns=[COL_TEXT, COL_SEED_ENTITIES, COL_VALIDATED_ENTITIES], - side_effect_columns=[COL_INITIAL_TAGGED_TEXT, COL_SEED_ENTITIES_JSON, COL_VALIDATED_SEED_ENTITIES], - ) - def apply_validation_to_seed_entities_with_additive_rule_guardrail(row: dict[str, Any]) -> dict[str, Any]: - row = apply_validation_to_seed_entities(row) - text = str(row.get(COL_TEXT, "")) - validated_seed = _entity_spans_from_payload(row.get(COL_VALIDATED_SEED_ENTITIES, {})) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _add_non_overlapping_rule_spans(validated_seed, rule_spans) - seed_entities = [entity.as_dict() for entity in guarded] - row[COL_VALIDATED_SEED_ENTITIES] = EntitiesSchema(entities=seed_entities).model_dump(mode="json") - row[COL_SEED_ENTITIES_JSON] = json.dumps(seed_entities) - row[COL_INITIAL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) - return row - - return apply_validation_to_seed_entities_with_additive_rule_guardrail - - -def _is_rule_covered_detector_span(entity: EntitySpan, spans: list[EntitySpan]) -> bool: - return any( - entity.label == span.label - and span.start_position <= entity.start_position - and span.end_position >= entity.end_position - for span in spans - ) - - -def _add_non_overlapping_rule_spans( - existing_spans: list[EntitySpan], - rule_spans: list[EntitySpan], -) -> list[EntitySpan]: - additions = [rule for rule in rule_spans if not any(_spans_overlap(rule, existing) for existing in existing_spans)] - return resolve_overlaps([*existing_spans, *additions]) - - -def _spans_overlap(left: EntitySpan, right: EntitySpan) -> bool: - return max(left.start_position, right.start_position) < min(left.end_position, right.end_position) - - def _entity_spans_from_payload(raw_payload: object) -> list[EntitySpan]: return [ EntitySpan( @@ -1266,7 +682,7 @@ def _execute_native_single_pass_row( failed_record_count=0, runtime=runtime, ) - return _native_single_pass_result_row(row, spans=spans, labels=labels), None + return _native_single_pass_result_row(row, spans=spans), None def _complete_native_single_pass( @@ -1418,15 +834,13 @@ def _native_single_pass_span(*, value: str, label: str, start: int, end: int) -> ) -def _native_single_pass_result_row(row: pd.Series, *, spans: list[EntitySpan], labels: list[str]) -> dict[str, Any]: +def _native_single_pass_result_row(row: pd.Series, *, spans: list[EntitySpan]) -> dict[str, Any]: text = str(row.get(COL_TEXT, "")) - rule_spans = detect_high_confidence_entities(text, labels=labels) - guarded = _add_non_overlapping_rule_spans(spans, rule_spans) output_row = row.to_dict() - output_row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in guarded]).model_dump( + output_row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump( mode="json" ) - output_row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=guarded) + output_row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) return output_row @@ -1522,21 +936,6 @@ def _require_native_endpoint(runtime: NativeDetectionRuntime) -> None: ) -def _make_native_rules_router_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_native_staged_method( - native_client=native_client, - gliner_seed_client=None, - native_runtime=native_runtime, - seed_source=SeedSource.rules_router, - workflow_name="entity-detection-native-rules-router", - skip_augmentation=False, - ) - - def _make_native_candidate_validate_no_augment_method( *, native_client: DirectDetectionClient | None, @@ -1546,7 +945,7 @@ def _make_native_candidate_validate_no_augment_method( native_client=native_client, gliner_seed_client=None, native_runtime=native_runtime, - seed_source=SeedSource.rules_plus_direct_llm, + seed_source=SeedSource.direct_llm, workflow_name="entity-detection-native-candidate-validate-no-augment", skip_augmentation=True, ) @@ -2257,7 +1656,7 @@ def _native_staged_request( data_summary: str | None, ) -> StagedDetectionRequest: return StagedDetectionRequest( - case_id=f"native-rules-router-{ordinal}", + case_id=f"native-staged-{ordinal}", text=str(row.get(COL_TEXT, "")), labels=labels, row_index=_safe_row_index(index, fallback=ordinal), @@ -2294,280 +1693,3 @@ def _native_output_dataframe( output[COL_DETECTED_ENTITIES] = pd.Series(dtype="object") output[COL_TAGGED_TEXT] = pd.Series(dtype="object") return output - - -def _detect_with_rules_only( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, -) -> dw.EntityDetectionResult: - return self.detect_with_high_confidence_rules(dataframe, entity_labels=entity_labels) - - -def _make_rules_covered_or_default_method(original: _DetectAndValidate) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - if _labels_are_rules_only(labels): - return _detect_rules_covered_rows_or_default( - original, - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - labels=labels, - ) - return original( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - - return detect_and_validate_entities - - -def _detect_rules_covered_rows_or_default( - original: _DetectAndValidate, - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - validation_single_chunk_full_text: bool, - entity_labels: list[str] | None, - data_summary: str | None, - preview_num_records: int | None, - labels: list[str], -) -> dw.EntityDetectionResult: - started = time.perf_counter() - if dataframe.empty: - result = _detect_with_rules_only( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - _record_rules_covered_route( - started=started, - total_row_count=0, - rule_row_count=0, - fallback_row_count=0, - result=result, - ) - return result - - coverage_mask = dataframe[COL_TEXT].apply( - lambda text: _structured_assignments_are_rule_covered(str(text), labels=labels) - ) - if bool(coverage_mask.all()): - result = _detect_with_rules_only( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - _record_rules_covered_route( - started=started, - total_row_count=len(dataframe), - rule_row_count=len(dataframe), - fallback_row_count=0, - result=result, - ) - return result - if not bool(coverage_mask.any()): - result = original( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - _record_rules_covered_route( - started=started, - total_row_count=len(dataframe), - rule_row_count=0, - fallback_row_count=len(dataframe), - result=result, - ) - return result - - rule_rows, default_rows = split_rows( - dataframe, - column=COL_TEXT, - predicate=lambda text: _structured_assignments_are_rule_covered(str(text), labels=labels), - ) - - rule_result = _detect_with_rules_only( - self, - rule_rows, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - default_result = original( - self, - default_rows, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - result = dw.EntityDetectionResult( - dataframe=merge_and_reorder(rule_result.dataframe, default_result.dataframe), - failed_records=[*rule_result.failed_records, *default_result.failed_records], - ) - _record_rules_covered_route( - started=started, - total_row_count=len(dataframe), - rule_row_count=len(rule_rows), - fallback_row_count=len(default_rows), - result=result, - ) - return result - - -def _record_rules_covered_route( - *, - started: float, - total_row_count: int, - rule_row_count: int, - fallback_row_count: int, - result: dw.EntityDetectionResult, -) -> None: - record_model_workflow( - workflow_name="entity-detection-rules-covered-router", - model_aliases=[], - input_row_count=total_row_count, - output_row_count=len(result.dataframe), - failed_record_count=len(result.failed_records), - elapsed_sec=time.perf_counter() - started, - status="completed" if not result.failed_records else "partial", - extra_fields={ - "route_total_row_count": total_row_count, - "route_rule_row_count": rule_row_count, - "route_fallback_row_count": fallback_row_count, - }, - ) - - -def _structured_assignments_are_rule_covered(text: str, *, labels: list[str]) -> bool: - allowed_labels = set(labels) - rule_spans = detect_high_confidence_entities(text, labels=labels) - covered_ranges = [(span.start_position, span.end_position) for span in rule_spans] - for match in _STRUCTURED_ASSIGNMENT_RE.finditer(text): - label = _structured_assignment_label(match.group("key")) - if label not in allowed_labels: - continue - start, end = _structured_assignment_value_span(match) - if not _range_overlaps_any(start, end, covered_ranges): - return False - return True - - -def _structured_assignment_value_span(match: re.Match[str]) -> tuple[int, int]: - if match.group("quoted") is not None: - return match.span("quoted") - return match.span("bare") - - -def _range_overlaps_any(start: int, end: int, ranges: list[tuple[int, int]]) -> bool: - return any(start < range_end and end > range_start for range_start, range_end in ranges) - - -def _structured_assignment_label(key: str) -> str: - normalized = key.lower().replace("-", "_") - if normalized in {"api_key", "aws_access_key_id", "access_key_id", "hf_token", "token", "auth_token", "session_id"}: - return "api_key" - if normalized == "authorization": - return "api_key" - if normalized in {"password", "pass", "secret", "aws_secret_access_key", "django_secret"}: - return "password" - if normalized == "database_url": - return "url" - if normalized == "pin": - return "pin" - if normalized in {"user", "username", "user_name", "login", "account"}: - return "user_name" - if normalized == "cookie": - return "http_cookie" - if normalized in {"trace_id", "request_id", "req_id", "order_id", "tenant_id", "unique_id"}: - return "unique_id" - if normalized in {"url", "uri", "endpoint", "callback"}: - return "url" - if normalized == "email": - return "email" - return "" - - -def _labels_are_rules_only(labels: list[str]) -> bool: - return dw.labels_are_supported_by_structured_rule_fast_lane(labels) diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py index 66a8e40c..7732c147 100644 --- a/tools/measurement/extract_signature_deltas.py +++ b/tools/measurement/extract_signature_deltas.py @@ -26,7 +26,6 @@ from pydantic import BaseModel, Field, ValidationError from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT -from anonymizer.engine.detection.rules import detect_high_confidence_entities from anonymizer.engine.schemas import EntitiesSchema, EntitySchema app = cyclopts.App(help=__doc__) @@ -52,7 +51,6 @@ class DeltaSide(StrEnum): class ContextResolution(StrEnum): parquet = "parquet" artifact_details = "artifact_details" - rule = "rule" metadata_only = "metadata_only" @@ -304,9 +302,6 @@ def _resolve_signature_context( parquet_context = _parquet_entity_context(artifact_row, signature, artifact_root, context_window) if parquet_context is not None: return parquet_context - rule_context = _rule_entity_context(artifact_row, signature, label, artifact_root, context_window) - if rule_context is not None: - return rule_context detail_context = _artifact_detail_context(artifact_row, signature, label, artifact_root, context_window) return detail_context or {"resolution": ContextResolution.metadata_only} @@ -327,24 +322,6 @@ def _parquet_entity_context( return None -def _rule_entity_context( - artifact_row: dict[str, object], - signature: str, - label: str | None, - artifact_root: Path, - context_window: int, -) -> dict[str, object] | None: - record = _artifact_record(artifact_row, artifact_root) - if record is None or label is None: - return None - text, row_index, _row = record - for span in detect_high_confidence_entities(text, labels=[label]): - entity = EntitySchema.model_validate(span.as_dict()) - if _entity_signature_hash(entity, row_index=row_index) == signature: - return _entity_context(entity, text, signature, context_window, ContextResolution.rule) - return None - - def _artifact_detail_context( artifact_row: dict[str, object], signature: str, diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 77b24493..e6a871de 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -54,15 +54,10 @@ from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_FINAL_ENTITIES -from anonymizer.engine.detection.rules import ( - STRUCTURED_RULE_FAST_LANE_LABELS, - SUPPORTED_RULE_LABELS, - detect_high_confidence_entities, -) from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references from anonymizer.engine.replace.structured_substitute import SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS -from anonymizer.engine.schemas import EntitiesSchema, EntitySchema +from anonymizer.engine.schemas import EntitiesSchema from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import MeasurementConfig, configured_measurement_session @@ -170,7 +165,6 @@ class ConfigSpec(BaseModel): emit_telemetry: bool = False experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default experimental_replacement_strategy: ExperimentalReplacementStrategy = ExperimentalReplacementStrategy.default - experimental_rule_labels: list[str] | None = None native_runtime: NativeRuntimeSpec | None = None @model_validator(mode="after") @@ -278,11 +272,6 @@ class _CaseExecution: _TRACE_FINAL_ARTIFACT_STRATEGIES = { - ExperimentalDetectionStrategy.rules_guardrail, - ExperimentalDetectionStrategy.rules_covered_or_default, - ExperimentalDetectionStrategy.rules_guardrail_compact_validation, - ExperimentalDetectionStrategy.rules_filter_guardrail, - ExperimentalDetectionStrategy.native_rules_router, ExperimentalDetectionStrategy.native_candidate_validate_no_augment, ExperimentalDetectionStrategy.detector_native_validate_no_augment, ExperimentalDetectionStrategy.detector_native_validate_native_augment, @@ -293,18 +282,7 @@ class _CaseExecution: ExperimentalDetectionStrategy.native_single_pass_values, ExperimentalDetectionStrategy.native_single_pass_values_recall, } -_RULE_BACKED_STRATEGIES = { - ExperimentalDetectionStrategy.rules_guardrail, - ExperimentalDetectionStrategy.rules_guardrail_compact_validation, - ExperimentalDetectionStrategy.rules_filter_guardrail, - ExperimentalDetectionStrategy.rules_seed_no_augment, - ExperimentalDetectionStrategy.rules_guardrail_no_augment, - ExperimentalDetectionStrategy.rules_filter_guardrail_no_augment, - ExperimentalDetectionStrategy.rules_guardrail_detector_only, - ExperimentalDetectionStrategy.rules_only, -} _NATIVE_RUNTIME_STRATEGIES = { - ExperimentalDetectionStrategy.native_rules_router, ExperimentalDetectionStrategy.native_candidate_validate_no_augment, ExperimentalDetectionStrategy.detector_native_validate_no_augment, ExperimentalDetectionStrategy.detector_native_validate_native_augment, @@ -450,10 +428,6 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) except Exception as exc: errors.append(f"config '{config.id}' invalid: {exc}") continue - try: - _preflight_experimental_detection_strategy(config, anonymizer_config) - except Exception as exc: - errors.append(f"config '{config.id}' experimental_detection_strategy invalid: {exc}") try: _preflight_native_runtime(config, spec=spec) except Exception as exc: @@ -483,39 +457,6 @@ def _active_config_ids(spec: BenchmarkSpec) -> set[str]: return {entry.config for entry in spec.matrix} -def _preflight_experimental_detection_strategy(config: ConfigSpec, anonymizer_config: AnonymizerConfig) -> None: - _preflight_experimental_rule_labels(config) - if config.experimental_detection_strategy != ExperimentalDetectionStrategy.rules_only: - return - entity_labels = anonymizer_config.detect.entity_labels - supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) - if entity_labels is None: - raise ValueError( - f"`rules_only` requires explicit detect.entity_labels limited to deterministic rule labels: {supported}" - ) - unsupported = sorted(set(entity_labels) - SUPPORTED_RULE_LABELS) - if unsupported: - raise ValueError( - f"unsupported high-confidence rule labels: {', '.join(unsupported)}; supported labels: {supported}" - ) - - -def _preflight_experimental_rule_labels(config: ConfigSpec) -> None: - if not config.experimental_rule_labels: - return - supported = ", ".join(sorted(SUPPORTED_RULE_LABELS)) - if config.experimental_detection_strategy not in _RULE_BACKED_STRATEGIES: - raise ValueError( - "experimental_rule_labels requires a rule-backed strategy: " - + ", ".join(sorted(strategy.value for strategy in _RULE_BACKED_STRATEGIES)) - ) - unsupported = sorted(set(config.experimental_rule_labels) - SUPPORTED_RULE_LABELS) - if unsupported: - raise ValueError( - f"unsupported experimental_rule_labels: {', '.join(unsupported)}; supported labels: {supported}" - ) - - def _preflight_native_runtime(config: ConfigSpec, *, spec: BenchmarkSpec) -> None: strategy = config.experimental_detection_strategy if strategy not in _NATIVE_RUNTIME_STRATEGIES: @@ -955,12 +896,7 @@ def _case_detection_artifact_path( ) if detection_artifact_path is not None or paths.artifact_snapshot is None: return detection_artifact_path - return export_rules_only_case_detection_artifacts( - config, - execution.input_data, - paths.artifact_output_path, - case=case, - ) + return None def _trace_final_artifact_path_if_requested( @@ -979,8 +915,6 @@ def _trace_final_artifact_path_if_requested( detection_artifact_path or output_path, trace_dataframe, case=case, - replace_existing=config.experimental_detection_strategy - == ExperimentalDetectionStrategy.rules_covered_or_default, ) @@ -989,16 +923,12 @@ def patch_case_detection_artifacts_from_trace_dataframe( trace_dataframe: pd.DataFrame, *, case: BenchmarkCase | None = None, - replace_existing: bool = False, ) -> Path | None: final_rows = _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe) if not final_rows: return None - if replace_existing: - patched = final_rows - else: - rows = _read_detection_artifact_payloads(output_path) if output_path.exists() else [] - patched = _merge_final_entity_artifact_rows(rows, final_rows) + rows = _read_detection_artifact_payloads(output_path) if output_path.exists() else [] + patched = _merge_final_entity_artifact_rows(rows, final_rows) if case is not None: patched = [_with_case_metadata(row, case=case) for row in patched] write_detection_artifact_payloads(patched, output_path) @@ -1148,7 +1078,7 @@ def _execute_case( ) with configured_measurement_session(measurement): with dd_parser_compat_context(dd_parser_compat): - detection_context_kwargs: dict[str, Any] = {"rule_labels": config.experimental_rule_labels} + detection_context_kwargs: dict[str, Any] = {} if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: detection_context_kwargs["native_runtime"] = _native_detection_runtime(spec, config) with experimental_detection_strategy_context( @@ -1390,64 +1320,6 @@ def export_case_detection_artifact_analysis( return output_path -def export_rules_only_case_detection_artifacts( - config: ConfigSpec, - input_data: AnonymizerInput, - output_path: Path, - *, - case: BenchmarkCase, -) -> Path | None: - if not _is_local_input_source(input_data.source): - return None - labels = build_anonymizer_config(config).detect.entity_labels - if not _uses_rules_only_artifact_export(config, labels): - return None - source = Path(input_data.source) - dataframe = _read_local_input_dataframe(source, suffix=infer_input_source_suffix(str(source))) - rows = [ - _with_case_metadata( - _rules_only_artifact_row( - text=record[input_data.text_column], - labels=labels, - row_index=int(row_index), - ), - case=case, - ) - for row_index, record in dataframe.iterrows() - ] - if not rows: - return None - write_detection_artifact_payloads(rows, output_path) - return output_path - - -def _uses_rules_only_artifact_export(config: ConfigSpec, labels: list[str] | None) -> bool: - if labels is None: - return False - if config.experimental_detection_strategy == ExperimentalDetectionStrategy.rules_only: - return True - if config.experimental_detection_strategy != ExperimentalDetectionStrategy.rules_covered_or_default: - return False - return set(labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) - - -def _rules_only_artifact_row(*, text: object, labels: list[str], row_index: int) -> dict[str, Any]: - entities = [ - EntitySchema.model_validate(span.as_dict()) - for span in detect_high_confidence_entities(str(text), labels=labels) - ] - return build_detection_artifact_row_from_entities( - workflow_name="entity-detection-rules-only", - batch_file="synthetic-rules-only", - row_index=row_index, - seed_entities=entities, - seed_validation_candidate_count=len(entities), - merged_validation_candidate_count=len(entities), - augmented_entities=[], - final_entities=entities, - ).model_dump() - - def _with_case_metadata(row: dict[str, Any], *, case: BenchmarkCase) -> dict[str, Any]: return { "suite_id": case.suite_id, @@ -1490,7 +1362,6 @@ def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: "case_id": case.case_id, "experimental_detection_strategy": config.experimental_detection_strategy.value, "experimental_replacement_strategy": config.experimental_replacement_strategy.value, - "experimental_rule_labels": config.experimental_rule_labels, "dd_parser_compat": spec.dd_parser_compat.value, } if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: diff --git a/tools/measurement/screen_strategy_comparisons.py b/tools/measurement/screen_strategy_comparisons.py index 59dc280e..0cd9e141 100644 --- a/tools/measurement/screen_strategy_comparisons.py +++ b/tools/measurement/screen_strategy_comparisons.py @@ -662,8 +662,6 @@ def group_recommendation(group: ScreenGroup) -> str: return "reliability_review" if is_label_policy_review_group(group): return "label_policy_review" - if is_fast_lane_review_group(group): - return "fast_lane_review" if group.performance_verdict_counts.get("improved", 0) == group.review_count: return "review_only" if group.performance_verdict_counts.get("improved", 0) or group.performance_verdict_counts.get("mixed", 0): @@ -757,36 +755,6 @@ def is_label_policy_review_group(group: ScreenGroup) -> bool: return bool(group.label_mismatch_label_counts or group.flag_counts.get("covered_label_mismatch")) -_FAST_LANE_REVIEW_STRATEGIES = {"rules_only", "rules_covered_or_default"} -_FAST_LANE_REVIEW_FLAGS = { - "candidate_skips_llm_validation", - "candidate_uses_rule_entities", - "entity_count_loss", - "no_candidate_detector_entities", - "span_boundary_mismatch", -} - - -def is_fast_lane_review_group(group: ScreenGroup) -> bool: - if group.candidate_strategy not in _FAST_LANE_REVIEW_STRATEGIES: - return False - if group.review_count != group.row_count: - return False - if group.performance_verdict_counts.get("improved", 0) != group.review_count: - return False - leak_count = group.sum_candidate_original_value_leak_count - if leak_count is None or leak_count != 0: - return False - if ( - group.baseline_only_label_counts - or group.stable_lost_label_counts - or group.candidate_original_value_leak_label_counts - ): - return False - flags = set(group.flag_counts) - return bool(flags) and flags.issubset(_FAST_LANE_REVIEW_FLAGS) - - def group_performance_summary(group: ScreenGroup) -> str: if group.performance_verdict_counts: return label_summary(group.performance_verdict_counts) diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py index 6da611da..ddb1a5d0 100644 --- a/tools/measurement/staged_detection_probe.py +++ b/tools/measurement/staged_detection_probe.py @@ -78,20 +78,13 @@ TagNotation, apply_augmented_entities, build_tagged_text, - build_validation_candidates, get_tag_notation, parse_raw_entities, - resolve_overlaps, -) -from anonymizer.engine.detection.rules import ( - STRUCTURED_RULE_FAST_LANE_LABELS, - detect_high_confidence_entities, ) from anonymizer.engine.schemas import ( EntitiesSchema, EntitySchema, RawValidationDecisionsSchema, - ValidatedDecisionSchema, ValidatedDecisionsSchema, ValidationCandidatesSchema, ) @@ -114,10 +107,6 @@ class SeedSource(StrEnum): direct_llm = "direct_llm" gliner = "gliner" - rules = "rules" - rules_trusted = "rules_trusted" - rules_plus_direct_llm = "rules_plus_direct_llm" - rules_router = "rules_router" class ValidationPromptMode(StrEnum): @@ -151,7 +140,6 @@ class StagedExecutionConfig(BaseModel): max_tokens: int = Field(default=4096, gt=0) timeout_sec: float = Field(default=180.0, gt=0) skip_augmentation: bool = False - skip_augmentation_when_rule_covered: bool = False validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text validation_max_entities_per_call: int = Field(default=10, gt=0) validation_excerpt_window_chars: int = Field(default=160, gt=0) @@ -252,7 +240,6 @@ class StagedDetectionCase(BaseModel): total_usage: dict[str, int] = Field(default_factory=dict) model_phase_count: int = 0 model_request_count: int = 0 - rule_covered_label_set: bool = False seed_suggestion_count: int = 0 seed_entity_count: int = 0 validation_candidate_count: int = 0 @@ -315,7 +302,6 @@ def run_staged_detection_case( max_tokens: int = 4096, timeout_sec: float = 180.0, skip_augmentation: bool = False, - skip_augmentation_when_rule_covered: bool = False, validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, validation_max_entities_per_call: int = 10, validation_excerpt_window_chars: int = 160, @@ -334,7 +320,6 @@ def run_staged_detection_case( max_tokens=max_tokens, timeout_sec=timeout_sec, skip_augmentation=skip_augmentation, - skip_augmentation_when_rule_covered=skip_augmentation_when_rule_covered, validation_prompt_mode=validation_prompt_mode, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, @@ -356,7 +341,6 @@ def execute_staged_detection_case( max_tokens: int = 4096, timeout_sec: float = 180.0, skip_augmentation: bool = False, - skip_augmentation_when_rule_covered: bool = False, validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, validation_max_entities_per_call: int = 10, validation_excerpt_window_chars: int = 160, @@ -429,14 +413,8 @@ def _run_seed_phase( seed_client: GlinerSeedClient | None, config: StagedExecutionConfig, ) -> tuple[dict[str, Any], int, DirectCompletion]: - if _uses_rule_short_circuit(request, config): - return _run_rules_seed_phase(request) if config.seed_source == SeedSource.gliner: return _run_gliner_seed_phase(request, seed_client or HttpxGlinerSeedClient(), config) - if config.seed_source in {SeedSource.rules, SeedSource.rules_trusted}: - return _run_rules_seed_phase(request) - if config.seed_source in {SeedSource.rules_plus_direct_llm, SeedSource.rules_router}: - return _run_rules_plus_direct_llm_seed_phase(request, client, config) return _run_direct_llm_seed_phase(request, client, config) @@ -475,25 +453,6 @@ def _run_direct_llm_seed_phase( return row, len(seed_suggestions), completion -def _run_rules_plus_direct_llm_seed_phase( - request: StagedDetectionRequest, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> tuple[dict[str, Any], int, DirectCompletion]: - completion = _complete(client, prompt=_seed_prompt(request), config=config) - direct_spans, seed_suggestions = _direct_seed_spans(request, completion.content) - rule_spans = detect_high_confidence_entities(request.text, labels=request.labels) - row = _seed_row_from_spans(request, resolve_overlaps([*rule_spans, *direct_spans])) - _limit_validation_candidates_to_sources(row, sources={"direct_seed"}) - return row, len(seed_suggestions) + len(rule_spans), completion - - -def _run_rules_seed_phase(request: StagedDetectionRequest) -> tuple[dict[str, Any], int, DirectCompletion]: - seed_spans = detect_high_confidence_entities(request.text, labels=request.labels) - completion = DirectCompletion(content="", elapsed_sec=0.0, usage={}) - return _seed_row_from_spans(request, seed_spans), len(seed_spans), completion - - def _complete( client: DirectDetectionClient, *, @@ -577,23 +536,12 @@ def _seed_entity_id(label: str, span: EntitySpan) -> str: return f"{label}_{span.start_position}_{span.end_position}" -def _limit_validation_candidates_to_sources(row: dict[str, Any], *, sources: set[str]) -> None: - text = str(row.get(COL_TEXT, "")) - seed_spans = [span for span in _seed_entity_spans(row) if span.source in sources] - row[COL_SEED_VALIDATION_CANDIDATES] = ValidationCandidatesSchema( - candidates=build_validation_candidates(text=text, entities=seed_spans) - ).model_dump(mode="json") - - def _run_validation_phase( row: dict[str, Any], request: StagedDetectionRequest, client: DirectDetectionClient, config: StagedExecutionConfig, ) -> DirectCompletion: - if config.seed_source == SeedSource.rules_trusted or _uses_rule_short_circuit(request, config): - _trust_seed_entities(row) - return DirectCompletion(content="", elapsed_sec=0.0, usage={}) candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) if not candidates.candidates: row[COL_VALIDATION_DECISIONS] = {"decisions": []} @@ -723,23 +671,6 @@ def _sum_completion_usage(completions: list[DirectCompletion]) -> dict[str, int] return dict(sorted(totals.items())) -def _trust_seed_entities(row: dict[str, Any]) -> None: - candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) - row[COL_VALIDATED_ENTITIES] = ValidatedDecisionsSchema( - decisions=[ - ValidatedDecisionSchema( - id=candidate.id, - decision="keep", - value=candidate.value, - label=candidate.label, - reason="trusted deterministic rule", - ) - for candidate in candidates.candidates - ] - ).model_dump(mode="json") - apply_validation_to_seed_entities(row) - - def _validation_prompt( request: StagedDetectionRequest, candidates: ValidationCandidatesSchema, @@ -799,19 +730,7 @@ def _run_augmentation_phase( def _should_skip_augmentation(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: - if config.skip_augmentation: - return True - if _uses_rule_short_circuit(request, config): - return True - if not config.skip_augmentation_when_rule_covered: - return False - if config.seed_source not in {SeedSource.rules, SeedSource.rules_trusted, SeedSource.rules_plus_direct_llm}: - return False - return set(request.labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) - - -def _uses_rule_short_circuit(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: - return config.seed_source == SeedSource.rules_router and _is_rule_covered_label_set(request) + return config.skip_augmentation def _augmentation_prompt(request: StagedDetectionRequest, row: dict[str, Any]) -> str: @@ -898,7 +817,6 @@ def _completed_case( total_usage=_sum_usage(phase_usage), model_phase_count=_model_phase_count(phase_model_work), model_request_count=_model_request_count(phase_model_requests), - rule_covered_label_set=_is_rule_covered_label_set(request), seed_suggestion_count=seed_suggestion_count, seed_entity_count=artifact.seed_entity_count, validation_candidate_count=artifact.seed_validation_candidate_count, @@ -923,28 +841,12 @@ def _phase_model_work( def _uses_seed_model(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: - if _uses_rule_short_circuit(request, config): - return False - return config.seed_source in { - SeedSource.direct_llm, - SeedSource.gliner, - SeedSource.rules_plus_direct_llm, - SeedSource.rules_router, - } + return config.seed_source in {SeedSource.direct_llm, SeedSource.gliner} def _uses_validation_model( request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig ) -> bool: - if _uses_rule_short_circuit(request, config): - return False - if ( - config.seed_source in {SeedSource.rules_trusted, SeedSource.rules_router} - and artifact.seed_validation_candidate_count == 0 - ): - return False - if config.seed_source == SeedSource.rules_trusted: - return False return artifact.seed_validation_candidate_count > 0 @@ -959,16 +861,12 @@ def _phase_skip_reasons( def _seed_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: - if config.seed_source in {SeedSource.rules, SeedSource.rules_trusted} or _uses_rule_short_circuit(request, config): - return "deterministic_rules" return None def _validation_skip_reason( request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig ) -> str | None: - if config.seed_source == SeedSource.rules_trusted or _uses_rule_short_circuit(request, config): - return "trusted_rules" if artifact.seed_validation_candidate_count == 0: return "no_seed_candidates" return None @@ -977,8 +875,6 @@ def _validation_skip_reason( def _augmentation_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: if config.skip_augmentation: return "disabled" - if _should_skip_augmentation(request, config): - return "rule_covered_labels" return None @@ -1018,10 +914,6 @@ def _model_request_count(phase_model_requests: PhaseModelRequests) -> int: return phase_model_requests.seed + phase_model_requests.validation + phase_model_requests.augmentation -def _is_rule_covered_label_set(request: StagedDetectionRequest) -> bool: - return set(request.labels).issubset(STRUCTURED_RULE_FAST_LANE_LABELS) - - def _extract_entity_suggestions(content: str) -> list[dict[str, str]]: payload = _load_embedded_json(content) raw_entities = payload.get("entities", []) if isinstance(payload, dict) else [] @@ -1329,7 +1221,6 @@ def run_probe( gliner_api_key_env: str = "NVIDIA_API_KEY", gliner_threshold: float = 0.3, skip_augmentation: bool = False, - skip_augmentation_when_rule_covered: bool = False, validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, validation_max_entities_per_call: int = 10, validation_excerpt_window_chars: int = 160, @@ -1388,7 +1279,6 @@ def _run_probe_cases( gliner_api_key_env=config.gliner_api_key_env, gliner_threshold=config.gliner_threshold, skip_augmentation=config.skip_augmentation, - skip_augmentation_when_rule_covered=config.skip_augmentation_when_rule_covered, validation_prompt_mode=config.validation_prompt_mode, validation_max_entities_per_call=config.validation_max_entities_per_call, validation_excerpt_window_chars=config.validation_excerpt_window_chars, @@ -1467,9 +1357,6 @@ def main( gliner_api_key_env: Annotated[str, cyclopts.Parameter("--gliner-api-key-env")] = "NVIDIA_API_KEY", gliner_threshold: Annotated[float, cyclopts.Parameter("--gliner-threshold")] = 0.3, skip_augmentation: Annotated[bool, cyclopts.Parameter("--skip-augmentation")] = False, - skip_augmentation_when_rule_covered: Annotated[ - bool, cyclopts.Parameter("--skip-augmentation-when-rule-covered") - ] = False, validation_prompt_mode: Annotated[ ValidationPromptMode, cyclopts.Parameter("--validation-prompt-mode") ] = ValidationPromptMode.full_text, From c015b878d2e2380db1a8f27964d34566571ea73f Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 22:52:44 +0000 Subject: [PATCH 08/26] Harden DD trace measurement hooks Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 64 ++++++++++++++--------- tests/engine/test_ndd_adapter.py | 5 ++ tests/test_measurement.py | 76 ++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 23 deletions(-) diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 70b05f9c..f48dee4a 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -281,7 +281,7 @@ def _as_alias_list(raw: Any) -> list[str]: if isinstance(raw, str): return [raw] if isinstance(raw, (list, tuple, set)): - return [str(item) for item in raw if str(item)] + return [str(item) for item in raw if item is not None and str(item)] return [str(raw)] @@ -373,6 +373,13 @@ def _model_usage_as_json(stats: object) -> Any: @contextmanager def _dd_message_trace(*, workflow_name: str) -> Iterator[None]: + """Trace DataDesigner model messages for the active measurement context. + + DataDesigner constructs ``ModelFacade`` instances internally, so this hook + wraps the facade class while a traced workflow runs. The wrappers re-check + the active collector before writing a trace record so unrelated concurrent + workflows pass through without contaminating the traced collector. + """ collector = current_collector() if collector is None or not collector.dd_trace_enabled: yield @@ -435,17 +442,19 @@ def _run_traced_completion( error_type = type(exc).__name__ raise finally: - _record_dd_message_trace( - collector=collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - response=response, - elapsed_sec=time.perf_counter() - started, - status=status, - error_type=error_type, - is_async=False, - ) + active_collector = _active_trace_collector(collector) + if active_collector is not None: + _record_dd_message_trace( + collector=active_collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=False, + ) def _traced_acompletion(original_acompletion: Any, *, collector: Any, workflow_name: str) -> Any: @@ -485,17 +494,26 @@ async def _run_traced_acompletion( error_type = type(exc).__name__ raise finally: - _record_dd_message_trace( - collector=collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - response=response, - elapsed_sec=time.perf_counter() - started, - status=status, - error_type=error_type, - is_async=True, - ) + active_collector = _active_trace_collector(collector) + if active_collector is not None: + _record_dd_message_trace( + collector=active_collector, + workflow_name=workflow_name, + model_facade=model_facade, + messages=messages, + response=response, + elapsed_sec=time.perf_counter() - started, + status=status, + error_type=error_type, + is_async=True, + ) + + +def _active_trace_collector(expected_collector: Any) -> Any | None: + active_collector = current_collector() + if active_collector is expected_collector and active_collector.dd_trace_enabled: + return active_collector + return None def _record_dd_message_trace( diff --git a/tests/engine/test_ndd_adapter.py b/tests/engine/test_ndd_adapter.py index ea0a01a8..7563f371 100644 --- a/tests/engine/test_ndd_adapter.py +++ b/tests/engine/test_ndd_adapter.py @@ -13,6 +13,7 @@ from data_designer.config.models import ModelConfig from data_designer.interface.data_designer import DataDesigner +from anonymizer.engine.ndd import adapter as ndd_adapter from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter from anonymizer.interface.errors import AnonymizerWorkflowError @@ -60,6 +61,10 @@ def _make_columns() -> list[ColumnConfigT]: ] +def test_as_alias_list_drops_none_items_before_stringifying() -> None: + assert ndd_adapter._as_alias_list(["validator", None, "", 0]) == ["validator", "0"] + + def test_attach_record_ids_adds_deterministic_ids() -> None: adapter = NddAdapter(data_designer=Mock(spec=DataDesigner)) input_df = pd.DataFrame({"text": ["a", "b"]}) diff --git a/tests/test_measurement.py b/tests/test_measurement.py index ebcbe95c..38dc0881 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -5,6 +5,7 @@ import json import logging +import threading from pathlib import Path from types import SimpleNamespace from typing import cast @@ -745,6 +746,81 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa assert "secret response" not in serialized_measurements +def test_dd_message_trace_does_not_capture_concurrent_unmeasured_calls( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + def fake_completion( + _self: object, + _messages: list[ChatMessage], + **_kwargs: object, + ) -> ChatCompletionResponse: + return ChatCompletionResponse( + message=AssistantMessage(content="response"), + usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), + ) + + monkeypatch.setattr(ModelFacade, "completion", fake_completion) + + class TraceDataDesigner: + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + errors: list[BaseException] = [] + + def concurrent_call_without_measurement_context() -> None: + try: + ModelFacade.completion( + SimpleNamespace( + model_alias="outside", + model_name="outside-model", + model_provider_name="provider", + ), + [ChatMessage.as_user("outside prompt")], + ) + except BaseException as exc: + errors.append(exc) + + thread = threading.Thread(target=concurrent_call_without_measurement_context) + thread.start() + thread.join() + assert errors == [] + + ModelFacade.completion( + SimpleNamespace(model_alias="inside", model_name="inside-model", model_provider_name="provider"), + [ChatMessage.as_user("inside prompt")], + ) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + + adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + trace_path = tmp_path / "trace.jsonl" + + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", dd_trace="all_messages", dd_trace_path=trace_path + ) + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="inside", model="dummy")], + columns=[ + LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="inside", + ) + ], + workflow_name="entity-detection", + preview_num_records=1, + ) + + traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] + assert [trace["model_alias"] for trace in traces] == ["inside"] + serialized_traces = json.dumps(traces) + assert "inside prompt" in serialized_traces + assert "outside prompt" not in serialized_traces + + def test_record_metrics_capture_generic_counts_without_raw_values() -> None: final_entities = { "entities": [ From f4fbe73c358cbc7626fadeea30a1dd42c90a4d25 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Mon, 8 Jun 2026 23:38:19 +0000 Subject: [PATCH 09/26] Move local structured substitute to stacked branch Signed-off-by: Aaron Gonzales --- .../engine/replace/structured_substitute.py | 299 ------- src/anonymizer/measurement.py | 12 +- tests/engine/test_structured_substitute.py | 160 ---- tests/test_measurement.py | 30 - tests/tools/test_benchmark_output_analysis.py | 24 +- tests/tools/test_compare_strategy_pairs.py | 14 +- tests/tools/test_measurement_tools.py | 57 -- tests/tools/test_replacement_strategies.py | 76 -- .../test_replay_replacement_strategies.py | 430 ---------- .../tools/test_screen_strategy_comparisons.py | 26 +- tools/measurement/README.md | 31 - tools/measurement/replacement_strategies.py | 58 -- .../replay_replacement_strategies.py | 736 ------------------ tools/measurement/run_benchmarks.py | 44 +- 14 files changed, 42 insertions(+), 1955 deletions(-) delete mode 100644 src/anonymizer/engine/replace/structured_substitute.py delete mode 100644 tests/engine/test_structured_substitute.py delete mode 100644 tests/tools/test_replacement_strategies.py delete mode 100644 tests/tools/test_replay_replacement_strategies.py delete mode 100644 tools/measurement/replacement_strategies.py delete mode 100644 tools/measurement/replay_replacement_strategies.py diff --git a/src/anonymizer/engine/replace/structured_substitute.py b/src/anonymizer/engine/replace/structured_substitute.py deleted file mode 100644 index 6205c005..00000000 --- a/src/anonymizer/engine/replace/structured_substitute.py +++ /dev/null @@ -1,299 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import hashlib -import re -from collections.abc import Callable, Iterable - -import pandas as pd - -from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE -from anonymizer.engine.schemas import EntitiesByValueSchema - -REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED = "local_structured" -SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS = frozenset( - { - "api_key", - "date_of_birth", - "email", - "http_cookie", - "organization_name", - "password", - "pin", - "religious_belief", - "street_address", - "unique_id", - "url", - "user_name", - } -) - -_RELIGIOUS_BELIEF_SUBSTITUTES = ( - "agnostic", - "atheist", - "buddhist", - "catholic", - "christian", - "hindu", - "jewish", - "muslim", - "secular", -) - - -def apply_structured_substitution_maps( - dataframe: pd.DataFrame, - *, - entities_column: str = COL_ENTITIES_BY_VALUE, -) -> pd.DataFrame: - """Attach deterministic substitute maps for supported structured labels. - - This helper intentionally builds only replacement maps. Text rewriting still - uses the normal replacement-map application path, so span handling remains - identical to LLM-backed ``Substitute``. - """ - output_df = dataframe.copy() - output_df[COL_REPLACEMENT_MAP] = output_df[entities_column].apply(build_structured_substitution_map) - output_df[COL_REPLACEMENT_MAP_SOURCE] = REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED - return output_df - - -def build_structured_substitution_map(raw_entities: object) -> dict[str, list[dict[str, str]]]: - """Build a substitute map without model calls for narrow structured labels.""" - parsed = EntitiesByValueSchema.from_raw(raw_entities) - unsupported = _unsupported_labels(parsed) - if unsupported: - supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) - raise ValueError( - f"local structured substitute supports only deterministic structured labels; " - f"unsupported labels: {', '.join(unsupported)}; supported labels: {supported}" - ) - - original_values = {entity.value for entity in parsed.entities_by_value if entity.value} - synthetic_values: set[str] = set() - replacements: list[dict[str, str]] = [] - seen: set[tuple[str, str]] = set() - for entity in parsed.entities_by_value: - if not entity.value: - continue - for label in entity.labels: - if not label: - continue - key = (entity.value, label) - if key in seen: - continue - seen.add(key) - synthetic = structured_substitute_value( - entity.value, - label, - forbidden_values=original_values | synthetic_values, - ) - synthetic_values.add(synthetic) - replacements.append( - { - "original": entity.value, - "label": label, - "synthetic": synthetic, - } - ) - return {"replacements": replacements} - - -def structured_substitute_value( - value: str, - label: str, - *, - forbidden_values: Iterable[str] | None = None, -) -> str: - """Return a deterministic synthetic value for one supported structured label.""" - generator = _GENERATORS.get(label) - if generator is None: - supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) - raise ValueError(f"unsupported local structured substitute label: {label}; supported labels: {supported}") - forbidden = {str(item) for item in forbidden_values or () if item is not None} - forbidden.add(value) - - for salt in ("", "alternate"): - synthetic = generator(value, _digest(value=value, label=label, salt=salt)) - if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): - return synthetic - - fallback_index = 0 - while True: - synthetic = f"synthetic-{label}-{_digest(value=value, label=label, salt=f'fallback-{fallback_index}')[:12]}" - if _synthetic_is_allowed(synthetic, original=value, forbidden_values=forbidden): - return synthetic - fallback_index += 1 - - -def _synthetic_is_allowed(synthetic: str, *, original: str, forbidden_values: set[str]) -> bool: - """Reject self-preservation and exact collisions with other protected originals.""" - return synthetic != original and original not in synthetic and synthetic not in forbidden_values - - -def _unsupported_labels(parsed: EntitiesByValueSchema) -> list[str]: - labels = {label for entity in parsed.entities_by_value for label in entity.labels if label} - return sorted(labels - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) - - -def _digest(*, value: str, label: str, salt: str = "") -> str: - return hashlib.sha256(f"{label}\0{salt}\0{value}".encode("utf-8")).hexdigest() - - -def _api_key(value: str, digest: str) -> str: - if value.startswith("ghp_"): - return "ghp_" + digest[:36] - if value.startswith("hf_"): - return "hf_" + digest[:40] - if value.startswith("pat-"): - return "pat-" + digest[:40] - if value.startswith("xoxb-"): - return f"xoxb-{digest[:12]}-{digest[12:24]}-{digest[24:36]}" - if value.startswith("ya29."): - return "ya29." + digest[:44] - if value.startswith("AKIA"): - return "AKIA" + digest[:16].upper() - sk_match = re.match(r"^(sk-(?:test|ant-api03|proj|prod)-)", value) - if sk_match: - return sk_match.group(1) + digest[:48] - return "tok_" + digest[:32] - - -def _password(_value: str, digest: str) -> str: - return f"Synthetic!{digest[:10]}A7" - - -def _email(_value: str, digest: str) -> str: - return f"user-{digest[:12]}@example.invalid" - - -def _http_cookie(value: str, digest: str) -> str: - parts = [part.strip() for part in value.split(";") if part.strip()] - rendered: list[str] = [] - for index, part in enumerate(parts): - if "=" not in part: - continue - name = part.split("=", 1)[0].strip() or f"cookie_{index}" - rendered.append(f"{name}={_cookie_value(name=name, digest=digest, index=index)}") - if rendered: - return "; ".join(rendered) - return f"session_id={digest[:32]}; auth_token={digest[32:56]}" - - -def _cookie_value(*, name: str, digest: str, index: int) -> str: - offset = (index * 8) % max(len(digest) - 8, 1) - chunk = digest[offset : offset + 24] - normalized = name.lower() - if "session" in normalized: - return chunk[:32] - if normalized.endswith("id") or normalized == "user_id": - return str(10000 + (int(chunk[:8], 16) % 90000)) - if "token" in normalized or "auth" in normalized or "jwt" in normalized: - return f"tok_{chunk}" - return chunk - - -def _pin(value: str, digest: str) -> str: - length = max(4, min(len(value), 8)) - number = int(digest[:12], 16) % (10**length) - return f"{number:0{length}d}" - - -def _unique_id(value: str, digest: str) -> str: - if re.fullmatch(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}", value): - return f"{digest[:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" - prefix_match = re.match(r"^([A-Za-z]+[-_])", value) - if prefix_match: - prefix = prefix_match.group(1) - return f"{prefix}{digest[:20]}" - if value.isdigit(): - return str(100000 + (int(digest[:12], 16) % 900000)) - return f"id_{digest[:24]}" - - -def _user_name(value: str, digest: str) -> str: - separator = "." if "." in value else "_" if "_" in value else "" - if separator: - return f"user{separator}{digest[:10]}" - return f"user{digest[:12]}" - - -def _url(value: str, digest: str) -> str: - scheme_match = re.match(r"^([A-Za-z][A-Za-z0-9+.-]*://)", value) - scheme = scheme_match.group(1) if scheme_match else "https://" - scheme_name = scheme[:-3].lower() - if scheme_name in {"postgres", "postgresql", "mysql", "mariadb", "mongodb", "mongodb+srv", "redis", "rediss"}: - return f"{scheme}user_{digest[:8]}:Synthetic!{digest[8:16]}@db-{digest[16:24]}.example.invalid:5432/app" - return f"{scheme}synthetic-{digest[:16]}.example.invalid/resource/{digest[16:24]}" - - -def _date_of_birth(value: str, digest: str) -> str: - year = 1950 + (int(digest[:4], 16) % 50) - month = 1 + (int(digest[4:6], 16) % 12) - day = 1 + (int(digest[6:8], 16) % 28) - if re.fullmatch(r"\d{4}", value): - return str(year) - ymd = re.fullmatch(r"\d{4}([/-])\d{1,2}\1\d{1,2}", value) - if ymd: - sep = ymd.group(1) - return f"{year:04d}{sep}{month:02d}{sep}{day:02d}" - mdy = re.fullmatch(r"\d{1,2}([/-])\d{1,2}\1(\d{2}|\d{4})", value) - if mdy: - sep = mdy.group(1) - rendered_year = f"{year % 100:02d}" if len(value.rsplit(sep, 1)[-1]) == 2 else f"{year:04d}" - return f"{month:02d}{sep}{day:02d}{sep}{rendered_year}" - return f"{year:04d}-{month:02d}-{day:02d}" - - -def _street_address(value: str, digest: str) -> str: - suffix_match = re.search( - r"\b(Street|St\.?|Avenue|Ave\.?|Road|Rd\.?|Drive|Dr\.?|Trail|Boulevard|Blvd\.?|Lane|Ln\.?|Court|Ct\.?)$", - value, - ) - suffix = suffix_match.group(1) if suffix_match else "Street" - number = 100 + (int(digest[:4], 16) % 8900) - return f"{number} Cedar Ridge {suffix}" - - -def _organization_name(value: str, digest: str) -> str: - suffixes = ( - "Center", - "Hospital", - "Clinic", - "University", - "College", - "Institute", - "Bank", - "Builders", - "Construction", - "Woodworks", - "Health", - ) - suffix = next((candidate for candidate in suffixes if value.endswith(candidate)), "Group") - prefixes = ("Northbridge", "Helios", "Mariner", "Summit", "Cedar") - prefix = prefixes[int(digest[:2], 16) % len(prefixes)] - return f"{prefix} {suffix}" - - -def _religious_belief(value: str, digest: str) -> str: - normalized = value.lower() - candidates = [candidate for candidate in _RELIGIOUS_BELIEF_SUBSTITUTES if candidate != normalized] - return candidates[int(digest[:2], 16) % len(candidates)] - - -_GENERATORS: dict[str, Callable[[str, str], str]] = { - "api_key": _api_key, - "date_of_birth": _date_of_birth, - "email": _email, - "http_cookie": _http_cookie, - "organization_name": _organization_name, - "password": _password, - "pin": _pin, - "religious_belief": _religious_belief, - "street_address": _street_address, - "unique_id": _unique_id, - "url": _url, - "user_name": _user_name, -} diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index d6c77be3..18c249eb 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -858,12 +858,8 @@ def _llm_record_fields( def _replace_map_generation_uses_llm(row: Any, *, columns: set[str]) -> bool: - from anonymizer.engine.constants import COL_REPLACEMENT_MAP_SOURCE - from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED - - if COL_REPLACEMENT_MAP_SOURCE not in columns: - return True - return row.get(COL_REPLACEMENT_MAP_SOURCE) != REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED + del row, columns + return True def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: @@ -1174,7 +1170,9 @@ def _entity_relaxed_ground_truth_metrics( ground_truth_entities: list[dict[str, Any]], ) -> dict[str, Any]: gt_found = sum( - 1 for ground_truth_entity in ground_truth_entities if _has_relaxed_entity_match(final_entities, ground_truth_entity) + 1 + for ground_truth_entity in ground_truth_entities + if _has_relaxed_entity_match(final_entities, ground_truth_entity) ) detected_tp = sum( 1 for final_entity in final_entities if _has_relaxed_entity_match(ground_truth_entities, final_entity) diff --git a/tests/engine/test_structured_substitute.py b/tests/engine/test_structured_substitute.py deleted file mode 100644 index 83014cba..00000000 --- a/tests/engine/test_structured_substitute.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import pandas as pd -import pytest - -from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_REPLACEMENT_MAP -from anonymizer.engine.replace import structured_substitute as structured_substitute_module -from anonymizer.engine.replace.structured_substitute import ( - apply_structured_substitution_maps, - build_structured_substitution_map, - structured_substitute_value, -) - - -@pytest.mark.parametrize( - ("label", "value"), - [ - ("api_key", "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"), - ("date_of_birth", "1978-02-03"), - ("email", "alice@example.com"), - ("http_cookie", "session_id=abc123xyz; user_id=12345; auth_token=secret-token"), - ("organization_name", "Acme Research Center"), - ("password", "CorrectHorse!123"), - ("pin", "97294"), - ("religious_belief", "secular"), - ("street_address", "123 Maple Street"), - ("unique_id", "req_KA5k78XNwT0yUNZkPpwq"), - ("url", "https://staging.example.com/admin"), - ("user_name", "sloanenguy217"), - ], -) -def test_structured_substitute_value_does_not_preserve_original(label: str, value: str) -> None: - synthetic = structured_substitute_value(value, label) - - assert synthetic != value - assert value not in synthetic - - -def test_build_structured_substitution_map_for_entities_by_value() -> None: - raw_entities = { - "entities_by_value": [ - {"value": "alice@example.com", "labels": ["email"]}, - {"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "labels": ["api_key"]}, - ] - } - - replacement_map = build_structured_substitution_map(raw_entities) - - replacements = replacement_map["replacements"] - assert [(item["original"], item["label"]) for item in replacements] == [ - ("alice@example.com", "email"), - ("sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "api_key"), - ] - serialized = str(replacement_map) - assert "alice@example.com" in serialized - assert "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" in serialized - assert all(item["original"] != item["synthetic"] for item in replacements) - - -def test_build_structured_substitution_map_avoids_other_original_values(monkeypatch: pytest.MonkeyPatch) -> None: - first_original = "alice@example.com" - second_original = "bob@example.com" - first_default_digest = structured_substitute_module._digest(value=first_original, label="email") - - def synthetic_email(value: str, digest: str) -> str: - if value == first_original and digest == first_default_digest: - return second_original - return f"user-{digest[:12]}@example.invalid" - - monkeypatch.setitem(structured_substitute_module._GENERATORS, "email", synthetic_email) - raw_entities = { - "entities_by_value": [ - {"value": first_original, "labels": ["email"]}, - {"value": second_original, "labels": ["email"]}, - ] - } - - replacement_map = build_structured_substitution_map(raw_entities) - - replacements = {(item["original"], item["label"]): item["synthetic"] for item in replacement_map["replacements"]} - assert replacements[(first_original, "email")] != second_original - assert set(replacements.values()).isdisjoint({first_original, second_original}) - - -def test_build_structured_substitution_map_avoids_duplicate_synthetic_values( - monkeypatch: pytest.MonkeyPatch, -) -> None: - first_original = "Acme Research Center" - second_original = "Globex Research Center" - default_digests = { - structured_substitute_module._digest(value=first_original, label="organization_name"), - structured_substitute_module._digest(value=second_original, label="organization_name"), - } - - def synthetic_organization(_value: str, digest: str) -> str: - if digest in default_digests: - return "Northbridge Center" - return f"Helios {digest[:8]} Center" - - monkeypatch.setitem(structured_substitute_module._GENERATORS, "organization_name", synthetic_organization) - raw_entities = { - "entities_by_value": [ - {"value": first_original, "labels": ["organization_name"]}, - {"value": second_original, "labels": ["organization_name"]}, - ] - } - - replacement_map = build_structured_substitution_map(raw_entities) - - synthetics = [item["synthetic"] for item in replacement_map["replacements"]] - assert len(synthetics) == 2 - assert len(set(synthetics)) == 2 - assert "Northbridge Center" in synthetics - - -def test_structured_cookie_substitute_preserves_cookie_shape() -> None: - synthetic = structured_substitute_value( - "session_id=abc123xyz; user_id=12345; auth_token=secret-token", - "http_cookie", - ) - - assert "session_id=" in synthetic - assert "user_id=" in synthetic - assert "auth_token=" in synthetic - session_value = synthetic.split("session_id=", 1)[1].split(";", 1)[0] - assert len(session_value) >= 16 - assert not session_value.isdigit() - assert "abc123xyz" not in synthetic - assert "secret-token" not in synthetic - - -def test_structured_unique_id_substitute_preserves_uuid_shape() -> None: - synthetic = structured_substitute_value("1ce21179-998b-447b-2dee-3e8adb6afa35", "unique_id") - - assert synthetic.count("-") == 4 - assert synthetic != "1ce21179-998b-447b-2dee-3e8adb6afa35" - - -def test_build_structured_substitution_map_rejects_unsupported_labels() -> None: - raw_entities = {"entities_by_value": [{"value": "Alice", "labels": ["person"]}]} - - with pytest.raises(ValueError, match="unsupported labels: person"): - build_structured_substitution_map(raw_entities) - - -def test_apply_structured_substitution_maps_adds_replacement_map_column() -> None: - dataframe = pd.DataFrame( - {COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": "alice@example.com", "labels": ["email"]}]}]} - ) - - output = apply_structured_substitution_maps(dataframe) - - assert COL_REPLACEMENT_MAP in output.columns - replacement = output[COL_REPLACEMENT_MAP].iloc[0]["replacements"][0] - assert replacement["original"] == "alice@example.com" - assert replacement["label"] == "email" - assert replacement["synthetic"].endswith("@example.invalid") diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 38dc0881..9a4665be 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -34,7 +34,6 @@ COL_REPAIR_ITERATIONS, COL_REPLACED_TEXT, COL_REPLACEMENT_MAP, - COL_REPLACEMENT_MAP_SOURCE, COL_SEED_VALIDATION_CANDIDATES, COL_TEXT, COL_UTILITY_SCORE, @@ -43,7 +42,6 @@ from anonymizer.engine.detection.detection_workflow import EntityDetectionResult, EntityDetectionWorkflow from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow -from anonymizer.engine.replace.structured_substitute import REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import ( @@ -1121,34 +1119,6 @@ def test_record_metrics_counts_standalone_short_value_replacement_leaks() -> Non assert record["original_value_leak_label_counts"] == {"age": 1} -def test_record_metrics_do_not_estimate_llm_replace_map_call_for_local_structured_source() -> None: - dataframe = pd.DataFrame( - { - COL_TEXT: ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"], - COL_FINAL_ENTITIES: [{"entities": [{"value": "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA", "label": "api_key"}]}], - COL_REPLACEMENT_MAP: [{"replacements": []}], - COL_REPLACEMENT_MAP_SOURCE: [REPLACEMENT_MAP_SOURCE_LOCAL_STRUCTURED], - } - ) - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - record_record_metrics( - dataframe, - mode="replace", - strategy="Substitute", - text_column=COL_TEXT, - validation_max_entities_per_call=100, - ) - - record = collector.records[0] - assert record["llm_calls_estimated_by_stage"] == { - "entity_detection": None, - "replace_map_generation": 0, - } - assert record["llm_calls_estimated_total"] is None - - def test_record_metrics_normalizes_integral_row_index_types() -> None: dataframe = pd.DataFrame( { diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index 76edb177..e83de30d 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -216,7 +216,7 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp "workload_id": "shell", "config_id": "native-local", "experimental_detection_strategy": "native_single_pass", - "experimental_replacement_strategy": "local_structured_substitute", + "experimental_replacement_strategy": "custom_replacement_strategy", "dd_parser_compat": "raw_json", "repetition": 0, "case_id": "shell__native-local__r000", @@ -423,7 +423,7 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp } assert cases["bio__default__r000"].artifact_final_entity_signature_count == 2 assert cases["shell__native-local__r000"].observed_total_requests == 0 - assert cases["shell__native-local__r000"].experimental_replacement_strategy == "local_structured_substitute" + assert cases["shell__native-local__r000"].experimental_replacement_strategy == "custom_replacement_strategy" assert cases["shell__native-local__r000"].observed_failed_request_rate is None assert cases["shell__native-local__r000"].observed_bridge_fallback_requests is None assert cases["shell__native-local__r000"].observed_non_bridge_failed_requests is None @@ -440,7 +440,9 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert cases["shell__native-local__r000"].artifact_final_augmenter_entity_count == 8 assert cases["shell__native-local__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] assert cases["shell__native-local__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} - assert cases["shell__native-local__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" + assert ( + cases["shell__native-local__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" + ) model_rows = {row.model_name: row for row in result.model_usage} assert model_rows["nvidia/gliner-pii"].observed_failed_requests == 0 assert model_rows["nvidia/gliner-pii"].observed_failed_request_rate == 0 @@ -498,7 +500,7 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert bio_group.micro_entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) assert bio_group.micro_entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) shell_group = next(group for group in result.groups if group.workload_id == "shell") - assert shell_group.experimental_replacement_strategy == "local_structured_substitute" + assert shell_group.experimental_replacement_strategy == "custom_replacement_strategy" assert shell_group.sum_original_value_leak_count == 1 assert shell_group.leaking_case_count == 1 assert shell_group.median_original_value_leak_count == 1 @@ -692,7 +694,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> workload_id="shell", config_id="native", experimental_detection_strategy="native_single_pass", - experimental_replacement_strategy="local_structured_substitute", + experimental_replacement_strategy="custom_replacement_strategy", dd_parser_compat="raw_json", repetition=0, case_id="shell__native__r000", @@ -705,7 +707,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> workload_id="shell", config_id="native", experimental_detection_strategy="native_single_pass", - experimental_replacement_strategy="local_structured_substitute", + experimental_replacement_strategy="custom_replacement_strategy", case_count=1, median_final_entity_count=8, median_observed_successful_requests=0, @@ -720,7 +722,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> workload_id="shell", config_id="native", experimental_detection_strategy="native_single_pass", - experimental_replacement_strategy="local_structured_substitute", + experimental_replacement_strategy="custom_replacement_strategy", dd_parser_compat="raw_json", case_id="shell__native__r000", run_id="shell__native__r000", @@ -736,7 +738,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> workload_id="shell", config_id="native", experimental_detection_strategy="native_single_pass", - experimental_replacement_strategy="local_structured_substitute", + experimental_replacement_strategy="custom_replacement_strategy", dd_parser_compat="raw_json", workflow_name="entity-detection", model_name="nvidia/gliner-pii", @@ -756,7 +758,7 @@ def test_write_analysis_tables_exports_case_and_group_tables(tmp_path: Path) -> assert pd.read_csv(output_dir / "case_analysis.csv")["case_id"].tolist() == ["shell__native__r000"] assert pd.read_csv(output_dir / "case_analysis.csv")["experimental_replacement_strategy"].tolist() == [ - "local_structured_substitute" + "custom_replacement_strategy" ] assert pd.read_csv(output_dir / "group_analysis.csv")["case_count"].tolist() == [1] assert pd.read_csv(output_dir / "model_analysis.csv")["model_name"].tolist() == ["nvidia/gliner-pii"] @@ -825,7 +827,7 @@ def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_p "workload_id": "secrets", "config_id": "candidate", "experimental_detection_strategy": "native_single_pass", - "experimental_replacement_strategy": "local_structured_substitute", + "experimental_replacement_strategy": "custom_replacement_strategy", "case_id": "secrets__candidate__r001", }, }, @@ -837,7 +839,7 @@ def test_analyze_benchmark_output_groups_replacement_strategies_separately(tmp_p assert result.group_count == 2 assert {group.experimental_replacement_strategy for group in result.groups} == { "default", - "local_structured_substitute", + "custom_replacement_strategy", } diff --git a/tests/tools/test_compare_strategy_pairs.py b/tests/tools/test_compare_strategy_pairs.py index ad0ab436..4576fe81 100644 --- a/tests/tools/test_compare_strategy_pairs.py +++ b/tests/tools/test_compare_strategy_pairs.py @@ -48,7 +48,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N "workload_id": "shell-1", "config_id": "shell-native", "experimental_detection_strategy": "native_candidate_validate_no_augment", - "experimental_replacement_strategy": "local_structured_substitute", + "experimental_replacement_strategy": "custom_replacement_strategy", "case_id": "shell__candidate", "pipeline_elapsed_sec": 0.8, "observed_total_requests": 1, @@ -78,7 +78,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N "workload_id": "legal-1", "config_id": "legal-native", "experimental_detection_strategy": "native_candidate_validate_no_augment", - "experimental_replacement_strategy": "local_structured_substitute", + "experimental_replacement_strategy": "custom_replacement_strategy", "case_id": "legal__candidate", "pipeline_elapsed_sec": 20.9, "observed_total_requests": 3, @@ -100,7 +100,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N by_workload = {row.workload_id: row for row in rows} shell = by_workload["shell-1"] assert shell.baseline_replacement_strategy == "default" - assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.candidate_replacement_strategy == "custom_replacement_strategy" assert shell.final_entity_count_delta == 4 assert shell.observed_total_requests_delta == -2 assert shell.observed_total_tokens_delta == -2774 @@ -119,7 +119,7 @@ def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> N legal = by_workload["legal-1"] assert legal.baseline_replacement_strategy == "default" - assert legal.candidate_replacement_strategy == "local_structured_substitute" + assert legal.candidate_replacement_strategy == "custom_replacement_strategy" assert legal.final_entity_count_delta == 0 assert legal.observed_total_tokens_delta == 0 assert legal.candidate_detector_entity_count == 26 @@ -896,7 +896,7 @@ def test_compare_case_analysis_flags_replacement_only_detection_instability() -> "workload_id": "structured-identifiers", "config_id": "local-substitute", "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "local_structured_substitute", + "experimental_replacement_strategy": "custom_replacement_strategy", "case_id": "candidate-r0", "pipeline_elapsed_sec": 7, "observed_total_requests": 3, @@ -1374,7 +1374,7 @@ def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: baseline_config_id="base", candidate_config_id="candidate", baseline_replacement_strategy="default", - candidate_replacement_strategy="local_structured_substitute", + candidate_replacement_strategy="custom_replacement_strategy", baseline_case_count=1, candidate_case_count=1, value_protection_verdict="review", @@ -1394,7 +1394,7 @@ def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: exported = pd.read_csv(output) assert exported["workload_id"].tolist() == ["shell-1"] - assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["candidate_replacement_strategy"].tolist() == ["custom_replacement_strategy"] assert exported["final_entity_count_delta"].tolist() == [4] assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 99642535..e3f5652d 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -753,63 +753,6 @@ def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) tool.preflight_suite(spec, spec_path=spec_path) -def test_benchmark_preflight_rejects_local_structured_substitute_for_contextual_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_local_substitute_contextual_labels", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice has token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: local-substitute-suite -workloads: - - id: input - source: input.csv -configs: - - id: local-substitute - detect: - entity_labels: [api_key, person] - replace: substitute - experimental_replacement_strategy: local_structured_substitute -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="unsupported local structured substitute labels: person"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_preflight_accepts_local_structured_substitute_supported_labels(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_local_substitute_supported_labels", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(input_path, index=False) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: local-substitute-suite -workloads: - - id: input - source: input.csv -configs: - - id: local-substitute - detect: - entity_labels: [api_key, email, http_cookie, password, pin, unique_id, url, user_name] - replace: substitute - experimental_replacement_strategy: local_structured_substitute -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - - def test_benchmark_example_suites_are_portable() -> None: example_paths = sorted((REPO_ROOT / "tools/measurement/examples").glob("*.yaml")) assert example_paths diff --git a/tests/tools/test_replacement_strategies.py b/tests/tools/test_replacement_strategies.py deleted file mode 100644 index 3b1ce86b..00000000 --- a/tests/tools/test_replacement_strategies.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType -from unittest.mock import Mock - -import pandas as pd - -from anonymizer.config.replace_strategies import Substitute -from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE, COL_FINAL_ENTITIES, COL_REPLACED_TEXT, COL_TEXT -from anonymizer.engine.ndd.model_loader import load_default_model_selection -from anonymizer.engine.replace.llm_replace_workflow import LlmReplaceWorkflow -from anonymizer.engine.replace.replace_runner import ReplacementWorkflow - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def test_local_structured_substitute_context_bypasses_dd_and_restores_method() -> None: - tool = load_tool( - "measurement_replacement_strategies_local_substitute", - REPO_ROOT / "tools/measurement/replacement_strategies.py", - ) - original_method = LlmReplaceWorkflow.generate_map_only - secret = "sk-test-AAAAAAAAAAAAAAAAAAAAAAAA" - adapter = Mock() - runner = ReplacementWorkflow(llm_workflow=LlmReplaceWorkflow(adapter=adapter)) - dataframe = pd.DataFrame( - { - COL_TEXT: [f"export API_KEY={secret}"], - COL_FINAL_ENTITIES: [ - { - "entities": [ - { - "value": secret, - "label": "api_key", - "start_position": len("export API_KEY="), - "end_position": len("export API_KEY=") + len(secret), - } - ] - } - ], - COL_ENTITIES_BY_VALUE: [{"entities_by_value": [{"value": secret, "labels": ["api_key"]}]}], - } - ) - - with tool.experimental_replacement_strategy_context( - tool.ExperimentalReplacementStrategy.local_structured_substitute - ): - result = runner.run( - dataframe, - replace_method=Substitute(), - model_configs=[], - selected_models=load_default_model_selection().replace, - ) - - adapter.run_workflow.assert_not_called() - assert LlmReplaceWorkflow.generate_map_only is original_method - replaced = result.dataframe[COL_REPLACED_TEXT].iloc[0] - assert secret not in replaced - assert "sk-test-" in replaced diff --git a/tests/tools/test_replay_replacement_strategies.py b/tests/tools/test_replay_replacement_strategies.py deleted file mode 100644 index 482d1919..00000000 --- a/tests/tools/test_replay_replacement_strategies.py +++ /dev/null @@ -1,430 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType, SimpleNamespace - -import pandas as pd - -from anonymizer.engine.constants import ( - COL_FINAL_ENTITIES, - COL_REPLACED_TEXT, - COL_REPLACEMENT_MAP, - COL_REPLACEMENT_MAP_SOURCE, - COL_TEXT, -) - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def test_replacement_row_metrics_counts_sanitized_replay_failures() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_metrics", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - row = pd.Series( - { - COL_FINAL_ENTITIES: { - "entities": [ - {"value": "secret-a", "label": "api_key", "start_position": 0, "end_position": 8}, - {"value": "secret-b", "label": "password", "start_position": 9, "end_position": 17}, - {"value": "1234", "label": "pin", "start_position": 18, "end_position": 22}, - ] - }, - COL_REPLACEMENT_MAP: { - "replacements": [ - {"original": "secret-a", "label": "api_key", "synthetic": "secret-b"}, - {"original": "secret-b", "label": "password", "synthetic": "Synthetic!123"}, - ] - }, - COL_REPLACED_TEXT: "the leaked pin is 1234", - } - ) - - metrics = tool.replacement_row_metrics(row) - - assert metrics.final_entity_count == 3 - assert metrics.replacement_count == 2 - assert metrics.missing_count == 1 - assert metrics.missing_labels == {"pin": 1} - assert metrics.collision_count == 1 - assert metrics.collision_labels == {"api_key": 1} - assert metrics.leak_count == 1 - assert metrics.leak_labels == {"pin": 1} - - -def test_summarize_replacement_dataframe_counts_sources_and_totals() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_summary", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - dataframe = pd.DataFrame( - [ - { - COL_FINAL_ENTITIES: { - "entities": [ - {"value": "alice@example.com", "label": "email", "start_position": 0, "end_position": 17} - ] - }, - COL_REPLACEMENT_MAP: { - "replacements": [ - { - "original": "alice@example.com", - "label": "email", - "synthetic": "user-123@example.invalid", - } - ] - }, - COL_REPLACED_TEXT: "user-123@example.invalid", - COL_REPLACEMENT_MAP_SOURCE: "local_structured", - } - ] - ) - - summary = tool.summarize_replacement_dataframe( - dataframe, - strategy=tool.ReplacementReplayStrategy.local_structured_substitute, - elapsed_sec=0.01, - ) - - assert summary.status == tool.ReplayStatus.completed - assert summary.final_entity_count == 1 - assert summary.replacement_count == 1 - assert summary.missing_count == 0 - assert summary.collision_count == 0 - assert summary.leak_count == 0 - assert summary.replacement_map_sources == {"local_structured": 1} - - -def test_aggregate_strategy_summaries_sums_repeated_backend_runs() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_aggregate", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - - summary = tool.aggregate_strategy_summaries( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - summaries=[ - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=1.5, - row_count=5, - final_entity_count=10, - replacement_count=9, - missing_count=1, - leak_count=1, - missing_labels={"api_key": 1}, - leak_labels={"api_key": 1}, - replacement_map_sources={"llm": 5}, - ), - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.error, - elapsed_sec=2.5, - row_count=5, - final_entity_count=10, - replacement_count=10, - collision_count=1, - collision_labels={"password": 1}, - replacement_map_sources={"llm": 4}, - error="provider failed", - ), - ], - ) - - assert summary.status == tool.ReplayStatus.error - assert summary.repetition_count == 2 - assert summary.elapsed_sec == 4.0 - assert summary.row_count == 10 - assert summary.final_entity_count == 20 - assert summary.replacement_count == 19 - assert summary.missing_count == 1 - assert summary.collision_count == 1 - assert summary.leak_count == 1 - assert summary.missing_labels == {"api_key": 1} - assert summary.collision_labels == {"password": 1} - assert summary.leak_labels == {"api_key": 1} - assert summary.replacement_map_sources == {"llm": 9} - assert summary.error == "provider failed" - - -def test_replay_comparison_row_marks_fast_complete_local_replay_viable() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_comparison_row", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - result = tool.ReplacementReplayResult( - input_path="/tmp/structured_identifiers.csv", - text_column="text", - labels=["api_key"], - dd_parser_compat=tool.DDParserCompatMode.raw_json, - detect_elapsed_sec=8.8, - detected_final_entity_count=2, - strategies=[ - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=6.2, - row_count=1, - final_entity_count=2, - replacement_count=2, - ), - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.local_structured_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=0.003, - row_count=1, - final_entity_count=2, - replacement_count=2, - ), - ], - ) - - row = tool.replay_comparison_row(result) - - assert row.workload_id == "structured_identifiers" - assert row.baseline_replacement_strategy == "default" - assert row.candidate_replacement_strategy == "local_structured_substitute" - assert row.pipeline_elapsed_sec_delta < 0 - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "pass" - assert row.safety_verdict == "pass" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "candidate_viable" - assert row.flags == [] - - -def test_replay_comparison_row_rejects_missing_local_replacements() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_comparison_missing", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - result = tool.ReplacementReplayResult( - input_path="/tmp/structured_identifiers.csv", - text_column="text", - labels=["api_key"], - dd_parser_compat=tool.DDParserCompatMode.raw_json, - detect_elapsed_sec=8.8, - detected_final_entity_count=2, - strategies=[ - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=6.2, - row_count=1, - final_entity_count=2, - replacement_count=2, - ), - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.local_structured_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=0.003, - row_count=1, - final_entity_count=2, - replacement_count=1, - missing_count=1, - ), - ], - ) - - row = tool.replay_comparison_row(result) - - assert row.candidate_replacement_missing_final_entity_count == 1 - assert "candidate_replacement_missing_final_entity" in row.flags - assert "replacement_count_loss" in row.flags - assert row.value_protection_verdict == "fail" - assert row.signature_parity_verdict == "fail" - assert row.safety_verdict == "fail" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "reject" - - -def test_replay_comparison_row_reviews_duplicate_local_synthetic_values() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_comparison_duplicates", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - result = tool.ReplacementReplayResult( - input_path="/tmp/biographies.csv", - text_column="biography", - labels=["organization_name", "religious_belief"], - nrows=5, - dd_parser_compat=tool.DDParserCompatMode.raw_json, - detect_elapsed_sec=18.8, - detected_final_entity_count=10, - strategies=[ - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=7.9, - row_count=5, - final_entity_count=10, - replacement_count=10, - ), - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.local_structured_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=0.003, - row_count=5, - final_entity_count=10, - replacement_count=10, - duplicate_synthetic_count=2, - ), - ], - ) - - row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") - - assert row.candidate_duplicate_synthetic_replacement_count == 2 - assert "candidate_duplicate_synthetic_replacement" in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_replay_comparison_row_reviews_candidate_replacement_count_gain_over_flawed_baseline() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_comparison_baseline_flaw", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - result = tool.ReplacementReplayResult( - input_path="/tmp/biographies.csv", - text_column="biography", - labels=["organization_name"], - nrows=5, - dd_parser_compat=tool.DDParserCompatMode.raw_json, - detect_elapsed_sec=12.9, - detected_final_entity_count=2, - strategies=[ - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.dd_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=7.6, - row_count=5, - final_entity_count=2, - replacement_count=1, - missing_count=1, - leak_count=1, - missing_labels={"organization_name": 1}, - leak_labels={"organization_name": 1}, - ), - tool.ReplacementReplaySummary( - strategy=tool.ReplacementReplayStrategy.local_structured_substitute, - status=tool.ReplayStatus.completed, - elapsed_sec=0.003, - row_count=5, - final_entity_count=2, - replacement_count=2, - ), - ], - ) - - row = tool.replay_comparison_row(result, workload_id="biography-supported-structured") - - assert row.replacement_count_delta == 1 - assert row.replacement_missing_final_entity_count_delta == -1 - assert row.baseline_replacement_missing_final_entity_label_counts == {"organization_name": 1} - assert row.candidate_replacement_missing_final_entity_label_counts == {} - assert row.baseline_original_value_leak_label_counts == {"organization_name": 1} - assert row.candidate_original_value_leak_label_counts == {} - assert "baseline_replacement_missing_final_entity" in row.flags - assert "baseline_original_value_leak" in row.flags - assert "candidate_covers_baseline_replacement_missing_final_entity" in row.flags - assert "candidate_covers_baseline_original_value_leak" in row.flags - assert "replacement_count_delta" in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_strip_replacement_columns_removes_prior_strategy_outputs() -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_strip", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - dataframe = pd.DataFrame( - { - "text": ["hello"], - COL_REPLACEMENT_MAP: [{"replacements": []}], - COL_REPLACEMENT_MAP_SOURCE: ["redact"], - COL_REPLACED_TEXT: ["hello"], - } - ) - - stripped = tool.strip_replacement_columns(dataframe) - - assert list(stripped.columns) == ["text"] - - -def test_build_replay_dataframe_uses_preview_when_nrows_is_set(tmp_path: Path) -> None: - tool = load_tool( - "measurement_replay_replacement_strategies_nrows", - REPO_ROOT / "tools/measurement/replay_replacement_strategies.py", - ) - trace_dataframe = pd.DataFrame( - { - COL_TEXT: ["born on 1988-11-21"], - COL_FINAL_ENTITIES: [ - { - "entities": [ - { - "value": "1988-11-21", - "label": "date_of_birth", - "start_position": 8, - "end_position": 18, - } - ] - } - ], - COL_REPLACEMENT_MAP: [{"replacements": []}], - COL_REPLACED_TEXT: ["born on [REDACTED_DATE_OF_BIRTH]"], - } - ) - - class StubAnonymizer: - preview_num_records: int | None = None - - def run(self, **_kwargs: object) -> object: - raise AssertionError("run should not be used for row-limited replay") - - def preview(self, **kwargs: object) -> object: - self.preview_num_records = kwargs["num_records"] # type: ignore[assignment] - return SimpleNamespace(trace_dataframe=trace_dataframe, resolved_text_column="text") - - anonymizer = StubAnonymizer() - source = tmp_path / "multiline.csv" - source.write_text('text\n"born on 1988-11-21"\n', encoding="utf-8") - - _elapsed, replay_df = tool.build_replay_dataframe( - anonymizer, - source=source, - text_column="text", - labels=["date_of_birth"], - nrows=1, - dd_parser_compat=tool.DDParserCompatMode.none, - ) - - assert anonymizer.preview_num_records == 1 - assert COL_REPLACEMENT_MAP not in replay_df.columns - assert COL_REPLACED_TEXT not in replay_df.columns - assert replay_df[COL_FINAL_ENTITIES].iloc[0]["entities"][0]["value"] == "1988-11-21" diff --git a/tests/tools/test_screen_strategy_comparisons.py b/tests/tools/test_screen_strategy_comparisons.py index 8b2fa237..ee690cac 100644 --- a/tests/tools/test_screen_strategy_comparisons.py +++ b/tests/tools/test_screen_strategy_comparisons.py @@ -39,7 +39,7 @@ def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) "baseline_strategy": "default", "candidate_strategy": "detector_only", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "baseline_case_count": 3, "candidate_case_count": 3, "safety_verdict": "review", @@ -117,16 +117,16 @@ def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) assert legal.augmented_new_final_value_count_delta == -2 shell = next(row for row in result.rows if row.workload_id == "shell-3") assert shell.baseline_replacement_strategy == "default" - assert shell.candidate_replacement_strategy == "local_structured_substitute" + assert shell.candidate_replacement_strategy == "custom_replacement_strategy" assert shell.baseline_case_count == 3 assert shell.candidate_case_count == 3 assert shell.shared_stable_final_entity_signature_count == 12 detector_local = next( group for group in result.groups - if group.group_key == "strategy:detector_only|replacement:local_structured_substitute" + if group.group_key == "strategy:detector_only|replacement:custom_replacement_strategy" ) - assert detector_local.candidate_replacement_strategy == "local_structured_substitute" + assert detector_local.candidate_replacement_strategy == "custom_replacement_strategy" assert detector_local.row_count == 1 no_augment = next(group for group in result.groups if group.group_key == "strategy:no_augment") assert no_augment.row_count == 1 @@ -146,7 +146,7 @@ def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: baseline_config_id="default", candidate_config_id="detector-only", baseline_replacement_strategy="default", - candidate_replacement_strategy="local_structured_substitute", + candidate_replacement_strategy="custom_replacement_strategy", safety_verdict="review", performance_verdict="improved", candidate_verdict="review", @@ -159,7 +159,7 @@ def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: exported = pd.read_csv(output) assert exported["workload_id"].tolist() == ["shell"] - assert exported["candidate_replacement_strategy"].tolist() == ["local_structured_substitute"] + assert exported["candidate_replacement_strategy"].tolist() == ["custom_replacement_strategy"] assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] @@ -205,7 +205,7 @@ def test_screen_strategy_comparisons_surfaces_evidence_level_counts(tmp_path: Pa "baseline_config_id": "default", "candidate_config_id": "local-substitute", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "value_protection_verdict": "pass", "signature_parity_verdict": "review", "safety_verdict": "review", @@ -217,7 +217,7 @@ def test_screen_strategy_comparisons_surfaces_evidence_level_counts(tmp_path: Pa "baseline_config_id": "default", "candidate_config_id": "local-substitute-legacy", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "safety_verdict": "pass", "performance_verdict": "improved", "candidate_verdict": "candidate_viable", @@ -348,7 +348,7 @@ def test_screen_strategy_comparisons_surfaces_reliability_review(tmp_path: Path) "baseline_config_id": "default-substitute", "candidate_config_id": "local-substitute", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "value_protection_verdict": "pass", "signature_parity_verdict": "pass", "safety_verdict": "review", @@ -380,7 +380,7 @@ def test_screen_strategy_comparisons_surfaces_replacement_replay_review(tmp_path "baseline_strategy": "default", "candidate_strategy": "default", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "value_protection_verdict": "pass", "signature_parity_verdict": "review", "safety_verdict": "review", @@ -416,7 +416,7 @@ def test_screen_strategy_comparisons_surfaces_baseline_defect_improvement_review "baseline_strategy": "default", "candidate_strategy": "default", "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "local_structured_substitute", + "candidate_replacement_strategy": "custom_replacement_strategy", "value_protection_verdict": "pass", "signature_parity_verdict": "review", "safety_verdict": "review", @@ -487,13 +487,13 @@ def test_screen_strategy_comparisons_groups_default_detection_by_replacement_str candidate_config_id="substitute-local", candidate_strategy="default", baseline_replacement_strategy="default", - candidate_replacement_strategy="local_structured_substitute", + candidate_replacement_strategy="custom_replacement_strategy", safety_verdict="pass", performance_verdict="improved", candidate_verdict="candidate_viable", ) - assert tool.group_base_for_row(row, config_aliases={}) == "replacement:local_structured_substitute" + assert tool.group_base_for_row(row, config_aliases={}) == "replacement:custom_replacement_strategy" def test_screen_strategy_comparisons_keeps_generic_review_without_leak_metrics() -> None: diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 69a017af..a970fc4d 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -214,37 +214,6 @@ more parser- and recall-sensitive. Any malformed JSON response becomes a failed case in analysis, and any missed baseline signature should be treated as a rejection rather than a latency win. -## Replacement strategy probes - -Replacement-map generation has a separate benchmark-only knob: - -```yaml -configs: - - id: structured-local-substitute - experimental_replacement_strategy: local_structured_substitute - detect: - entity_labels: [api_key, email, password, url] - replace: - strategy: substitute -``` - -Supported replacement strategy values: - -- `default`: run normal replacement behavior. `Substitute` uses the configured - `replacement_generator` role through DataDesigner. -- `local_structured_substitute`: for `replace: substitute`, build deterministic - synthetic replacement maps locally for supported structured labels. Text - replacement still uses Anonymizer's normal replacement-map application code. - -`local_structured_substitute` requires `replace: substitute` and explicit -`detect.entity_labels`. Every label must be one of the structured substitute -labels: `api_key`, `date_of_birth`, `email`, `organization_name`, `password`, -`http_cookie`, `pin`, `religious_belief`, `street_address`, `unique_id`, `url`, -or `user_name`. The preflight rejects contextual labels such as `person`. This -is deliberate. The local substitute map generator does not understand names, -social relations, cultural consistency, or prose semantics; use the default -DataDesigner-backed `Substitute` path for those workloads. - ## DataDesigner traces For debugging DataDesigner calls, pass `--dd-trace last-message` or diff --git a/tools/measurement/replacement_strategies.py b/tools/measurement/replacement_strategies.py deleted file mode 100644 index 0089ac72..00000000 --- a/tools/measurement/replacement_strategies.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Experimental replacement strategies for benchmark-only performance probes.""" - -from __future__ import annotations - -from collections.abc import Iterator -from contextlib import contextmanager -from enum import StrEnum - -import pandas as pd -from data_designer.config.models import ModelConfig - -from anonymizer.config.models import ReplaceModelSelection -from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE -from anonymizer.engine.replace import llm_replace_workflow as lrw -from anonymizer.engine.replace.structured_substitute import apply_structured_substitution_maps - - -class ExperimentalReplacementStrategy(StrEnum): - default = "default" - local_structured_substitute = "local_structured_substitute" - - -@contextmanager -def experimental_replacement_strategy_context(strategy: ExperimentalReplacementStrategy) -> Iterator[None]: - """Temporarily apply a benchmark-only replacement strategy.""" - if strategy == ExperimentalReplacementStrategy.default: - yield - return - - original_method = lrw.LlmReplaceWorkflow.generate_map_only - if strategy == ExperimentalReplacementStrategy.local_structured_substitute: - lrw.LlmReplaceWorkflow.generate_map_only = _local_structured_generate_map_only # type: ignore[method-assign] - else: - raise ValueError(f"unsupported experimental replacement strategy: {strategy}") - try: - yield - finally: - lrw.LlmReplaceWorkflow.generate_map_only = original_method # type: ignore[method-assign] - - -def _local_structured_generate_map_only( - self: lrw.LlmReplaceWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: ReplaceModelSelection, - instructions: str | None = None, - entities_column: str = COL_ENTITIES_BY_VALUE, - preview_num_records: int | None = None, -) -> lrw.LlmReplaceResult: - _ = self, model_configs, selected_models, instructions, preview_num_records - return lrw.LlmReplaceResult( - dataframe=apply_structured_substitution_maps(dataframe, entities_column=entities_column), - failed_records=[], - ) diff --git a/tools/measurement/replay_replacement_strategies.py b/tools/measurement/replay_replacement_strategies.py deleted file mode 100644 index f5512aa5..00000000 --- a/tools/measurement/replay_replacement_strategies.py +++ /dev/null @@ -1,736 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Replay substitute strategies on one fixed detection trace. - -Usage: - uv run python tools/measurement/replay_replacement_strategies.py data.csv \ - --text-column text --labels api_key,http_cookie,password,pin,unique_id,user_name \ - --model-configs /stable-cache/anonymizer/local-vllm-json-models.yaml \ - --model-providers /stable-cache/anonymizer/local-vllm-providers.yaml \ - --dd-parser-compat raw_json --json -""" - -from __future__ import annotations - -import json -import logging -import sys -import time -from collections import Counter -from enum import StrEnum -from pathlib import Path -from typing import Annotated, Any - -import cyclopts -import pandas as pd -from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context -from pydantic import BaseModel, Field, ValidationError -from replacement_strategies import ( - ExperimentalReplacementStrategy, - experimental_replacement_strategy_context, -) - -from anonymizer.config.anonymizer_config import AnonymizerConfig, AnonymizerInput, Detect -from anonymizer.config.replace_strategies import Redact, Substitute -from anonymizer.engine.constants import ( - COL_FINAL_ENTITIES, - COL_REPLACED_TEXT, - COL_REPLACEMENT_MAP, - COL_REPLACEMENT_MAP_SOURCE, -) -from anonymizer.engine.schemas import EntitiesSchema, EntityReplacementMapSchema -from anonymizer.interface.anonymizer import Anonymizer, _unrename_output_columns -from anonymizer.measurement import _output_contains_original_value - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.replay_replacement_strategies") - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -class ReplayStatus(StrEnum): - completed = "completed" - error = "error" - - -class ReplacementReplayStrategy(StrEnum): - dd_substitute = "dd_substitute" - local_structured_substitute = "local_structured_substitute" - - -class ReplacementRowMetrics(BaseModel): - final_entity_count: int = 0 - replacement_count: int = 0 - missing_count: int = 0 - collision_count: int = 0 - leak_count: int = 0 - duplicate_synthetic_count: int = 0 - missing_labels: dict[str, int] = Field(default_factory=dict) - collision_labels: dict[str, int] = Field(default_factory=dict) - leak_labels: dict[str, int] = Field(default_factory=dict) - - -class ReplacementReplaySummary(BaseModel): - strategy: ReplacementReplayStrategy - status: ReplayStatus - repetition_count: int = 1 - elapsed_sec: float | None = None - row_count: int = 0 - final_entity_count: int = 0 - replacement_count: int = 0 - missing_count: int = 0 - collision_count: int = 0 - leak_count: int = 0 - duplicate_synthetic_count: int = 0 - missing_labels: dict[str, int] = Field(default_factory=dict) - collision_labels: dict[str, int] = Field(default_factory=dict) - leak_labels: dict[str, int] = Field(default_factory=dict) - replacement_map_sources: dict[str, int] = Field(default_factory=dict) - error: str | None = None - - -class ReplacementReplayResult(BaseModel): - input_path: str - text_column: str - labels: list[str] - nrows: int | None = None - replacement_repetitions: int = 1 - dd_parser_compat: DDParserCompatMode - detect_elapsed_sec: float - detected_final_entity_count: int - strategies: list[ReplacementReplaySummary] - - -class ReplacementReplayComparisonRow(BaseModel): - workload_id: str - baseline_config_id: str = "dd_substitute_replay" - candidate_config_id: str = "local_structured_substitute_replay" - baseline_strategy: str = "default" - candidate_strategy: str = "default" - baseline_replacement_strategy: str = "default" - candidate_replacement_strategy: str = "local_structured_substitute" - baseline_case_count: int - candidate_case_count: int - baseline_pipeline_elapsed_sec: float | None = None - candidate_pipeline_elapsed_sec: float | None = None - pipeline_elapsed_sec_delta: float | None = None - pipeline_elapsed_sec_delta_pct: float | None = None - baseline_final_entity_count: int - candidate_final_entity_count: int - final_entity_count_delta: int - baseline_replacement_count: int - candidate_replacement_count: int - replacement_count_delta: int - baseline_replacement_missing_final_entity_count: int - candidate_replacement_missing_final_entity_count: int - replacement_missing_final_entity_count_delta: int - baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_synthetic_original_collision_count: int - candidate_replacement_synthetic_original_collision_count: int - replacement_synthetic_original_collision_count_delta: int - baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_duplicate_synthetic_replacement_count: int - candidate_duplicate_synthetic_replacement_count: int - duplicate_synthetic_replacement_count_delta: int - baseline_original_value_leak_count: int - candidate_original_value_leak_count: int - original_value_leak_count_delta: int - baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - value_protection_verdict: str - signature_parity_verdict: str - safety_verdict: str - performance_verdict: str - candidate_verdict: str - flags: list[str] = Field(default_factory=list) - - -_log_format = LogFormat.plain - - -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") - else: - sys.stderr.write(f"ERROR: bad_input error={error}\n") - - -def parse_labels(raw: str) -> list[str]: - labels = [label.strip() for label in raw.split(",") if label.strip()] - if not labels: - raise ValueError("--labels must contain at least one non-empty label") - return list(dict.fromkeys(labels)) - - -def run_replacement_replay( - *, - source: Path, - text_column: str, - labels: list[str], - nrows: int | None, - model_configs: Path | None, - model_providers: Path | None, - artifact_path: Path, - dd_parser_compat: DDParserCompatMode, - replacement_repetitions: int, -) -> ReplacementReplayResult: - anonymizer = Anonymizer( - model_configs=model_configs, - model_providers=model_providers, - artifact_path=artifact_path, - ) - detect_elapsed, replay_df = build_replay_dataframe( - anonymizer, - source=source, - text_column=text_column, - labels=labels, - nrows=nrows, - dd_parser_compat=dd_parser_compat, - ) - strategies = [ - run_strategy_repetitions( - anonymizer, - replay_df, - ReplacementReplayStrategy.dd_substitute, - dd_parser_compat, - repetitions=replacement_repetitions, - ), - run_strategy_repetitions( - anonymizer, - replay_df, - ReplacementReplayStrategy.local_structured_substitute, - dd_parser_compat, - repetitions=replacement_repetitions, - ), - ] - return ReplacementReplayResult( - input_path=str(source), - text_column=text_column, - labels=labels, - nrows=nrows, - replacement_repetitions=replacement_repetitions, - dd_parser_compat=dd_parser_compat, - detect_elapsed_sec=detect_elapsed, - detected_final_entity_count=count_final_entities(replay_df), - strategies=strategies, - ) - - -def build_replay_dataframe( - anonymizer: Anonymizer, - *, - source: Path, - text_column: str, - labels: list[str], - nrows: int | None, - dd_parser_compat: DDParserCompatMode, -) -> tuple[float, pd.DataFrame]: - input_data = AnonymizerInput(source=str(source), text_column=text_column) - config = AnonymizerConfig(detect=Detect(entity_labels=labels), replace=Redact()) - with dd_parser_compat_context(dd_parser_compat): - start = time.perf_counter() - if nrows is None: - detected = anonymizer.run(config=config, data=input_data) - else: - detected = anonymizer.preview(config=config, data=input_data, num_records=nrows) - elapsed = time.perf_counter() - start - internal_df = _unrename_output_columns(detected.trace_dataframe, resolved_text_column=detected.resolved_text_column) - return elapsed, strip_replacement_columns(internal_df) - - -def strip_replacement_columns(dataframe: pd.DataFrame) -> pd.DataFrame: - return dataframe.drop(columns=[COL_REPLACEMENT_MAP, COL_REPLACEMENT_MAP_SOURCE, COL_REPLACED_TEXT], errors="ignore") - - -def run_strategy_repetitions( - anonymizer: Anonymizer, - dataframe: pd.DataFrame, - strategy: ReplacementReplayStrategy, - dd_parser_compat: DDParserCompatMode, - *, - repetitions: int, -) -> ReplacementReplaySummary: - summaries = [run_strategy(anonymizer, dataframe, strategy, dd_parser_compat) for _ in range(repetitions)] - return aggregate_strategy_summaries(strategy=strategy, summaries=summaries) - - -def aggregate_strategy_summaries( - *, - strategy: ReplacementReplayStrategy, - summaries: list[ReplacementReplaySummary], -) -> ReplacementReplaySummary: - sources: Counter[str] = Counter() - missing_labels: Counter[str] = Counter() - collision_labels: Counter[str] = Counter() - leak_labels: Counter[str] = Counter() - errors: list[str] = [] - elapsed_values: list[float] = [] - status = ReplayStatus.completed - for summary in summaries: - sources.update(summary.replacement_map_sources) - missing_labels.update(summary.missing_labels) - collision_labels.update(summary.collision_labels) - leak_labels.update(summary.leak_labels) - if summary.elapsed_sec is not None: - elapsed_values.append(summary.elapsed_sec) - if summary.status == ReplayStatus.error: - status = ReplayStatus.error - if summary.error: - errors.append(summary.error) - return ReplacementReplaySummary( - strategy=strategy, - status=status, - repetition_count=len(summaries), - elapsed_sec=sum(elapsed_values) if elapsed_values else None, - row_count=sum(summary.row_count for summary in summaries), - final_entity_count=sum(summary.final_entity_count for summary in summaries), - replacement_count=sum(summary.replacement_count for summary in summaries), - missing_count=sum(summary.missing_count for summary in summaries), - collision_count=sum(summary.collision_count for summary in summaries), - leak_count=sum(summary.leak_count for summary in summaries), - duplicate_synthetic_count=sum(summary.duplicate_synthetic_count for summary in summaries), - missing_labels=dict(sorted(missing_labels.items())), - collision_labels=dict(sorted(collision_labels.items())), - leak_labels=dict(sorted(leak_labels.items())), - replacement_map_sources=dict(sorted(sources.items())), - error="; ".join(errors) if errors else None, - ) - - -def run_strategy( - anonymizer: Anonymizer, - dataframe: pd.DataFrame, - strategy: ReplacementReplayStrategy, - dd_parser_compat: DDParserCompatMode, -) -> ReplacementReplaySummary: - try: - start = time.perf_counter() - result_df = execute_strategy(anonymizer, dataframe.copy(), strategy, dd_parser_compat) - elapsed = time.perf_counter() - start - except Exception as exc: - logger.exception("replacement replay strategy failed: %s", strategy) - return ReplacementReplaySummary(strategy=strategy, status=ReplayStatus.error, error=str(exc)) - return summarize_replacement_dataframe(result_df, strategy=strategy, elapsed_sec=elapsed) - - -def execute_strategy( - anonymizer: Anonymizer, - dataframe: pd.DataFrame, - strategy: ReplacementReplayStrategy, - dd_parser_compat: DDParserCompatMode, -) -> pd.DataFrame: - if strategy == ReplacementReplayStrategy.local_structured_substitute: - with experimental_replacement_strategy_context(ExperimentalReplacementStrategy.local_structured_substitute): - return execute_substitute(anonymizer, dataframe) - with dd_parser_compat_context(dd_parser_compat): - return execute_substitute(anonymizer, dataframe) - - -def execute_substitute(anonymizer: Anonymizer, dataframe: pd.DataFrame) -> pd.DataFrame: - result = anonymizer._replace_runner.run( # benchmark probe against the configured runner - dataframe, - replace_method=Substitute(), - model_configs=anonymizer._model_configs, - selected_models=anonymizer._selected_models.replace, - ) - return result.dataframe - - -def summarize_replacement_dataframe( - dataframe: pd.DataFrame, - *, - strategy: ReplacementReplayStrategy, - elapsed_sec: float, -) -> ReplacementReplaySummary: - rows = [replacement_row_metrics(row) for _, row in dataframe.iterrows()] - return ReplacementReplaySummary( - strategy=strategy, - status=ReplayStatus.completed, - elapsed_sec=elapsed_sec, - row_count=len(rows), - final_entity_count=sum(row.final_entity_count for row in rows), - replacement_count=sum(row.replacement_count for row in rows), - missing_count=sum(row.missing_count for row in rows), - collision_count=sum(row.collision_count for row in rows), - leak_count=sum(row.leak_count for row in rows), - duplicate_synthetic_count=sum(row.duplicate_synthetic_count for row in rows), - missing_labels=sum_label_counts(row.missing_labels for row in rows), - collision_labels=sum_label_counts(row.collision_labels for row in rows), - leak_labels=sum_label_counts(row.leak_labels for row in rows), - replacement_map_sources=count_sources(dataframe), - ) - - -def replacement_row_metrics(row: Any) -> ReplacementRowMetrics: - entities = [entity.model_dump() for entity in EntitiesSchema.from_raw(row[COL_FINAL_ENTITIES]).entities] - replacements = parse_replacements(row[COL_REPLACEMENT_MAP]) - missing = missing_replacements(entities, replacements) - collisions = synthetic_original_collisions(entities, replacements) - leaks = leaked_entities(entities, str(row[COL_REPLACED_TEXT])) - synthetic_values = [item["synthetic"] for item in replacements if item.get("synthetic")] - return ReplacementRowMetrics( - final_entity_count=len(entities), - replacement_count=len(replacements), - missing_count=len(missing), - collision_count=len(collisions), - leak_count=len(leaks), - duplicate_synthetic_count=max(0, len(synthetic_values) - len(set(synthetic_values))), - missing_labels=count_labels(missing), - collision_labels=count_labels(collisions), - leak_labels=count_labels(leaks), - ) - - -def parse_replacements(raw: object) -> list[dict[str, str]]: - if hasattr(raw, "model_dump"): - raw = raw.model_dump(mode="python") - if isinstance(raw, str): - raw = json.loads(raw) - return [item.model_dump() for item in EntityReplacementMapSchema.model_validate(raw).replacements] - - -def missing_replacements(entities: list[dict[str, Any]], replacements: list[dict[str, str]]) -> list[dict[str, Any]]: - replacement_pairs = {(item["original"], item["label"]) for item in replacements} - return [entity for entity in entities if (entity.get("value"), entity.get("label")) not in replacement_pairs] - - -def synthetic_original_collisions( - entities: list[dict[str, Any]], - replacements: list[dict[str, str]], -) -> list[dict[str, str]]: - original_values = {entity["value"] for entity in entities if entity.get("value")} - return [item for item in replacements if item.get("synthetic") in original_values] - - -def leaked_entities(entities: list[dict[str, Any]], output_text: str) -> list[dict[str, Any]]: - return [ - entity - for entity in entities - if entity.get("value") and _output_contains_original_value(output_text, str(entity["value"])) - ] - - -def count_labels(items: list[dict[str, Any]]) -> dict[str, int]: - return dict(sorted(Counter(str(item.get("label") or "") for item in items if item.get("label")).items())) - - -def sum_label_counts(values: object) -> dict[str, int]: - counts: Counter[str] = Counter() - for mapping in values: - counts.update(mapping) - return dict(sorted(counts.items())) - - -def count_sources(dataframe: pd.DataFrame) -> dict[str, int]: - if COL_REPLACEMENT_MAP_SOURCE not in dataframe.columns: - return {} - return dict(sorted(Counter(str(source) for source in dataframe[COL_REPLACEMENT_MAP_SOURCE]).items())) - - -def count_final_entities(dataframe: pd.DataFrame) -> int: - return sum(len(EntitiesSchema.from_raw(raw).entities) for raw in dataframe[COL_FINAL_ENTITIES]) - - -def render_result(result: ReplacementReplayResult, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - lines = [ - f"input={result.input_path}", - f"labels={','.join(result.labels)}", - f"nrows={result.nrows if result.nrows is not None else 'all'}", - f"replacement_repetitions={result.replacement_repetitions}", - f"detected_final_entities={result.detected_final_entity_count}", - f"detection_elapsed_sec={result.detect_elapsed_sec:.3f}", - ] - lines.extend(render_strategy(summary) for summary in result.strategies) - return "\n".join(lines) - - -def render_strategy(summary: ReplacementReplaySummary) -> str: - if summary.status == ReplayStatus.error: - return f"{summary.strategy}: error={summary.error}" - return ( - f"{summary.strategy}: repetitions={summary.repetition_count} elapsed_sec={summary.elapsed_sec:.3f} " - f"replacements={summary.replacement_count} missing={summary.missing_count} " - f"leaks={summary.leak_count} collisions={summary.collision_count}" - ) - - -def replay_comparison_row( - result: ReplacementReplayResult, *, workload_id: str | None = None -) -> ReplacementReplayComparisonRow: - summaries = {summary.strategy: summary for summary in result.strategies} - baseline = summaries.get(ReplacementReplayStrategy.dd_substitute) - candidate = summaries.get(ReplacementReplayStrategy.local_structured_substitute) - if baseline is None or candidate is None: - raise ValueError( - "replacement replay result must include dd_substitute and local_structured_substitute summaries" - ) - flags = replay_comparison_flags(baseline, candidate) - value_protection = replay_value_protection_verdict(flags) - signature_parity = replay_signature_parity_verdict(baseline, candidate, flags) - safety = replay_safety_verdict(value_protection, signature_parity) - performance = replay_performance_verdict(baseline.elapsed_sec, candidate.elapsed_sec) - return ReplacementReplayComparisonRow( - workload_id=workload_id or Path(result.input_path).stem, - baseline_case_count=baseline.row_count, - candidate_case_count=candidate.row_count, - baseline_pipeline_elapsed_sec=baseline.elapsed_sec, - candidate_pipeline_elapsed_sec=candidate.elapsed_sec, - pipeline_elapsed_sec_delta=delta(baseline.elapsed_sec, candidate.elapsed_sec), - pipeline_elapsed_sec_delta_pct=delta_pct(baseline.elapsed_sec, candidate.elapsed_sec), - baseline_final_entity_count=baseline.final_entity_count, - candidate_final_entity_count=candidate.final_entity_count, - final_entity_count_delta=candidate.final_entity_count - baseline.final_entity_count, - baseline_replacement_count=baseline.replacement_count, - candidate_replacement_count=candidate.replacement_count, - replacement_count_delta=candidate.replacement_count - baseline.replacement_count, - baseline_replacement_missing_final_entity_count=baseline.missing_count, - candidate_replacement_missing_final_entity_count=candidate.missing_count, - replacement_missing_final_entity_count_delta=candidate.missing_count - baseline.missing_count, - baseline_replacement_missing_final_entity_label_counts=baseline.missing_labels, - candidate_replacement_missing_final_entity_label_counts=candidate.missing_labels, - baseline_replacement_synthetic_original_collision_count=baseline.collision_count, - candidate_replacement_synthetic_original_collision_count=candidate.collision_count, - replacement_synthetic_original_collision_count_delta=candidate.collision_count - baseline.collision_count, - baseline_replacement_synthetic_original_collision_label_counts=baseline.collision_labels, - candidate_replacement_synthetic_original_collision_label_counts=candidate.collision_labels, - baseline_duplicate_synthetic_replacement_count=baseline.duplicate_synthetic_count, - candidate_duplicate_synthetic_replacement_count=candidate.duplicate_synthetic_count, - duplicate_synthetic_replacement_count_delta=( - candidate.duplicate_synthetic_count - baseline.duplicate_synthetic_count - ), - baseline_original_value_leak_count=baseline.leak_count, - candidate_original_value_leak_count=candidate.leak_count, - original_value_leak_count_delta=candidate.leak_count - baseline.leak_count, - baseline_original_value_leak_label_counts=baseline.leak_labels, - candidate_original_value_leak_label_counts=candidate.leak_labels, - value_protection_verdict=value_protection, - signature_parity_verdict=signature_parity, - safety_verdict=safety, - performance_verdict=performance, - candidate_verdict=replay_candidate_verdict(safety, performance), - flags=flags, - ) - - -def replay_comparison_flags( - baseline: ReplacementReplaySummary, - candidate: ReplacementReplaySummary, -) -> list[str]: - flags: list[str] = [] - if baseline.status == ReplayStatus.error: - flags.append("baseline_case_failures") - if candidate.status == ReplayStatus.error: - flags.append("candidate_case_failures") - if baseline.missing_count: - flags.append("baseline_replacement_missing_final_entity") - if candidate.missing_count: - flags.append("candidate_replacement_missing_final_entity") - if baseline.missing_count and candidate.missing_count < baseline.missing_count: - flags.append("candidate_reduces_baseline_replacement_missing_final_entity") - if baseline.missing_count and candidate.missing_count == 0: - flags.append("candidate_covers_baseline_replacement_missing_final_entity") - if baseline.collision_count: - flags.append("baseline_replacement_synthetic_original_collision") - if candidate.collision_count: - flags.append("candidate_replacement_synthetic_original_collision") - if baseline.collision_count and candidate.collision_count < baseline.collision_count: - flags.append("candidate_reduces_baseline_replacement_synthetic_original_collision") - if baseline.collision_count and candidate.collision_count == 0: - flags.append("candidate_covers_baseline_replacement_synthetic_original_collision") - if baseline.duplicate_synthetic_count: - flags.append("baseline_duplicate_synthetic_replacement") - if candidate.duplicate_synthetic_count: - flags.append("candidate_duplicate_synthetic_replacement") - if baseline.leak_count: - flags.append("baseline_original_value_leak") - if candidate.leak_count: - flags.append("candidate_original_value_leak") - if baseline.leak_count and candidate.leak_count < baseline.leak_count: - flags.append("candidate_reduces_baseline_original_value_leak") - if baseline.leak_count and candidate.leak_count == 0: - flags.append("candidate_covers_baseline_original_value_leak") - if baseline.final_entity_count != candidate.final_entity_count: - flags.append( - "entity_count_loss" if candidate.final_entity_count < baseline.final_entity_count else "entity_count_delta" - ) - if baseline.replacement_count != candidate.replacement_count: - flags.append( - "replacement_count_loss" - if candidate.replacement_count < baseline.replacement_count - else "replacement_count_delta" - ) - return flags - - -def replay_value_protection_verdict(flags: list[str]) -> str: - flag_set = set(flags) - if flag_set & { - "candidate_case_failures", - "candidate_original_value_leak", - "candidate_replacement_missing_final_entity", - "candidate_replacement_synthetic_original_collision", - "entity_count_loss", - "replacement_count_loss", - }: - return "fail" - if "baseline_case_failures" in flag_set: - return "review" - baseline_defects = { - "baseline_original_value_leak", - "baseline_replacement_missing_final_entity", - "baseline_replacement_synthetic_original_collision", - } - corrected_baseline_defects = { - "baseline_original_value_leak": "candidate_covers_baseline_original_value_leak", - "baseline_replacement_missing_final_entity": ("candidate_covers_baseline_replacement_missing_final_entity"), - "baseline_replacement_synthetic_original_collision": ( - "candidate_covers_baseline_replacement_synthetic_original_collision" - ), - } - for defect in baseline_defects & flag_set: - if corrected_baseline_defects[defect] not in flag_set: - return "review" - return "pass" - - -def replay_signature_parity_verdict( - baseline: ReplacementReplaySummary, - candidate: ReplacementReplaySummary, - flags: list[str], -) -> str: - if candidate.status == ReplayStatus.error: - return "fail" - if baseline.status == ReplayStatus.error: - return "review" - if candidate.final_entity_count < baseline.final_entity_count: - return "fail" - if candidate.final_entity_count != baseline.final_entity_count: - return "review" - if candidate.replacement_count < baseline.replacement_count: - return "fail" - if candidate.replacement_count != baseline.replacement_count: - return "review" - if flags: - return "review" - return "pass" - - -def replay_safety_verdict(value_protection: str, signature_parity: str) -> str: - if "fail" in {value_protection, signature_parity}: - return "fail" - if "review" in {value_protection, signature_parity}: - return "review" - return "pass" - - -def replay_performance_verdict(baseline_elapsed: float | None, candidate_elapsed: float | None) -> str: - elapsed_delta = delta(baseline_elapsed, candidate_elapsed) - if elapsed_delta is None: - return "unknown" - if elapsed_delta < 0: - return "improved" - if elapsed_delta > 0: - return "regressed" - return "unchanged" - - -def replay_candidate_verdict(safety: str, performance: str) -> str: - if safety == "fail": - return "reject" - if safety == "pass" and performance == "improved": - return "candidate_viable" - return "review" - - -def delta(baseline: float | None, candidate: float | None) -> float | None: - if baseline is None or candidate is None: - return None - return candidate - baseline - - -def delta_pct(baseline: float | None, candidate: float | None) -> float | None: - value = delta(baseline, candidate) - if value is None or baseline in {None, 0}: - return None - return value / baseline * 100 - - -def write_replay_comparison( - result: ReplacementReplayResult, - output: Path, - *, - workload_id: str | None = None, -) -> None: - row = replay_comparison_row(result, workload_id=workload_id).model_dump(mode="json") - row["flags"] = json.dumps(row["flags"], ensure_ascii=True) - pd.DataFrame([row]).to_csv(output, index=False) - - -@app.default -def main( - source: Path, - *, - text_column: Annotated[str, cyclopts.Parameter("--text-column")] = "text", - labels: Annotated[str, cyclopts.Parameter("--labels")], - nrows: Annotated[int | None, cyclopts.Parameter("--nrows")] = None, - replacement_repetitions: Annotated[int, cyclopts.Parameter("--replacement-repetitions")] = 1, - model_configs: Annotated[Path | None, cyclopts.Parameter("--model-configs")] = None, - model_providers: Annotated[Path | None, cyclopts.Parameter("--model-providers")] = None, - artifact_path: Annotated[Path, cyclopts.Parameter("--artifact-path")] = Path( - "/tmp/anonymizer-replacement-replay-artifacts" - ), - dd_parser_compat: Annotated[DDParserCompatMode, cyclopts.Parameter("--dd-parser-compat")] = ( - DDParserCompatMode.none - ), - output: Annotated[Path | None, cyclopts.Parameter("--output")] = None, - comparison_output: Annotated[Path | None, cyclopts.Parameter("--comparison-output")] = None, - workload_id: Annotated[str | None, cyclopts.Parameter("--workload-id")] = None, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - if nrows is not None and nrows <= 0: - raise ValueError("--nrows must be greater than zero") - if replacement_repetitions <= 0: - raise ValueError("--replacement-repetitions must be greater than zero") - result = run_replacement_replay( - source=source, - text_column=text_column, - labels=parse_labels(labels), - nrows=nrows, - model_configs=model_configs, - model_providers=model_providers, - artifact_path=artifact_path, - dd_parser_compat=dd_parser_compat, - replacement_repetitions=replacement_repetitions, - ) - except (ValidationError, ValueError, FileNotFoundError) as exc: - log_bad_input(str(exc)) - raise SystemExit(125) from exc - rendered = render_result(result, json_output=json_output) - if output is not None: - output.write_text(rendered + "\n", encoding="utf-8") - else: - sys.stdout.write(rendered + "\n") - if comparison_output is not None: - write_replay_comparison(result, comparison_output, workload_id=workload_id) - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index e6a871de..2306aecc 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -38,10 +38,6 @@ ) from export_measurements import ExportFormat, export_tables, read_measurements from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator -from replacement_strategies import ( - ExperimentalReplacementStrategy, - experimental_replacement_strategy_context, -) from anonymizer.config.anonymizer_config import ( AnonymizerConfig, @@ -56,7 +52,6 @@ from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_FINAL_ENTITIES from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references -from anonymizer.engine.replace.structured_substitute import SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS from anonymizer.engine.schemas import EntitiesSchema from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import MeasurementConfig, configured_measurement_session @@ -164,7 +159,6 @@ class ConfigSpec(BaseModel): rewrite: RewriteSpec | None = None emit_telemetry: bool = False experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default - experimental_replacement_strategy: ExperimentalReplacementStrategy = ExperimentalReplacementStrategy.default native_runtime: NativeRuntimeSpec | None = None @model_validator(mode="after") @@ -432,10 +426,6 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) _preflight_native_runtime(config, spec=spec) except Exception as exc: errors.append(f"config '{config.id}' native_runtime invalid: {exc}") - try: - _preflight_experimental_replacement_strategy(config, anonymizer_config) - except Exception as exc: - errors.append(f"config '{config.id}' experimental_replacement_strategy invalid: {exc}") if parsed_models is None: continue try: @@ -516,30 +506,6 @@ def _native_detection_runtime(spec: BenchmarkSpec, config: ConfigSpec) -> Native ) -def _preflight_experimental_replacement_strategy( - config: ConfigSpec, - anonymizer_config: AnonymizerConfig, -) -> None: - if config.experimental_replacement_strategy == ExperimentalReplacementStrategy.default: - return - if config.experimental_replacement_strategy != ExperimentalReplacementStrategy.local_structured_substitute: - raise ValueError(f"unsupported strategy: {config.experimental_replacement_strategy}") - if not isinstance(anonymizer_config.replace, Substitute): - raise ValueError("local_structured_substitute requires replace: substitute") - entity_labels = anonymizer_config.detect.entity_labels - supported = ", ".join(sorted(SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS)) - if entity_labels is None: - raise ValueError( - "`local_structured_substitute` requires explicit detect.entity_labels limited to " - f"structured substitute labels: {supported}" - ) - unsupported = sorted(set(entity_labels) - SUPPORTED_STRUCTURED_SUBSTITUTE_LABELS) - if unsupported: - raise ValueError( - f"unsupported local structured substitute labels: {', '.join(unsupported)}; supported labels: {supported}" - ) - - def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: raw = _resolve_config_source(spec.model_providers, base_dir) if raw is None: @@ -1085,11 +1051,10 @@ def _execute_case( config.experimental_detection_strategy, **detection_context_kwargs, ): - with experimental_replacement_strategy_context(config.experimental_replacement_strategy): - result = anonymizer.run( - config=anonymizer_config, - data=input_data, - ) + result = anonymizer.run( + config=anonymizer_config, + data=input_data, + ) return _CaseExecution(input_data=input_data, trace_dataframe=getattr(result, "trace_dataframe", None)) @@ -1361,7 +1326,6 @@ def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: "repetition": case.repetition, "case_id": case.case_id, "experimental_detection_strategy": config.experimental_detection_strategy.value, - "experimental_replacement_strategy": config.experimental_replacement_strategy.value, "dd_parser_compat": spec.dd_parser_compat.value, } if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: From 00a5f818ac224635a2074806395a422129c56d1f Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 00:02:37 +0000 Subject: [PATCH 10/26] Format staged detection analysis tool Signed-off-by: Aaron Gonzales --- tools/measurement/analyze_staged_detection_output.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py index 737d439a..de0b3780 100644 --- a/tools/measurement/analyze_staged_detection_output.py +++ b/tools/measurement/analyze_staged_detection_output.py @@ -40,6 +40,8 @@ class LogFormat(StrEnum): _log_format = LogFormat.plain + + class StagedCaseAnalysisRow(BaseModel): source_path: str case_id: str @@ -433,6 +435,7 @@ def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDe f"direct_only={group.direct_only_final_entity_signature_count_sum}, lost_labels={lost}" ) + def _top_labels(label_deltas: list[LabelDeltaAnalysisRow], *, seed_source: str | None, delta_type: str) -> str: matches = [delta for delta in label_deltas if delta.seed_source == seed_source and delta.delta_type == delta_type] if not matches: From d011a7cd7f9ba8b5c40677850cd284ef7f5062ee Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 00:18:14 +0000 Subject: [PATCH 11/26] Clarify measurement docs and fixtures Signed-off-by: Aaron Gonzales --- .../detection-artifacts.jsonl | 2 + .../benchmark-output/measurements.jsonl | 5 + .../traces/bio__default__r000.jsonl | 2 + tests/tools/test_benchmark_output_analysis.py | 490 ++---------------- tests/tools/test_measurement_tools.py | 38 +- tools/measurement/AGENTS.md | 39 ++ tools/measurement/README.md | 50 +- 7 files changed, 157 insertions(+), 469 deletions(-) create mode 100644 tests/fixtures/measurement/benchmark-output/detection-artifacts.jsonl create mode 100644 tests/fixtures/measurement/benchmark-output/measurements.jsonl create mode 100644 tests/fixtures/measurement/benchmark-output/traces/bio__default__r000.jsonl create mode 100644 tools/measurement/AGENTS.md diff --git a/tests/fixtures/measurement/benchmark-output/detection-artifacts.jsonl b/tests/fixtures/measurement/benchmark-output/detection-artifacts.jsonl new file mode 100644 index 00000000..c656eaa0 --- /dev/null +++ b/tests/fixtures/measurement/benchmark-output/detection-artifacts.jsonl @@ -0,0 +1,2 @@ +{"suite_id":"suite","workload_id":"bio","config_id":"default","repetition":0,"case_id":"bio__default__r000","run_id":"bio__default__r000","workflow_name":"entity-detection","seed_entity_count":13,"seed_validation_candidate_count":13,"augmented_entity_count":1,"augmented_new_final_value_count":1,"final_entity_count":14,"final_source_counts":{"detector":11,"augmenter":3},"final_entity_signature_hashes":["bio-hash-a","bio-hash-b"],"final_entity_signature_labels":{"bio-hash-a":"person","bio-hash-b":"city"},"final_entity_signature_details":{"bio-hash-a":{"label":"person","source":"detector","row_index":0,"start_position":0,"end_position":5,"value_hash":"hash-person","value_length":5}},"final_entity_signature_count":2} +{"suite_id":"suite","workload_id":"shell","config_id":"native-local","repetition":0,"case_id":"shell__native-local__r000","run_id":"shell__native-local__r000","workflow_name":"native-single-pass","seed_entity_count":8,"seed_validation_candidate_count":0,"augmented_entity_count":0,"augmented_new_final_value_count":0,"final_entity_count":8,"final_source_counts":{"augmenter":8},"final_entity_signature_hashes":["shell-hash-a"],"final_entity_signature_labels":{"shell-hash-a":"api_key"},"final_entity_signature_details":{"shell-hash-a":{"label":"api_key","source":"native","row_index":0,"start_position":12,"end_position":32,"value_hash":"hash-secret","value_length":20}},"final_entity_signature_count":1} diff --git a/tests/fixtures/measurement/benchmark-output/measurements.jsonl b/tests/fixtures/measurement/benchmark-output/measurements.jsonl new file mode 100644 index 00000000..c776a21d --- /dev/null +++ b/tests/fixtures/measurement/benchmark-output/measurements.jsonl @@ -0,0 +1,5 @@ +{"record_type":"ndd_workflow","run_id":"bio__default__r000","workflow_name":"entity-detection","elapsed_sec":8.5,"observed_total_requests":4,"observed_successful_requests":3,"observed_input_tokens":5000,"observed_output_tokens":1000,"observed_total_tokens":6000,"observed_failed_requests":1,"model_usage":{"nvidia/gliner-pii":{"request_usage":{"successful_requests":1,"failed_requests":0,"total_requests":1},"token_usage":{"input_tokens":1000,"output_tokens":100,"total_tokens":1100}},"local-nemotron-json":{"model_alias":"local-nemotron-json","model_name":"nvidia/nemotron-3-super","model_provider_name":"local-vllm","request_usage":{"successful_requests":2,"failed_requests":1,"total_requests":3},"token_usage":{"input_tokens":4000,"output_tokens":900,"total_tokens":4900}}},"detect":{"validation_max_entities_per_call":10},"run_tags":{"suite_id":"suite","workload_id":"bio","workload_category":"synthetic_biography","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","entity_label_set_id":"agent","entity_label_count":4,"gliner_threshold":0.3,"topology_endpoint_count":2,"topology_gpu_count":4,"topology_tensor_parallelism":2,"repetition":0,"case_id":"bio__default__r000"}} +{"record_type":"stage","run_id":"bio__default__r000","stage":"Anonymizer._run_internal","elapsed_sec":10.0,"status":"completed","run_tags":{"suite_id":"suite","workload_id":"bio","workload_category":"synthetic_biography","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","entity_label_set_id":"agent","entity_label_count":4,"gliner_threshold":0.3,"topology_endpoint_count":2,"topology_gpu_count":4,"topology_tensor_parallelism":2,"repetition":0,"case_id":"bio__default__r000"}} +{"record_type":"record","run_id":"bio__default__r000","text_length_tokens":1200,"final_entity_count":14,"ground_truth_entity_count":20,"entity_true_positive_count":10,"entity_false_positive_count":4,"entity_false_negative_count":10,"entity_relaxed_gt_found_count":15,"entity_relaxed_detected_tp_count":14,"entity_relaxed_label_compatible_gt_found_count":13,"entity_relaxed_label_compatible_detected_tp_count":12,"replacement_count":12,"replacement_missing_final_entity_count":2,"replacement_missing_final_entity_label_counts":{"date":2},"replacement_missing_final_value_count":1,"replacement_synthetic_original_collision_count":1,"replacement_synthetic_original_collision_label_counts":{"date":1},"replacement_synthetic_original_collision_value_count":1,"original_value_leak_count":0,"original_value_leak_label_counts":{},"run_tags":{"suite_id":"suite","workload_id":"bio","workload_category":"synthetic_biography","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","entity_label_set_id":"agent","entity_label_count":4,"gliner_threshold":0.3,"repetition":0,"case_id":"bio__default__r000"}} +{"record_type":"record","run_id":"bio__default__r000","text_length_tokens":300,"final_entity_count":0,"ground_truth_entity_count":2,"entity_true_positive_count":0,"entity_false_positive_count":0,"entity_false_negative_count":2,"entity_relaxed_gt_found_count":0,"entity_relaxed_detected_tp_count":0,"entity_relaxed_label_compatible_gt_found_count":0,"entity_relaxed_label_compatible_detected_tp_count":0,"replacement_count":0,"replacement_missing_final_entity_count":0,"replacement_missing_final_entity_label_counts":{},"replacement_missing_final_value_count":0,"replacement_synthetic_original_collision_count":0,"replacement_synthetic_original_collision_label_counts":{},"replacement_synthetic_original_collision_value_count":0,"original_value_leak_count":0,"original_value_leak_label_counts":{},"run_tags":{"suite_id":"suite","workload_id":"bio","workload_category":"synthetic_biography","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","entity_label_set_id":"agent","entity_label_count":4,"gliner_threshold":0.3,"repetition":0,"case_id":"bio__default__r000"}} +{"record_type":"record","run_id":"shell__native-local__r000","text_length_tokens":750,"final_entity_count":8,"replacement_count":8,"replacement_missing_final_entity_count":0,"replacement_missing_final_entity_label_counts":{},"replacement_missing_final_value_count":0,"replacement_synthetic_original_collision_count":0,"replacement_synthetic_original_collision_label_counts":{},"replacement_synthetic_original_collision_value_count":0,"original_value_leak_count":1,"original_value_leak_label_counts":{"api_key":1},"run_tags":{"suite_id":"suite","workload_id":"shell","config_id":"native-local","experimental_detection_strategy":"native_single_pass","experimental_replacement_strategy":"custom_replacement_strategy","dd_parser_compat":"raw_json","repetition":0,"case_id":"shell__native-local__r000"}} diff --git a/tests/fixtures/measurement/benchmark-output/traces/bio__default__r000.jsonl b/tests/fixtures/measurement/benchmark-output/traces/bio__default__r000.jsonl new file mode 100644 index 00000000..67ce5dc4 --- /dev/null +++ b/tests/fixtures/measurement/benchmark-output/traces/bio__default__r000.jsonl @@ -0,0 +1,2 @@ +{"record_type":"dd_message_trace","run_id":"bio__default__r000","workflow_name":"entity-detection","model_alias":"local-nemotron-json","status":"error","error_type":"SyncClientUnavailableError","is_async":false,"messages":[{"role":"user","content":"Alice has sk-test"}],"response":"Alice still has sk-test","run_tags":{"suite_id":"suite","workload_id":"bio","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","repetition":0,"case_id":"bio__default__r000"}} +{"record_type":"dd_message_trace","run_id":"bio__default__r000","workflow_name":"entity-detection","model_alias":"local-nemotron-json","status":"success","is_async":true,"messages":[{"role":"user","content":"sk-test"}],"response":"Alice","run_tags":{"suite_id":"suite","workload_id":"bio","config_id":"default","experimental_detection_strategy":"default","experimental_replacement_strategy":"default","dd_parser_compat":"raw_json","repetition":0,"case_id":"bio__default__r000"}} diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index e83de30d..dbb8afc4 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -5,6 +5,7 @@ import importlib.util import json +import shutil import sys from pathlib import Path from types import ModuleType @@ -30,311 +31,19 @@ def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") +def _copy_fixture(tmp_path: Path, fixture_name: str) -> Path: + fixture_dir = REPO_ROOT / "tests" / "fixtures" / "measurement" / fixture_name + destination = tmp_path / fixture_name + shutil.copytree(fixture_dir, destination) + return destination + + def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp_path: Path) -> None: tool = load_tool( "measurement_benchmark_output_analysis", REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", ) - benchmark_dir = tmp_path / "benchmark" - benchmark_dir.mkdir() - _write_jsonl( - benchmark_dir / "measurements.jsonl", - [ - { - "record_type": "ndd_workflow", - "run_id": "bio__default__r000", - "workflow_name": "entity-detection", - "elapsed_sec": 8.5, - "observed_total_requests": 4, - "observed_successful_requests": 3, - "observed_input_tokens": 5000, - "observed_output_tokens": 1000, - "observed_total_tokens": 6000, - "observed_failed_requests": 1, - "model_usage": { - "nvidia/gliner-pii": { - "request_usage": { - "successful_requests": 1, - "failed_requests": 0, - "total_requests": 1, - }, - "token_usage": { - "input_tokens": 1000, - "output_tokens": 100, - "total_tokens": 1100, - }, - }, - "local-nemotron-json": { - "model_alias": "local-nemotron-json", - "model_name": "nvidia/nemotron-3-super", - "model_provider_name": "local-vllm", - "request_usage": { - "successful_requests": 2, - "failed_requests": 1, - "total_requests": 3, - }, - "token_usage": { - "input_tokens": 4000, - "output_tokens": 900, - "total_tokens": 4900, - }, - }, - }, - "detect": {"validation_max_entities_per_call": 10}, - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "workload_category": "synthetic_biography", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "entity_label_set_id": "agent", - "entity_label_count": 4, - "gliner_threshold": 0.3, - "topology_endpoint_count": 2, - "topology_gpu_count": 4, - "topology_tensor_parallelism": 2, - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - { - "record_type": "stage", - "run_id": "bio__default__r000", - "stage": "Anonymizer._run_internal", - "elapsed_sec": 10.0, - "status": "completed", - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "workload_category": "synthetic_biography", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "entity_label_set_id": "agent", - "entity_label_count": 4, - "gliner_threshold": 0.3, - "topology_endpoint_count": 2, - "topology_gpu_count": 4, - "topology_tensor_parallelism": 2, - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - { - "record_type": "record", - "run_id": "bio__default__r000", - "text_length_tokens": 1200, - "final_entity_count": 14, - "ground_truth_entity_count": 20, - "entity_true_positive_count": 10, - "entity_false_positive_count": 4, - "entity_false_negative_count": 10, - "entity_relaxed_gt_found_count": 15, - "entity_relaxed_detected_tp_count": 14, - "entity_relaxed_label_compatible_gt_found_count": 13, - "entity_relaxed_label_compatible_detected_tp_count": 12, - "replacement_count": 12, - "replacement_missing_final_entity_count": 2, - "replacement_missing_final_entity_label_counts": {"date": 2}, - "replacement_missing_final_value_count": 1, - "replacement_synthetic_original_collision_count": 1, - "replacement_synthetic_original_collision_label_counts": {"date": 1}, - "replacement_synthetic_original_collision_value_count": 1, - "original_value_leak_count": 0, - "original_value_leak_label_counts": {}, - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "workload_category": "synthetic_biography", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "entity_label_set_id": "agent", - "entity_label_count": 4, - "gliner_threshold": 0.3, - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - { - "record_type": "record", - "run_id": "bio__default__r000", - "text_length_tokens": 300, - "final_entity_count": 0, - "ground_truth_entity_count": 2, - "entity_true_positive_count": 0, - "entity_false_positive_count": 0, - "entity_false_negative_count": 2, - "entity_relaxed_gt_found_count": 0, - "entity_relaxed_detected_tp_count": 0, - "entity_relaxed_label_compatible_gt_found_count": 0, - "entity_relaxed_label_compatible_detected_tp_count": 0, - "replacement_count": 0, - "replacement_missing_final_entity_count": 0, - "replacement_missing_final_entity_label_counts": {}, - "replacement_missing_final_value_count": 0, - "replacement_synthetic_original_collision_count": 0, - "replacement_synthetic_original_collision_label_counts": {}, - "replacement_synthetic_original_collision_value_count": 0, - "original_value_leak_count": 0, - "original_value_leak_label_counts": {}, - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "workload_category": "synthetic_biography", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "entity_label_set_id": "agent", - "entity_label_count": 4, - "gliner_threshold": 0.3, - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - { - "record_type": "record", - "run_id": "shell__native-local__r000", - "text_length_tokens": 750, - "final_entity_count": 8, - "replacement_count": 8, - "replacement_missing_final_entity_count": 0, - "replacement_missing_final_entity_label_counts": {}, - "replacement_missing_final_value_count": 0, - "replacement_synthetic_original_collision_count": 0, - "replacement_synthetic_original_collision_label_counts": {}, - "replacement_synthetic_original_collision_value_count": 0, - "original_value_leak_count": 1, - "original_value_leak_label_counts": {"api_key": 1}, - "run_tags": { - "suite_id": "suite", - "workload_id": "shell", - "config_id": "native-local", - "experimental_detection_strategy": "native_single_pass", - "experimental_replacement_strategy": "custom_replacement_strategy", - "dd_parser_compat": "raw_json", - "repetition": 0, - "case_id": "shell__native-local__r000", - }, - }, - ], - ) - _write_jsonl( - benchmark_dir / "detection-artifacts.jsonl", - [ - { - "suite_id": "suite", - "workload_id": "bio", - "config_id": "default", - "repetition": 0, - "case_id": "bio__default__r000", - "run_id": "bio__default__r000", - "workflow_name": "entity-detection", - "seed_entity_count": 13, - "seed_validation_candidate_count": 13, - "augmented_entity_count": 1, - "augmented_new_final_value_count": 1, - "final_entity_count": 14, - "final_source_counts": {"detector": 11, "augmenter": 3}, - "final_entity_signature_hashes": ["bio-hash-a", "bio-hash-b"], - "final_entity_signature_labels": {"bio-hash-a": "person", "bio-hash-b": "city"}, - "final_entity_signature_details": { - "bio-hash-a": { - "label": "person", - "source": "detector", - "row_index": 0, - "start_position": 0, - "end_position": 5, - "value_hash": "hash-person", - "value_length": 5, - } - }, - "final_entity_signature_count": 2, - }, - { - "suite_id": "suite", - "workload_id": "shell", - "config_id": "native-local", - "repetition": 0, - "case_id": "shell__native-local__r000", - "run_id": "shell__native-local__r000", - "workflow_name": "native-single-pass", - "seed_entity_count": 8, - "seed_validation_candidate_count": 0, - "augmented_entity_count": 0, - "augmented_new_final_value_count": 0, - "final_entity_count": 8, - "final_source_counts": {"augmenter": 8}, - "final_entity_signature_hashes": ["shell-hash-a"], - "final_entity_signature_labels": {"shell-hash-a": "api_key"}, - "final_entity_signature_details": { - "shell-hash-a": { - "label": "api_key", - "source": "native", - "row_index": 0, - "start_position": 12, - "end_position": 32, - "value_hash": "hash-secret", - "value_length": 20, - } - }, - "final_entity_signature_count": 1, - }, - ], - ) - traces_dir = benchmark_dir / "traces" - traces_dir.mkdir() - _write_jsonl( - traces_dir / "bio__default__r000.jsonl", - [ - { - "record_type": "dd_message_trace", - "run_id": "bio__default__r000", - "workflow_name": "entity-detection", - "model_alias": "local-nemotron-json", - "status": "error", - "error_type": "SyncClientUnavailableError", - "is_async": False, - "messages": [{"role": "user", "content": "Alice has sk-test"}], - "response": "Alice still has sk-test", - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - { - "record_type": "dd_message_trace", - "run_id": "bio__default__r000", - "workflow_name": "entity-detection", - "model_alias": "local-nemotron-json", - "status": "success", - "is_async": True, - "messages": [{"role": "user", "content": "sk-test"}], - "response": "Alice", - "run_tags": { - "suite_id": "suite", - "workload_id": "bio", - "config_id": "default", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "dd_parser_compat": "raw_json", - "repetition": 0, - "case_id": "bio__default__r000", - }, - }, - ], - ) + benchmark_dir = _copy_fixture(tmp_path, "benchmark-output") result = tool.analyze_benchmark_output(benchmark_dir) @@ -342,168 +51,49 @@ def test_analyze_benchmark_output_joins_measurements_and_detection_artifacts(tmp assert result.group_count == 2 assert result.model_usage_count == 2 assert result.model_usage_group_count == 2 + cases = {row.case_id: row for row in result.cases} - assert cases["bio__default__r000"].workload_category == "synthetic_biography" - assert cases["bio__default__r000"].entity_label_set_id == "agent" - assert cases["bio__default__r000"].entity_label_count == 4 - assert cases["bio__default__r000"].gliner_threshold == pytest.approx(0.3) - assert cases["bio__default__r000"].experimental_replacement_strategy == "default" - assert cases["bio__default__r000"].observed_total_requests == 4 - assert cases["bio__default__r000"].observed_successful_requests == 3 - assert cases["bio__default__r000"].observed_input_tokens == 5000 - assert cases["bio__default__r000"].observed_output_tokens == 1000 - assert cases["bio__default__r000"].observed_total_tokens == 6000 - assert cases["bio__default__r000"].observed_failed_request_rate == pytest.approx(1 / 4) - assert cases["bio__default__r000"].dd_trace_record_count == 2 - assert cases["bio__default__r000"].dd_trace_error_count == 1 - assert cases["bio__default__r000"].dd_trace_sync_client_unavailable_count == 1 - assert cases["bio__default__r000"].observed_bridge_fallback_requests == 1 - assert cases["bio__default__r000"].observed_non_bridge_total_requests == 3 - assert cases["bio__default__r000"].observed_non_bridge_failed_requests == 0 - assert cases["bio__default__r000"].observed_non_bridge_failed_request_rate == 0 - assert cases["bio__default__r000"].record_count == 2 - assert cases["bio__default__r000"].input_text_tokens_total == 1500 - assert cases["bio__default__r000"].records_per_pipeline_sec == pytest.approx(0.2) - assert cases["bio__default__r000"].records_per_ndd_sec == pytest.approx(2 / 8.5) - assert cases["bio__default__r000"].input_text_tokens_per_pipeline_sec == 150 - assert cases["bio__default__r000"].input_text_tokens_per_ndd_sec == pytest.approx(1500 / 8.5) - assert cases["bio__default__r000"].topology_endpoint_count == 2 - assert cases["bio__default__r000"].topology_gpu_count == 4 - assert cases["bio__default__r000"].topology_tensor_parallelism == 2 - assert cases["bio__default__r000"].input_text_tokens_per_endpoint_sec == 75 - assert cases["bio__default__r000"].input_text_tokens_per_gpu_sec == 37.5 - assert cases["bio__default__r000"].empty_detection_count == 1 - assert cases["bio__default__r000"].empty_detection_rate == 0.5 - assert cases["bio__default__r000"].empty_detection_with_ground_truth_count == 1 - assert cases["bio__default__r000"].empty_detection_with_ground_truth_rate == 0.5 - assert cases["bio__default__r000"].ground_truth_record_count == 2 - assert cases["bio__default__r000"].ground_truth_entity_count == 22 - assert cases["bio__default__r000"].entity_true_positive_count == 10 - assert cases["bio__default__r000"].entity_false_positive_count == 4 - assert cases["bio__default__r000"].entity_false_negative_count == 12 - assert cases["bio__default__r000"].entity_precision == pytest.approx(10 / 14) - assert cases["bio__default__r000"].entity_recall == pytest.approx(10 / 22) - assert cases["bio__default__r000"].entity_relaxed_gt_found_count == 15 - assert cases["bio__default__r000"].entity_relaxed_detected_tp_count == 14 - assert cases["bio__default__r000"].entity_relaxed_label_compatible_gt_found_count == 13 - assert cases["bio__default__r000"].entity_relaxed_label_compatible_detected_tp_count == 12 - assert cases["bio__default__r000"].entity_relaxed_precision == 1.0 - assert cases["bio__default__r000"].entity_relaxed_recall == pytest.approx(15 / 22) - assert cases["bio__default__r000"].entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) - assert cases["bio__default__r000"].entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) - assert cases["bio__default__r000"].validation_max_entities_per_call == 10 - assert cases["bio__default__r000"].original_value_leak_count == 0 - assert cases["bio__default__r000"].original_value_leak_record_count == 0 - assert cases["bio__default__r000"].original_value_leak_label_counts == {} - assert cases["bio__default__r000"].replacement_missing_final_entity_count == 2 - assert cases["bio__default__r000"].replacement_missing_final_entity_label_counts == {"date": 2} - assert cases["bio__default__r000"].replacement_missing_final_value_count == 1 - assert cases["bio__default__r000"].replacement_synthetic_original_collision_count == 1 - assert cases["bio__default__r000"].replacement_synthetic_original_collision_label_counts == {"date": 1} - assert cases["bio__default__r000"].replacement_synthetic_original_collision_value_count == 1 - assert cases["bio__default__r000"].seed_validation_candidate_count == 13 - assert cases["bio__default__r000"].estimated_seed_validation_chunk_count == 2 - assert cases["bio__default__r000"].augmented_new_final_value_count == 1 - assert cases["bio__default__r000"].artifact_final_detector_entity_count == 11 - assert cases["bio__default__r000"].artifact_final_augmenter_entity_count == 3 - assert cases["bio__default__r000"].artifact_final_entity_signature_hashes == ["bio-hash-a", "bio-hash-b"] - assert cases["bio__default__r000"].artifact_final_entity_signature_labels == { - "bio-hash-a": "person", - "bio-hash-b": "city", + bio = cases["bio__default__r000"] + assert bio.workload_category == "synthetic_biography" + assert bio.observed_failed_request_rate == pytest.approx(1 / 4) + assert bio.dd_trace_error_count == 1 + assert bio.observed_bridge_fallback_requests == 1 + assert bio.record_count == 2 + assert bio.entity_precision == pytest.approx(10 / 14) + assert bio.entity_recall == pytest.approx(10 / 22) + assert bio.replacement_missing_final_entity_label_counts == {"date": 2} + assert bio.replacement_synthetic_original_collision_label_counts == {"date": 1} + assert bio.artifact_final_detector_entity_count == 11 + assert bio.artifact_final_augmenter_entity_count == 3 + assert bio.artifact_final_entity_signature_hashes == ["bio-hash-a", "bio-hash-b"] + assert bio.artifact_final_entity_signature_details["bio-hash-a"] == { + "label": "person", + "source": "detector", + "row_index": 0, + "start_position": 0, + "end_position": 5, + "value_length": 5, } - assert cases["bio__default__r000"].artifact_final_entity_signature_details == { - "bio-hash-a": { - "label": "person", - "source": "detector", - "row_index": 0, - "start_position": 0, - "end_position": 5, - "value_length": 5, - } - } - assert cases["bio__default__r000"].artifact_final_entity_signature_count == 2 - assert cases["shell__native-local__r000"].observed_total_requests == 0 - assert cases["shell__native-local__r000"].experimental_replacement_strategy == "custom_replacement_strategy" - assert cases["shell__native-local__r000"].observed_failed_request_rate is None - assert cases["shell__native-local__r000"].observed_bridge_fallback_requests is None - assert cases["shell__native-local__r000"].observed_non_bridge_failed_requests is None - assert cases["shell__native-local__r000"].final_entity_count == 8 - assert cases["shell__native-local__r000"].replacement_missing_final_entity_count == 0 - assert cases["shell__native-local__r000"].replacement_missing_final_entity_label_counts == {} - assert cases["shell__native-local__r000"].replacement_missing_final_value_count == 0 - assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_count == 0 - assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_label_counts == {} - assert cases["shell__native-local__r000"].replacement_synthetic_original_collision_value_count == 0 - assert cases["shell__native-local__r000"].original_value_leak_count == 1 - assert cases["shell__native-local__r000"].original_value_leak_record_count == 1 - assert cases["shell__native-local__r000"].original_value_leak_label_counts == {"api_key": 1} - assert cases["shell__native-local__r000"].artifact_final_augmenter_entity_count == 8 - assert cases["shell__native-local__r000"].artifact_final_entity_signature_hashes == ["shell-hash-a"] - assert cases["shell__native-local__r000"].artifact_final_entity_signature_labels == {"shell-hash-a": "api_key"} - assert ( - cases["shell__native-local__r000"].artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" - ) + + shell = cases["shell__native-local__r000"] + assert shell.experimental_replacement_strategy == "custom_replacement_strategy" + assert shell.original_value_leak_count == 1 + assert shell.artifact_final_entity_signature_details["shell-hash-a"]["source"] == "native" + model_rows = {row.model_name: row for row in result.model_usage} - assert model_rows["nvidia/gliner-pii"].observed_failed_requests == 0 - assert model_rows["nvidia/gliner-pii"].observed_failed_request_rate == 0 assert model_rows["nvidia/gliner-pii"].observed_total_tokens == 1100 - assert model_rows["nvidia/nemotron-3-super"].model_alias == "local-nemotron-json" - assert model_rows["nvidia/nemotron-3-super"].experimental_replacement_strategy == "default" assert model_rows["nvidia/nemotron-3-super"].model_provider_name == "local-vllm" - assert model_rows["nvidia/nemotron-3-super"].observed_failed_requests == 1 assert model_rows["nvidia/nemotron-3-super"].observed_failed_request_rate == pytest.approx(1 / 3) - assert model_rows["nvidia/nemotron-3-super"].observed_total_tokens == 4900 - model_groups = {(row.model_alias, row.model_name): row for row in result.model_usage_groups} - nemotron_group = model_groups[("local-nemotron-json", "nvidia/nemotron-3-super")] - assert nemotron_group.model_provider_name == "local-vllm" - assert nemotron_group.sum_observed_failed_requests == 1 - assert nemotron_group.observed_failed_request_rate == pytest.approx(1 / 3) - assert nemotron_group.median_observed_total_requests == 3 + bio_group = next(group for group in result.groups if group.workload_id == "bio") - assert bio_group.workload_category == "synthetic_biography" - assert bio_group.entity_label_set_id == "agent" - assert bio_group.entity_label_count == 4 - assert bio_group.gliner_threshold == pytest.approx(0.3) - assert bio_group.experimental_replacement_strategy == "default" - assert bio_group.median_observed_bridge_fallback_requests == 1 - assert bio_group.median_observed_non_bridge_total_requests == 3 - assert bio_group.median_observed_non_bridge_failed_requests == 0 - assert bio_group.median_observed_non_bridge_failed_request_rate == 0 - assert bio_group.median_replacement_missing_final_entity_count == 2 - assert bio_group.median_replacement_missing_final_value_count == 1 - assert bio_group.replacement_missing_final_entity_label_counts == {"date": 2} - assert bio_group.median_replacement_synthetic_original_collision_count == 1 - assert bio_group.median_replacement_synthetic_original_collision_value_count == 1 - assert bio_group.replacement_synthetic_original_collision_label_counts == {"date": 1} assert bio_group.total_record_count == 2 - assert bio_group.total_input_text_tokens == 1500 - assert bio_group.median_input_text_tokens_per_pipeline_sec == 150 - assert bio_group.median_input_text_tokens_per_endpoint_sec == 75 - assert bio_group.median_input_text_tokens_per_gpu_sec == 37.5 - assert bio_group.total_empty_detection_count == 1 - assert bio_group.empty_detection_rate == 0.5 - assert bio_group.total_empty_detection_with_ground_truth_count == 1 - assert bio_group.empty_detection_with_ground_truth_rate == 0.5 - assert bio_group.total_ground_truth_record_count == 2 - assert bio_group.sum_ground_truth_entity_count == 22 - assert bio_group.sum_entity_true_positive_count == 10 - assert bio_group.sum_entity_false_positive_count == 4 - assert bio_group.sum_entity_false_negative_count == 12 assert bio_group.micro_entity_precision == pytest.approx(10 / 14) - assert bio_group.micro_entity_recall == pytest.approx(10 / 22) - assert bio_group.sum_entity_relaxed_gt_found_count == 15 - assert bio_group.sum_entity_relaxed_detected_tp_count == 14 - assert bio_group.sum_entity_relaxed_label_compatible_gt_found_count == 13 - assert bio_group.sum_entity_relaxed_label_compatible_detected_tp_count == 12 - assert bio_group.micro_entity_relaxed_precision == 1.0 - assert bio_group.micro_entity_relaxed_recall == pytest.approx(15 / 22) - assert bio_group.micro_entity_relaxed_label_compatible_precision == pytest.approx(12 / 14) - assert bio_group.micro_entity_relaxed_label_compatible_recall == pytest.approx(13 / 22) + assert bio_group.replacement_missing_final_entity_label_counts == {"date": 2} + assert bio_group.replacement_synthetic_original_collision_label_counts == {"date": 1} + shell_group = next(group for group in result.groups if group.workload_id == "shell") assert shell_group.experimental_replacement_strategy == "custom_replacement_strategy" assert shell_group.sum_original_value_leak_count == 1 - assert shell_group.leaking_case_count == 1 - assert shell_group.median_original_value_leak_count == 1 serialized = result.model_dump_json() assert "Alice" not in serialized diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index e3f5652d..c5ba6e0d 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -47,6 +47,13 @@ def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[ } +def _copy_biography_data(tmp_path: Path, filename: str = "input.csv") -> Path: + source = REPO_ROOT / "docs" / "data" / "NVIDIA_synthetic_biographies.csv" + destination = tmp_path / filename + destination.write_bytes(source.read_bytes()) + return destination + + def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") dataframe = pd.DataFrame( @@ -609,6 +616,7 @@ def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None workloads: - id: biography source: biographies.csv + text_column: biography configs: - id: redact replace: redact @@ -619,7 +627,7 @@ def test_benchmark_dry_run_expands_cases_without_writing(tmp_path: Path) -> None """, encoding="utf-8", ) - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "biographies.csv", index=False) + _copy_biography_data(tmp_path, "biographies.csv") output_dir = tmp_path / "dry-run-output" result = tool.run_or_plan( @@ -717,8 +725,7 @@ def test_benchmark_preflight_rejects_sliced_remote_workload(tmp_path: Path) -> N def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) -> None: tool = load_tool("measurement_benchmark_tool_preflight_models", REPO_ROOT / "tools/measurement/run_benchmarks.py") - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ @@ -741,6 +748,7 @@ def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) workloads: - id: biography source: input.csv + text_column: biography configs: - id: substitute replace: substitute @@ -794,8 +802,7 @@ def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None tool = load_tool( "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) provider_path = tmp_path / "providers.yaml" provider_path.write_text("not_providers: []\n", encoding="utf-8") spec_path = tmp_path / "suite.yaml" @@ -806,6 +813,7 @@ def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None workloads: - id: biography source: input.csv + text_column: biography configs: - id: redact replace: redact @@ -822,8 +830,7 @@ def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> Non tool = load_tool( "measurement_benchmark_tool_preflight_provider_path", REPO_ROOT / "tools/measurement/run_benchmarks.py" ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) provider_path = tmp_path / "providers.yaml" provider_path.write_text( """ @@ -843,6 +850,7 @@ def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> Non workloads: - id: biography source: input.csv + text_column: biography configs: - id: redact replace: redact @@ -864,8 +872,7 @@ def test_benchmark_preflight_rejects_native_strategy_without_runtime( ) monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ @@ -873,6 +880,7 @@ def test_benchmark_preflight_rejects_native_strategy_without_runtime( workloads: - id: input source: input.csv + text_column: biography configs: - id: native-single-pass experimental_detection_strategy: native_single_pass @@ -928,8 +936,7 @@ def test_benchmark_preflight_rejects_native_strategy_without_endpoint_or_model( ) monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ @@ -939,6 +946,7 @@ def test_benchmark_preflight_rejects_native_strategy_without_endpoint_or_model( workloads: - id: input source: input.csv + text_column: biography configs: - id: native-single-pass experimental_detection_strategy: native_single_pass @@ -962,8 +970,7 @@ def test_benchmark_native_runtime_resolves_endpoint_and_model_from_env( ) monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", "http://runtime-from-env/v1") monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_MODEL", "env-model") - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ @@ -973,6 +980,7 @@ def test_benchmark_native_runtime_resolves_endpoint_and_model_from_env( workloads: - id: input source: input.csv + text_column: biography configs: - id: native-single-pass experimental_detection_strategy: native_single_pass @@ -1008,8 +1016,7 @@ def test_benchmark_preflight_skips_inactive_native_configs(tmp_path: Path) -> No "measurement_benchmark_tool_inactive_native_runtime", REPO_ROOT / "tools/measurement/run_benchmarks.py", ) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(input_path, index=False) + _copy_biography_data(tmp_path) spec_path = tmp_path / "suite.yaml" spec_path.write_text( """ @@ -1017,6 +1024,7 @@ def test_benchmark_preflight_skips_inactive_native_configs(tmp_path: Path) -> No workloads: - id: input source: input.csv + text_column: biography configs: - id: redact replace: redact diff --git a/tools/measurement/AGENTS.md b/tools/measurement/AGENTS.md new file mode 100644 index 00000000..9b17f879 --- /dev/null +++ b/tools/measurement/AGENTS.md @@ -0,0 +1,39 @@ + + + +# Measurement tool agent notes + +This directory is for benchmark and analysis tooling around Anonymizer. Keep +product behavior in `src/anonymizer` and keep benchmark-only strategy switches +inside `tools/measurement`. + +## Boundaries + +- The measurement layer records facts about Anonymizer runs. It should not + decide production defaults. +- `run_benchmarks.py` owns local benchmark suite execution, preflight checks, + per-case raw shards, and measurement export. +- Direct and staged probe scripts are prompt/runtime experiments. Promote a + probe into `run_benchmarks.py` only after it has stable artifacts, analysis + fields, and regression coverage. +- Distributed DataDesigner execution belongs outside this directory. Detection + export APIs build configs for an external runtime; the measurement tools + should analyze the artifacts that runtime writes. + +## Tests + +Prefer fixtures that look like tool inputs over large constructed tables inside +test functions. For analysis tools, checked-in fixture directories under +`tests/fixtures/measurement/` are easier to review than hundreds of inline +JSON-like rows. + +Keep tests focused on contracts: + +- input files accepted by the tool +- output table shape and key grouping fields +- safety gates and verdicts +- sensitive values excluded from sanitized analysis output +- preflight failures for user-actionable mistakes + +Avoid exhaustive assertions for every derived metric in one test. Add a focused +test when a metric has non-obvious behavior or has regressed before. diff --git a/tools/measurement/README.md b/tools/measurement/README.md index a970fc4d..db29dcff 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -3,10 +3,52 @@ # Measurement tools -This directory contains developer tools for measuring Anonymizer runs, exporting -measurement JSONL to tables, and comparing benchmark strategies. Run the tools -inside the project environment, either with an activated venv or through -`uv run`. +This directory contains developer tools for measuring Anonymizer runs, +exporting measurement JSONL to tables, and comparing benchmark strategies. +Run the tools inside the project environment, either with an activated venv or +through `uv run`. + +Use these tools when you need evidence about cost, latency, reliability, or +anonymization quality. They are not product entry points and the benchmark-only +strategy knobs are not public Anonymizer defaults. + +## System overview + +The measurement system has three layers: + +- Instrumentation in Anonymizer emits JSONL records for runs, stages, + DataDesigner workflows, direct model workflows, and per-record safety metrics. +- Benchmark runners and probes create repeatable workloads and write those JSONL + records plus optional sidecars such as detection artifacts and DataDesigner + traces. +- Analysis tools convert raw run artifacts into case, group, model, and + comparison tables. + +External/distributed execution is a separate boundary. Detection export APIs are +responsible for building DataDesigner configs that an external runtime can +execute. The tools here should consume the resulting measurement JSONL, +detection artifacts, and trace sidecars; they should not own SLURM +orchestration or distributed DataDesigner execution. + +## Tool map + +| Task | Tool | +| --- | --- | +| Export raw measurement JSONL to tables | `export_measurements.py` | +| Run repeatable Anonymizer suites | `run_benchmarks.py` | +| Inspect DataDesigner traces | `analyze_dd_traces.py` | +| Inspect detection artifact sidecars | `analyze_detection_artifacts.py` | +| Probe one direct detection prompt | `direct_detection_probe.py` | +| Probe staged seed/validate/augment paths | `staged_detection_probe.py` | +| Analyze staged probe outputs | `analyze_staged_detection_output.py` | +| Analyze benchmark output directories | `analyze_benchmark_output.py` | +| Compare a candidate against a baseline | `compare_strategy_pairs.py` | +| Screen many comparison files | `screen_strategy_comparisons.py` | +| Extract exact-signature deltas | `extract_signature_deltas.py` | + +Most workflows start with `run_benchmarks.py`, then +`analyze_benchmark_output.py`, then either `compare_strategy_pairs.py` or +`screen_strategy_comparisons.py`. ```bash uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables From 3a2ea400d20bb696da83caaa271444b45da6da18 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 00:37:50 +0000 Subject: [PATCH 12/26] Factor measurement tool support helpers Signed-off-by: Aaron Gonzales --- src/anonymizer/measurement.py | 6 +- tools/measurement/AGENTS.md | 7 + tools/measurement/README.md | 15 ++ tools/measurement/analyze_benchmark_output.py | 141 +++--------------- tools/measurement/analyze_dd_traces.py | 100 ++----------- .../analyze_detection_artifacts.py | 34 +---- .../analyze_staged_detection_output.py | 104 ++----------- tools/measurement/compare_strategy_pairs.py | 31 +--- tools/measurement/direct_detection_probe.py | 25 +--- tools/measurement/export_measurements.py | 48 +----- tools/measurement/extract_signature_deltas.py | 31 +--- .../measurement/measurement_tools/__init__.py | 4 + tools/measurement/measurement_tools/cli.py | 34 +++++ tools/measurement/measurement_tools/stats.py | 49 ++++++ tools/measurement/measurement_tools/tables.py | 93 ++++++++++++ tools/measurement/run_benchmarks.py | 29 +--- .../screen_strategy_comparisons.py | 33 +--- tools/measurement/staged_detection_probe.py | 19 +-- 18 files changed, 274 insertions(+), 529 deletions(-) create mode 100644 tools/measurement/measurement_tools/__init__.py create mode 100644 tools/measurement/measurement_tools/cli.py create mode 100644 tools/measurement/measurement_tools/stats.py create mode 100644 tools/measurement/measurement_tools/tables.py diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 18c249eb..182c1c5c 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -364,7 +364,11 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator return sink = _JsonlMeasurementSink(config.output_path) if config.streaming else None - dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) if config.dd_trace != "none" else None + dd_trace_sink = None + if config.dd_trace != "none": + if config.dd_trace_path is None: + raise ValueError("dd_trace_path is required when dd_trace is enabled") + dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) collector = MeasurementCollector( run_id=config.run_id, record_hash_key=config.record_hash_key, diff --git a/tools/measurement/AGENTS.md b/tools/measurement/AGENTS.md index 9b17f879..a2db7203 100644 --- a/tools/measurement/AGENTS.md +++ b/tools/measurement/AGENTS.md @@ -19,6 +19,13 @@ inside `tools/measurement`. - Distributed DataDesigner execution belongs outside this directory. Detection export APIs build configs for an external runtime; the measurement tools should analyze the artifacts that runtime writes. +- Shared command-line concerns live in `measurement_tools/`: CLI logging, + output formats, table writing, and small numeric aggregations. Do not + redefine `LogFormat`, `ExportFormat`, bad-input logging, or model-row table + export in each script. +- Prefer explicit specs and functions over analyzer base classes. A script + should own its row models, parsing, and metric semantics; shared helpers + should own boring IO/aggregation policy. ## Tests diff --git a/tools/measurement/README.md b/tools/measurement/README.md index db29dcff..d8492fc1 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -50,6 +50,21 @@ Most workflows start with `run_benchmarks.py`, then `analyze_benchmark_output.py`, then either `compare_strategy_pairs.py` or `screen_strategy_comparisons.py`. +## Implementation shape + +The scripts keep workload-specific row models and metric logic local, but share +boring command and export policy through `measurement_tools/`: + +- `measurement_tools.cli`: `LogFormat`, logging setup, and structured + bad-input errors. +- `measurement_tools.tables`: `ExportFormat`, model-row table specs, manifest + writing, and CSV/Parquet/JSONL output. +- `measurement_tools.stats`: small numeric helpers used by analysis groupers. + +This is intentionally composition-based. New analysis tools should declare +their own row models and call the shared helpers rather than inheriting from a +common analyzer base class. + ```bash uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables ``` diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py index 056c2287..5e1a482e 100644 --- a/tools/measurement/analyze_benchmark_output.py +++ b/tools/measurement/analyze_benchmark_output.py @@ -16,12 +16,19 @@ import logging import math import sys -from enum import StrEnum from pathlib import Path from typing import Annotated, Any, cast import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.stats import median_or_none as _median_or_none +from measurement_tools.stats import none_if_nan as _none_if_nan +from measurement_tools.stats import sum_int_or_zero as _sum_int_or_zero +from measurement_tools.stats import sum_or_none as _sum_or_none +from measurement_tools.stats import sum_or_zero as _sum_or_zero +from measurement_tools.tables import AnalysisExportResult, ExportFormat, ModelTableSpec +from measurement_tools.tables import write_analysis_tables as _write_analysis_table_specs from pydantic import BaseModel, Field, computed_field app = cyclopts.App(help=__doc__) @@ -38,20 +45,6 @@ } -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class CaseAnalysisRow(BaseModel): suite_id: str | None = None workload_id: str | None = None @@ -314,34 +307,6 @@ def model_usage_group_count(self) -> int: return len(self.model_usage_groups) -class TableSummary(BaseModel): - table: str - rows: int - path: str - - -class AnalysisExportResult(BaseModel): - output_dir: str - format: ExportFormat - tables: list[TableSummary] - manifest_path: str - - -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def analyze_benchmark_output( benchmark_dir: Path, *, @@ -1048,21 +1013,6 @@ def _signature_count(artifact_rows: pd.DataFrame, *, signature_hashes: list[str] return _sum_or_none(artifact_rows, "final_entity_signature_count") -def _sum_or_zero(dataframe: pd.DataFrame, column: str) -> float: - if column not in dataframe.columns: - return 0.0 - return float(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) - - -def _sum_or_none(dataframe: pd.DataFrame, column: str) -> float | None: - if column not in dataframe.columns: - return None - values = pd.to_numeric(dataframe[column], errors="coerce").dropna() - if values.empty: - return None - return float(values.sum()) - - def _positive_count(dataframe: pd.DataFrame, column: str) -> int: if column not in dataframe.columns: return 0 @@ -1349,12 +1299,6 @@ def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysi ) -def _none_if_nan(value: object) -> str | None: - if pd.isna(value): - return None - return str(value) - - def _int_if_not_nan(value: object) -> int | None: if pd.isna(value): return None @@ -1367,17 +1311,6 @@ def _float_if_not_nan(value: object) -> float | None: return float(cast(Any, value)) -def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: - values = pd.to_numeric(dataframe[column], errors="coerce").dropna() - if values.empty: - return None - return float(values.median()) - - -def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: - return int(_sum_or_zero(dataframe, column)) - - def _sum_bool_or_zero(dataframe: pd.DataFrame, column: str) -> int: if column not in dataframe.columns: return 0 @@ -1452,54 +1385,16 @@ def write_analysis_tables( output_dir: Path, export_format: ExportFormat, ) -> AnalysisExportResult: - output_dir.mkdir(parents=True, exist_ok=True) - tables = [ - _write_model_rows( - result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, CaseAnalysisRow - ), - _write_model_rows( - result.groups, output_dir / f"group_analysis.{export_format.value}", export_format, GroupAnalysisRow - ), - _write_model_rows( - result.model_usage, - output_dir / f"model_analysis.{export_format.value}", - export_format, - ModelUsageAnalysisRow, - ), - _write_model_rows( - result.model_usage_groups, - output_dir / f"model_group_analysis.{export_format.value}", - export_format, - ModelUsageGroupAnalysisRow, - ), - ] - export_result = AnalysisExportResult( - output_dir=str(output_dir), - format=export_format, - tables=tables, - manifest_path=str(output_dir / "manifest.json"), + return _write_analysis_table_specs( + output_dir, + export_format, + [ + ModelTableSpec("case_analysis", result.cases, CaseAnalysisRow), + ModelTableSpec("group_analysis", result.groups, GroupAnalysisRow), + ModelTableSpec("model_analysis", result.model_usage, ModelUsageAnalysisRow), + ModelTableSpec("model_group_analysis", result.model_usage_groups, ModelUsageGroupAnalysisRow), + ], ) - Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") - return export_result - - -def _write_model_rows( - rows: list[BaseModel], - path: Path, - export_format: ExportFormat, - row_model: type[BaseModel], -) -> TableSummary: - if rows: - table = pd.json_normalize([row.model_dump() for row in rows], sep=".") - else: - table = pd.DataFrame(columns=list(row_model.model_fields)) - if export_format == ExportFormat.parquet: - table.to_parquet(path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(path, index=False) - else: - table.to_json(path, orient="records", lines=True) - return TableSummary(table=path.stem, rows=len(table), path=str(path)) def render_result(result: BenchmarkOutputAnalysis, *, json_output: bool) -> str: @@ -1543,7 +1438,7 @@ def main( if output is not None: write_analysis_tables(result, output, format) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") diff --git a/tools/measurement/analyze_dd_traces.py b/tools/measurement/analyze_dd_traces.py index 157beec9..b0fcae55 100644 --- a/tools/measurement/analyze_dd_traces.py +++ b/tools/measurement/analyze_dd_traces.py @@ -15,12 +15,17 @@ import logging import re import sys -from enum import StrEnum from pathlib import Path from typing import Annotated, Any import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.stats import median_or_none as _median_or_none +from measurement_tools.stats import none_if_nan as _none_if_nan +from measurement_tools.stats import sum_int_or_zero as _sum_int_or_zero +from measurement_tools.tables import AnalysisExportResult, ExportFormat, ModelTableSpec +from measurement_tools.tables import write_analysis_tables as _write_analysis_table_specs from pydantic import BaseModel, Field, computed_field app = cyclopts.App(help=__doc__) @@ -29,20 +34,6 @@ JSON_FENCE_RE = re.compile(r"```\s*json\b", re.IGNORECASE) -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class TraceAnalysisRow(BaseModel): trace_file: str line_number: int @@ -110,34 +101,6 @@ def group_count(self) -> int: return len(self.groups) -class TableSummary(BaseModel): - table: str - rows: int - path: str - - -class AnalysisExportResult(BaseModel): - output_dir: str - format: ExportFormat - tables: list[TableSummary] - manifest_path: str - - -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def analyze_trace_path(trace_path: Path) -> TraceAnalysis: rows = [ _trace_row(record, trace_file=path, line_number=line_number) @@ -346,50 +309,15 @@ def _int_or_none(value: object) -> int | None: return int(value) if value is not None and not pd.isna(value) else None -def _none_if_nan(value: object) -> str | None: - if pd.isna(value): - return None - return str(value) - - -def _median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: - values = pd.to_numeric(dataframe[column], errors="coerce").dropna() - if values.empty: - return None - return float(values.median()) - - -def _sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: - if column not in dataframe.columns: - return 0 - return int(pd.to_numeric(dataframe[column], errors="coerce").fillna(0).sum()) - - def write_analysis_tables(result: TraceAnalysis, output_dir: Path, export_format: ExportFormat) -> AnalysisExportResult: - output_dir.mkdir(parents=True, exist_ok=True) - tables = [ - _write_model_rows(result.rows, output_dir / f"trace_analysis.{export_format.value}", export_format), - _write_model_rows(result.groups, output_dir / f"trace_group_analysis.{export_format.value}", export_format), - ] - export_result = AnalysisExportResult( - output_dir=str(output_dir), - format=export_format, - tables=tables, - manifest_path=str(output_dir / "manifest.json"), + return _write_analysis_table_specs( + output_dir, + export_format, + [ + ModelTableSpec("trace_analysis", result.rows, TraceAnalysisRow), + ModelTableSpec("trace_group_analysis", result.groups, TraceGroupAnalysisRow), + ], ) - Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") - return export_result - - -def _write_model_rows(rows: list[BaseModel], path: Path, export_format: ExportFormat) -> TableSummary: - table = pd.DataFrame([row.model_dump() for row in rows]) - if export_format == ExportFormat.parquet: - table.to_parquet(path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(path, index=False) - else: - table.to_json(path, orient="records", lines=True) - return TableSummary(table=path.stem, rows=len(table), path=str(path)) def render_result(result: TraceAnalysis, *, json_output: bool) -> str: @@ -422,7 +350,7 @@ def main( if output is not None: write_analysis_tables(result, output, format) except (OSError, ValueError, json.JSONDecodeError) as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") diff --git a/tools/measurement/analyze_detection_artifacts.py b/tools/measurement/analyze_detection_artifacts.py index 0c16da75..2d340c55 100644 --- a/tools/measurement/analyze_detection_artifacts.py +++ b/tools/measurement/analyze_detection_artifacts.py @@ -19,12 +19,13 @@ import re import sys from collections import Counter -from enum import StrEnum from pathlib import Path from typing import Annotated, Iterable import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat from pydantic import BaseModel, Field from anonymizer.engine.constants import ( @@ -47,20 +48,6 @@ API_KEY_PREFIX_RE = re.compile(r"^(sk-|sk_|sk-ant-|sk-proj-|ghp_|pat-|hf_|xox[a-z]-|ya29\.|aiza|akia|bearer\s+)", re.I) -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class DetectionArtifactRow(BaseModel): workflow_name: str batch_file: str @@ -88,21 +75,6 @@ class DetectionArtifactAnalysis(BaseModel): rows: list[DetectionArtifactRow] = Field(default_factory=list) -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def analyze_artifacts( artifact_path: Path, *, @@ -361,7 +333,7 @@ def main( if output is not None: write_rows(result.rows, output, format) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py index de0b3780..d74e0f32 100644 --- a/tools/measurement/analyze_staged_detection_output.py +++ b/tools/measurement/analyze_staged_detection_output.py @@ -16,32 +16,19 @@ import logging import sys from collections import Counter, defaultdict -from enum import StrEnum from pathlib import Path from typing import Annotated, Any import cyclopts -import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import AnalysisExportResult, ExportFormat, ModelTableSpec +from measurement_tools.tables import write_analysis_tables as _write_analysis_table_specs from pydantic import BaseModel, Field, computed_field app = cyclopts.App(help=__doc__) logger = logging.getLogger("measurement.staged_detection_output") -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class StagedCaseAnalysisRow(BaseModel): source_path: str case_id: str @@ -106,19 +93,6 @@ class LabelDeltaAnalysisRow(BaseModel): count: int -class TableSummary(BaseModel): - table: str - rows: int - path: str - - -class AnalysisExportResult(BaseModel): - output_dir: str - format: ExportFormat - tables: list[TableSummary] - manifest_path: str - - class StagedDetectionOutputAnalysis(BaseModel): source_path: str cases: list[StagedCaseAnalysisRow] = Field(default_factory=list) @@ -141,21 +115,6 @@ def label_delta_count(self) -> int: return len(self.label_deltas) -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def analyze_staged_detection_output(input_path: Path) -> StagedDetectionOutputAnalysis: case_path = resolve_case_path(input_path) case_rows = [_build_case_row(row, source_path=case_path) for row in read_case_records(case_path)] @@ -363,54 +322,15 @@ def write_analysis_tables( output_dir: Path, export_format: ExportFormat, ) -> AnalysisExportResult: - output_dir.mkdir(parents=True, exist_ok=True) - tables = [ - _write_model_rows( - result.cases, output_dir / f"case_analysis.{export_format.value}", export_format, StagedCaseAnalysisRow - ), - _write_model_rows( - result.groups, - output_dir / f"group_analysis.{export_format.value}", - export_format, - StagedGroupAnalysisRow, - ), - _write_model_rows( - result.label_deltas, - output_dir / f"label_delta_analysis.{export_format.value}", - export_format, - LabelDeltaAnalysisRow, - ), - ] - export_result = AnalysisExportResult( - output_dir=str(output_dir), - format=export_format, - tables=tables, - manifest_path=str(output_dir / "manifest.json"), + return _write_analysis_table_specs( + output_dir, + export_format, + [ + ModelTableSpec("case_analysis", result.cases, StagedCaseAnalysisRow), + ModelTableSpec("group_analysis", result.groups, StagedGroupAnalysisRow), + ModelTableSpec("label_delta_analysis", result.label_deltas, LabelDeltaAnalysisRow), + ], ) - Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") - return export_result - - -def _write_model_rows( - rows: list[BaseModel], - path: Path, - export_format: ExportFormat, - row_model: type[BaseModel], -) -> TableSummary: - table = _rows_to_table(rows, row_model) - if export_format == ExportFormat.parquet: - table.to_parquet(path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(path, index=False) - else: - table.to_json(path, orient="records", lines=True) - return TableSummary(table=path.stem, rows=len(table), path=str(path)) - - -def _rows_to_table(rows: list[BaseModel], row_model: type[BaseModel]) -> pd.DataFrame: - if rows: - return pd.json_normalize([row.model_dump() for row in rows], sep=".") - return pd.DataFrame(columns=list(row_model.model_fields)) def render_result(result: StagedDetectionOutputAnalysis, *, json_output: bool) -> str: @@ -463,7 +383,7 @@ def main( if output is not None: write_analysis_tables(result, output, format) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py index bdbad0d5..e9075e2c 100644 --- a/tools/measurement/compare_strategy_pairs.py +++ b/tools/measurement/compare_strategy_pairs.py @@ -25,6 +25,8 @@ import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat from pydantic import BaseModel, Field app = cyclopts.App(help=__doc__) @@ -40,17 +42,6 @@ } -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - class SafetyVerdict(StrEnum): passed = "pass" review = "review" @@ -71,7 +62,6 @@ class CandidateVerdict(StrEnum): reject = "reject" -_log_format = LogFormat.plain _MIN_CANDIDATE_OVERLAP_RATIO = 0.8 _MAX_CANDIDATE_BOUNDARY_GAP_CHARS = 8 _MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS = 8 @@ -218,21 +208,6 @@ def comparison_count(self) -> int: return len(self.comparisons) -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def read_case_analysis(path: Path) -> pd.DataFrame: if not path.exists() or path.is_dir(): raise ValueError(f"case analysis path is not a file: {path}") @@ -1444,7 +1419,7 @@ def main( candidate_strategy=candidate_strategy, ) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc result = ComparisonResult( input_path=str(case_analysis), diff --git a/tools/measurement/direct_detection_probe.py b/tools/measurement/direct_detection_probe.py index 8659fec0..05f93e0d 100644 --- a/tools/measurement/direct_detection_probe.py +++ b/tools/measurement/direct_detection_probe.py @@ -26,6 +26,7 @@ import pandas as pd from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities from dd_parser_compat import _load_embedded_json +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input from pydantic import BaseModel, Field, ValidationError, model_validator from anonymizer.engine.detection.detection_workflow import _format_label_examples @@ -46,11 +47,6 @@ class CaseStatus(StrEnum): error = "error" -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - class PromptMode(StrEnum): compact = "compact" recall = "recall" @@ -151,23 +147,6 @@ def complete(self, request: DirectGenerationRequest) -> DirectCompletion: ) -_log_format = LogFormat.plain - - -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") - return - logger.error("bad_input error=%s", error) - - def build_chat_payload(request: DirectGenerationRequest) -> dict[str, Any]: payload: dict[str, Any] = { "model": request.model, @@ -565,7 +544,7 @@ def main( baseline_artifacts=baseline_artifacts, ) except (ValueError, ValidationError, httpx.HTTPError) as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") if result.error_count: diff --git a/tools/measurement/export_measurements.py b/tools/measurement/export_measurements.py index 796fcc40..6480bbf9 100755 --- a/tools/measurement/export_measurements.py +++ b/tools/measurement/export_measurements.py @@ -12,12 +12,13 @@ import json import logging import sys -from enum import StrEnum from pathlib import Path from typing import Annotated import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat, ensure_can_write, write_table from pydantic import BaseModel, Field app = cyclopts.App(help=__doc__) @@ -26,20 +27,6 @@ MANIFEST_FILENAME = "manifest.json" -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class TableSummary(BaseModel): record_type: str rows: int @@ -56,21 +43,6 @@ class ExportResult(BaseModel): manifest_path: str -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def read_measurements(path: Path) -> pd.DataFrame: if not path.exists(): raise ValueError(f"input path does not exist: {path}") @@ -104,15 +76,6 @@ def _json_cell(value: object) -> object: return json.dumps(value, ensure_ascii=True, sort_keys=True) -def write_table(table: pd.DataFrame, path: Path, export_format: ExportFormat) -> None: - if export_format == ExportFormat.parquet: - table.to_parquet(path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(path, index=False) - else: - table.to_json(path, orient="records", lines=True) - - def export_tables( dataframe: pd.DataFrame, *, @@ -158,11 +121,6 @@ def write_manifest(result: ExportResult, path: Path, *, overwrite: bool) -> None path.write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") -def ensure_can_write(path: Path, *, overwrite: bool) -> None: - if path.exists() and not overwrite: - raise ValueError(f"output already exists: {path}; pass --overwrite to replace it") - - def render_result(result: ExportResult, *, json_output: bool) -> str: if json_output: return result.model_dump_json(indent=2) @@ -196,7 +154,7 @@ def main( overwrite=overwrite, ) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py index 7732c147..a676255f 100644 --- a/tools/measurement/extract_signature_deltas.py +++ b/tools/measurement/extract_signature_deltas.py @@ -23,6 +23,8 @@ import cyclopts import pandas as pd from analyze_detection_artifacts import _entity_signature_hash +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat from pydantic import BaseModel, Field, ValidationError from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT @@ -32,17 +34,6 @@ logger = logging.getLogger("measurement.signature_deltas") -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - class DeltaSide(StrEnum): baseline_only = "baseline_only" candidate_only = "candidate_only" @@ -54,7 +45,6 @@ class ContextResolution(StrEnum): metadata_only = "metadata_only" -_log_format = LogFormat.plain _SIGNATURE_DETAIL_FIELDS = { "label", "source", @@ -97,21 +87,6 @@ class _ArtifactSide(BaseModel): config_id: str | None = None -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def extract_signature_deltas( baseline_artifacts: Path, candidate_artifacts: Path, @@ -527,7 +502,7 @@ def main( context_window=context_window, ) except (ValueError, ValidationError) as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc if output is not None: write_rows(result.rows, output, format) diff --git a/tools/measurement/measurement_tools/__init__.py b/tools/measurement/measurement_tools/__init__.py new file mode 100644 index 00000000..28a69b6a --- /dev/null +++ b/tools/measurement/measurement_tools/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Shared support for measurement command-line tools.""" diff --git a/tools/measurement/measurement_tools/cli.py b/tools/measurement/measurement_tools/cli.py new file mode 100644 index 00000000..505bbf2f --- /dev/null +++ b/tools/measurement/measurement_tools/cli.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Shared CLI logging helpers for measurement tools.""" + +from __future__ import annotations + +import json +import logging +import sys +from enum import StrEnum + + +class LogFormat(StrEnum): + plain = "plain" + json = "json" + + +_log_format = LogFormat.plain + + +def configure_logging(log_format: LogFormat) -> None: + global _log_format + + _log_format = log_format + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def log_bad_input(logger: logging.Logger, error: str) -> None: + if _log_format == LogFormat.json: + payload = {"level": "error", "event": "bad_input", "error": error} + sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") + return + logger.error("bad_input error=%s", error) diff --git a/tools/measurement/measurement_tools/stats.py b/tools/measurement/measurement_tools/stats.py new file mode 100644 index 00000000..bdc30977 --- /dev/null +++ b/tools/measurement/measurement_tools/stats.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Small aggregation helpers shared by measurement analysis tools.""" + +from __future__ import annotations + +from typing import cast + +import pandas as pd + + +def none_if_nan(value: object) -> str | None: + if pd.isna(value): + return None + return str(value) + + +def median_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + if column not in dataframe.columns: + return None + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.median()) + + +def sum_int_or_zero(dataframe: pd.DataFrame, column: str) -> int: + return int(sum_or_zero(dataframe, column)) + + +def sum_or_zero(dataframe: pd.DataFrame, column: str) -> float: + value = sum_or_none(dataframe, column) + return 0.0 if value is None else value + + +def sum_or_none(dataframe: pd.DataFrame, column: str) -> float | None: + if column not in dataframe.columns: + return None + values = pd.to_numeric(dataframe[column], errors="coerce").dropna() + if values.empty: + return None + return float(values.sum()) + + +def optional_number(value: object) -> float | None: + if value is None or pd.isna(value): + return None + return float(cast(float, value)) diff --git a/tools/measurement/measurement_tools/tables.py b/tools/measurement/measurement_tools/tables.py new file mode 100644 index 00000000..4e697aa8 --- /dev/null +++ b/tools/measurement/measurement_tools/tables.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Shared table export helpers for measurement analysis tools.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from pathlib import Path +from typing import Sequence + +import pandas as pd +from pydantic import BaseModel, Field + + +class ExportFormat(StrEnum): + parquet = "parquet" + csv = "csv" + jsonl = "jsonl" + + +class TableSummary(BaseModel): + table: str + rows: int + path: str + + +class AnalysisExportResult(BaseModel): + output_dir: str + format: ExportFormat + tables: list[TableSummary] = Field(default_factory=list) + manifest_path: str + + +@dataclass(frozen=True) +class ModelTableSpec: + name: str + rows: Sequence[BaseModel] + row_model: type[BaseModel] | None = None + + +def write_analysis_tables( + output_dir: Path, + export_format: ExportFormat, + specs: Sequence[ModelTableSpec], +) -> AnalysisExportResult: + output_dir.mkdir(parents=True, exist_ok=True) + tables = [ + write_model_rows(spec.rows, output_dir / f"{spec.name}.{export_format.value}", export_format, spec.row_model) + for spec in specs + ] + export_result = AnalysisExportResult( + output_dir=str(output_dir), + format=export_format, + tables=tables, + manifest_path=str(output_dir / "manifest.json"), + ) + Path(export_result.manifest_path).write_text(export_result.model_dump_json(indent=2) + "\n", encoding="utf-8") + return export_result + + +def write_model_rows( + rows: Sequence[BaseModel], + path: Path, + export_format: ExportFormat, + row_model: type[BaseModel] | None = None, +) -> TableSummary: + table = rows_to_table(rows, row_model) + write_table(table, path, export_format) + return TableSummary(table=path.stem, rows=len(table), path=str(path)) + + +def rows_to_table(rows: Sequence[BaseModel], row_model: type[BaseModel] | None = None) -> pd.DataFrame: + if rows: + return pd.json_normalize([row.model_dump() for row in rows], sep=".") + if row_model is None: + return pd.DataFrame() + return pd.DataFrame(columns=list(row_model.model_fields)) + + +def write_table(table: pd.DataFrame, path: Path, export_format: ExportFormat) -> None: + if export_format == ExportFormat.parquet: + table.to_parquet(path, index=False) + elif export_format == ExportFormat.csv: + table.to_csv(path, index=False) + else: + table.to_json(path, orient="records", lines=True) + + +def ensure_can_write(path: Path, *, overwrite: bool) -> None: + if path.exists() and not overwrite: + raise ValueError(f"output already exists: {path}; pass --overwrite to replace it") diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 2306aecc..a832e704 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -36,7 +36,9 @@ NativeDetectionRuntime, experimental_detection_strategy_context, ) -from export_measurements import ExportFormat, export_tables, read_measurements +from export_measurements import export_tables, read_measurements +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator from anonymizer.config.anonymizer_config import ( @@ -60,14 +62,6 @@ logger = logging.getLogger("measurement.benchmark") -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - -_log_format = LogFormat.plain - - class CaseStatus(StrEnum): planned = "planned" completed = "completed" @@ -307,21 +301,6 @@ class _CaseExecution: _FINAL_ARTIFACT_PREFIXES = tuple(f"{key}." for key in _FINAL_ARTIFACT_KEYS if key != "final_entity_signature_hashes") -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def load_spec(path: Path) -> BenchmarkSpec: if not path.exists() or path.is_dir(): raise ValueError(f"spec path is not a file: {path}") @@ -1444,7 +1423,7 @@ def main( trace_dir=trace_dir, ) except (ValueError, ValidationError) as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc sys.stdout.write(render_result(result, json_output=json_output) + "\n") if any(case.status == CaseStatus.error for case in result.cases): diff --git a/tools/measurement/screen_strategy_comparisons.py b/tools/measurement/screen_strategy_comparisons.py index 0cd9e141..69542205 100644 --- a/tools/measurement/screen_strategy_comparisons.py +++ b/tools/measurement/screen_strategy_comparisons.py @@ -21,6 +21,8 @@ import cyclopts import pandas as pd +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input +from measurement_tools.tables import ExportFormat from pydantic import BaseModel, Field, model_validator app = cyclopts.App(help=__doc__) @@ -36,17 +38,6 @@ } -class ExportFormat(StrEnum): - parquet = "parquet" - csv = "csv" - jsonl = "jsonl" - - -class LogFormat(StrEnum): - plain = "plain" - json = "json" - - class GroupBy(StrEnum): strategy = "strategy" strategy_workload_family = "strategy_workload_family" @@ -202,24 +193,6 @@ class ScreenResult(BaseModel): groups: list[ScreenGroup] = Field(default_factory=list) -_log_format = LogFormat.plain - - -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - payload = {"level": "error", "event": "bad_input", "error": error} - sys.stderr.write(json.dumps(payload, ensure_ascii=True, sort_keys=True) + "\n") - return - logger.error("bad_input error=%s", error) - - def screen_comparison_paths( paths: list[Path], *, @@ -1121,7 +1094,7 @@ def main( source_excludes=source_exclude or [], ) except ValueError as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc if output is not None: write_rows(result.rows, output, format) diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py index ddb1a5d0..57295a54 100644 --- a/tools/measurement/staged_detection_probe.py +++ b/tools/measurement/staged_detection_probe.py @@ -36,13 +36,13 @@ DirectDetectionRequest, DirectGenerationRequest, HttpxDirectDetectionClient, - LogFormat, PromptMode, SignatureComparison, build_direct_prompt, compare_signature_sets, parse_labels, ) +from measurement_tools.cli import LogFormat, configure_logging, log_bad_input from pydantic import BaseModel, Field, ValidationError, model_validator from anonymizer.engine.constants import ( @@ -100,7 +100,6 @@ _UNCONFIGURED_MODEL = "configured-native-model" _UNCONFIGURED_GLINER_ENDPOINT = "configured-gliner-endpoint" _UNCONFIGURED_GLINER_MODEL = "configured-gliner-model" -_log_format = LogFormat.plain _DATE_OF_BIRTH_CONTEXT_RE = re.compile(r"\b(born|birth|date of birth|dob)\b", re.IGNORECASE) @@ -273,20 +272,6 @@ class StagedDetectionExecution: row: dict[str, Any] -def configure_logging(log_format: LogFormat) -> None: - global _log_format - - _log_format = log_format - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - - -def log_bad_input(error: str) -> None: - if _log_format == LogFormat.json: - sys.stderr.write(json.dumps({"level": "error", "event": "bad_input", "error": error}) + "\n") - return - logger.error("bad_input error=%s", error) - - def run_staged_detection_case( request: StagedDetectionRequest, *, @@ -1382,7 +1367,7 @@ def _run_main_probe(params: dict[str, Any]) -> StagedDetectionRun: try: return run_probe(**params) except (ValueError, ValidationError, httpx.HTTPError) as exc: - log_bad_input(str(exc)) + log_bad_input(logger, str(exc)) raise SystemExit(125) from exc From ed8ba48f17b1ba40bfc7ed1a84fd8517bb2050c9 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 04:58:05 +0000 Subject: [PATCH 13/26] Fix measurement cleanup and benchmark metrics Signed-off-by: Aaron Gonzales --- src/anonymizer/measurement.py | 111 +++++++++++++++++--------- tests/test_measurement.py | 103 +++++++++++++++++++++++- tests/tools/test_measurement_tools.py | 59 ++++++++++++++ tools/measurement/run_benchmarks.py | 32 ++++++-- 4 files changed, 258 insertions(+), 47 deletions(-) diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 182c1c5c..150876cf 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -194,10 +194,17 @@ def record(self, record_type: str, **fields: Any) -> None: def close(self) -> None: """Close any streaming measurement sink attached to this collector.""" - if self._record_sink is not None: - self._record_sink.close() - if self._dd_trace_sink is not None: - self._dd_trace_sink.close() + close_error: Exception | None = None + for sink in (self._record_sink, self._dd_trace_sink): + if sink is None: + continue + try: + sink.close() + except Exception as exc: + if close_error is None: + close_error = exc + if close_error is not None: + raise close_error @property def dd_trace_mode(self) -> DDTraceMode: @@ -391,8 +398,18 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator if config.streaming: _close_collector_safely(config=config, collector=collector, body_error=body_error) else: - _write_collector_safely(config=config, collector=collector, body_error=body_error) - _close_collector_safely(config=config, collector=collector, body_error=body_error) + write_error: BaseException | None = None + try: + _write_collector_safely(config=config, collector=collector, body_error=body_error) + except BaseException as exc: + write_error = exc + raise + finally: + _close_collector_safely( + config=config, + collector=collector, + body_error=body_error or write_error, + ) def current_collector() -> MeasurementCollector | None: @@ -1136,11 +1153,11 @@ def _entity_ground_truth_metrics( "entity_relaxed_label_compatible_f1": None, } - predicted = _entity_identity_set(final_entities) - expected = _entity_identity_set(ground_truth_entities) - true_positive = len(predicted & expected) - false_positive = len(predicted - expected) - false_negative = len(expected - predicted) + predicted = _entity_identity_counts(final_entities) + expected = _entity_identity_counts(ground_truth_entities) + true_positive = sum((predicted & expected).values()) + false_positive = sum((predicted - expected).values()) + false_negative = sum((expected - predicted).values()) precision = _safe_ratio(true_positive, true_positive + false_positive) recall = _safe_ratio(true_positive, true_positive + false_negative) return { @@ -1158,14 +1175,14 @@ def _entity_ground_truth_metrics( } -def _entity_identity_set(entities: list[dict[str, Any]]) -> set[tuple[str, str]]: - identities: set[tuple[str, str]] = set() +def _entity_identity_counts(entities: list[dict[str, Any]]) -> Counter[tuple[str, str]]: + identities: Counter[tuple[str, str]] = Counter() for entity in entities: label = entity.get("label") value = entity.get("value") if label is None or value is None: continue - identities.add((str(value), str(label))) + identities[(str(value), str(label))] += 1 return identities @@ -1173,24 +1190,16 @@ def _entity_relaxed_ground_truth_metrics( final_entities: list[dict[str, Any]], ground_truth_entities: list[dict[str, Any]], ) -> dict[str, Any]: - gt_found = sum( - 1 - for ground_truth_entity in ground_truth_entities - if _has_relaxed_entity_match(final_entities, ground_truth_entity) - ) - detected_tp = sum( - 1 for final_entity in final_entities if _has_relaxed_entity_match(ground_truth_entities, final_entity) - ) - label_compatible_gt_found = sum( - 1 - for ground_truth_entity in ground_truth_entities - if _has_relaxed_entity_match(final_entities, ground_truth_entity, require_label_compatible=True) - ) - label_compatible_detected_tp = sum( - 1 - for final_entity in final_entities - if _has_relaxed_entity_match(ground_truth_entities, final_entity, require_label_compatible=True) + relaxed_match_count = _relaxed_entity_match_count(final_entities, ground_truth_entities) + label_compatible_match_count = _relaxed_entity_match_count( + final_entities, + ground_truth_entities, + require_label_compatible=True, ) + gt_found = relaxed_match_count + detected_tp = relaxed_match_count + label_compatible_gt_found = label_compatible_match_count + label_compatible_detected_tp = label_compatible_match_count precision = _safe_ratio(detected_tp, len(final_entities)) recall = _safe_ratio(gt_found, len(ground_truth_entities)) label_compatible_precision = _safe_ratio(label_compatible_detected_tp, len(final_entities)) @@ -1209,16 +1218,40 @@ def _entity_relaxed_ground_truth_metrics( } -def _has_relaxed_entity_match( - candidates: list[dict[str, Any]], - target: dict[str, Any], +def _relaxed_entity_match_count( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]], *, require_label_compatible: bool = False, -) -> bool: - return any( - _entities_match_relaxed(candidate, target, require_label_compatible=require_label_compatible) - for candidate in candidates - ) +) -> int: + matches_by_ground_truth = [ + [ + final_index + for final_index, final_entity in enumerate(final_entities) + if _entities_match_relaxed( + final_entity, + ground_truth_entity, + require_label_compatible=require_label_compatible, + ) + ] + for ground_truth_entity in ground_truth_entities + ] + matched_ground_truth_by_final: dict[int, int] = {} + + def assign(ground_truth_index: int, seen: set[int]) -> bool: + for final_index in matches_by_ground_truth[ground_truth_index]: + if final_index in seen: + continue + seen.add(final_index) + if final_index not in matched_ground_truth_by_final or assign( + matched_ground_truth_by_final[final_index], + seen, + ): + matched_ground_truth_by_final[final_index] = ground_truth_index + return True + return False + + return sum(1 for ground_truth_index in range(len(ground_truth_entities)) if assign(ground_truth_index, set())) def _entities_match_relaxed( diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 9a4665be..c462a0cb 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -8,7 +8,7 @@ import threading from pathlib import Path from types import SimpleNamespace -from typing import cast +from typing import Any, cast from unittest.mock import Mock import numpy as np @@ -630,6 +630,35 @@ def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector collector.record("example") +def test_measurement_config_strict_write_errors_still_close_collector( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + closed_run_ids: list[str] = [] + + def raise_write_error(_self: MeasurementConfig, _collector: MeasurementCollector) -> None: + raise OSError("cannot write") + + def close_collector(self: MeasurementCollector) -> None: + closed_run_ids.append(self.run_id) + + monkeypatch.setattr(MeasurementConfig, "write_collector", raise_write_error) + monkeypatch.setattr(MeasurementCollector, "close", close_collector) + + with pytest.raises(OSError, match="cannot write"): + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", + fail_on_write_error=True, + run_id="strict-write-run", + ) + ) as collector: + assert collector is not None + collector.record("example") + + assert closed_run_ids == ["strict-write-run"] + + def test_measurement_config_write_errors_do_not_mask_body_errors( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -667,6 +696,33 @@ def test_streaming_measurement_session_writes_jsonl_without_retaining_records(tm assert [json.loads(line)["value"] for line in lines] == [1, 2] +def test_measurement_collector_close_attempts_all_sinks_after_failure() -> None: + close_events: list[str] = [] + + class FakeSink: + def __init__(self, name: str, *, fail: bool = False) -> None: + self.name = name + self.fail = fail + + def write_record(self, _record: dict[str, Any]) -> None: + pass + + def close(self) -> None: + close_events.append(self.name) + if self.fail: + raise OSError(f"{self.name} close failed") + + collector = MeasurementCollector( + record_sink=FakeSink("records", fail=True), + dd_trace_sink=FakeSink("dd-trace"), + ) + + with pytest.raises(OSError, match="records close failed"): + collector.close() + + assert close_events == ["records", "dd-trace"] + + def test_streaming_measurement_requires_jsonl_output(tmp_path: Path) -> None: with pytest.raises(ValueError, match="streaming measurement output only supports jsonl"): MeasurementConfig(output_path=tmp_path / "measurements.json", output_format="json", streaming=True) @@ -908,6 +964,51 @@ def test_record_metrics_capture_generic_counts_without_raw_values() -> None: assert "Maya" not in serialized +def test_record_metrics_counts_duplicate_ground_truth_entities_by_occurrence() -> None: + final_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + ] + } + ground_truth_entities = { + "entities": [ + {"value": "Alice", "label": "first_name", "start_position": 0, "end_position": 5}, + {"value": "Alice", "label": "first_name", "start_position": 18, "end_position": 23}, + ] + } + dataframe = pd.DataFrame( + { + COL_TEXT: ["Alice talked with Alice"], + COL_FINAL_ENTITIES: [final_entities], + "ground_truth_entities": [ground_truth_entities], + } + ) + collector = MeasurementCollector(record_hash_key="test-key") + + with measurement_session(collector): + record_record_metrics( + dataframe, + mode="replace", + strategy="Redact", + text_column=COL_TEXT, + validation_max_entities_per_call=2, + ) + + record = collector.records[0] + assert record["ground_truth_entity_count"] == 2 + assert record["ground_truth_entity_label_counts"] == {"first_name": 2} + assert record["entity_true_positive_count"] == 1 + assert record["entity_false_positive_count"] == 0 + assert record["entity_false_negative_count"] == 1 + assert record["entity_precision"] == 1.0 + assert record["entity_recall"] == 0.5 + assert record["entity_f1"] == pytest.approx(2 / 3) + assert record["entity_relaxed_gt_found_count"] == 1 + assert record["entity_relaxed_detected_tp_count"] == 1 + assert record["entity_relaxed_precision"] == 1.0 + assert record["entity_relaxed_recall"] == 0.5 + + def test_record_metrics_capture_relaxed_gt_label_equivalence_without_raw_values() -> None: final_entities = { "entities": [ diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index c5ba6e0d..41be00d4 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -54,6 +54,21 @@ def _copy_biography_data(tmp_path: Path, filename: str = "input.csv") -> Path: return destination +def test_benchmark_spec_rejects_duplicate_matrix_entries() -> None: + tool = load_tool("measurement_benchmark_tool_duplicate_matrix", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + with pytest.raises(ValidationError, match="duplicate matrix workload/config entry"): + tool.BenchmarkSpec( + suite_id="duplicate-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + matrix=[ + tool.MatrixEntry(workload="input", config="redact"), + tool.MatrixEntry(workload="input", config="redact"), + ], + ) + + def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") dataframe = pd.DataFrame( @@ -271,6 +286,50 @@ def test_benchmark_patches_detection_artifacts_from_final_trace_dataframe(tmp_pa assert "final_entity_signature_labels.stale" not in row +def test_benchmark_no_export_disables_trace_detection_artifact_sidecar(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_no_export_trace_artifact", + REPO_ROOT / "tools/measurement/run_benchmarks.py", + ) + spec = tool.BenchmarkSpec( + suite_id="no-export-suite", + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[ + tool.ConfigSpec( + id="native-single-pass-redact", + replace="redact", + experimental_detection_strategy="native_single_pass", + ) + ], + ) + case = tool.BenchmarkCase( + suite_id="no-export-suite", + workload_id="input", + config_id="native-single-pass-redact", + repetition=0, + case_id="input__native-single-pass-redact__r000", + ) + contexts = _minimal_case_contexts(tool, spec, tmp_path) + paths = tool._case_run_paths(case, contexts=contexts, export_detection_artifacts=False) + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) + execution = tool._CaseExecution( + input_data=tool.AnonymizerInput(source=str(input_path)), + trace_dataframe=_final_trace_dataframe_with_rule_entity(), + ) + + result = tool._case_detection_artifact_path( + contexts, + paths, + case=case, + config=spec.configs[0], + execution=execution, + ) + + assert result is None + assert not paths.artifact_output_path.exists() + + def test_run_suite_records_detection_artifact_analysis_path( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index a832e704..a1ad72c9 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -217,6 +217,10 @@ def _validate_matrix_references(self, workload_ids: set[str], config_ids: set[st raise ValueError(f"matrix references unknown workload id(s): {', '.join(missing_workloads)}") if missing_configs: raise ValueError(f"matrix references unknown config id(s): {', '.join(missing_configs)}") + duplicate_entries = _duplicate_matrix_entries(self.matrix) + if duplicate_entries: + formatted = ", ".join(f"{workload}/{config}" for workload, config in duplicate_entries) + raise ValueError(f"duplicate matrix workload/config entry(s): {formatted}; use repetitions for repeats") class BenchmarkCase(BaseModel): @@ -251,6 +255,7 @@ class _CaseRunPaths: artifact_output_path: Path trace_path: Path | None artifact_snapshot: dict[str, int] | None + export_detection_artifacts: bool @dataclass(frozen=True) @@ -333,6 +338,17 @@ def _cross_product_matrix(spec: BenchmarkSpec) -> list[MatrixEntry]: ] +def _duplicate_matrix_entries(matrix: list[MatrixEntry]) -> list[tuple[str, str]]: + seen: set[tuple[str, str]] = set() + duplicates: set[tuple[str, str]] = set() + for entry in matrix: + key = (entry.workload, entry.config) + if key in seen: + duplicates.add(key) + seen.add(key) + return sorted(duplicates) + + def prepare_output_dir(output_dir: Path, *, overwrite: bool, dry_run: bool) -> None: if dry_run: return @@ -742,6 +758,7 @@ def _case_run_paths( artifact_snapshot=snapshot_detection_artifacts(contexts["artifact_path"]) if export_detection_artifacts else None, + export_detection_artifacts=export_detection_artifacts, ) @@ -832,13 +849,14 @@ def _case_detection_artifact_path( case=case, artifact_snapshot=paths.artifact_snapshot, ) - detection_artifact_path = _trace_final_artifact_path_if_requested( - config, - detection_artifact_path, - paths.artifact_output_path, - case=case, - trace_dataframe=execution.trace_dataframe, - ) + if paths.export_detection_artifacts: + detection_artifact_path = _trace_final_artifact_path_if_requested( + config, + detection_artifact_path, + paths.artifact_output_path, + case=case, + trace_dataframe=execution.trace_dataframe, + ) if detection_artifact_path is not None or paths.artifact_snapshot is None: return detection_artifact_path return None From a852659434d217aa3c7e98c130c8ae45dcc59407 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 05:04:00 +0000 Subject: [PATCH 14/26] Use native DataDesigner message traces Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 334 ++++++++++++--------------- tests/test_measurement.py | 167 ++++++-------- 2 files changed, 223 insertions(+), 278 deletions(-) diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index f48dee4a..d279769f 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -8,18 +8,19 @@ import tempfile import time import uuid -from collections.abc import Iterator, Mapping -from contextlib import contextmanager +from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from threading import RLock from typing import TYPE_CHECKING, Any, cast +from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.models import ModelConfig from data_designer.config.seed import SamplingStrategy from data_designer.config.seed_source import LocalFileSeedSource +from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX +from data_designer.config.utils.trace_type import TraceType from anonymizer.interface.errors import AnonymizerWorkflowError from anonymizer.measurement import current_collector, record_ndd_workflow @@ -31,7 +32,6 @@ logger = logging.getLogger("anonymizer.ndd") RECORD_ID_COLUMN = "_anonymizer_record_id" -_DD_MESSAGE_TRACE_PATCH_LOCK = RLock() @dataclass(frozen=True) @@ -51,6 +51,15 @@ class WorkflowRunResult: failed_records: list[FailedRecord] +@dataclass(frozen=True) +class _NativeTraceColumn: + column_name: str + trace_column_name: str + model_alias: str | None + model_name: str | None + model_provider_name: str | None + + class NddAdapter: """Adapter for running NDD workflows with uniform I/O and record tracking.""" @@ -100,7 +109,19 @@ def run_workflow( else len(workflow_input_df) ) started = time.perf_counter() - usage_probe = _DataDesignerUsageProbe(self._data_designer, enabled=current_collector() is not None) + collector = current_collector() + usage_probe = _DataDesignerUsageProbe(self._data_designer, enabled=collector is not None) + columns, native_trace_columns, unsupported_trace_columns = _configure_native_dd_message_traces( + columns=columns, + model_configs=model_configs, + collector=collector, + ) + _record_dd_trace_coverage( + workflow_name=workflow_name, + collector=collector, + native_trace_columns=native_trace_columns, + unsupported_trace_columns=unsupported_trace_columns, + ) with tempfile.TemporaryDirectory(prefix=f"anonymizer_{workflow_name}_") as tmp_dir: seed_path = str(Path(tmp_dir) / "seed.parquet") @@ -112,7 +133,7 @@ def run_workflow( config_builder.add_column(column) try: - with usage_probe, _dd_message_trace(workflow_name=workflow_name): + with usage_probe: if preview_num_records is None: run_results = self._data_designer.create( config_builder, @@ -129,6 +150,12 @@ def run_workflow( output_df = workflow_input_df.iloc[0:0].copy() else: output_df = preview_results.dataset + output_df = _record_and_strip_native_dd_message_traces( + output_df=output_df, + workflow_name=workflow_name, + collector=collector, + native_trace_columns=native_trace_columns, + ) except Exception as exc: logger.warning( "Workflow failed for %d input record(s) on model(s) %s: %s", @@ -371,182 +398,148 @@ def _model_usage_as_json(stats: object) -> Any: return stats -@contextmanager -def _dd_message_trace(*, workflow_name: str) -> Iterator[None]: - """Trace DataDesigner model messages for the active measurement context. - - DataDesigner constructs ``ModelFacade`` instances internally, so this hook - wraps the facade class while a traced workflow runs. The wrappers re-check - the active collector before writing a trace record so unrelated concurrent - workflows pass through without contaminating the traced collector. - """ - collector = current_collector() +def _configure_native_dd_message_traces( + *, + columns: list[ColumnConfigT], + model_configs: list[ModelConfig], + collector: Any | None, +) -> tuple[list[ColumnConfigT], list[_NativeTraceColumn], list[ColumnConfigT]]: if collector is None or not collector.dd_trace_enabled: - yield - return + return columns, [], [] - from data_designer.engine.models.facade import ModelFacade + model_configs_by_alias = {model_config.alias: model_config for model_config in model_configs} + traced_columns: list[_NativeTraceColumn] = [] + unsupported_columns: list[ColumnConfigT] = [] + configured_columns: list[ColumnConfigT] = [] + trace_type = _native_dd_trace_type() - with _DD_MESSAGE_TRACE_PATCH_LOCK: - original_completion = ModelFacade.completion - original_acompletion = ModelFacade.acompletion - ModelFacade.completion = _traced_completion( - original_completion, collector=collector, workflow_name=workflow_name - ) - ModelFacade.acompletion = _traced_acompletion( - original_acompletion, - collector=collector, - workflow_name=workflow_name, - ) - try: - yield - finally: - ModelFacade.completion = original_completion - ModelFacade.acompletion = original_acompletion + for column in columns: + if isinstance(column, LLMTextColumnConfig): + configured_column = cast(ColumnConfigT, column.model_copy(update={"with_trace": trace_type})) + configured_columns.append(configured_column) + model_config = model_configs_by_alias.get(column.model_alias) + traced_columns.append( + _NativeTraceColumn( + column_name=column.name, + trace_column_name=f"{column.name}{TRACE_COLUMN_POSTFIX}", + model_alias=column.model_alias, + model_name=getattr(model_config, "model", None), + model_provider_name=getattr(model_config, "provider", None), + ) + ) + continue + configured_columns.append(column) + if _column_has_untraced_model_calls(column): + unsupported_columns.append(column) -def _traced_completion(original_completion: Any, *, collector: Any, workflow_name: str) -> Any: - def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: - return _run_traced_completion( - original_completion, - collector=collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - args=args, - kwargs=kwargs, - ) + return configured_columns, traced_columns, unsupported_columns + + +def _native_dd_trace_type() -> TraceType: + # Preserve Anonymizer's existing dd_trace=last_message semantics: the trace + # sink records the final prompt message and response separately, while DD's + # native LAST_MESSAGE side effect only keeps the final assistant message. + return TraceType.ALL_MESSAGES - return traced +def _column_has_untraced_model_calls(column: ColumnConfigT) -> bool: + return isinstance(column, CustomColumnConfig) and bool(_extract_workflow_model_aliases([column])) -def _run_traced_completion( - completion: Any, + +def _record_dd_trace_coverage( *, - collector: Any, workflow_name: str, - model_facade: Any, - messages: list[Any], - args: tuple[Any, ...], - kwargs: dict[str, Any], + collector: Any, + native_trace_columns: list[_NativeTraceColumn], + unsupported_trace_columns: list[ColumnConfigT], ) -> Any: - started = time.perf_counter() - response: Any | None = None - status = "completed" - error_type: str | None = None - try: - response = completion(model_facade, messages, *args, **kwargs) - return response - except BaseException as exc: - status = "error" - error_type = type(exc).__name__ - raise - finally: - active_collector = _active_trace_collector(collector) - if active_collector is not None: - _record_dd_message_trace( - collector=active_collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - response=response, - elapsed_sec=time.perf_counter() - started, - status=status, - error_type=error_type, - is_async=False, - ) - + if collector is None or not collector.dd_trace_enabled: + return + collector.record( + "dd_trace_coverage", + workflow_name=workflow_name, + trace_backend="data_designer_column", + trace_mode=collector.dd_trace_mode, + native_trace_type=_native_dd_trace_type().value, + traced_column_count=len(native_trace_columns), + traced_column_names=[column.column_name for column in native_trace_columns], + unsupported_column_count=len(unsupported_trace_columns), + unsupported_column_names=[column.name for column in unsupported_trace_columns], + unsupported_column_types=[_column_type_name(column) for column in unsupported_trace_columns], + ) -def _traced_acompletion(original_acompletion: Any, *, collector: Any, workflow_name: str) -> Any: - async def traced(model_facade: Any, messages: list[Any], *args: Any, **kwargs: Any) -> Any: - return await _run_traced_acompletion( - original_acompletion, - collector=collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - args=args, - kwargs=kwargs, - ) - return traced +def _column_type_name(column: ColumnConfigT) -> str: + column_type = getattr(column, "column_type", None) + return str(column_type) if column_type is not None else type(column).__name__ -async def _run_traced_acompletion( - acompletion: Any, +def _record_and_strip_native_dd_message_traces( *, - collector: Any, + output_df: pd.DataFrame, workflow_name: str, - model_facade: Any, - messages: list[Any], - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> Any: - started = time.perf_counter() - response: Any | None = None - status = "completed" - error_type: str | None = None - try: - response = await acompletion(model_facade, messages, *args, **kwargs) - return response - except BaseException as exc: - status = "error" - error_type = type(exc).__name__ - raise - finally: - active_collector = _active_trace_collector(collector) - if active_collector is not None: - _record_dd_message_trace( - collector=active_collector, - workflow_name=workflow_name, - model_facade=model_facade, - messages=messages, - response=response, - elapsed_sec=time.perf_counter() - started, - status=status, - error_type=error_type, - is_async=True, - ) + collector: Any, + native_trace_columns: list[_NativeTraceColumn], +) -> pd.DataFrame: + if not native_trace_columns: + return output_df + + trace_column_names = [column.trace_column_name for column in native_trace_columns] + if collector is not None and collector.dd_trace_enabled: + for _, row in output_df.iterrows(): + for trace_column in native_trace_columns: + if trace_column.trace_column_name not in output_df.columns: + continue + trace_messages = _native_trace_messages(row.get(trace_column.trace_column_name)) + if not trace_messages: + continue + collector.record_dd_message_trace( + workflow_name=workflow_name, + trace_source="data_designer_column", + column_name=trace_column.column_name, + trace_column_name=trace_column.trace_column_name, + model_alias=trace_column.model_alias, + model_name=trace_column.model_name, + model_provider_name=trace_column.model_provider_name, + modality="chat", + is_async=None, + status="completed", + error_type=None, + elapsed_sec=None, + messages=_select_native_trace_messages(trace_messages, mode=collector.dd_trace_mode), + response=_native_trace_response(trace_messages), + usage=None, + ) + existing_trace_columns = [column_name for column_name in trace_column_names if column_name in output_df.columns] + if not existing_trace_columns: + return output_df + return output_df.drop(columns=existing_trace_columns) -def _active_trace_collector(expected_collector: Any) -> Any | None: - active_collector = current_collector() - if active_collector is expected_collector and active_collector.dd_trace_enabled: - return active_collector - return None +def _native_trace_messages(value: Any) -> list[dict[str, Any]]: + if not isinstance(value, list): + return [] + return [_trace_message(message) for message in value] -def _record_dd_message_trace( - *, - collector: Any, - workflow_name: str, - model_facade: Any, - messages: list[Any], - response: Any | None, - elapsed_sec: float, - status: str, - error_type: str | None, - is_async: bool, -) -> None: - collector.record_dd_message_trace( - workflow_name=workflow_name, - model_alias=getattr(model_facade, "model_alias", None), - model_name=getattr(model_facade, "model_name", None), - model_provider_name=getattr(model_facade, "model_provider_name", None), - modality="chat", - is_async=is_async, - status=status, - error_type=error_type, - elapsed_sec=elapsed_sec, - messages=_trace_messages(messages, mode=collector.dd_trace_mode), - response=_trace_response(response), - usage=_trace_usage(getattr(response, "usage", None) if response is not None else None), - ) + +def _select_native_trace_messages(messages: list[dict[str, Any]], *, mode: str) -> list[dict[str, Any]]: + if mode == "all_messages": + return messages + last_prompt = next((message for message in reversed(messages) if message.get("role") != "assistant"), None) + return [last_prompt] if last_prompt is not None else [] -def _trace_messages(messages: list[Any], *, mode: str) -> list[dict[str, Any]]: - selected = messages if mode == "all_messages" else messages[-1:] - return [_trace_message(message) for message in selected] +def _native_trace_response(messages: list[dict[str, Any]]) -> dict[str, Any] | None: + assistant_message = next((message for message in reversed(messages) if message.get("role") == "assistant"), None) + if assistant_message is None: + return None + return { + "content": assistant_message.get("content"), + "reasoning_content": assistant_message.get("reasoning_content"), + "tool_calls": _trace_tool_calls(assistant_message.get("tool_calls", [])), + } def _trace_message(message: Any) -> dict[str, Any]: @@ -558,30 +551,7 @@ def _trace_message(message: Any) -> dict[str, Any]: return {"role": getattr(message, "role", None), "content": getattr(message, "content", None)} -def _trace_response(response: Any | None) -> dict[str, Any] | None: - if response is None: - return None - message = getattr(response, "message", None) - if message is None: - return None - return { - "content": getattr(message, "content", None), - "reasoning_content": getattr(message, "reasoning_content", None), - "tool_calls": _trace_tool_calls(getattr(message, "tool_calls", [])), - } - - def _trace_tool_calls(tool_calls: Any) -> list[Any]: if isinstance(tool_calls, list): return [getattr(tool_call, "__dict__", tool_call) for tool_call in tool_calls] return [] - - -def _trace_usage(usage: Any | None) -> dict[str, Any] | None: - if usage is None: - return None - return { - "input_tokens": getattr(usage, "input_tokens", None), - "output_tokens": getattr(usage, "output_tokens", None), - "total_tokens": getattr(usage, "total_tokens", None), - } diff --git a/tests/test_measurement.py b/tests/test_measurement.py index c462a0cb..65472fed 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -5,7 +5,6 @@ import json import logging -import threading from pathlib import Path from types import SimpleNamespace from typing import Any, cast @@ -14,11 +13,11 @@ import numpy as np import pandas as pd import pytest -from data_designer.config.column_configs import LLMTextColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig +from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ModelConfig -from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, Usage -from data_designer.engine.models.facade import ModelFacade -from data_designer.engine.models.utils import ChatMessage +from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX +from data_designer.config.utils.trace_type import TraceType from data_designer.interface.data_designer import DataDesigner import anonymizer.measurement as measurement @@ -733,36 +732,35 @@ def test_dd_message_trace_requires_trace_path(tmp_path: Path) -> None: MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_trace="last_message") -def test_ndd_adapter_writes_opt_in_dd_message_trace( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) - - def fake_completion( - _self: object, - _messages: list[ChatMessage], - skip_usage_tracking: bool = False, - **_kwargs: object, - ) -> ChatCompletionResponse: - assert skip_usage_tracking is False - return ChatCompletionResponse( - message=AssistantMessage(content="secret response"), - usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), - ) - - monkeypatch.setattr(ModelFacade, "completion", fake_completion) +def test_ndd_adapter_writes_native_dd_message_trace_and_strips_trace_columns(tmp_path: Path) -> None: + input_df = pd.DataFrame( + { + "text": ["Alice works at Acme"], + f"notes{TRACE_COLUMN_POSTFIX}": ["user supplied trace-looking column"], + RECORD_ID_COLUMN: ["record-a"], + } + ) + original_column = LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="alias", + ) + captured_columns: list[LLMTextColumnConfig] = [] class TraceDataDesigner: - def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: - ModelFacade.completion( - SimpleNamespace(model_alias="alias", model_name="dummy-model", model_provider_name="provider"), + def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespace: + traced_column = cast(Any, config_builder).get_column_config("raw_detected") + captured_columns.append(traced_column) + output = input_df.iloc[:num_records].copy() + output["raw_detected"] = "[]" + output[f"raw_detected{TRACE_COLUMN_POSTFIX}"] = [ [ - ChatMessage.as_system("system secret"), - ChatMessage.as_user("prompt secret"), - ], - ) - return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + {"role": "system", "content": [{"type": "text", "text": "system secret"}]}, + {"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}, + {"role": "assistant", "content": "secret response", "reasoning_content": "scratch"}, + ] + ] + return SimpleNamespace(dataset=output) adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) trace_path = tmp_path / "trace.jsonl" @@ -772,79 +770,59 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa output_path=tmp_path / "measurements.jsonl", dd_trace="last_message", dd_trace_path=trace_path ) ): - adapter.run_workflow( + result = adapter.run_workflow( input_df, - model_configs=[ModelConfig(alias="alias", model="dummy")], - columns=[ - LLMTextColumnConfig( - name="raw_detected", - prompt="{{ text }}", - model_alias="alias", - ) - ], + model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], + columns=[original_column], workflow_name="entity-detection", preview_num_records=1, ) + assert original_column.with_trace == TraceType.NONE + assert captured_columns[0].with_trace == TraceType.ALL_MESSAGES + assert f"raw_detected{TRACE_COLUMN_POSTFIX}" not in result.dataframe.columns + assert f"notes{TRACE_COLUMN_POSTFIX}" in result.dataframe.columns + trace = json.loads(trace_path.read_text(encoding="utf-8").strip()) assert trace["record_type"] == "dd_message_trace" + assert trace["trace_source"] == "data_designer_column" assert trace["workflow_name"] == "entity-detection" assert trace["model_alias"] == "alias" + assert trace["model_name"] == "dummy-model" + assert trace["model_provider_name"] == "provider" assert trace["status"] == "completed" assert trace["messages"] == [{"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}] assert trace["response"]["content"] == "secret response" - assert trace["usage"] == {"input_tokens": 3, "output_tokens": 2, "total_tokens": 5} + assert trace["usage"] is None + + measurements = [json.loads(line) for line in (tmp_path / "measurements.jsonl").read_text().splitlines()] + coverage = [record for record in measurements if record["record_type"] == "dd_trace_coverage"] + assert len(coverage) == 1 + assert coverage[0]["traced_column_count"] == 1 + assert coverage[0]["unsupported_column_count"] == 0 - serialized_measurements = (tmp_path / "measurements.jsonl").read_text(encoding="utf-8") + serialized_measurements = json.dumps(measurements) assert "prompt secret" not in serialized_measurements assert "secret response" not in serialized_measurements -def test_dd_message_trace_does_not_capture_concurrent_unmeasured_calls( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: +def test_ndd_adapter_records_custom_column_dd_trace_coverage_gap(tmp_path: Path) -> None: input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) - def fake_completion( - _self: object, - _messages: list[ChatMessage], - **_kwargs: object, - ) -> ChatCompletionResponse: - return ChatCompletionResponse( - message=AssistantMessage(content="response"), - usage=Usage(input_tokens=3, output_tokens=2, total_tokens=5), - ) - - monkeypatch.setattr(ModelFacade, "completion", fake_completion) + @custom_column_generator(required_columns=["text"], model_aliases=["alias"]) + def custom_generator( + row: dict[str, Any], + generator_params: Any, + models: dict[str, Any], + ) -> dict[str, str]: + _ = row, generator_params, models + return {"raw_detected": "[]"} class TraceDataDesigner: def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: - errors: list[BaseException] = [] - - def concurrent_call_without_measurement_context() -> None: - try: - ModelFacade.completion( - SimpleNamespace( - model_alias="outside", - model_name="outside-model", - model_provider_name="provider", - ), - [ChatMessage.as_user("outside prompt")], - ) - except BaseException as exc: - errors.append(exc) - - thread = threading.Thread(target=concurrent_call_without_measurement_context) - thread.start() - thread.join() - assert errors == [] - - ModelFacade.completion( - SimpleNamespace(model_alias="inside", model_name="inside-model", model_provider_name="provider"), - [ChatMessage.as_user("inside prompt")], - ) - return SimpleNamespace(dataset=input_df.iloc[:num_records].copy()) + output = input_df.iloc[:num_records].copy() + output["raw_detected"] = "[]" + return SimpleNamespace(dataset=output) adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) trace_path = tmp_path / "trace.jsonl" @@ -856,23 +834,20 @@ def concurrent_call_without_measurement_context() -> None: ): adapter.run_workflow( input_df, - model_configs=[ModelConfig(alias="inside", model="dummy")], - columns=[ - LLMTextColumnConfig( - name="raw_detected", - prompt="{{ text }}", - model_alias="inside", - ) - ], + model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], + columns=[CustomColumnConfig(name="raw_detected", generator_function=custom_generator)], workflow_name="entity-detection", preview_num_records=1, ) - traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] - assert [trace["model_alias"] for trace in traces] == ["inside"] - serialized_traces = json.dumps(traces) - assert "inside prompt" in serialized_traces - assert "outside prompt" not in serialized_traces + assert trace_path.read_text(encoding="utf-8") == "" + measurements = [json.loads(line) for line in (tmp_path / "measurements.jsonl").read_text().splitlines()] + coverage = [record for record in measurements if record["record_type"] == "dd_trace_coverage"] + assert len(coverage) == 1 + assert coverage[0]["traced_column_count"] == 0 + assert coverage[0]["unsupported_column_count"] == 1 + assert coverage[0]["unsupported_column_names"] == ["raw_detected"] + assert coverage[0]["unsupported_column_types"] == ["custom"] def test_record_metrics_capture_generic_counts_without_raw_values() -> None: From 8d9fab8aa0d931d70e06f2cce8ade26c3ac0fe62 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 05:12:48 +0000 Subject: [PATCH 15/26] Add sanitized DataDesigner task traces Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 87 +++++++++++++++++++++- src/anonymizer/measurement.py | 37 +++++++++- tests/test_measurement.py | 101 +++++++++++++++++++++++++- tests/tools/test_measurement_tools.py | 6 ++ tools/measurement/README.md | 31 ++++++++ tools/measurement/run_benchmarks.py | 44 +++++++++++ 6 files changed, 302 insertions(+), 4 deletions(-) diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index d279769f..924e89fb 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -8,15 +8,18 @@ import tempfile import time import uuid -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path +from threading import RLock from typing import TYPE_CHECKING, Any, cast from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.models import ModelConfig +from data_designer.config.run_config import RunConfig from data_designer.config.seed import SamplingStrategy from data_designer.config.seed_source import LocalFileSeedSource from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX @@ -65,6 +68,7 @@ class NddAdapter: def __init__(self, data_designer: DataDesigner) -> None: self._data_designer = data_designer + self._run_lock = RLock() logger.debug("NDD adapter: artifact_path=%s", getattr(data_designer, "_artifact_path", "unknown")) def run_workflow( @@ -133,19 +137,22 @@ def run_workflow( config_builder.add_column(column) try: - with usage_probe: + task_traces: list[Any] = [] + with self._run_lock, usage_probe, _temporary_dd_task_trace(self._data_designer, collector=collector): if preview_num_records is None: run_results = self._data_designer.create( config_builder, num_records=len(workflow_input_df), dataset_name=workflow_name, ) + task_traces = list(getattr(run_results, "task_traces", []) or []) output_df = run_results.load_dataset() else: preview_results = self._data_designer.preview( config_builder, num_records=record_count, ) + task_traces = list(getattr(preview_results, "task_traces", []) or []) if preview_results.dataset is None: output_df = workflow_input_df.iloc[0:0].copy() else: @@ -156,6 +163,11 @@ def run_workflow( collector=collector, native_trace_columns=native_trace_columns, ) + _record_dd_task_traces( + workflow_name=workflow_name, + collector=collector, + task_traces=task_traces, + ) except Exception as exc: logger.warning( "Workflow failed for %d input record(s) on model(s) %s: %s", @@ -398,6 +410,35 @@ def _model_usage_as_json(stats: object) -> Any: return stats +@contextmanager +def _temporary_dd_task_trace(data_designer: DataDesigner, *, collector: Any | None) -> Iterator[None]: + if collector is None or not collector.dd_task_trace_enabled: + yield + return + + original_run_config = getattr(data_designer, "run_config", None) + set_run_config = getattr(data_designer, "set_run_config", None) + if original_run_config is None or not callable(set_run_config): + yield + return + + traced_run_config = _run_config_with_async_trace(original_run_config) + set_run_config(traced_run_config) + try: + yield + finally: + set_run_config(original_run_config) + + +def _run_config_with_async_trace(run_config: Any) -> Any: + model_copy = getattr(run_config, "model_copy", None) + if callable(model_copy): + return model_copy(update={"async_trace": True}) + if isinstance(run_config, RunConfig): + return run_config.model_copy(update={"async_trace": True}) + return run_config + + def _configure_native_dd_message_traces( *, columns: list[ColumnConfigT], @@ -555,3 +596,45 @@ def _trace_tool_calls(tool_calls: Any) -> list[Any]: if isinstance(tool_calls, list): return [getattr(tool_call, "__dict__", tool_call) for tool_call in tool_calls] return [] + + +def _record_dd_task_traces(*, workflow_name: str, collector: Any | None, task_traces: list[Any]) -> None: + if collector is None or not collector.dd_task_trace_enabled: + return + for task_trace in task_traces: + collector.record_dd_task_trace( + workflow_name=workflow_name, + trace_source="data_designer_scheduler", + column=_trace_attr(task_trace, "column"), + row_group=_trace_attr(task_trace, "row_group"), + row_index=_trace_attr(task_trace, "row_index"), + task_type=_trace_attr(task_trace, "task_type"), + status=_trace_attr(task_trace, "status"), + error_present=bool(_trace_attr(task_trace, "error")), + queue_wait_sec=_trace_duration( + _trace_attr(task_trace, "dispatched_at"), + _trace_attr(task_trace, "slot_acquired_at"), + ), + execution_sec=_trace_duration( + _trace_attr(task_trace, "slot_acquired_at"), + _trace_attr(task_trace, "completed_at"), + ), + total_sec=_trace_duration( + _trace_attr(task_trace, "dispatched_at"), + _trace_attr(task_trace, "completed_at"), + ), + ) + + +def _trace_attr(task_trace: Any, name: str) -> Any: + if isinstance(task_trace, Mapping): + return task_trace.get(name) + return getattr(task_trace, name, None) + + +def _trace_duration(start: Any, end: Any) -> float | None: + if not isinstance(start, (int, float)) or not isinstance(end, (int, float)): + return None + if start <= 0 or end <= 0 or end < start: + return None + return float(end - start) diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py index 150876cf..0f55a1c1 100644 --- a/src/anonymizer/measurement.py +++ b/src/anonymizer/measurement.py @@ -127,6 +127,7 @@ class _MeasurementEnvSettings(BaseSettings): keep_records: bool = True dd_trace: DDTraceMode = "none" dd_trace_path: str | None = None + dd_task_trace_path: str | None = None fail_on_write_error: bool = False run_id: str | None = None run_tags: dict[str, Any] = Field(default_factory=dict) @@ -151,6 +152,7 @@ def __init__( keep_records: bool = True, dd_trace_mode: DDTraceMode = "none", dd_trace_sink: _MeasurementSink | None = None, + dd_task_trace_sink: _MeasurementSink | None = None, fail_on_write_error: bool = False, ) -> None: self.run_id = run_id or uuid.uuid4().hex @@ -160,9 +162,11 @@ def __init__( self._keep_records = keep_records self._dd_trace_mode = dd_trace_mode self._dd_trace_sink = dd_trace_sink + self._dd_task_trace_sink = dd_task_trace_sink self._fail_on_write_error = fail_on_write_error self._sink_failed = False self._dd_trace_failed = False + self._dd_task_trace_failed = False if record_hash_key is None: self._record_hash_key = secrets.token_bytes(32) elif isinstance(record_hash_key, str): @@ -195,7 +199,7 @@ def record(self, record_type: str, **fields: Any) -> None: def close(self) -> None: """Close any streaming measurement sink attached to this collector.""" close_error: Exception | None = None - for sink in (self._record_sink, self._dd_trace_sink): + for sink in (self._record_sink, self._dd_trace_sink, self._dd_task_trace_sink): if sink is None: continue try: @@ -214,6 +218,10 @@ def dd_trace_mode(self) -> DDTraceMode: def dd_trace_enabled(self) -> bool: return self._dd_trace_mode != "none" and self._dd_trace_sink is not None + @property + def dd_task_trace_enabled(self) -> bool: + return self._dd_task_trace_sink is not None + def record_dd_message_trace(self, **fields: Any) -> None: """Write an explicitly opt-in DataDesigner message trace record. @@ -242,6 +250,29 @@ def record_dd_message_trace(self, **fields: Any) -> None: if self._fail_on_write_error: raise + def record_dd_task_trace(self, **fields: Any) -> None: + """Write an opt-in sanitized DataDesigner scheduler task trace record.""" + if not self.dd_task_trace_enabled or self._dd_task_trace_failed: + return + + record = _json_safe( + { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": "dd_task_trace", + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + ) + try: + cast(_MeasurementSink, self._dd_task_trace_sink).write_record(record) + except Exception: + self._dd_task_trace_failed = True + logger.warning("Failed to write DataDesigner task trace records") + if self._fail_on_write_error: + raise + def _write_record_to_sink(self, record: dict[str, Any]) -> None: if self._sink_failed: return @@ -289,6 +320,7 @@ class MeasurementConfig: keep_records: bool = True dd_trace: DDTraceMode = "none" dd_trace_path: str | Path | None = None + dd_task_trace_path: str | Path | None = None run_id: str | None = None record_hash_key: bytes | str | None = None run_tags: Mapping[str, Any] | None = None @@ -327,6 +359,7 @@ def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> Measuremen keep_records=settings.keep_records, dd_trace=settings.dd_trace, dd_trace_path=settings.dd_trace_path, + dd_task_trace_path=settings.dd_task_trace_path, run_id=settings.run_id, run_tags=settings.run_tags, fail_on_write_error=settings.fail_on_write_error, @@ -376,6 +409,7 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator if config.dd_trace_path is None: raise ValueError("dd_trace_path is required when dd_trace is enabled") dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) + dd_task_trace_sink = _JsonlMeasurementSink(config.dd_task_trace_path) if config.dd_task_trace_path else None collector = MeasurementCollector( run_id=config.run_id, record_hash_key=config.record_hash_key, @@ -385,6 +419,7 @@ def configured_measurement_session(config: MeasurementConfig | None) -> Iterator keep_records=config.keep_records, dd_trace_mode=config.dd_trace, dd_trace_sink=dd_trace_sink, + dd_task_trace_sink=dd_task_trace_sink, fail_on_write_error=config.fail_on_write_error, ) with measurement_session(collector): diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 65472fed..4032718d 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -16,6 +16,7 @@ from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ModelConfig +from data_designer.config.run_config import RunConfig from data_designer.config.utils.constants import TRACE_COLUMN_POSTFIX from data_designer.config.utils.trace_type import TraceType from data_designer.interface.data_designer import DataDesigner @@ -43,6 +44,7 @@ from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow from anonymizer.interface.anonymizer import Anonymizer +from anonymizer.interface.errors import AnonymizerWorkflowError from anonymizer.measurement import ( DEFAULT_MEASUREMENT_ENV_PREFIX, MEASUREMENT_SCHEMA_VERSION, @@ -714,12 +716,13 @@ def close(self) -> None: collector = MeasurementCollector( record_sink=FakeSink("records", fail=True), dd_trace_sink=FakeSink("dd-trace"), + dd_task_trace_sink=FakeSink("dd-task-trace"), ) with pytest.raises(OSError, match="records close failed"): collector.close() - assert close_events == ["records", "dd-trace"] + assert close_events == ["records", "dd-trace", "dd-task-trace"] def test_streaming_measurement_requires_jsonl_output(tmp_path: Path) -> None: @@ -850,6 +853,102 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa assert coverage[0]["unsupported_column_types"] == ["custom"] +def test_ndd_adapter_writes_sanitized_dd_task_traces_and_restores_run_config(tmp_path: Path) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + class TraceDataDesigner: + def __init__(self) -> None: + self.run_config = RunConfig(async_trace=False) + self.async_trace_values: list[bool] = [] + + def set_run_config(self, run_config: RunConfig) -> None: + self.async_trace_values.append(run_config.async_trace) + self.run_config = run_config + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + assert self.run_config.async_trace is True + task_trace = SimpleNamespace( + column="raw_detected", + row_group=0, + row_index=7, + task_type="llm", + dispatched_at=10.0, + slot_acquired_at=10.25, + completed_at=12.0, + status="error", + error="raw secret token Alice", + ) + return SimpleNamespace(dataset=input_df.iloc[:num_records].copy(), task_traces=[task_trace]) + + data_designer = TraceDataDesigner() + adapter = NddAdapter(data_designer=cast(DataDesigner, data_designer)) + task_trace_path = tmp_path / "task-trace.jsonl" + + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_task_trace_path=task_trace_path) + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], + columns=[LLMTextColumnConfig(name="raw_detected", prompt="{{ text }}", model_alias="alias")], + workflow_name="entity-detection", + preview_num_records=1, + ) + + assert data_designer.async_trace_values == [True, False] + assert data_designer.run_config.async_trace is False + task_trace = json.loads(task_trace_path.read_text(encoding="utf-8").strip()) + assert task_trace["record_type"] == "dd_task_trace" + assert task_trace["workflow_name"] == "entity-detection" + assert task_trace["column"] == "raw_detected" + assert task_trace["row_group"] == 0 + assert task_trace["row_index"] == 7 + assert task_trace["task_type"] == "llm" + assert task_trace["status"] == "error" + assert task_trace["error_present"] is True + assert task_trace["queue_wait_sec"] == pytest.approx(0.25) + assert task_trace["execution_sec"] == pytest.approx(1.75) + assert task_trace["total_sec"] == pytest.approx(2.0) + assert "raw secret token Alice" not in task_trace_path.read_text(encoding="utf-8") + assert "raw secret token Alice" not in (tmp_path / "measurements.jsonl").read_text(encoding="utf-8") + + +def test_ndd_adapter_restores_run_config_when_task_traced_workflow_fails(tmp_path: Path) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + class TraceDataDesigner: + def __init__(self) -> None: + self.run_config = RunConfig(async_trace=False) + self.async_trace_values: list[bool] = [] + + def set_run_config(self, run_config: RunConfig) -> None: + self.async_trace_values.append(run_config.async_trace) + self.run_config = run_config + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + assert self.run_config.async_trace is True + _ = num_records + raise RuntimeError("raw secret failure") + + data_designer = TraceDataDesigner() + adapter = NddAdapter(data_designer=cast(DataDesigner, data_designer)) + + with pytest.raises(AnonymizerWorkflowError, match="Workflow failed"): + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_task_trace_path=tmp_path / "task.jsonl") + ): + adapter.run_workflow( + input_df, + model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], + columns=[LLMTextColumnConfig(name="raw_detected", prompt="{{ text }}", model_alias="alias")], + workflow_name="entity-detection", + preview_num_records=1, + ) + + assert data_designer.async_trace_values == [True, False] + assert data_designer.run_config.async_trace is False + + def test_record_metrics_capture_generic_counts_without_raw_values() -> None: final_entities = { "entities": [ diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 41be00d4..f27078df 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -42,6 +42,8 @@ def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[ "raw_dir": tmp_path / "raw", "dd_trace": tool.DDTraceMode.none, "trace_dir": tmp_path / "traces", + "dd_task_trace": False, + "task_trace_dir": tmp_path / "task-traces", "dd_parser_compat": spec.dd_parser_compat, "artifact_path": tmp_path / "artifacts", } @@ -1134,6 +1136,7 @@ def run(self, *, config: Any, data: Any) -> None: case_id="input__redact__r000", ) trace_path = tmp_path / "traces" / "input__redact__r000.jsonl" + task_trace_path = tmp_path / "task-traces" / "input__redact__r000.jsonl" tool._execute_case( FakeAnonymizer(), @@ -1141,6 +1144,7 @@ def run(self, *, config: Any, data: Any) -> None: spec.configs[0], raw_path=tmp_path / "raw" / "input__redact__r000.jsonl", trace_path=trace_path, + task_trace_path=task_trace_path, case=case, spec=spec, base_dir=tmp_path, @@ -1151,6 +1155,7 @@ def run(self, *, config: Any, data: Any) -> None: assert len(captured) == 1 assert captured[0].dd_trace == "all_messages" assert captured[0].dd_trace_path == trace_path + assert captured[0].dd_task_trace_path == task_trace_path assert captured[0].streaming is True assert captured[0].keep_records is False @@ -1351,6 +1356,7 @@ def run(self, *, config: Any, data: Any) -> None: spec.configs[0], raw_path=tmp_path / "raw" / "input__native-single-pass-redact__r000.jsonl", trace_path=None, + task_trace_path=None, case=case, spec=spec, base_dir=tmp_path, diff --git a/tools/measurement/README.md b/tools/measurement/README.md index d8492fc1..09ed5022 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -93,6 +93,9 @@ uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json uv run python tools/measurement/run_benchmarks.py suite.yaml \ --output benchmark-runs/suite \ --dd-trace last-message +uv run python tools/measurement/run_benchmarks.py suite.yaml \ + --output benchmark-runs/suite \ + --dd-task-trace ``` The repo-data smoke suite can be run with DataDesigner traces enabled: @@ -284,6 +287,14 @@ values, replacement values, secrets, and PII. Treat them as debug artifacts: keep them out of shared benchmark bundles unless they have been reviewed or redacted. +Anonymizer requests these traces through DataDesigner native LLM column trace +side effects. That covers `LLMTextColumnConfig` and +`LLMStructuredColumnConfig`, but not model calls made inside +`CustomColumnConfig` generator functions. Safe measurement output includes a +`dd_trace_coverage` record with unsupported custom columns so trace-enabled +runs can detect this gap. Full custom-column message tracing would need a +DataDesigner hook; it is a good candidate for an upstream DataDesigner PR. + Summarize traced calls without copying raw prompts or responses into analysis output: @@ -299,6 +310,26 @@ captures run tags, workflow/model metadata, status, elapsed time, prompt and response lengths, token counts, and response-shape flags such as `raw_json`, `fenced_json`, `embedded_json`, `text`, and `none`. +## DataDesigner Scheduler Task Traces + +Pass `--dd-task-trace` to collect sanitized DataDesigner async scheduler task +timing records. The benchmark runner writes one sidecar per case under +`task-traces/{case_id}.jsonl` by default; use `--task-trace-dir` to choose +another directory. + +Task trace records are separate from raw message traces. They include scheduler +metadata such as workflow name, column, row group, row index, task type, status, +queue wait time, execution time, total time, and whether an error was present. +They intentionally do not store raw DataDesigner error strings because those +can contain prompts, outputs, or source values. + +```bash +uv run python tools/measurement/run_benchmarks.py \ + suite.yaml \ + --output benchmark-runs/suite \ + --dd-task-trace +``` + ## Direct probes `direct_detection_probe.py` calls a local OpenAI-compatible endpoint directly diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index a1ad72c9..23c586af 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -234,6 +234,7 @@ class BenchmarkCase(BaseModel): measurement_path: str | None = None detection_artifact_path: str | None = None trace_path: str | None = None + task_trace_path: str | None = None error: str | None = None attempt_count: int = 0 attempt_errors: list[str] = Field(default_factory=list) @@ -254,6 +255,7 @@ class _CaseRunPaths: raw_path: Path artifact_output_path: Path trace_path: Path | None + task_trace_path: Path | None artifact_snapshot: dict[str, int] | None export_detection_artifacts: bool @@ -566,6 +568,8 @@ def run_suite( fail_fast: bool, dd_trace: DDTraceMode, trace_dir: Path | None, + dd_task_trace: bool = False, + task_trace_dir: Path | None = None, ) -> BenchmarkResult: contexts = _build_contexts( spec, @@ -573,6 +577,8 @@ def run_suite( output_dir=output_dir, dd_trace=dd_trace, trace_dir=trace_dir, + dd_task_trace=dd_task_trace, + task_trace_dir=task_trace_dir, ) anonymizer = Anonymizer(**contexts["anonymizer_kwargs"]) cases = _run_cases(spec, contexts=contexts, anonymizer=anonymizer, fail_fast=fail_fast, export=export) @@ -663,6 +669,8 @@ def _build_contexts( output_dir: Path, dd_trace: DDTraceMode, trace_dir: Path | None, + dd_task_trace: bool = False, + task_trace_dir: Path | None = None, ) -> dict[str, Any]: base_dir = spec_path.parent artifact_path = _resolve_optional_path(spec.artifact_path, base_dir) or output_dir / "artifacts" @@ -673,6 +681,8 @@ def _build_contexts( "raw_dir": output_dir / "raw", "dd_trace": dd_trace, "trace_dir": trace_dir or output_dir / "traces", + "dd_task_trace": dd_task_trace, + "task_trace_dir": task_trace_dir or output_dir / "task-traces", "dd_parser_compat": spec.dd_parser_compat, "artifact_path": artifact_path, "anonymizer_kwargs": { @@ -755,6 +765,7 @@ def _case_run_paths( raw_path=contexts["raw_dir"] / f"{case.case_id}.jsonl", artifact_output_path=contexts["raw_dir"] / f"{case.case_id}.detection-artifacts.jsonl", trace_path=_case_trace_path(case, contexts=contexts), + task_trace_path=_case_task_trace_path(case, contexts=contexts), artifact_snapshot=snapshot_detection_artifacts(contexts["artifact_path"]) if export_detection_artifacts else None, @@ -781,6 +792,7 @@ def _run_case_success( config, raw_path=paths.raw_path, trace_path=paths.trace_path, + task_trace_path=paths.task_trace_path, case=case, spec=spec, base_dir=contexts["base_dir"], @@ -801,6 +813,7 @@ def _run_case_success( raw_path=paths.raw_path, detection_artifact_path=detection_artifact_path, trace_path=paths.trace_path, + task_trace_path=paths.task_trace_path, attempt_count=attempt_count, attempt_errors=attempt_errors, ) @@ -829,6 +842,7 @@ def _run_case_error( raw_path=paths.raw_path, detection_artifact_path=detection_artifact_path, trace_path=paths.trace_path, + task_trace_path=paths.task_trace_path, error=str(error), attempt_count=attempt_count, attempt_errors=attempt_errors, @@ -968,6 +982,7 @@ def _case_with_result( raw_path: Path, detection_artifact_path: Path | None, trace_path: Path | None, + task_trace_path: Path | None, attempt_count: int, attempt_errors: list[str], error: str | None = None, @@ -979,6 +994,7 @@ def _case_with_result( "measurement_path": str(raw_path), "detection_artifact_path": (str(detection_artifact_path) if detection_artifact_path is not None else None), "trace_path": str(trace_path) if trace_path is not None else None, + "task_trace_path": str(task_trace_path) if task_trace_path is not None else None, "error": error, "attempt_count": attempt_count, "attempt_errors": list(attempt_errors), @@ -1009,6 +1025,12 @@ def _case_trace_path(case: BenchmarkCase, *, contexts: dict[str, Any]) -> Path | return contexts["trace_dir"] / f"{case.case_id}.jsonl" +def _case_task_trace_path(case: BenchmarkCase, *, contexts: dict[str, Any]) -> Path | None: + if not contexts["dd_task_trace"]: + return None + return contexts["task_trace_dir"] / f"{case.case_id}.jsonl" + + def _execute_case( anonymizer: Anonymizer, workload: WorkloadSpec, @@ -1016,6 +1038,7 @@ def _execute_case( *, raw_path: Path, trace_path: Path | None, + task_trace_path: Path | None, case: BenchmarkCase, spec: BenchmarkSpec, base_dir: Path, @@ -1037,6 +1060,7 @@ def _execute_case( keep_records=False, dd_trace=dd_trace.value, dd_trace_path=trace_path, + dd_task_trace_path=task_trace_path, fail_on_write_error=True, ) with configured_measurement_session(measurement): @@ -1396,6 +1420,8 @@ def dry_run_result( export: bool, dd_trace: DDTraceMode, trace_dir: Path | None, + dd_task_trace: bool = False, + task_trace_dir: Path | None = None, ) -> BenchmarkResult: cases = build_cases(spec) if dd_trace != DDTraceMode.none: @@ -1403,6 +1429,12 @@ def dry_run_result( cases = [ case.model_copy(update={"trace_path": str(resolved_trace_dir / f"{case.case_id}.jsonl")}) for case in cases ] + if dd_task_trace: + resolved_task_trace_dir = task_trace_dir or output_dir / "task-traces" + cases = [ + case.model_copy(update={"task_trace_path": str(resolved_task_trace_dir / f"{case.case_id}.jsonl")}) + for case in cases + ] return BenchmarkResult( suite_id=spec.suite_id, output_dir=str(output_dir), @@ -1425,6 +1457,8 @@ def main( fail_fast: Annotated[bool, cyclopts.Parameter("--fail-fast")] = False, dd_trace: Annotated[DDTraceMode, cyclopts.Parameter("--dd-trace")] = DDTraceMode.none, trace_dir: Annotated[Path | None, cyclopts.Parameter("--trace-dir")] = None, + dd_task_trace: Annotated[bool, cyclopts.Parameter("--dd-task-trace")] = False, + task_trace_dir: Annotated[Path | None, cyclopts.Parameter("--task-trace-dir")] = None, json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, ) -> None: @@ -1439,6 +1473,8 @@ def main( fail_fast=fail_fast, dd_trace=dd_trace, trace_dir=trace_dir, + dd_task_trace=dd_task_trace, + task_trace_dir=task_trace_dir, ) except (ValueError, ValidationError) as exc: log_bad_input(logger, str(exc)) @@ -1458,11 +1494,15 @@ def run_or_plan( fail_fast: bool, dd_trace: DDTraceMode = DDTraceMode.none, trace_dir: Path | None = None, + dd_task_trace: bool = False, + task_trace_dir: Path | None = None, ) -> BenchmarkResult: benchmark_spec = load_spec(spec_path) output_dir = output or Path("benchmark-runs") / benchmark_spec.suite_id if trace_dir is not None and dd_trace == DDTraceMode.none: raise ValueError("--trace-dir requires --dd-trace") + if task_trace_dir is not None and not dd_task_trace: + raise ValueError("--task-trace-dir requires --dd-task-trace") preflight_suite(benchmark_spec, spec_path=spec_path) if dry_run: return dry_run_result( @@ -1471,6 +1511,8 @@ def run_or_plan( export=export, dd_trace=dd_trace, trace_dir=trace_dir, + dd_task_trace=dd_task_trace, + task_trace_dir=task_trace_dir, ) prepare_output_dir(output_dir, overwrite=overwrite, dry_run=dry_run) return run_suite( @@ -1481,6 +1523,8 @@ def run_or_plan( fail_fast=fail_fast, dd_trace=dd_trace, trace_dir=trace_dir, + dd_task_trace=dd_task_trace, + task_trace_dir=task_trace_dir, ) From eddb29189d34ac4ae26819a20511c58e81ad6d96 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 18:25:14 +0000 Subject: [PATCH 16/26] Cover structured DataDesigner message traces Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 5 +++-- tests/test_measurement.py | 26 ++++++++++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 924e89fb..9a13cf5d 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -15,7 +15,7 @@ from threading import RLock from typing import TYPE_CHECKING, Any, cast -from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig from data_designer.config.column_types import ColumnConfigT from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.models import ModelConfig @@ -35,6 +35,7 @@ logger = logging.getLogger("anonymizer.ndd") RECORD_ID_COLUMN = "_anonymizer_record_id" +_TRACEABLE_LLM_COLUMN_TYPES = (LLMTextColumnConfig, LLMStructuredColumnConfig) @dataclass(frozen=True) @@ -455,7 +456,7 @@ def _configure_native_dd_message_traces( trace_type = _native_dd_trace_type() for column in columns: - if isinstance(column, LLMTextColumnConfig): + if isinstance(column, _TRACEABLE_LLM_COLUMN_TYPES): configured_column = cast(ColumnConfigT, column.model_copy(update={"with_trace": trace_type})) configured_columns.append(configured_column) model_config = model_configs_by_alias.get(column.model_alias) diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 4032718d..4c16b783 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -13,7 +13,7 @@ import numpy as np import pandas as pd import pytest -from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig +from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ModelConfig from data_designer.config.run_config import RunConfig @@ -735,7 +735,12 @@ def test_dd_message_trace_requires_trace_path(tmp_path: Path) -> None: MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_trace="last_message") -def test_ndd_adapter_writes_native_dd_message_trace_and_strips_trace_columns(tmp_path: Path) -> None: +@pytest.mark.parametrize("structured", [False, True]) +def test_ndd_adapter_writes_native_dd_message_trace_and_strips_trace_columns( + tmp_path: Path, + *, + structured: bool, +) -> None: input_df = pd.DataFrame( { "text": ["Alice works at Acme"], @@ -743,10 +748,19 @@ def test_ndd_adapter_writes_native_dd_message_trace_and_strips_trace_columns(tmp RECORD_ID_COLUMN: ["record-a"], } ) - original_column = LLMTextColumnConfig( - name="raw_detected", - prompt="{{ text }}", - model_alias="alias", + original_column = ( + LLMStructuredColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="alias", + output_format={"type": "object", "properties": {"entities": {"type": "array"}}}, + ) + if structured + else LLMTextColumnConfig( + name="raw_detected", + prompt="{{ text }}", + model_alias="alias", + ) ) captured_columns: list[LLMTextColumnConfig] = [] From aa96fa648f245a67bc015a310df773c927dbe79c Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 19:38:06 +0000 Subject: [PATCH 17/26] Add manual benchmark CI workflow Signed-off-by: Aaron Gonzales --- .github/workflows/benchmark-ci.yml | 174 +++++++++++++++++++++++++++++ .gitignore | 1 + tools/measurement/README.md | 15 +++ 3 files changed, 190 insertions(+) create mode 100644 .github/workflows/benchmark-ci.yml diff --git a/.github/workflows/benchmark-ci.yml b/.github/workflows/benchmark-ci.yml new file mode 100644 index 00000000..a67ffebf --- /dev/null +++ b/.github/workflows/benchmark-ci.yml @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +name: Benchmark CI + +on: + workflow_dispatch: + inputs: + ref: + description: "Commit SHA, branch, or tag to benchmark" + required: true + default: "main" + suite: + description: "Benchmark suite YAML path" + required: true + default: "tools/measurement/examples/repo-data-smoke.yaml" + output_dir: + description: "Output directory for benchmark artifacts" + required: true + default: "benchmark-results" + dd_trace: + description: "Capture DataDesigner message traces" + required: true + type: choice + options: + - "none" + - "last-message" + - "all-messages" + default: "none" + dd_task_trace: + description: "Capture sanitized DataDesigner scheduler task traces" + required: true + type: choice + options: + - "false" + - "true" + default: "false" + fail_fast: + description: "Stop at the first failed benchmark case" + required: true + type: choice + options: + - "false" + - "true" + default: "false" + +permissions: + contents: read + +env: + NEMO_TELEMETRY_ENABLED: "false" + BENCHMARK_REF: ${{ inputs.ref }} + BENCHMARK_SUITE: ${{ inputs.suite }} + BENCHMARK_OUTPUT_DIR: ${{ inputs.output_dir }} + BENCHMARK_DD_TRACE: ${{ inputs.dd_trace }} + BENCHMARK_DD_TASK_TRACE: ${{ inputs.dd_task_trace }} + BENCHMARK_FAIL_FAST: ${{ inputs.fail_fast }} + +jobs: + benchmark: + name: Benchmark + runs-on: [self-hosted, anonymizer-evals] + timeout-minutes: 120 + + steps: + - name: Checkout benchmark target + uses: actions/checkout@v4 + with: + ref: ${{ env.BENCHMARK_REF }} + fetch-depth: "0" + + - name: Resolve benchmark target commit + id: target + run: echo "commit=$(git rev-parse HEAD)" >> "$GITHUB_OUTPUT" + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install dependencies + run: uv sync --group dev + + - name: Check NVIDIA API key + env: + NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }} + run: | + if [ -z "${NVIDIA_API_KEY:-}" ]; then + echo "::error::NVIDIA_API_KEY secret is required for benchmark CI" + exit 1 + fi + + - name: Run benchmark suite + env: + NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }} + run: | + TRACE_ARGS=(--dd-trace "$BENCHMARK_DD_TRACE") + if [ "$BENCHMARK_DD_TRACE" != "none" ]; then + TRACE_ARGS+=(--trace-dir "$BENCHMARK_OUTPUT_DIR/traces") + fi + + TASK_TRACE_ARGS=() + if [ "$BENCHMARK_DD_TASK_TRACE" = "true" ]; then + TASK_TRACE_ARGS+=(--dd-task-trace --task-trace-dir "$BENCHMARK_OUTPUT_DIR/task-traces") + fi + + FAIL_FAST_ARGS=() + if [ "$BENCHMARK_FAIL_FAST" = "true" ]; then + FAIL_FAST_ARGS+=(--fail-fast) + fi + + uv run python tools/measurement/run_benchmarks.py \ + "$BENCHMARK_SUITE" \ + --output "$BENCHMARK_OUTPUT_DIR" \ + --overwrite \ + "${TRACE_ARGS[@]}" \ + "${TASK_TRACE_ARGS[@]}" \ + "${FAIL_FAST_ARGS[@]}" + + - name: Add benchmark summary + if: always() + env: + BENCHMARK_COMMIT: ${{ steps.target.outputs.commit }} + run: | + python - <<'PY' + import json + import os + from pathlib import Path + + output_dir = Path(os.environ["BENCHMARK_OUTPUT_DIR"]) + summary_path = output_dir / "summary.json" + step_summary = Path(os.environ["GITHUB_STEP_SUMMARY"]) + + with step_summary.open("a", encoding="utf-8") as handle: + handle.write("# Anonymizer Benchmark\n\n") + handle.write(f"- Ref: `{os.environ['BENCHMARK_REF']}`\n") + handle.write(f"- Commit: `{os.environ.get('BENCHMARK_COMMIT', 'unknown')}`\n") + handle.write(f"- Suite: `{os.environ['BENCHMARK_SUITE']}`\n") + handle.write(f"- Output: `{output_dir}`\n") + handle.write(f"- DD traces: `{os.environ['BENCHMARK_DD_TRACE']}`\n") + handle.write(f"- DD task traces: `{os.environ['BENCHMARK_DD_TASK_TRACE']}`\n\n") + + if not summary_path.exists(): + handle.write("`summary.json` was not produced. Check job logs for setup or preflight failures.\n") + raise SystemExit(0) + + summary = json.loads(summary_path.read_text(encoding="utf-8")) + cases = summary.get("cases", []) + completed = sum(1 for case in cases if case.get("status") == "completed") + errors = sum(1 for case in cases if case.get("status") == "error") + handle.write(f"Ran {completed}/{len(cases)} case(s); errors={errors}.\n\n") + handle.write("| Case | Status | Elapsed | Attempts |\n") + handle.write("| --- | --- | ---: | ---: |\n") + for case in cases: + elapsed = case.get("elapsed_sec") + elapsed_text = "" if elapsed is None else f"{elapsed:.2f}s" + handle.write( + f"| `{case.get('case_id')}` | {case.get('status')} | {elapsed_text} | " + f"{case.get('attempt_count', 0)} |\n" + ) + PY + + - name: Upload benchmark artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: anonymizer-benchmark-${{ steps.target.outputs.commit }} + path: ${{ env.BENCHMARK_OUTPUT_DIR }}/ + if-no-files-found: warn diff --git a/.gitignore b/.gitignore index 6548d6ec..4382af19 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,7 @@ ai/tmp/ # Anonymizer execution artifacts .anonymizer-artifacts/ +benchmark-results/ docs/notebook_source/data/synth_bios_sample10_anonymized.csv # TLS certs and keys (if any) diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 09ed5022..69042cd9 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -114,6 +114,21 @@ DD_TRACE_MODE=all-messages \ /tmp/anonymizer-repo-data-smoke-dd-traces-full ``` +## Benchmark CI + +`.github/workflows/benchmark-ci.yml` runs the same benchmark runner from a +manual GitHub Actions dispatch. It targets the self-hosted +`anonymizer-evals` runner, checks out the requested ref, installs the project +environment, runs a suite, appends a short case summary to the GitHub step +summary, and uploads the full output directory as a workflow artifact. + +The default suite is `tools/measurement/examples/repo-data-smoke.yaml`. Dispatch +inputs let operators choose the ref, suite path, output directory, +DataDesigner message trace mode, sanitized scheduler task traces, and +fail-fast behavior. The workflow requires the repository secret +`NVIDIA_API_KEY` because the default model configuration uses NVIDIA-hosted +models. + Benchmark suites are YAML files with three parts: - `workloads`: input datasets and text-column metadata. From 38451645d5e49a80322987e9b3220c759311b286 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 20:32:36 +0000 Subject: [PATCH 18/26] Document measurement observability Signed-off-by: Aaron Gonzales --- docs/development/observability.md | 170 ++++++++++++++++++++++++++++++ mkdocs.yml | 2 + tools/measurement/README.md | 11 ++ 3 files changed, 183 insertions(+) create mode 100644 docs/development/observability.md diff --git a/docs/development/observability.md b/docs/development/observability.md new file mode 100644 index 00000000..38cb047d --- /dev/null +++ b/docs/development/observability.md @@ -0,0 +1,170 @@ + + + +# Observability + +Anonymizer keeps local run measurement in `src/anonymizer/measurement.py`. +Measurement hooks record timings, counts, model-call summaries, and safety +metrics without changing anonymization behavior. Benchmark tools convert those +records into tables for latency, reliability, model usage, and quality analysis. + +Measurement is separate from anonymous NVIDIA telemetry. Telemetry can report +one product event per run or preview. Users can opt out as described in +[Telemetry and Privacy](../index.md#telemetry-and-privacy). Measurement records +are local artifacts. They are written only when developer tooling or caller code +activates a measurement session. + +## Model + +Instrumentation is passive unless a `MeasurementCollector` is active in the +current context: + +```python +from anonymizer.measurement import MeasurementConfig, configured_measurement_session + +measurement = MeasurementConfig(output_path="benchmark-runs/case/measurements.jsonl") + +with configured_measurement_session(measurement): + result = anonymizer.run(config=config, data=data) +``` + +Instrumentation uses these entry points: + +- `stage_timer(...)` wraps pipeline phases and records elapsed time. +- `record_run_metadata(...)` records config, input, model, and runtime metadata + once per run, without raw source values. +- `record_record_metrics(...)` records per-row counts and safety metrics from + the trace DataFrame. +- `record_ndd_workflow(...)` records DataDesigner workflow summaries at the + `NddAdapter` boundary. +- `record_model_workflow(...)` records benchmark-only direct model calls that do + not use DataDesigner. + +The public API and CLI do not read measurement environment variables by default. +Benchmark and developer tools opt into measurement explicitly. + +## Record Types + +Measurement output is JSONL by default. Each row has a `record_type` and shared +run metadata. + +| Record type | Meaning | +| --- | --- | +| `run` | One anonymization call: mode, strategy, input shape, config metadata, model aliases, runtime metadata. | +| `stage` | Pipeline phase timing, status, row counts, and row throughput. | +| `record` | Per-input-row counts, text-length buckets, entity counts, ground-truth comparison metrics when present, replacement coverage, leakage flags, and estimated LLM calls. | +| `ndd_workflow` | DataDesigner workflow summary: workflow name, model aliases, row counts, failures, elapsed time, usage summary, and throughput. | +| `model_workflow` | Direct model workflow summary for benchmark-only paths outside DataDesigner. | +| `dd_trace_coverage` | Trace coverage summary for DataDesigner columns when message tracing is enabled. | + +Use `tools/measurement/export_measurements.py` to convert raw measurement JSONL +into Parquet, CSV, or JSONL tables. + +## Output and Sinks + +`MeasurementConfig` controls output: + +| Field | Purpose | +| --- | --- | +| `output_path` | Destination for measurement records. | +| `output_format` | `jsonl` or `json`; defaults to `jsonl`. | +| `record_level` | Include per-row `record` entries; defaults to `True`. | +| `streaming` | Write JSONL records as they are emitted instead of collecting them in memory. | +| `keep_records` | Keep emitted records in memory for caller access. | +| `run_id` | Optional stable run ID. | +| `run_tags` | Caller-supplied tags copied to every record. | +| `fail_on_write_error` | Raise output write/close failures when the run body succeeded. | + +Streaming mode supports JSONL only. Use it for long benchmark suites where +holding all measurement records in memory is unnecessary. + +`MeasurementConfig.from_env()` can read `ANONYMIZER_MEASUREMENT_*` settings for +developer tooling. Product entry points do not call it automatically. + +| Environment variable | Field | +| --- | --- | +| `ANONYMIZER_MEASUREMENT_OUTPUT_PATH` | `output_path` | +| `ANONYMIZER_MEASUREMENT_OUTPUT_FORMAT` | `output_format` | +| `ANONYMIZER_MEASUREMENT_RECORD_LEVEL` | `record_level` | +| `ANONYMIZER_MEASUREMENT_STREAMING` | `streaming` | +| `ANONYMIZER_MEASUREMENT_KEEP_RECORDS` | `keep_records` | +| `ANONYMIZER_MEASUREMENT_DD_TRACE` | `dd_trace` | +| `ANONYMIZER_MEASUREMENT_DD_TRACE_PATH` | `dd_trace_path` | +| `ANONYMIZER_MEASUREMENT_DD_TASK_TRACE_PATH` | `dd_task_trace_path` | +| `ANONYMIZER_MEASUREMENT_FAIL_ON_WRITE_ERROR` | `fail_on_write_error` | +| `ANONYMIZER_MEASUREMENT_RUN_ID` | `run_id` | +| `ANONYMIZER_MEASUREMENT_RUN_TAGS` | `run_tags` | + +## DataDesigner Message Traces + +DataDesigner message traces are optional sidecar artifacts for model-call +debugging: + +```python +measurement = MeasurementConfig( + output_path="benchmark-runs/case/measurements.jsonl", + dd_trace="last_message", + dd_trace_path="benchmark-runs/case/traces.jsonl", +) +``` + +`last_message` stores the final prompt message for each traced DataDesigner +model call. `all_messages` stores the full message list. + +Message traces are separate from measurement records. They may contain raw input +text, prompts, generated output, entity values, replacement values, secrets, and +PII. Do not share them unless they have been reviewed or redacted. + +Anonymizer requests these traces through DataDesigner native LLM column trace +side effects. That covers `LLMTextColumnConfig` and +`LLMStructuredColumnConfig`. It does not cover model calls made inside +`CustomColumnConfig` generator functions. When tracing is enabled, the +measurement stream records a `dd_trace_coverage` row so benchmark analysis can +see unsupported columns. + +## DataDesigner Task Traces + +Scheduler task traces are a separate sidecar: + +```python +measurement = MeasurementConfig( + output_path="benchmark-runs/case/measurements.jsonl", + dd_task_trace_path="benchmark-runs/case/task-traces.jsonl", +) +``` + +Task traces capture DataDesigner scheduler timing metadata: workflow, column, +row group, row index, task type, status, queue wait time, execution time, total +time, and whether an error was present. They do not store raw DataDesigner error +strings because those strings can contain prompts, outputs, or source values. + +## Safety Rules + +Measurement records must not contain raw text, entity values, prompts, generated +outputs, replacement maps, provider secrets, or API keys. + +Use counts, labels, lengths, buckets, model aliases, status flags, elapsed time, +token counts, request counts, and run-scoped HMACs instead. The collector hashes +record identity with a per-run key. Record hashes can join artifacts from one +run, but they are not stable identifiers across unrelated runs unless the caller +supplies the same hash key deliberately. + +When adding instrumentation: + +- Put timing around stable phase boundaries, not every helper call. +- Record metadata at the boundary where the information is known. +- Keep raw debug payloads in explicit sidecars, never in measurement records. +- Prefer `run_tags` for benchmark context such as suite ID, case ID, workload, + config, or experimental strategy. +- Keep benchmark-only strategy switches in `tools/measurement`, not product + defaults. + +## Key Files + +| File | Purpose | +| --- | --- | +| `src/anonymizer/measurement.py` | Collector, config, context managers, safe record builders, and trace sidecar hooks. | +| `src/anonymizer/interface/anonymizer.py` | Run-level and per-record measurement integration. | +| `src/anonymizer/engine/ndd/adapter.py` | DataDesigner workflow measurement, native message trace capture, and scheduler task trace capture. | +| `tools/measurement/run_benchmarks.py` | Benchmark suite runner that activates measurement sessions and writes per-case artifacts. | +| `tools/measurement/README.md` | Detailed benchmark and analysis command reference. | diff --git a/mkdocs.yml b/mkdocs.yml index 29d11c01..45673b1d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -167,6 +167,8 @@ nav: - Choosing a Replacement Strategy: notebooks/03_choosing_a_replacement_strategy.ipynb - Rewriting Biographies: notebooks/04_rewriting_biographies.ipynb - Rewriting Legal Documents: notebooks/05_rewriting_legal_documents.ipynb + - Development: + - Observability: development/observability.md - API Reference: reference/ - Developer Notes: - devnotes/index.md diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 69042cd9..3da61579 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -122,6 +122,12 @@ manual GitHub Actions dispatch. It targets the self-hosted environment, runs a suite, appends a short case summary to the GitHub step summary, and uploads the full output directory as a workflow artifact. +The job is intentionally manual. It runs only through `workflow_dispatch`; it +does not run on `push`, `pull_request`, `schedule`, or the default PR CI path. +GitHub exposes manual dispatch only after the workflow file exists on the +repository default branch. After that, launch it from the Actions UI, GitHub +CLI, or API. + The default suite is `tools/measurement/examples/repo-data-smoke.yaml`. Dispatch inputs let operators choose the ref, suite path, output directory, DataDesigner message trace mode, sanitized scheduler task traces, and @@ -129,6 +135,11 @@ fail-fast behavior. The workflow requires the repository secret `NVIDIA_API_KEY` because the default model configuration uses NVIDIA-hosted models. +The `ref` input defaults to `main`. To benchmark a PR or experiment branch, set +`ref` to that branch name or commit SHA. The workflow checks out that ref and +uses the benchmark runner and suite files from the checkout, so the selected ref +must contain `tools/measurement/run_benchmarks.py` and the requested suite path. + Benchmark suites are YAML files with three parts: - `workloads`: input datasets and text-column metadata. From b3645cc37c421f9d4e2804ea882170fdfa0a8c04 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 21:43:18 +0000 Subject: [PATCH 19/26] Refactor measurement observability package Signed-off-by: Aaron Gonzales --- docs/development/observability.md | 4 +- src/anonymizer/engine/ndd/adapter.py | 16 +- src/anonymizer/measurement.py | 1577 ----------------- src/anonymizer/measurement/__init__.py | 42 + src/anonymizer/measurement/_coerce.py | 134 ++ src/anonymizer/measurement/collector.py | 200 +++ src/anonymizer/measurement/config.py | 136 ++ src/anonymizer/measurement/constants.py | 11 + .../measurement/metrics/__init__.py | 4 + .../measurement/metrics/entities.py | 216 +++ .../measurement/metrics/llm_calls.py | 51 + .../measurement/metrics/replacements.py | 85 + src/anonymizer/measurement/metrics/rewrite.py | 94 + src/anonymizer/measurement/recorders.py | 220 +++ .../measurement/records/__init__.py | 4 + src/anonymizer/measurement/records/model.py | 115 ++ src/anonymizer/measurement/records/row.py | 195 ++ src/anonymizer/measurement/records/run.py | 108 ++ src/anonymizer/measurement/session.py | 115 ++ src/anonymizer/measurement/sinks.py | 54 + tools/measurement/README.md | 66 +- 21 files changed, 1850 insertions(+), 1597 deletions(-) delete mode 100644 src/anonymizer/measurement.py create mode 100644 src/anonymizer/measurement/__init__.py create mode 100644 src/anonymizer/measurement/_coerce.py create mode 100644 src/anonymizer/measurement/collector.py create mode 100644 src/anonymizer/measurement/config.py create mode 100644 src/anonymizer/measurement/constants.py create mode 100644 src/anonymizer/measurement/metrics/__init__.py create mode 100644 src/anonymizer/measurement/metrics/entities.py create mode 100644 src/anonymizer/measurement/metrics/llm_calls.py create mode 100644 src/anonymizer/measurement/metrics/replacements.py create mode 100644 src/anonymizer/measurement/metrics/rewrite.py create mode 100644 src/anonymizer/measurement/recorders.py create mode 100644 src/anonymizer/measurement/records/__init__.py create mode 100644 src/anonymizer/measurement/records/model.py create mode 100644 src/anonymizer/measurement/records/row.py create mode 100644 src/anonymizer/measurement/records/run.py create mode 100644 src/anonymizer/measurement/session.py create mode 100644 src/anonymizer/measurement/sinks.py diff --git a/docs/development/observability.md b/docs/development/observability.md index 38cb047d..b69cff29 100644 --- a/docs/development/observability.md +++ b/docs/development/observability.md @@ -3,7 +3,7 @@ # Observability -Anonymizer keeps local run measurement in `src/anonymizer/measurement.py`. +Anonymizer keeps local run measurement in the `anonymizer.measurement` package. Measurement hooks record timings, counts, model-call summaries, and safety metrics without changing anonymization behavior. Benchmark tools convert those records into tables for latency, reliability, model usage, and quality analysis. @@ -163,7 +163,7 @@ When adding instrumentation: | File | Purpose | | --- | --- | -| `src/anonymizer/measurement.py` | Collector, config, context managers, safe record builders, and trace sidecar hooks. | +| `src/anonymizer/measurement/` | Collector, config, context managers, safe record builders, and trace sidecar hooks. | | `src/anonymizer/interface/anonymizer.py` | Run-level and per-record measurement integration. | | `src/anonymizer/engine/ndd/adapter.py` | DataDesigner workflow measurement, native message trace capture, and scheduler task trace capture. | | `tools/measurement/run_benchmarks.py` | Benchmark suite runner that activates measurement sessions and writes per-case artifacts. | diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 9a13cf5d..3ca6855e 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -146,14 +146,14 @@ def run_workflow( num_records=len(workflow_input_df), dataset_name=workflow_name, ) - task_traces = list(getattr(run_results, "task_traces", []) or []) + task_traces = _task_traces_from_result(run_results) output_df = run_results.load_dataset() else: preview_results = self._data_designer.preview( config_builder, num_records=record_count, ) - task_traces = list(getattr(preview_results, "task_traces", []) or []) + task_traces = _task_traces_from_result(preview_results) if preview_results.dataset is None: output_df = workflow_input_df.iloc[0:0].copy() else: @@ -440,6 +440,18 @@ def _run_config_with_async_trace(run_config: Any) -> Any: return run_config +def _task_traces_from_result(result: Any) -> list[Any]: + raw_traces = getattr(result, "task_traces", None) + if raw_traces is None: + return [] + if isinstance(raw_traces, list): + return raw_traces + try: + return list(raw_traces) + except TypeError: + return [] + + def _configure_native_dd_message_traces( *, columns: list[ColumnConfigT], diff --git a/src/anonymizer/measurement.py b/src/anonymizer/measurement.py deleted file mode 100644 index 0f55a1c1..00000000 --- a/src/anonymizer/measurement.py +++ /dev/null @@ -1,1577 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import hashlib -import hmac -import json -import logging -import math -import platform -import secrets -import time -import uuid -from collections import Counter -from collections.abc import Iterator, Mapping -from contextlib import contextmanager -from contextvars import ContextVar -from dataclasses import dataclass -from importlib.metadata import PackageNotFoundError, version -from numbers import Integral -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Protocol, cast -from urllib.parse import urlparse - -from pydantic import Field, ValidationError -from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError - -from anonymizer.engine.constants import COL_FINAL_ENTITIES - -if TYPE_CHECKING: - import pandas as pd - - -_ACTIVE_COLLECTOR: ContextVar[MeasurementCollector | None] = ContextVar( - "anonymizer_measurement_collector", - default=None, -) -_GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") -_ENTITY_LABEL_EQUIVALENCE_CLASSES = ( - frozenset( - { - "access_token", - "api_key", - "auth_token", - "bearer_token", - "password", - "secret_key", - "session_id", - "unique_id", - "user_id", - } - ), - frozenset({"full_name", "person_name", "user", "user_name", "username"}), - frozenset({"phone", "phone_number", "telephone"}), - frozenset({"email", "email_address"}), - frozenset({"cookie", "http_cookie", "session_cookie"}), -) -_ENTITY_LABEL_EQUIVALENCE: dict[str, str] = { - label: sorted(labels)[0] for labels in _ENTITY_LABEL_EQUIVALENCE_CLASSES for label in labels -} -MEASUREMENT_SCHEMA_VERSION = 1 -DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" -DD_TRACE_MODES = {"none", "last_message", "all_messages"} -DDTraceMode = Literal["none", "last_message", "all_messages"] - -logger = logging.getLogger("anonymizer.measurement") - - -class _MeasurementWriter(Protocol): - def write(self, records: list[dict[str, Any]], path: str | Path) -> None: ... - - -class _MeasurementSink(Protocol): - def write_record(self, record: dict[str, Any]) -> None: ... - - def close(self) -> None: ... - - -class _JsonlMeasurementWriter: - def write(self, records: list[dict[str, Any]], path: str | Path) -> None: - output_path = Path(path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="utf-8") as f: - for record in records: - f.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") - - -class _JsonMeasurementWriter: - def write(self, records: list[dict[str, Any]], path: str | Path) -> None: - output_path = Path(path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="utf-8") as f: - json.dump(records, f, ensure_ascii=True, indent=2, sort_keys=True) - - -def _writer_for_format(output_format: Literal["jsonl", "json"]) -> _MeasurementWriter: - if output_format == "json": - return _JsonMeasurementWriter() - return _JsonlMeasurementWriter() - - -class _JsonlMeasurementSink: - def __init__(self, path: str | Path) -> None: - output_path = Path(path) - output_path.parent.mkdir(parents=True, exist_ok=True) - self._file = output_path.open("w", encoding="utf-8", buffering=1) - - def write_record(self, record: dict[str, Any]) -> None: - self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") - - def close(self) -> None: - self._file.close() - - -class _MeasurementEnvSettings(BaseSettings): - model_config = SettingsConfigDict( - env_prefix=DEFAULT_MEASUREMENT_ENV_PREFIX, - env_ignore_empty=True, - extra="ignore", - ) - - output_path: str | None = None - output_format: Literal["jsonl", "json"] = "jsonl" - record_level: bool = True - streaming: bool = False - keep_records: bool = True - dd_trace: DDTraceMode = "none" - dd_trace_path: str | None = None - dd_task_trace_path: str | None = None - fail_on_write_error: bool = False - run_id: str | None = None - run_tags: dict[str, Any] = Field(default_factory=dict) - - -class MeasurementCollector: - """In-memory collector for local benchmark and throughput records. - - Records contain counts, labels, lengths, aliases, timings, and run-scoped - HMACs. They must not contain raw text, entity values, prompts, generated - outputs, replacement maps, provider secrets, or API keys. - """ - - def __init__( - self, - *, - run_id: str | None = None, - record_hash_key: bytes | str | None = None, - record_level: bool = True, - run_tags: Mapping[str, Any] | None = None, - record_sink: _MeasurementSink | None = None, - keep_records: bool = True, - dd_trace_mode: DDTraceMode = "none", - dd_trace_sink: _MeasurementSink | None = None, - dd_task_trace_sink: _MeasurementSink | None = None, - fail_on_write_error: bool = False, - ) -> None: - self.run_id = run_id or uuid.uuid4().hex - self.record_level = record_level - self.run_tags = cast(dict[str, Any], _json_safe(dict(run_tags or {}))) - self._record_sink = record_sink - self._keep_records = keep_records - self._dd_trace_mode = dd_trace_mode - self._dd_trace_sink = dd_trace_sink - self._dd_task_trace_sink = dd_task_trace_sink - self._fail_on_write_error = fail_on_write_error - self._sink_failed = False - self._dd_trace_failed = False - self._dd_task_trace_failed = False - if record_hash_key is None: - self._record_hash_key = secrets.token_bytes(32) - elif isinstance(record_hash_key, str): - self._record_hash_key = record_hash_key.encode("utf-8") - else: - self._record_hash_key = bytes(record_hash_key) - self._records: list[dict[str, Any]] = [] - - @property - def records(self) -> list[dict[str, Any]]: - """Return a shallow copy of collected measurement records.""" - return list(self._records) - - def record(self, record_type: str, **fields: Any) -> None: - """Append one machine-readable measurement record.""" - record = { - **fields, - "schema_version": MEASUREMENT_SCHEMA_VERSION, - "record_type": record_type, - "run_id": self.run_id, - "run_tags": self.run_tags, - "timestamp_unix_sec": time.time(), - } - safe_record = _json_safe(record) - if self._keep_records: - self._records.append(safe_record) - if self._record_sink is not None: - self._write_record_to_sink(safe_record) - - def close(self) -> None: - """Close any streaming measurement sink attached to this collector.""" - close_error: Exception | None = None - for sink in (self._record_sink, self._dd_trace_sink, self._dd_task_trace_sink): - if sink is None: - continue - try: - sink.close() - except Exception as exc: - if close_error is None: - close_error = exc - if close_error is not None: - raise close_error - - @property - def dd_trace_mode(self) -> DDTraceMode: - return self._dd_trace_mode - - @property - def dd_trace_enabled(self) -> bool: - return self._dd_trace_mode != "none" and self._dd_trace_sink is not None - - @property - def dd_task_trace_enabled(self) -> bool: - return self._dd_task_trace_sink is not None - - def record_dd_message_trace(self, **fields: Any) -> None: - """Write an explicitly opt-in DataDesigner message trace record. - - These records may contain raw prompts, input text, model outputs, and - PII. They are intentionally written to a separate trace sink and are - never appended to the safe measurement record list. - """ - if not self.dd_trace_enabled or self._dd_trace_failed: - return - - record = _json_safe( - { - **fields, - "schema_version": MEASUREMENT_SCHEMA_VERSION, - "record_type": "dd_message_trace", - "run_id": self.run_id, - "run_tags": self.run_tags, - "timestamp_unix_sec": time.time(), - } - ) - try: - cast(_MeasurementSink, self._dd_trace_sink).write_record(record) - except Exception: - self._dd_trace_failed = True - logger.warning("Failed to write DataDesigner message trace records") - if self._fail_on_write_error: - raise - - def record_dd_task_trace(self, **fields: Any) -> None: - """Write an opt-in sanitized DataDesigner scheduler task trace record.""" - if not self.dd_task_trace_enabled or self._dd_task_trace_failed: - return - - record = _json_safe( - { - **fields, - "schema_version": MEASUREMENT_SCHEMA_VERSION, - "record_type": "dd_task_trace", - "run_id": self.run_id, - "run_tags": self.run_tags, - "timestamp_unix_sec": time.time(), - } - ) - try: - cast(_MeasurementSink, self._dd_task_trace_sink).write_record(record) - except Exception: - self._dd_task_trace_failed = True - logger.warning("Failed to write DataDesigner task trace records") - if self._fail_on_write_error: - raise - - def _write_record_to_sink(self, record: dict[str, Any]) -> None: - if self._sink_failed: - return - try: - cast(_MeasurementSink, self._record_sink).write_record(record) - except Exception: - self._sink_failed = True - logger.warning("Failed to stream Anonymizer measurement records") - if self._fail_on_write_error: - raise - - def record_hash(self, *, row_index: object, text: str) -> str: - """Return a run-scoped HMAC for joining records without storing text.""" - serialized = json.dumps( - {"row_index": str(row_index), "text": text}, - default=str, - sort_keys=True, - separators=(",", ":"), - ) - return hmac.new(self._record_hash_key, serialized.encode("utf-8"), hashlib.sha256).hexdigest() - - def write_jsonl(self, path: str | Path) -> None: - """Write records as newline-delimited JSON.""" - _JsonlMeasurementWriter().write(self._records, path) - - def write_json(self, path: str | Path) -> None: - """Write records as a JSON array.""" - _JsonMeasurementWriter().write(self._records, path) - - def to_dataframe(self) -> pd.DataFrame: - """Return records as a pandas DataFrame for benchmark tooling.""" - import pandas as pd - - return pd.DataFrame(self._records) - - -@dataclass(frozen=True) -class MeasurementConfig: - """Configuration for writing structured measurement records around a run.""" - - output_path: str | Path - output_format: Literal["jsonl", "json"] = "jsonl" - record_level: bool = True - streaming: bool = False - keep_records: bool = True - dd_trace: DDTraceMode = "none" - dd_trace_path: str | Path | None = None - dd_task_trace_path: str | Path | None = None - run_id: str | None = None - record_hash_key: bytes | str | None = None - run_tags: Mapping[str, Any] | None = None - fail_on_write_error: bool = False - - def __post_init__(self) -> None: - if self.output_format not in {"jsonl", "json"}: - raise ValueError("output_format must be 'jsonl' or 'json'") - if self.streaming and self.output_format != "jsonl": - raise ValueError("streaming measurement output only supports jsonl") - if self.dd_trace not in DD_TRACE_MODES: - raise ValueError("dd_trace must be 'none', 'last_message', or 'all_messages'") - if self.dd_trace != "none" and self.dd_trace_path is None: - raise ValueError("dd_trace_path is required when dd_trace is enabled") - - @classmethod - def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> MeasurementConfig | None: - """Build measurement config from environment variables, or None if output is unset. - - This is intentionally opt-in. Anonymizer API and CLI calls do not read - measurement environment variables unless benchmark/tooling code calls this - helper explicitly. - """ - try: - settings = _load_measurement_env_settings(prefix=prefix) - except (SettingsError, ValidationError) as exc: - raise ValueError(_measurement_env_error_message(exc, prefix=prefix)) from None - - if settings.output_path is None: - return None - return cls( - output_path=settings.output_path, - output_format=settings.output_format, - record_level=settings.record_level, - streaming=settings.streaming, - keep_records=settings.keep_records, - dd_trace=settings.dd_trace, - dd_trace_path=settings.dd_trace_path, - dd_task_trace_path=settings.dd_task_trace_path, - run_id=settings.run_id, - run_tags=settings.run_tags, - fail_on_write_error=settings.fail_on_write_error, - ) - - @classmethod - def from_sources( - cls, - explicit: MeasurementConfig | None = None, - *, - env: bool = False, - prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX, - ) -> MeasurementConfig | None: - """Resolve measurement config from explicit config first, then optional env.""" - if explicit is not None: - return explicit - if env: - return cls.from_env(prefix=prefix) - return None - - def write_collector(self, collector: MeasurementCollector) -> None: - """Write a collector using this config's output format.""" - _writer_for_format(self.output_format).write(collector.records, self.output_path) - - -@contextmanager -def measurement_session(collector: MeasurementCollector | None = None) -> Iterator[MeasurementCollector]: - """Activate a collector for code running in this context.""" - active = collector or MeasurementCollector() - token = _ACTIVE_COLLECTOR.set(active) - try: - yield active - finally: - _ACTIVE_COLLECTOR.reset(token) - - -@contextmanager -def configured_measurement_session(config: MeasurementConfig | None) -> Iterator[MeasurementCollector | None]: - """Activate and persist a collector when a measurement config is provided.""" - if config is None: - yield None - return - - sink = _JsonlMeasurementSink(config.output_path) if config.streaming else None - dd_trace_sink = None - if config.dd_trace != "none": - if config.dd_trace_path is None: - raise ValueError("dd_trace_path is required when dd_trace is enabled") - dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) - dd_task_trace_sink = _JsonlMeasurementSink(config.dd_task_trace_path) if config.dd_task_trace_path else None - collector = MeasurementCollector( - run_id=config.run_id, - record_hash_key=config.record_hash_key, - record_level=config.record_level, - run_tags=config.run_tags, - record_sink=sink, - keep_records=config.keep_records, - dd_trace_mode=config.dd_trace, - dd_trace_sink=dd_trace_sink, - dd_task_trace_sink=dd_task_trace_sink, - fail_on_write_error=config.fail_on_write_error, - ) - with measurement_session(collector): - body_error: BaseException | None = None - try: - yield collector - except BaseException as exc: - body_error = exc - raise - finally: - if config.streaming: - _close_collector_safely(config=config, collector=collector, body_error=body_error) - else: - write_error: BaseException | None = None - try: - _write_collector_safely(config=config, collector=collector, body_error=body_error) - except BaseException as exc: - write_error = exc - raise - finally: - _close_collector_safely( - config=config, - collector=collector, - body_error=body_error or write_error, - ) - - -def current_collector() -> MeasurementCollector | None: - """Return the active collector, if measurement is enabled.""" - return _ACTIVE_COLLECTOR.get() - - -@contextmanager -def stage_timer(stage: str, **fields: Any) -> Iterator[dict[str, Any]]: - """Record wall time for a stage when collection is active.""" - collector = current_collector() - if collector is None: - yield fields - return - - started = time.perf_counter() - status = "completed" - try: - yield fields - except BaseException: - status = "error" - raise - finally: - elapsed_sec = time.perf_counter() - started - collector.record( - "stage", - stage=stage, - status=status, - elapsed_sec=elapsed_sec, - **fields, - **_row_throughput_fields( - elapsed_sec=elapsed_sec, - input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), - output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), - ), - ) - - -def record_stage(stage: str, *, elapsed_sec: float, status: str = "completed", **fields: Any) -> None: - """Record a pre-timed stage measurement if collection is active.""" - collector = current_collector() - if collector is None: - return - collector.record( - "stage", - stage=stage, - status=status, - elapsed_sec=elapsed_sec, - **fields, - **_row_throughput_fields( - elapsed_sec=elapsed_sec, - input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), - output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), - ), - ) - - -def record_ndd_workflow( - *, - workflow_name: str, - model_aliases: list[str], - input_row_count: int, - output_row_count: int | None, - failed_record_count: int | None, - elapsed_sec: float, - status: str = "completed", - seed_row_count: int | None = None, - preview_num_records: int | None = None, - column_count: int | None = None, - column_names: list[str] | None = None, - model_usage: Mapping[str, Any] | None = None, -) -> None: - """Record one DataDesigner workflow execution through the adapter boundary.""" - _record_model_workflow( - workflow_name=workflow_name, - model_aliases=model_aliases, - input_row_count=input_row_count, - output_row_count=output_row_count, - failed_record_count=failed_record_count, - elapsed_sec=elapsed_sec, - status=status, - seed_row_count=seed_row_count, - preview_num_records=preview_num_records, - column_count=column_count, - column_names=column_names, - model_usage=model_usage, - record_type="ndd_workflow", - extra_fields=None, - ) - - -def record_model_workflow( - *, - workflow_name: str, - model_aliases: list[str], - input_row_count: int, - output_row_count: int | None, - failed_record_count: int | None, - elapsed_sec: float, - status: str = "completed", - seed_row_count: int | None = None, - preview_num_records: int | None = None, - column_count: int | None = None, - column_names: list[str] | None = None, - model_usage: Mapping[str, Any] | None = None, - extra_fields: Mapping[str, Any] | None = None, -) -> None: - """Record one sanitized model-backed workflow execution. - - Use this for non-DataDesigner model calls that still need benchmark - accounting. Raw prompts, text, responses, and replacement values do not - belong in ``model_usage``. - """ - _record_model_workflow( - workflow_name=workflow_name, - model_aliases=model_aliases, - input_row_count=input_row_count, - output_row_count=output_row_count, - failed_record_count=failed_record_count, - elapsed_sec=elapsed_sec, - status=status, - seed_row_count=seed_row_count, - preview_num_records=preview_num_records, - column_count=column_count, - column_names=column_names, - model_usage=model_usage, - record_type="model_workflow", - extra_fields=extra_fields, - ) - - -def _record_model_workflow( - *, - workflow_name: str, - model_aliases: list[str], - input_row_count: int, - output_row_count: int | None, - failed_record_count: int | None, - elapsed_sec: float, - status: str, - seed_row_count: int | None, - preview_num_records: int | None, - column_count: int | None, - column_names: list[str] | None, - model_usage: Mapping[str, Any] | None, - record_type: str, - extra_fields: Mapping[str, Any] | None, -) -> None: - collector = current_collector() - if collector is None: - return - observed_usage = _summarize_model_usage(model_usage) - workflow_fields = { - "workflow_name": workflow_name, - "status": status, - "model_aliases": sorted(set(model_aliases)), - "input_row_count": input_row_count, - "seed_row_count": seed_row_count, - "output_row_count": output_row_count, - "failed_record_count": failed_record_count, - "elapsed_sec": elapsed_sec, - "preview_num_records": preview_num_records, - "column_count": column_count, - "column_names": column_names or [], - "model_usage": dict(model_usage or {}), - **dict(extra_fields or {}), - } - collector.record(record_type, **_model_workflow_fields(workflow_fields, observed_usage)) - - -def _model_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: - return { - **fields, - **observed_usage, - "observed_failed_request_rate": _safe_ratio( - observed_usage["observed_failed_requests"], - observed_usage["observed_total_requests"], - ), - **_throughput_fields( - elapsed_sec=cast(float, fields["elapsed_sec"]), - input_row_count=cast(int, fields["input_row_count"]), - output_row_count=cast(int | None, fields["output_row_count"]), - total_tokens=observed_usage["observed_total_tokens"], - total_requests=observed_usage["observed_total_requests"], - successful_requests=observed_usage["observed_successful_requests"], - ), - } - - -def record_run_metadata( - *, - config: Any, - data: Any, - mode: str, - strategy: str, - input_row_count: int, - preview_num_records: int | None, - model_configs: list[Any], -) -> None: - """Record sanitized run/config metadata once per anonymizer run.""" - collector = current_collector() - if collector is None: - return - - detect = getattr(config, "detect", None) - source = str(getattr(data, "source", "")) - collector.record( - "run", - mode=mode, - strategy=strategy, - input_row_count=input_row_count, - preview_num_records=preview_num_records, - source_hash=collector.record_hash(row_index="source", text=source), - input_source=_source_metadata(source), - input_text_column=str(getattr(data, "text_column", "")), - input_has_id_column=bool(getattr(data, "id_column", None)), - input_has_data_summary=bool(getattr(data, "data_summary", None)), - detect=_detect_config_metadata(detect), - replace=_replace_config_metadata(getattr(config, "replace", None)), - rewrite=_rewrite_config_metadata(getattr(config, "rewrite", None)), - models=[_model_config_metadata(model_config) for model_config in model_configs], - runtime=_runtime_metadata(), - ) - - -def record_record_metrics( - dataframe: pd.DataFrame, - *, - mode: str, - strategy: str, - text_column: str, - validation_max_entities_per_call: int, -) -> None: - """Record per-row count, length, and nominal-call metrics from a trace DataFrame.""" - collector = current_collector() - if collector is None or not collector.record_level: - return - - ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) - columns = set(dataframe.columns) - for row_index, row in dataframe.iterrows(): - final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) - collector.record( - "record", - **_base_record_fields( - collector=collector, - row_index=row_index, - row=row, - text_column=text_column, - mode=mode, - strategy=strategy, - ), - **_entity_record_fields(row, final_entities=final_entities, ground_truth_column=ground_truth_column), - **_replacement_record_fields(row, columns=columns, final_entities=final_entities), - **_rewrite_record_fields(row, columns=columns), - **_original_value_leak_record_fields(row, columns=columns, final_entities=final_entities), - **_llm_record_fields( - row, - columns=columns, - mode=mode, - strategy=strategy, - final_entity_count=len(final_entities), - validation_max_entities_per_call=validation_max_entities_per_call, - ), - ) - - -def _detect_config_metadata(detect: Any | None) -> dict[str, Any]: - entity_labels = getattr(detect, "entity_labels", None) - if entity_labels is None: - from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS - - entity_label_count = len(DEFAULT_ENTITY_LABELS) - else: - entity_label_count = len(entity_labels) - return { - "gliner_threshold": getattr(detect, "gliner_threshold", None), - "entity_label_source": "custom" if entity_labels is not None else "default", - "entity_label_count": entity_label_count, - "entity_labels": list(entity_labels) if entity_labels is not None else None, - "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), - "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), - } - - -def _base_record_fields( - *, - collector: MeasurementCollector, - row_index: object, - row: Any, - text_column: str, - mode: str, - strategy: str, -) -> dict[str, Any]: - text = str(row.get(text_column, "")) - text_length_tokens = _count_text_tokens(text) - return { - "mode": mode, - "strategy": strategy, - "row_index": _safe_row_index(row_index), - "record_hash": collector.record_hash(row_index=row_index, text=text), - "text_length_chars": len(text), - "text_length_chars_bucket": _size_bucket(len(text)), - "text_length_tokens": text_length_tokens, - "text_length_tokens_bucket": _size_bucket(text_length_tokens), - } - - -def _entity_record_fields( - row: Any, - *, - final_entities: list[dict[str, Any]], - ground_truth_column: str | None, -) -> dict[str, Any]: - ground_truth_entities = ( - _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None - ) - return { - "final_entity_count": len(final_entities), - "final_entity_label_counts": dict( - sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) - ), - **_entity_ground_truth_metrics(final_entities, ground_truth_entities), - } - - -def _replacement_record_fields( - row: Any, - *, - columns: set[str], - final_entities: list[dict[str, Any]], -) -> dict[str, Any]: - from anonymizer.engine.constants import COL_REPLACEMENT_MAP - - if COL_REPLACEMENT_MAP not in columns: - return {} - raw_map = row.get(COL_REPLACEMENT_MAP) - return { - **_replacement_map_metrics(raw_map), - **_replacement_coverage_metrics(raw_map, final_entities), - **_replacement_collision_metrics(raw_map, final_entities), - } - - -def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: - from anonymizer.engine.constants import ( - COL_ANY_HIGH_LEAKED, - COL_LEAKAGE_MASS, - COL_NEEDS_HUMAN_REVIEW, - COL_NEEDS_REPAIR, - COL_UTILITY_SCORE, - COL_WEIGHTED_LEAKAGE_RATE, - ) - - return { - "utility_score": _coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in columns else None, - "leakage_mass": _coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in columns else None, - "weighted_leakage_rate": ( - _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) if COL_WEIGHTED_LEAKAGE_RATE in columns else None - ), - "any_high_leaked": _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in columns else None, - "needs_human_review": ( - _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in columns else None - ), - "needs_repair": _coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in columns else None, - } - - -def _original_value_leak_record_fields( - row: Any, - *, - columns: set[str], - final_entities: list[dict[str, Any]], -) -> dict[str, Any]: - output_column = _output_text_column(columns) - if output_column is None: - return {"original_value_leak_count": None, "original_value_leak_label_counts": {}} - output_text = str(row.get(output_column, "")) - leaked = [ - entity - for entity in final_entities - if entity.get("value") and _output_contains_original_value(output_text, str(entity.get("value"))) - ] - return { - "original_value_leak_count": len(leaked), - "original_value_leak_label_counts": dict( - sorted(Counter(str(entity.get("label") or "") for entity in leaked if entity.get("label")).items()) - ), - } - - -def _output_contains_original_value(output_text: str, value: str) -> bool: - if _needs_boundary_sensitive_leak_match(value): - return _contains_with_alnum_boundaries(output_text, value) - return value in output_text - - -def _needs_boundary_sensitive_leak_match(value: str) -> bool: - return len(value) <= 4 or value.isdigit() - - -def _contains_with_alnum_boundaries(output_text: str, value: str) -> bool: - start = 0 - while True: - match_start = output_text.find(value, start) - if match_start < 0: - return False - match_end = match_start + len(value) - if _has_alnum_boundaries(output_text, match_start, match_end): - return True - start = match_start + 1 - - -def _has_alnum_boundaries(text: str, start: int, end: int) -> bool: - before_is_alnum = start > 0 and text[start - 1].isalnum() - after_is_alnum = end < len(text) and text[end].isalnum() - return not before_is_alnum and not after_is_alnum - - -def _output_text_column(columns: set[str]) -> str | None: - from anonymizer.engine.constants import COL_REPLACED_TEXT, COL_REWRITTEN_TEXT - - if COL_REPLACED_TEXT in columns: - return COL_REPLACED_TEXT - if COL_REWRITTEN_TEXT in columns: - return COL_REWRITTEN_TEXT - return None - - -def _llm_record_fields( - row: Any, - *, - columns: set[str], - mode: str, - strategy: str, - final_entity_count: int, - validation_max_entities_per_call: int, -) -> dict[str, Any]: - from anonymizer.engine.constants import COL_REPAIR_ITERATIONS - - detected_candidate_count = _detected_candidate_count(row, columns=columns) - validation_chunk_count = _validation_chunk_count( - detected_candidate_count, - validation_max_entities_per_call=validation_max_entities_per_call, - ) - grouped_entity_count = _grouped_entity_count(row, columns=columns, final_entity_count=final_entity_count) - repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) - replace_map_generation_uses_llm = _replace_map_generation_uses_llm(row, columns=columns) - calls_by_stage = estimate_llm_calls_by_stage( - mode=mode, - strategy=strategy, - has_grouped_entities=grouped_entity_count > 0, - validation_chunk_count=validation_chunk_count, - repair_iterations=repair_iterations, - replace_map_generation_uses_llm=replace_map_generation_uses_llm, - ) - total_estimated = ( - sum(calls_by_stage.values()) if all(value is not None for value in calls_by_stage.values()) else None - ) - return { - "detected_candidate_count": detected_candidate_count, - "validation_chunk_count": validation_chunk_count, - "repair_iterations": repair_iterations if mode == "rewrite" else 0, - "llm_calls_estimated_by_stage": calls_by_stage, - "llm_calls_estimated_total": total_estimated, - } - - -def _replace_map_generation_uses_llm(row: Any, *, columns: set[str]) -> bool: - del row, columns - return True - - -def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: - from anonymizer.engine.constants import COL_SEED_VALIDATION_CANDIDATES - - if COL_SEED_VALIDATION_CANDIDATES not in columns: - return None - return _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) - - -def _grouped_entity_count(row: Any, *, columns: set[str], final_entity_count: int) -> int: - from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE - - if COL_ENTITIES_BY_VALUE not in columns: - return final_entity_count - return _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) - - -def _source_metadata(source: str) -> dict[str, Any]: - parsed = urlparse(source) - if parsed.scheme in {"http", "https"}: - return { - "kind": "remote_file", - "scheme": parsed.scheme, - "suffix": Path(parsed.path).suffix.lower() or None, - } - if parsed.scheme == "file": - return { - "kind": "local_file", - "scheme": "file", - "suffix": Path(parsed.path).suffix.lower() or None, - } - return { - "kind": "local_file" if source else "unknown", - "scheme": None, - "suffix": Path(source).suffix.lower() or None, - } - - -def _replace_config_metadata(replace_config: Any | None) -> dict[str, Any] | None: - if replace_config is None: - return None - - metadata: dict[str, Any] = { - "strategy": type(replace_config).__name__, - "has_instructions": bool(getattr(replace_config, "instructions", None)), - } - for attr in ("normalize_label", "algorithm", "digest_length"): - if hasattr(replace_config, attr): - metadata[attr] = getattr(replace_config, attr) - if hasattr(replace_config, "format_template"): - metadata["has_format_template"] = True - return metadata - - -def _rewrite_config_metadata(rewrite_config: Any | None) -> dict[str, Any] | None: - if rewrite_config is None: - return None - return { - "risk_tolerance": _enum_value(getattr(rewrite_config, "risk_tolerance", None)), - "max_repair_iterations": getattr(rewrite_config, "max_repair_iterations", None), - "strict_entity_protection": getattr(rewrite_config, "strict_entity_protection", None), - "has_privacy_goal": bool(getattr(rewrite_config, "privacy_goal", None)), - "has_instructions": bool(getattr(rewrite_config, "instructions", None)), - } - - -def _model_config_metadata(model_config: Any) -> dict[str, Any]: - inference_parameters = getattr(model_config, "inference_parameters", None) - return { - "alias": getattr(model_config, "alias", None), - "model": getattr(model_config, "model", None), - "provider": _enum_value(getattr(model_config, "provider", None)), - "base_url": bool(getattr(model_config, "base_url", None)), - "max_parallel_requests": getattr(inference_parameters, "max_parallel_requests", None), - } - - -def _runtime_metadata() -> dict[str, Any]: - try: - anonymizer_version = version("nemo-anonymizer") - except PackageNotFoundError: - anonymizer_version = None - return { - "anonymizer_version": anonymizer_version, - "measurement_schema_version": MEASUREMENT_SCHEMA_VERSION, - "platform_machine": platform.machine(), - "platform_system": platform.system(), - "python_version": platform.python_version(), - } - - -def _enum_value(value: Any) -> Any: - return getattr(value, "value", value) - - -def _throughput_fields( - *, - elapsed_sec: float, - input_row_count: int | None, - output_row_count: int | None, - total_tokens: int | None, - total_requests: int | None, - successful_requests: int | None, -) -> dict[str, float | None]: - return { - "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), - "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), - "observed_tokens_per_sec": _safe_rate(total_tokens, elapsed_sec), - "observed_requests_per_sec": _safe_rate(total_requests, elapsed_sec), - "observed_tokens_per_successful_request": _safe_ratio(total_tokens, successful_requests), - } - - -def _row_throughput_fields( - *, - elapsed_sec: float, - input_row_count: int | None, - output_row_count: int | None, -) -> dict[str, float | None]: - if input_row_count is not None and input_row_count < 0: - input_row_count = None - if output_row_count is not None and output_row_count < 0: - output_row_count = None - return { - "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), - "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), - } - - -def estimate_llm_calls_by_stage( - *, - mode: str, - strategy: str, - has_grouped_entities: bool, - validation_chunk_count: int | None, - repair_iterations: int = 0, - replace_map_generation_uses_llm: bool = True, -) -> dict[str, int | None]: - """Estimate nominal model calls for one record, split by workflow stage.""" - detection_calls = None if validation_chunk_count is None else 2 + validation_chunk_count - replace_map_generation = 0 - if replace_map_generation_uses_llm and has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): - replace_map_generation = 1 - - if mode != "rewrite": - return { - "entity_detection": detection_calls, - "replace_map_generation": replace_map_generation, - } - - rewrite_body_calls = has_grouped_entities - return { - "entity_detection": detection_calls, - "latent_entity_detection": 1 if rewrite_body_calls else 0, - "replace_map_generation": replace_map_generation, - "rewrite_pipeline": 5 if rewrite_body_calls else 0, - "rewrite_evaluate": 3 * (1 + repair_iterations) if rewrite_body_calls else 0, - "rewrite_repair": repair_iterations if rewrite_body_calls else 0, - "rewrite_final_judge": 1 if rewrite_body_calls else 0, - } - - -def _write_collector_safely( - *, - config: MeasurementConfig, - collector: MeasurementCollector, - body_error: BaseException | None, -) -> None: - try: - config.write_collector(collector) - except Exception as exc: - logger.warning("Failed to write Anonymizer measurement records (%s)", type(exc).__name__) - if body_error is None and config.fail_on_write_error: - raise - - -def _close_collector_safely( - *, - config: MeasurementConfig, - collector: MeasurementCollector, - body_error: BaseException | None, -) -> None: - try: - collector.close() - except Exception as exc: - logger.warning("Failed to close Anonymizer measurement stream (%s)", type(exc).__name__) - if body_error is None and config.fail_on_write_error: - raise - - -def _measurement_env_error_message(exc: SettingsError | ValidationError, *, prefix: str) -> str: - fields: set[str] = set() - if isinstance(exc, ValidationError): - for error in exc.errors(include_input=False): - loc = error.get("loc", ()) - if loc: - fields.add(str(loc[0]).upper()) - else: - error_text = str(exc).lower() - for field_name in _MeasurementEnvSettings.model_fields: - if field_name in error_text: - fields.add(field_name.upper()) - - if fields: - env_fields = ", ".join(f"{prefix}{field}" for field in sorted(fields)) - return f"Invalid Anonymizer measurement environment configuration for: {env_fields}" - return "Invalid Anonymizer measurement environment configuration" - - -def _load_measurement_env_settings(*, prefix: str) -> _MeasurementEnvSettings: - settings_factory = cast(Any, _MeasurementEnvSettings) - return cast(_MeasurementEnvSettings, settings_factory(_env_prefix=prefix)) - - -def _validation_chunk_count( - detected_candidate_count: int | None, - *, - validation_max_entities_per_call: int, -) -> int | None: - if detected_candidate_count is None: - return None - if detected_candidate_count <= 0: - return 0 - return int(math.ceil(detected_candidate_count / validation_max_entities_per_call)) - - -def _safe_row_index(row_index: object) -> int | None: - if isinstance(row_index, bool): - return None - if isinstance(row_index, Integral): - return int(row_index) - return None - - -def _entities_from_raw(raw: object) -> list[dict[str, Any]]: - payload = _coerce_payload(raw) - if isinstance(payload, Mapping): - items = cast(Mapping[str, Any], payload).get("entities", []) - elif isinstance(payload, list): - items = payload - else: - items = [] - return [dict(cast(Mapping[str, Any], item)) for item in items if isinstance(item, Mapping)] - - -def _entity_ground_truth_metrics( - final_entities: list[dict[str, Any]], - ground_truth_entities: list[dict[str, Any]] | None, -) -> dict[str, Any]: - if ground_truth_entities is None: - return { - "ground_truth_entity_count": None, - "ground_truth_entity_label_counts": None, - "entity_true_positive_count": None, - "entity_false_positive_count": None, - "entity_false_negative_count": None, - "entity_precision": None, - "entity_recall": None, - "entity_f1": None, - "entity_relaxed_gt_found_count": None, - "entity_relaxed_detected_tp_count": None, - "entity_relaxed_label_compatible_gt_found_count": None, - "entity_relaxed_label_compatible_detected_tp_count": None, - "entity_relaxed_precision": None, - "entity_relaxed_recall": None, - "entity_relaxed_f1": None, - "entity_relaxed_label_compatible_precision": None, - "entity_relaxed_label_compatible_recall": None, - "entity_relaxed_label_compatible_f1": None, - } - - predicted = _entity_identity_counts(final_entities) - expected = _entity_identity_counts(ground_truth_entities) - true_positive = sum((predicted & expected).values()) - false_positive = sum((predicted - expected).values()) - false_negative = sum((expected - predicted).values()) - precision = _safe_ratio(true_positive, true_positive + false_positive) - recall = _safe_ratio(true_positive, true_positive + false_negative) - return { - "ground_truth_entity_count": len(ground_truth_entities), - "ground_truth_entity_label_counts": dict( - sorted(Counter(e.get("label", "") for e in ground_truth_entities if e.get("label")).items()) - ), - "entity_true_positive_count": true_positive, - "entity_false_positive_count": false_positive, - "entity_false_negative_count": false_negative, - "entity_precision": precision, - "entity_recall": recall, - "entity_f1": _f1(precision, recall), - **_entity_relaxed_ground_truth_metrics(final_entities, ground_truth_entities), - } - - -def _entity_identity_counts(entities: list[dict[str, Any]]) -> Counter[tuple[str, str]]: - identities: Counter[tuple[str, str]] = Counter() - for entity in entities: - label = entity.get("label") - value = entity.get("value") - if label is None or value is None: - continue - identities[(str(value), str(label))] += 1 - return identities - - -def _entity_relaxed_ground_truth_metrics( - final_entities: list[dict[str, Any]], - ground_truth_entities: list[dict[str, Any]], -) -> dict[str, Any]: - relaxed_match_count = _relaxed_entity_match_count(final_entities, ground_truth_entities) - label_compatible_match_count = _relaxed_entity_match_count( - final_entities, - ground_truth_entities, - require_label_compatible=True, - ) - gt_found = relaxed_match_count - detected_tp = relaxed_match_count - label_compatible_gt_found = label_compatible_match_count - label_compatible_detected_tp = label_compatible_match_count - precision = _safe_ratio(detected_tp, len(final_entities)) - recall = _safe_ratio(gt_found, len(ground_truth_entities)) - label_compatible_precision = _safe_ratio(label_compatible_detected_tp, len(final_entities)) - label_compatible_recall = _safe_ratio(label_compatible_gt_found, len(ground_truth_entities)) - return { - "entity_relaxed_gt_found_count": gt_found, - "entity_relaxed_detected_tp_count": detected_tp, - "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, - "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, - "entity_relaxed_precision": precision, - "entity_relaxed_recall": recall, - "entity_relaxed_f1": _f1(precision, recall), - "entity_relaxed_label_compatible_precision": label_compatible_precision, - "entity_relaxed_label_compatible_recall": label_compatible_recall, - "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), - } - - -def _relaxed_entity_match_count( - final_entities: list[dict[str, Any]], - ground_truth_entities: list[dict[str, Any]], - *, - require_label_compatible: bool = False, -) -> int: - matches_by_ground_truth = [ - [ - final_index - for final_index, final_entity in enumerate(final_entities) - if _entities_match_relaxed( - final_entity, - ground_truth_entity, - require_label_compatible=require_label_compatible, - ) - ] - for ground_truth_entity in ground_truth_entities - ] - matched_ground_truth_by_final: dict[int, int] = {} - - def assign(ground_truth_index: int, seen: set[int]) -> bool: - for final_index in matches_by_ground_truth[ground_truth_index]: - if final_index in seen: - continue - seen.add(final_index) - if final_index not in matched_ground_truth_by_final or assign( - matched_ground_truth_by_final[final_index], - seen, - ): - matched_ground_truth_by_final[final_index] = ground_truth_index - return True - return False - - return sum(1 for ground_truth_index in range(len(ground_truth_entities)) if assign(ground_truth_index, set())) - - -def _entities_match_relaxed( - left: dict[str, Any], - right: dict[str, Any], - *, - require_label_compatible: bool, -) -> bool: - if require_label_compatible and not _entity_labels_compatible(left.get("label"), right.get("label")): - return False - left_span = _entity_span(left) - right_span = _entity_span(right) - if left_span is not None and right_span is not None: - return left_span[0] < right_span[1] and right_span[0] < left_span[1] - left_value = left.get("value") - right_value = right.get("value") - return left_value is not None and right_value is not None and str(left_value) == str(right_value) - - -def _entity_span(entity: dict[str, Any]) -> tuple[int, int] | None: - start = _coerce_float(entity.get("start_position", entity.get("start"))) - end = _coerce_float(entity.get("end_position", entity.get("end"))) - if start is None or end is None: - return None - start_int = int(start) - end_int = int(end) - if start_int < 0 or end_int <= start_int: - return None - return start_int, end_int - - -def _entity_labels_compatible(left: object, right: object) -> bool: - left_key = _entity_label_key(left) - right_key = _entity_label_key(right) - return left_key is not None and right_key is not None and left_key == right_key - - -def _entity_label_key(label: object) -> str | None: - if label is None: - return None - normalized = str(label).strip().lower() - if not normalized: - return None - return _ENTITY_LABEL_EQUIVALENCE.get(normalized, normalized) - - -def _replacement_map_metrics(raw: object) -> dict[str, Any]: - replacement_maps = _replacement_maps_from_raw(raw) - synthetic_values = [] - for item in replacement_maps: - synthetic = item.get("replacement", item.get("synthetic")) - if synthetic is not None: - synthetic_values.append(str(synthetic)) - return { - "replacement_count": len(replacement_maps), - "replacement_label_counts": dict( - sorted(Counter(item.get("label", "") for item in replacement_maps if item.get("label")).items()) - ), - "replacement_duplicate_value_count": max(0, len(synthetic_values) - len(set(synthetic_values))), - } - - -def _replacement_coverage_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: - replacement_original_values = { - str(original) - for item in _replacement_maps_from_raw(raw) - if (original := item.get("original")) is not None and str(original) - } - missing_entities = [ - entity - for entity in final_entities - if entity.get("value") and str(entity.get("value")) not in replacement_original_values - ] - missing_values = {str(entity.get("value")) for entity in missing_entities if entity.get("value")} - return { - "replacement_missing_final_entity_count": len(missing_entities), - "replacement_missing_final_entity_label_counts": dict( - sorted( - Counter(str(entity.get("label") or "") for entity in missing_entities if entity.get("label")).items() - ) - ), - "replacement_missing_final_value_count": len(missing_values), - } - - -def _replacement_collision_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: - synthetic_values = { - str(synthetic) - for item in _replacement_maps_from_raw(raw) - if (synthetic := item.get("replacement", item.get("synthetic"))) is not None and str(synthetic) - } - collided_entities = [ - entity for entity in final_entities if entity.get("value") and str(entity.get("value")) in synthetic_values - ] - collided_values = {str(entity.get("value")) for entity in collided_entities if entity.get("value")} - return { - "replacement_synthetic_original_collision_count": len(collided_entities), - "replacement_synthetic_original_collision_label_counts": dict( - sorted( - Counter(str(entity.get("label") or "") for entity in collided_entities if entity.get("label")).items() - ) - ), - "replacement_synthetic_original_collision_value_count": len(collided_values), - } - - -def _replacement_maps_from_raw(raw: object) -> list[Mapping[str, Any]]: - payload = _coerce_payload(raw) - if isinstance(payload, Mapping): - replacements_raw = cast(Mapping[str, Any], payload).get("replacements") - tolist = getattr(replacements_raw, "tolist", None) - if callable(tolist): - replacements_raw = tolist() - replacements = replacements_raw if isinstance(replacements_raw, list) else [] - elif isinstance(payload, list): - replacements = payload - else: - replacements = [] - return [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] - - -def _count_items(raw: object, *, primary_key: str, fallback_keys: tuple[str, ...] = ()) -> int: - payload = _coerce_payload(raw) - if isinstance(payload, Mapping): - payload_map = cast(Mapping[str, Any], payload) - for key in (primary_key, *fallback_keys): - items = payload_map.get(key) - if isinstance(items, list): - return len(items) - return 0 - if isinstance(payload, list): - return len(payload) - return 0 - - -def _coerce_payload(raw: object) -> object: - model_dump = getattr(raw, "model_dump", None) - if callable(model_dump): - return model_dump(mode="python") - if isinstance(raw, str): - try: - return json.loads(raw) - except json.JSONDecodeError: - return {} - if raw is None: - return {} - return raw - - -def _coerce_int(raw: object, *, default: int) -> int: - try: - return int(cast(Any, raw)) - except (TypeError, ValueError): - return default - - -def _coerce_float(raw: object) -> float | None: - try: - value = float(cast(Any, raw)) - except (TypeError, ValueError): - return None - return None if math.isnan(value) else value - - -def _coerce_bool(raw: object) -> bool | None: - if raw is None: - return None - if isinstance(raw, float) and math.isnan(raw): - return None - if isinstance(raw, bool): - return raw - if isinstance(raw, str): - lowered = raw.strip().lower() - if lowered in {"true", "1", "yes"}: - return True - if lowered in {"false", "0", "no"}: - return False - return None - try: - return bool(cast(Any, raw)) - except (TypeError, ValueError): - return None - - -def _safe_rate(numerator: int | float | None, elapsed_sec: float) -> float | None: - if numerator is None or elapsed_sec <= 0: - return None - return float(numerator) / elapsed_sec - - -def _safe_ratio(numerator: int | float | None, denominator: int | float | None) -> float | None: - if numerator is None or denominator is None or denominator == 0: - return None - return float(numerator) / float(denominator) - - -def _f1(precision: float | None, recall: float | None) -> float | None: - if precision is None or recall is None or precision + recall == 0: - return None - return 2 * precision * recall / (precision + recall) - - -def _size_bucket(value: int) -> str: - if value == 0: - return "0" - for upper in (128, 512, 2048, 8192): - if value < upper: - return f"1-{upper - 1}" if upper == 128 else f"{upper // 4}-{upper - 1}" - return "8192+" - - -def _count_text_tokens(text: str) -> int: - try: - import tiktoken - - tokenizer = tiktoken.get_encoding("cl100k_base") - return len(tokenizer.encode(text, disallowed_special=())) - except Exception: - return len(text.split()) - - -def _summarize_model_usage(model_usage: Mapping[str, Any] | None) -> dict[str, int | None]: - totals = _empty_model_usage_totals() - for usage in (model_usage or {}).values(): - if not isinstance(usage, Mapping): - continue - _add_model_usage_totals(totals, usage) - - if totals["total_tokens"] == 0: - totals["total_tokens"] = totals["input_tokens"] + totals["output_tokens"] - if totals["total_requests"] == 0: - totals["total_requests"] = totals["successful_requests"] + totals["failed_requests"] - - return { - "observed_input_tokens": totals["input_tokens"], - "observed_output_tokens": totals["output_tokens"], - "observed_total_tokens": totals["total_tokens"], - "observed_reasoning_tokens": totals["reasoning_tokens"] if totals["has_reasoning_tokens"] else None, - "observed_successful_requests": totals["successful_requests"], - "observed_failed_requests": totals["failed_requests"], - "observed_total_requests": totals["total_requests"], - } - - -def _empty_model_usage_totals() -> dict[str, int | bool]: - return { - "input_tokens": 0, - "output_tokens": 0, - "total_tokens": 0, - "reasoning_tokens": 0, - "has_reasoning_tokens": False, - "successful_requests": 0, - "failed_requests": 0, - "total_requests": 0, - } - - -def _add_model_usage_totals(totals: dict[str, int | bool], usage: Mapping[str, Any]) -> None: - token_usage = usage.get("token_usage") - if isinstance(token_usage, Mapping): - totals["input_tokens"] += _coerce_int(token_usage.get("input_tokens"), default=0) - totals["output_tokens"] += _coerce_int(token_usage.get("output_tokens"), default=0) - totals["total_tokens"] += _coerce_int(token_usage.get("total_tokens"), default=0) - if token_usage.get("reasoning_tokens") is not None: - totals["has_reasoning_tokens"] = True - totals["reasoning_tokens"] += _coerce_int(token_usage.get("reasoning_tokens"), default=0) - - request_usage = usage.get("request_usage") - if isinstance(request_usage, Mapping): - totals["successful_requests"] += _coerce_int(request_usage.get("successful_requests"), default=0) - totals["failed_requests"] += _coerce_int(request_usage.get("failed_requests"), default=0) - totals["total_requests"] += _coerce_int(request_usage.get("total_requests"), default=0) - - -def _json_safe(value: object) -> Any: - if isinstance(value, dict): - return {str(k): _json_safe(v) for k, v in value.items()} - if isinstance(value, list): - return [_json_safe(v) for v in value] - if isinstance(value, tuple): - return [_json_safe(v) for v in value] - if isinstance(value, set): - return sorted((_json_safe(v) for v in value), key=str) - if isinstance(value, float) and not math.isfinite(value): - return None - if isinstance(value, (str, int, float, bool)) or value is None: - return value - return str(value) diff --git a/src/anonymizer/measurement/__init__.py b/src/anonymizer/measurement/__init__.py new file mode 100644 index 00000000..944ba35e --- /dev/null +++ b/src/anonymizer/measurement/__init__.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from anonymizer.measurement.collector import MeasurementCollector +from anonymizer.measurement.config import MeasurementConfig +from anonymizer.measurement.constants import ( + DD_TRACE_MODES, + DEFAULT_MEASUREMENT_ENV_PREFIX, + MEASUREMENT_SCHEMA_VERSION, + DDTraceMode, +) +from anonymizer.measurement.metrics.llm_calls import estimate_llm_calls_by_stage +from anonymizer.measurement.recorders import ( + record_model_workflow, + record_ndd_workflow, + record_run_metadata, + record_stage, + stage_timer, +) +from anonymizer.measurement.records.row import record_record_metrics +from anonymizer.measurement.session import configured_measurement_session, current_collector, measurement_session + +__all__ = [ + "DD_TRACE_MODES", + "DDTraceMode", + "DEFAULT_MEASUREMENT_ENV_PREFIX", + "MEASUREMENT_SCHEMA_VERSION", + "MeasurementCollector", + "MeasurementConfig", + "configured_measurement_session", + "current_collector", + "estimate_llm_calls_by_stage", + "measurement_session", + "record_model_workflow", + "record_ndd_workflow", + "record_record_metrics", + "record_run_metadata", + "record_stage", + "stage_timer", +] diff --git a/src/anonymizer/measurement/_coerce.py b/src/anonymizer/measurement/_coerce.py new file mode 100644 index 00000000..05d2cfcb --- /dev/null +++ b/src/anonymizer/measurement/_coerce.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import math +from collections.abc import Mapping +from numbers import Integral +from typing import Any, cast + + +def _safe_row_index(row_index: object) -> int | None: + if isinstance(row_index, bool): + return None + if isinstance(row_index, Integral): + return int(row_index) + return None + + +def _count_items(raw: object, *, primary_key: str, fallback_keys: tuple[str, ...] = ()) -> int: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + payload_map = cast(Mapping[str, Any], payload) + for key in (primary_key, *fallback_keys): + items = payload_map.get(key) + if isinstance(items, list): + return len(items) + return 0 + if isinstance(payload, list): + return len(payload) + return 0 + + +def _coerce_payload(raw: object) -> object: + model_dump = getattr(raw, "model_dump", None) + if callable(model_dump): + return model_dump(mode="python") + if isinstance(raw, str): + try: + return json.loads(raw) + except json.JSONDecodeError: + return {} + if raw is None: + return {} + return raw + + +def _coerce_int(raw: object, *, default: int) -> int: + try: + return int(cast(Any, raw)) + except (TypeError, ValueError): + return default + + +def _coerce_float(raw: object) -> float | None: + try: + value = float(cast(Any, raw)) + except (TypeError, ValueError): + return None + return None if math.isnan(value) else value + + +def _coerce_bool(raw: object) -> bool | None: + if raw is None: + return None + if isinstance(raw, float) and math.isnan(raw): + return None + if isinstance(raw, bool): + return raw + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered in {"true", "1", "yes"}: + return True + if lowered in {"false", "0", "no"}: + return False + return None + try: + return bool(cast(Any, raw)) + except (TypeError, ValueError): + return None + + +def _safe_rate(numerator: int | float | None, elapsed_sec: float) -> float | None: + if numerator is None or elapsed_sec <= 0: + return None + return float(numerator) / elapsed_sec + + +def _safe_ratio(numerator: int | float | None, denominator: int | float | None) -> float | None: + if numerator is None or denominator is None or denominator == 0: + return None + return float(numerator) / float(denominator) + + +def _f1(precision: float | None, recall: float | None) -> float | None: + if precision is None or recall is None or precision + recall == 0: + return None + return 2 * precision * recall / (precision + recall) + + +def _size_bucket(value: int) -> str: + if value == 0: + return "0" + for upper in (128, 512, 2048, 8192): + if value < upper: + return f"1-{upper - 1}" if upper == 128 else f"{upper // 4}-{upper - 1}" + return "8192+" + + +def _count_text_tokens(text: str) -> int: + try: + import tiktoken + + tokenizer = tiktoken.get_encoding("cl100k_base") + return len(tokenizer.encode(text, disallowed_special=())) + except Exception: + return len(text.split()) + + +def _json_safe(value: object) -> Any: + if isinstance(value, dict): + return {str(k): _json_safe(v) for k, v in value.items()} + if isinstance(value, list): + return [_json_safe(v) for v in value] + if isinstance(value, tuple): + return [_json_safe(v) for v in value] + if isinstance(value, set): + return sorted((_json_safe(v) for v in value), key=str) + if isinstance(value, float) and not math.isfinite(value): + return None + if isinstance(value, (str, int, float, bool)) or value is None: + return value + return str(value) diff --git a/src/anonymizer/measurement/collector.py b/src/anonymizer/measurement/collector.py new file mode 100644 index 00000000..704b6dfa --- /dev/null +++ b/src/anonymizer/measurement/collector.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +import hmac +import json +import logging +import secrets +import time +import uuid +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, cast + +from anonymizer.measurement._coerce import _json_safe +from anonymizer.measurement.constants import MEASUREMENT_SCHEMA_VERSION, DDTraceMode +from anonymizer.measurement.sinks import _JsonlMeasurementWriter, _JsonMeasurementWriter, _MeasurementSink + +if TYPE_CHECKING: + import pandas as pd + +logger = logging.getLogger("anonymizer.measurement") + + +class MeasurementCollector: + """In-memory collector for local benchmark and throughput records. + + Records contain counts, labels, lengths, aliases, timings, and run-scoped + HMACs. They must not contain raw text, entity values, prompts, generated + outputs, replacement maps, provider secrets, or API keys. + """ + + def __init__( + self, + *, + run_id: str | None = None, + record_hash_key: bytes | str | None = None, + record_level: bool = True, + run_tags: Mapping[str, Any] | None = None, + record_sink: _MeasurementSink | None = None, + keep_records: bool = True, + dd_trace_mode: DDTraceMode = "none", + dd_trace_sink: _MeasurementSink | None = None, + dd_task_trace_sink: _MeasurementSink | None = None, + fail_on_write_error: bool = False, + ) -> None: + self.run_id = run_id or uuid.uuid4().hex + self.record_level = record_level + self.run_tags = cast(dict[str, Any], _json_safe(dict(run_tags or {}))) + self._record_sink = record_sink + self._keep_records = keep_records + self._dd_trace_mode = dd_trace_mode + self._dd_trace_sink = dd_trace_sink + self._dd_task_trace_sink = dd_task_trace_sink + self._fail_on_write_error = fail_on_write_error + self._sink_failed = False + self._dd_trace_failed = False + self._dd_task_trace_failed = False + if record_hash_key is None: + self._record_hash_key = secrets.token_bytes(32) + elif isinstance(record_hash_key, str): + self._record_hash_key = record_hash_key.encode("utf-8") + else: + self._record_hash_key = bytes(record_hash_key) + self._records: list[dict[str, Any]] = [] + + @property + def records(self) -> list[dict[str, Any]]: + """Return a shallow copy of collected measurement records.""" + return list(self._records) + + def record(self, record_type: str, **fields: Any) -> None: + """Append one machine-readable measurement record.""" + record = { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": record_type, + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + safe_record = _json_safe(record) + if self._keep_records: + self._records.append(safe_record) + if self._record_sink is not None: + self._write_record_to_sink(safe_record) + + def close(self) -> None: + """Close any streaming measurement sink attached to this collector.""" + close_error: Exception | None = None + for sink in (self._record_sink, self._dd_trace_sink, self._dd_task_trace_sink): + if sink is None: + continue + try: + sink.close() + except Exception as exc: + if close_error is None: + close_error = exc + if close_error is not None: + raise close_error + + @property + def dd_trace_mode(self) -> DDTraceMode: + return self._dd_trace_mode + + @property + def dd_trace_enabled(self) -> bool: + return self._dd_trace_mode != "none" and self._dd_trace_sink is not None + + @property + def dd_task_trace_enabled(self) -> bool: + return self._dd_task_trace_sink is not None + + def record_dd_message_trace(self, **fields: Any) -> None: + """Write an explicitly opt-in DataDesigner message trace record. + + These records may contain raw prompts, input text, model outputs, and + PII. They are intentionally written to a separate trace sink and are + never appended to the safe measurement record list. + """ + if not self.dd_trace_enabled or self._dd_trace_failed: + return + + record = _json_safe( + { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": "dd_message_trace", + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + ) + try: + cast(_MeasurementSink, self._dd_trace_sink).write_record(record) + except Exception: + self._dd_trace_failed = True + logger.warning("Failed to write DataDesigner message trace records") + if self._fail_on_write_error: + raise + + def record_dd_task_trace(self, **fields: Any) -> None: + """Write an opt-in sanitized DataDesigner scheduler task trace record.""" + if not self.dd_task_trace_enabled or self._dd_task_trace_failed: + return + + record = _json_safe( + { + **fields, + "schema_version": MEASUREMENT_SCHEMA_VERSION, + "record_type": "dd_task_trace", + "run_id": self.run_id, + "run_tags": self.run_tags, + "timestamp_unix_sec": time.time(), + } + ) + try: + cast(_MeasurementSink, self._dd_task_trace_sink).write_record(record) + except Exception: + self._dd_task_trace_failed = True + logger.warning("Failed to write DataDesigner task trace records") + if self._fail_on_write_error: + raise + + def _write_record_to_sink(self, record: dict[str, Any]) -> None: + if self._sink_failed: + return + try: + cast(_MeasurementSink, self._record_sink).write_record(record) + except Exception: + self._sink_failed = True + logger.warning("Failed to stream Anonymizer measurement records") + if self._fail_on_write_error: + raise + + def record_hash(self, *, row_index: object, text: str) -> str: + """Return a run-scoped HMAC for joining records without storing text.""" + serialized = json.dumps( + {"row_index": str(row_index), "text": text}, + default=str, + sort_keys=True, + separators=(",", ":"), + ) + return hmac.new(self._record_hash_key, serialized.encode("utf-8"), hashlib.sha256).hexdigest() + + def write_jsonl(self, path: str | Path) -> None: + """Write records as newline-delimited JSON.""" + _JsonlMeasurementWriter().write(self._records, path) + + def write_json(self, path: str | Path) -> None: + """Write records as a JSON array.""" + _JsonMeasurementWriter().write(self._records, path) + + def to_dataframe(self) -> pd.DataFrame: + """Return records as a pandas DataFrame for benchmark tooling.""" + import pandas as pd + + return pd.DataFrame(self._records) diff --git a/src/anonymizer/measurement/config.py b/src/anonymizer/measurement/config.py new file mode 100644 index 00000000..623f3549 --- /dev/null +++ b/src/anonymizer/measurement/config.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal, cast + +from pydantic import Field, ValidationError +from pydantic_settings import BaseSettings, SettingsConfigDict, SettingsError + +from anonymizer.measurement.collector import MeasurementCollector +from anonymizer.measurement.constants import DD_TRACE_MODES, DEFAULT_MEASUREMENT_ENV_PREFIX, DDTraceMode +from anonymizer.measurement.sinks import _writer_for_format + + +class _MeasurementEnvSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix=DEFAULT_MEASUREMENT_ENV_PREFIX, + env_ignore_empty=True, + extra="ignore", + ) + + output_path: str | None = None + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | None = None + dd_task_trace_path: str | None = None + fail_on_write_error: bool = False + run_id: str | None = None + run_tags: dict[str, Any] = Field(default_factory=dict) + + +@dataclass(frozen=True) +class MeasurementConfig: + """Configuration for writing structured measurement records around a run.""" + + output_path: str | Path + output_format: Literal["jsonl", "json"] = "jsonl" + record_level: bool = True + streaming: bool = False + keep_records: bool = True + dd_trace: DDTraceMode = "none" + dd_trace_path: str | Path | None = None + dd_task_trace_path: str | Path | None = None + run_id: str | None = None + record_hash_key: bytes | str | None = None + run_tags: Mapping[str, Any] | None = None + fail_on_write_error: bool = False + + def __post_init__(self) -> None: + if self.output_format not in {"jsonl", "json"}: + raise ValueError("output_format must be 'jsonl' or 'json'") + if self.streaming and self.output_format != "jsonl": + raise ValueError("streaming measurement output only supports jsonl") + if self.dd_trace not in DD_TRACE_MODES: + raise ValueError("dd_trace must be 'none', 'last_message', or 'all_messages'") + if self.dd_trace != "none" and self.dd_trace_path is None: + raise ValueError("dd_trace_path is required when dd_trace is enabled") + + @classmethod + def from_env(cls, *, prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX) -> MeasurementConfig | None: + """Build measurement config from environment variables, or None if output is unset. + + This is intentionally opt-in. Anonymizer API and CLI calls do not read + measurement environment variables unless benchmark/tooling code calls this + helper explicitly. + """ + try: + settings = _load_measurement_env_settings(prefix=prefix) + except (SettingsError, ValidationError) as exc: + raise ValueError(_measurement_env_error_message(exc, prefix=prefix)) from None + + if settings.output_path is None: + return None + return cls( + output_path=settings.output_path, + output_format=settings.output_format, + record_level=settings.record_level, + streaming=settings.streaming, + keep_records=settings.keep_records, + dd_trace=settings.dd_trace, + dd_trace_path=settings.dd_trace_path, + dd_task_trace_path=settings.dd_task_trace_path, + run_id=settings.run_id, + run_tags=settings.run_tags, + fail_on_write_error=settings.fail_on_write_error, + ) + + @classmethod + def from_sources( + cls, + explicit: MeasurementConfig | None = None, + *, + env: bool = False, + prefix: str = DEFAULT_MEASUREMENT_ENV_PREFIX, + ) -> MeasurementConfig | None: + """Resolve measurement config from explicit config first, then optional env.""" + if explicit is not None: + return explicit + if env: + return cls.from_env(prefix=prefix) + return None + + def write_collector(self, collector: MeasurementCollector) -> None: + """Write a collector using this config's output format.""" + _writer_for_format(self.output_format).write(collector.records, self.output_path) + + +def _measurement_env_error_message(exc: SettingsError | ValidationError, *, prefix: str) -> str: + fields: set[str] = set() + if isinstance(exc, ValidationError): + for error in exc.errors(include_input=False): + loc = error.get("loc", ()) + if loc: + fields.add(str(loc[0]).upper()) + else: + error_text = str(exc).lower() + for field_name in _MeasurementEnvSettings.model_fields: + if field_name in error_text: + fields.add(field_name.upper()) + + if fields: + env_fields = ", ".join(f"{prefix}{field}" for field in sorted(fields)) + return f"Invalid Anonymizer measurement environment configuration for: {env_fields}" + return "Invalid Anonymizer measurement environment configuration" + + +def _load_measurement_env_settings(*, prefix: str) -> _MeasurementEnvSettings: + settings_factory = cast(Any, _MeasurementEnvSettings) + return cast(_MeasurementEnvSettings, settings_factory(_env_prefix=prefix)) diff --git a/src/anonymizer/measurement/constants.py b/src/anonymizer/measurement/constants.py new file mode 100644 index 00000000..0fd1ff89 --- /dev/null +++ b/src/anonymizer/measurement/constants.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Literal + +MEASUREMENT_SCHEMA_VERSION = 1 +DEFAULT_MEASUREMENT_ENV_PREFIX = "ANONYMIZER_MEASUREMENT_" +DD_TRACE_MODES = {"none", "last_message", "all_messages"} +DDTraceMode = Literal["none", "last_message", "all_messages"] diff --git a/src/anonymizer/measurement/metrics/__init__.py b/src/anonymizer/measurement/metrics/__init__.py new file mode 100644 index 00000000..3d2894b7 --- /dev/null +++ b/src/anonymizer/measurement/metrics/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/src/anonymizer/measurement/metrics/entities.py b/src/anonymizer/measurement/metrics/entities.py new file mode 100644 index 00000000..da2bfe3a --- /dev/null +++ b/src/anonymizer/measurement/metrics/entities.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter +from collections.abc import Mapping +from typing import Any, cast + +from anonymizer.measurement._coerce import _coerce_float, _coerce_payload, _f1, _safe_ratio + +_GROUND_TRUTH_ENTITY_COLUMNS = ("ground_truth_entities", "gt_entities", "expected_entities") +_ENTITY_LABEL_EQUIVALENCE_CLASSES = ( + frozenset( + { + "access_token", + "api_key", + "auth_token", + "bearer_token", + "password", + "secret_key", + "session_id", + "unique_id", + "user_id", + } + ), + frozenset({"full_name", "person_name", "user", "user_name", "username"}), + frozenset({"phone", "phone_number", "telephone"}), + frozenset({"email", "email_address"}), + frozenset({"cookie", "http_cookie", "session_cookie"}), +) +_ENTITY_LABEL_EQUIVALENCE: dict[str, str] = { + label: sorted(labels)[0] for labels in _ENTITY_LABEL_EQUIVALENCE_CLASSES for label in labels +} + + +def _entities_from_raw(raw: object) -> list[dict[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + items = cast(Mapping[str, Any], payload).get("entities", []) + elif isinstance(payload, list): + items = payload + else: + items = [] + return [dict(cast(Mapping[str, Any], item)) for item in items if isinstance(item, Mapping)] + + +def _entity_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]] | None, +) -> dict[str, Any]: + if ground_truth_entities is None: + return { + "ground_truth_entity_count": None, + "ground_truth_entity_label_counts": None, + "entity_true_positive_count": None, + "entity_false_positive_count": None, + "entity_false_negative_count": None, + "entity_precision": None, + "entity_recall": None, + "entity_f1": None, + "entity_relaxed_gt_found_count": None, + "entity_relaxed_detected_tp_count": None, + "entity_relaxed_label_compatible_gt_found_count": None, + "entity_relaxed_label_compatible_detected_tp_count": None, + "entity_relaxed_precision": None, + "entity_relaxed_recall": None, + "entity_relaxed_f1": None, + "entity_relaxed_label_compatible_precision": None, + "entity_relaxed_label_compatible_recall": None, + "entity_relaxed_label_compatible_f1": None, + } + + predicted = _entity_identity_counts(final_entities) + expected = _entity_identity_counts(ground_truth_entities) + true_positive = sum((predicted & expected).values()) + false_positive = sum((predicted - expected).values()) + false_negative = sum((expected - predicted).values()) + precision = _safe_ratio(true_positive, true_positive + false_positive) + recall = _safe_ratio(true_positive, true_positive + false_negative) + return { + "ground_truth_entity_count": len(ground_truth_entities), + "ground_truth_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in ground_truth_entities if e.get("label")).items()) + ), + "entity_true_positive_count": true_positive, + "entity_false_positive_count": false_positive, + "entity_false_negative_count": false_negative, + "entity_precision": precision, + "entity_recall": recall, + "entity_f1": _f1(precision, recall), + **_entity_relaxed_ground_truth_metrics(final_entities, ground_truth_entities), + } + + +def _entity_identity_counts(entities: list[dict[str, Any]]) -> Counter[tuple[str, str]]: + identities: Counter[tuple[str, str]] = Counter() + for entity in entities: + label = entity.get("label") + value = entity.get("value") + if label is None or value is None: + continue + identities[(str(value), str(label))] += 1 + return identities + + +def _entity_relaxed_ground_truth_metrics( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]], +) -> dict[str, Any]: + relaxed_match_count = _relaxed_entity_match_count(final_entities, ground_truth_entities) + label_compatible_match_count = _relaxed_entity_match_count( + final_entities, + ground_truth_entities, + require_label_compatible=True, + ) + gt_found = relaxed_match_count + detected_tp = relaxed_match_count + label_compatible_gt_found = label_compatible_match_count + label_compatible_detected_tp = label_compatible_match_count + precision = _safe_ratio(detected_tp, len(final_entities)) + recall = _safe_ratio(gt_found, len(ground_truth_entities)) + label_compatible_precision = _safe_ratio(label_compatible_detected_tp, len(final_entities)) + label_compatible_recall = _safe_ratio(label_compatible_gt_found, len(ground_truth_entities)) + return { + "entity_relaxed_gt_found_count": gt_found, + "entity_relaxed_detected_tp_count": detected_tp, + "entity_relaxed_label_compatible_gt_found_count": label_compatible_gt_found, + "entity_relaxed_label_compatible_detected_tp_count": label_compatible_detected_tp, + "entity_relaxed_precision": precision, + "entity_relaxed_recall": recall, + "entity_relaxed_f1": _f1(precision, recall), + "entity_relaxed_label_compatible_precision": label_compatible_precision, + "entity_relaxed_label_compatible_recall": label_compatible_recall, + "entity_relaxed_label_compatible_f1": _f1(label_compatible_precision, label_compatible_recall), + } + + +def _relaxed_entity_match_count( + final_entities: list[dict[str, Any]], + ground_truth_entities: list[dict[str, Any]], + *, + require_label_compatible: bool = False, +) -> int: + matches_by_ground_truth = [ + [ + final_index + for final_index, final_entity in enumerate(final_entities) + if _entities_match_relaxed( + final_entity, + ground_truth_entity, + require_label_compatible=require_label_compatible, + ) + ] + for ground_truth_entity in ground_truth_entities + ] + matched_ground_truth_by_final: dict[int, int] = {} + + def assign(ground_truth_index: int, seen: set[int]) -> bool: + for final_index in matches_by_ground_truth[ground_truth_index]: + if final_index in seen: + continue + seen.add(final_index) + if final_index not in matched_ground_truth_by_final or assign( + matched_ground_truth_by_final[final_index], + seen, + ): + matched_ground_truth_by_final[final_index] = ground_truth_index + return True + return False + + return sum(1 for ground_truth_index in range(len(ground_truth_entities)) if assign(ground_truth_index, set())) + + +def _entities_match_relaxed( + left: dict[str, Any], + right: dict[str, Any], + *, + require_label_compatible: bool, +) -> bool: + if require_label_compatible and not _entity_labels_compatible(left.get("label"), right.get("label")): + return False + left_span = _entity_span(left) + right_span = _entity_span(right) + if left_span is not None and right_span is not None: + return left_span[0] < right_span[1] and right_span[0] < left_span[1] + left_value = left.get("value") + right_value = right.get("value") + return left_value is not None and right_value is not None and str(left_value) == str(right_value) + + +def _entity_span(entity: dict[str, Any]) -> tuple[int, int] | None: + start = _coerce_float(entity.get("start_position", entity.get("start"))) + end = _coerce_float(entity.get("end_position", entity.get("end"))) + if start is None or end is None: + return None + start_int = int(start) + end_int = int(end) + if start_int < 0 or end_int <= start_int: + return None + return start_int, end_int + + +def _entity_labels_compatible(left: object, right: object) -> bool: + left_key = _entity_label_key(left) + right_key = _entity_label_key(right) + return left_key is not None and right_key is not None and left_key == right_key + + +def _entity_label_key(label: object) -> str | None: + if label is None: + return None + normalized = str(label).strip().lower() + if not normalized: + return None + return _ENTITY_LABEL_EQUIVALENCE.get(normalized, normalized) diff --git a/src/anonymizer/measurement/metrics/llm_calls.py b/src/anonymizer/measurement/metrics/llm_calls.py new file mode 100644 index 00000000..545c61cf --- /dev/null +++ b/src/anonymizer/measurement/metrics/llm_calls.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math + + +def estimate_llm_calls_by_stage( + *, + mode: str, + strategy: str, + has_grouped_entities: bool, + validation_chunk_count: int | None, + repair_iterations: int = 0, + replace_map_generation_uses_llm: bool = True, +) -> dict[str, int | None]: + """Estimate nominal model calls for one record, split by workflow stage.""" + detection_calls = None if validation_chunk_count is None else 2 + validation_chunk_count + replace_map_generation = 0 + if replace_map_generation_uses_llm and has_grouped_entities and (mode == "rewrite" or strategy == "Substitute"): + replace_map_generation = 1 + + if mode != "rewrite": + return { + "entity_detection": detection_calls, + "replace_map_generation": replace_map_generation, + } + + rewrite_body_calls = has_grouped_entities + return { + "entity_detection": detection_calls, + "latent_entity_detection": 1 if rewrite_body_calls else 0, + "replace_map_generation": replace_map_generation, + "rewrite_pipeline": 5 if rewrite_body_calls else 0, + "rewrite_evaluate": 3 * (1 + repair_iterations) if rewrite_body_calls else 0, + "rewrite_repair": repair_iterations if rewrite_body_calls else 0, + "rewrite_final_judge": 1 if rewrite_body_calls else 0, + } + + +def _validation_chunk_count( + detected_candidate_count: int | None, + *, + validation_max_entities_per_call: int, +) -> int | None: + if detected_candidate_count is None: + return None + if detected_candidate_count <= 0: + return 0 + return int(math.ceil(detected_candidate_count / validation_max_entities_per_call)) diff --git a/src/anonymizer/measurement/metrics/replacements.py b/src/anonymizer/measurement/metrics/replacements.py new file mode 100644 index 00000000..01ed9a84 --- /dev/null +++ b/src/anonymizer/measurement/metrics/replacements.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter +from collections.abc import Mapping +from typing import Any, cast + +from anonymizer.measurement._coerce import _coerce_payload + + +def _replacement_map_metrics(raw: object) -> dict[str, Any]: + replacement_maps = _replacement_maps_from_raw(raw) + synthetic_values = [] + for item in replacement_maps: + synthetic = item.get("replacement", item.get("synthetic")) + if synthetic is not None: + synthetic_values.append(str(synthetic)) + return { + "replacement_count": len(replacement_maps), + "replacement_label_counts": dict( + sorted(Counter(item.get("label", "") for item in replacement_maps if item.get("label")).items()) + ), + "replacement_duplicate_value_count": max(0, len(synthetic_values) - len(set(synthetic_values))), + } + + +def _replacement_coverage_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + replacement_original_values = { + str(original) + for item in _replacement_maps_from_raw(raw) + if (original := item.get("original")) is not None and str(original) + } + missing_entities = [ + entity + for entity in final_entities + if entity.get("value") and str(entity.get("value")) not in replacement_original_values + ] + missing_values = {str(entity.get("value")) for entity in missing_entities if entity.get("value")} + return { + "replacement_missing_final_entity_count": len(missing_entities), + "replacement_missing_final_entity_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in missing_entities if entity.get("label")).items() + ) + ), + "replacement_missing_final_value_count": len(missing_values), + } + + +def _replacement_collision_metrics(raw: object, final_entities: list[dict[str, Any]]) -> dict[str, Any]: + synthetic_values = { + str(synthetic) + for item in _replacement_maps_from_raw(raw) + if (synthetic := item.get("replacement", item.get("synthetic"))) is not None and str(synthetic) + } + collided_entities = [ + entity for entity in final_entities if entity.get("value") and str(entity.get("value")) in synthetic_values + ] + collided_values = {str(entity.get("value")) for entity in collided_entities if entity.get("value")} + return { + "replacement_synthetic_original_collision_count": len(collided_entities), + "replacement_synthetic_original_collision_label_counts": dict( + sorted( + Counter(str(entity.get("label") or "") for entity in collided_entities if entity.get("label")).items() + ) + ), + "replacement_synthetic_original_collision_value_count": len(collided_values), + } + + +def _replacement_maps_from_raw(raw: object) -> list[Mapping[str, Any]]: + payload = _coerce_payload(raw) + if isinstance(payload, Mapping): + replacements_raw = cast(Mapping[str, Any], payload).get("replacements") + tolist = getattr(replacements_raw, "tolist", None) + if callable(tolist): + replacements_raw = tolist() + replacements = replacements_raw if isinstance(replacements_raw, list) else [] + elif isinstance(payload, list): + replacements = payload + else: + replacements = [] + return [cast(Mapping[str, Any], item) for item in replacements if isinstance(item, Mapping)] diff --git a/src/anonymizer/measurement/metrics/rewrite.py b/src/anonymizer/measurement/metrics/rewrite.py new file mode 100644 index 00000000..ea68ec49 --- /dev/null +++ b/src/anonymizer/measurement/metrics/rewrite.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter +from typing import Any + +from anonymizer.measurement._coerce import _coerce_bool, _coerce_float + + +def _rewrite_record_fields(row: Any, *, columns: set[str]) -> dict[str, Any]: + from anonymizer.engine.constants import ( + COL_ANY_HIGH_LEAKED, + COL_LEAKAGE_MASS, + COL_NEEDS_HUMAN_REVIEW, + COL_NEEDS_REPAIR, + COL_UTILITY_SCORE, + COL_WEIGHTED_LEAKAGE_RATE, + ) + + return { + "utility_score": _coerce_float(row.get(COL_UTILITY_SCORE)) if COL_UTILITY_SCORE in columns else None, + "leakage_mass": _coerce_float(row.get(COL_LEAKAGE_MASS)) if COL_LEAKAGE_MASS in columns else None, + "weighted_leakage_rate": ( + _coerce_float(row.get(COL_WEIGHTED_LEAKAGE_RATE)) if COL_WEIGHTED_LEAKAGE_RATE in columns else None + ), + "any_high_leaked": _coerce_bool(row.get(COL_ANY_HIGH_LEAKED)) if COL_ANY_HIGH_LEAKED in columns else None, + "needs_human_review": ( + _coerce_bool(row.get(COL_NEEDS_HUMAN_REVIEW)) if COL_NEEDS_HUMAN_REVIEW in columns else None + ), + "needs_repair": _coerce_bool(row.get(COL_NEEDS_REPAIR)) if COL_NEEDS_REPAIR in columns else None, + } + + +def _original_value_leak_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: + output_column = _output_text_column(columns) + if output_column is None: + return {"original_value_leak_count": None, "original_value_leak_label_counts": {}} + output_text = str(row.get(output_column, "")) + leaked = [ + entity + for entity in final_entities + if entity.get("value") and _output_contains_original_value(output_text, str(entity.get("value"))) + ] + return { + "original_value_leak_count": len(leaked), + "original_value_leak_label_counts": dict( + sorted(Counter(str(entity.get("label") or "") for entity in leaked if entity.get("label")).items()) + ), + } + + +def _output_contains_original_value(output_text: str, value: str) -> bool: + if _needs_boundary_sensitive_leak_match(value): + return _contains_with_alnum_boundaries(output_text, value) + return value in output_text + + +def _needs_boundary_sensitive_leak_match(value: str) -> bool: + return len(value) <= 4 or value.isdigit() + + +def _contains_with_alnum_boundaries(output_text: str, value: str) -> bool: + start = 0 + while True: + match_start = output_text.find(value, start) + if match_start < 0: + return False + match_end = match_start + len(value) + if _has_alnum_boundaries(output_text, match_start, match_end): + return True + start = match_start + 1 + + +def _has_alnum_boundaries(text: str, start: int, end: int) -> bool: + before_is_alnum = start > 0 and text[start - 1].isalnum() + after_is_alnum = end < len(text) and text[end].isalnum() + return not before_is_alnum and not after_is_alnum + + +def _output_text_column(columns: set[str]) -> str | None: + from anonymizer.engine.constants import COL_REPLACED_TEXT, COL_REWRITTEN_TEXT + + if COL_REPLACED_TEXT in columns: + return COL_REPLACED_TEXT + if COL_REWRITTEN_TEXT in columns: + return COL_REWRITTEN_TEXT + return None diff --git a/src/anonymizer/measurement/recorders.py b/src/anonymizer/measurement/recorders.py new file mode 100644 index 00000000..1a1bbf3f --- /dev/null +++ b/src/anonymizer/measurement/recorders.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +from collections.abc import Iterator, Mapping +from contextlib import contextmanager +from typing import Any + +from anonymizer.measurement._coerce import _coerce_int +from anonymizer.measurement.records.model import _model_workflow_fields, _row_throughput_fields, _summarize_model_usage +from anonymizer.measurement.records.run import ( + _detect_config_metadata, + _model_config_metadata, + _replace_config_metadata, + _rewrite_config_metadata, + _runtime_metadata, + _source_metadata, +) +from anonymizer.measurement.session import current_collector + + +@contextmanager +def stage_timer(stage: str, **fields: Any) -> Iterator[dict[str, Any]]: + """Record wall time for a stage when collection is active.""" + collector = current_collector() + if collector is None: + yield fields + return + + started = time.perf_counter() + status = "completed" + try: + yield fields + except BaseException: + status = "error" + raise + finally: + elapsed_sec = time.perf_counter() - started + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_stage(stage: str, *, elapsed_sec: float, status: str = "completed", **fields: Any) -> None: + """Record a pre-timed stage measurement if collection is active.""" + collector = current_collector() + if collector is None: + return + collector.record( + "stage", + stage=stage, + status=status, + elapsed_sec=elapsed_sec, + **fields, + **_row_throughput_fields( + elapsed_sec=elapsed_sec, + input_row_count=_coerce_int(fields.get("input_row_count"), default=-1), + output_row_count=_coerce_int(fields.get("output_row_count"), default=-1), + ), + ) + + +def record_ndd_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, +) -> None: + """Record one DataDesigner workflow execution through the adapter boundary.""" + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="ndd_workflow", + extra_fields=None, + ) + + +def record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str = "completed", + seed_row_count: int | None = None, + preview_num_records: int | None = None, + column_count: int | None = None, + column_names: list[str] | None = None, + model_usage: Mapping[str, Any] | None = None, + extra_fields: Mapping[str, Any] | None = None, +) -> None: + """Record one sanitized model-backed workflow execution. + + Use this for non-DataDesigner model calls that still need benchmark + accounting. Raw prompts, text, responses, and replacement values do not + belong in ``model_usage``. + """ + _record_model_workflow( + workflow_name=workflow_name, + model_aliases=model_aliases, + input_row_count=input_row_count, + output_row_count=output_row_count, + failed_record_count=failed_record_count, + elapsed_sec=elapsed_sec, + status=status, + seed_row_count=seed_row_count, + preview_num_records=preview_num_records, + column_count=column_count, + column_names=column_names, + model_usage=model_usage, + record_type="model_workflow", + extra_fields=extra_fields, + ) + + +def _record_model_workflow( + *, + workflow_name: str, + model_aliases: list[str], + input_row_count: int, + output_row_count: int | None, + failed_record_count: int | None, + elapsed_sec: float, + status: str, + seed_row_count: int | None, + preview_num_records: int | None, + column_count: int | None, + column_names: list[str] | None, + model_usage: Mapping[str, Any] | None, + record_type: str, + extra_fields: Mapping[str, Any] | None, +) -> None: + collector = current_collector() + if collector is None: + return + observed_usage = _summarize_model_usage(model_usage) + workflow_fields = { + "workflow_name": workflow_name, + "status": status, + "model_aliases": sorted(set(model_aliases)), + "input_row_count": input_row_count, + "seed_row_count": seed_row_count, + "output_row_count": output_row_count, + "failed_record_count": failed_record_count, + "elapsed_sec": elapsed_sec, + "preview_num_records": preview_num_records, + "column_count": column_count, + "column_names": column_names or [], + "model_usage": dict(model_usage or {}), + **dict(extra_fields or {}), + } + collector.record(record_type, **_model_workflow_fields(workflow_fields, observed_usage)) + + +def record_run_metadata( + *, + config: Any, + data: Any, + mode: str, + strategy: str, + input_row_count: int, + preview_num_records: int | None, + model_configs: list[Any], +) -> None: + """Record sanitized run/config metadata once per anonymizer run.""" + collector = current_collector() + if collector is None: + return + + detect = getattr(config, "detect", None) + source = str(getattr(data, "source", "")) + collector.record( + "run", + mode=mode, + strategy=strategy, + input_row_count=input_row_count, + preview_num_records=preview_num_records, + source_hash=collector.record_hash(row_index="source", text=source), + input_source=_source_metadata(source), + input_text_column=str(getattr(data, "text_column", "")), + input_has_id_column=bool(getattr(data, "id_column", None)), + input_has_data_summary=bool(getattr(data, "data_summary", None)), + detect=_detect_config_metadata(detect), + replace=_replace_config_metadata(getattr(config, "replace", None)), + rewrite=_rewrite_config_metadata(getattr(config, "rewrite", None)), + models=[_model_config_metadata(model_config) for model_config in model_configs], + runtime=_runtime_metadata(), + ) diff --git a/src/anonymizer/measurement/records/__init__.py b/src/anonymizer/measurement/records/__init__.py new file mode 100644 index 00000000..3d2894b7 --- /dev/null +++ b/src/anonymizer/measurement/records/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/src/anonymizer/measurement/records/model.py b/src/anonymizer/measurement/records/model.py new file mode 100644 index 00000000..826270e5 --- /dev/null +++ b/src/anonymizer/measurement/records/model.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +from anonymizer.measurement._coerce import _coerce_int, _safe_rate, _safe_ratio + + +def _model_workflow_fields(fields: dict[str, Any], observed_usage: dict[str, int | None]) -> dict[str, Any]: + return { + **fields, + **observed_usage, + "observed_failed_request_rate": _safe_ratio( + observed_usage["observed_failed_requests"], + observed_usage["observed_total_requests"], + ), + **_throughput_fields( + elapsed_sec=cast(float, fields["elapsed_sec"]), + input_row_count=cast(int, fields["input_row_count"]), + output_row_count=cast(int | None, fields["output_row_count"]), + total_tokens=observed_usage["observed_total_tokens"], + total_requests=observed_usage["observed_total_requests"], + successful_requests=observed_usage["observed_successful_requests"], + ), + } + + +def _throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, + total_tokens: int | None, + total_requests: int | None, + successful_requests: int | None, +) -> dict[str, float | None]: + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + "observed_tokens_per_sec": _safe_rate(total_tokens, elapsed_sec), + "observed_requests_per_sec": _safe_rate(total_requests, elapsed_sec), + "observed_tokens_per_successful_request": _safe_ratio(total_tokens, successful_requests), + } + + +def _row_throughput_fields( + *, + elapsed_sec: float, + input_row_count: int | None, + output_row_count: int | None, +) -> dict[str, float | None]: + if input_row_count is not None and input_row_count < 0: + input_row_count = None + if output_row_count is not None and output_row_count < 0: + output_row_count = None + return { + "input_rows_per_sec": _safe_rate(input_row_count, elapsed_sec), + "output_rows_per_sec": _safe_rate(output_row_count, elapsed_sec), + } + + +def _summarize_model_usage(model_usage: Mapping[str, Any] | None) -> dict[str, int | None]: + totals = _empty_model_usage_totals() + for usage in (model_usage or {}).values(): + if not isinstance(usage, Mapping): + continue + _add_model_usage_totals(totals, usage) + + if totals["total_tokens"] == 0: + totals["total_tokens"] = totals["input_tokens"] + totals["output_tokens"] + if totals["total_requests"] == 0: + totals["total_requests"] = totals["successful_requests"] + totals["failed_requests"] + + return { + "observed_input_tokens": totals["input_tokens"], + "observed_output_tokens": totals["output_tokens"], + "observed_total_tokens": totals["total_tokens"], + "observed_reasoning_tokens": totals["reasoning_tokens"] if totals["has_reasoning_tokens"] else None, + "observed_successful_requests": totals["successful_requests"], + "observed_failed_requests": totals["failed_requests"], + "observed_total_requests": totals["total_requests"], + } + + +def _empty_model_usage_totals() -> dict[str, int | bool]: + return { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "reasoning_tokens": 0, + "has_reasoning_tokens": False, + "successful_requests": 0, + "failed_requests": 0, + "total_requests": 0, + } + + +def _add_model_usage_totals(totals: dict[str, int | bool], usage: Mapping[str, Any]) -> None: + token_usage = usage.get("token_usage") + if isinstance(token_usage, Mapping): + totals["input_tokens"] += _coerce_int(token_usage.get("input_tokens"), default=0) + totals["output_tokens"] += _coerce_int(token_usage.get("output_tokens"), default=0) + totals["total_tokens"] += _coerce_int(token_usage.get("total_tokens"), default=0) + if token_usage.get("reasoning_tokens") is not None: + totals["has_reasoning_tokens"] = True + totals["reasoning_tokens"] += _coerce_int(token_usage.get("reasoning_tokens"), default=0) + + request_usage = usage.get("request_usage") + if isinstance(request_usage, Mapping): + totals["successful_requests"] += _coerce_int(request_usage.get("successful_requests"), default=0) + totals["failed_requests"] += _coerce_int(request_usage.get("failed_requests"), default=0) + totals["total_requests"] += _coerce_int(request_usage.get("total_requests"), default=0) diff --git a/src/anonymizer/measurement/records/row.py b/src/anonymizer/measurement/records/row.py new file mode 100644 index 00000000..e8b64106 --- /dev/null +++ b/src/anonymizer/measurement/records/row.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter +from typing import TYPE_CHECKING, Any + +from anonymizer.engine.constants import COL_FINAL_ENTITIES +from anonymizer.measurement._coerce import ( + _coerce_int, + _count_items, + _count_text_tokens, + _safe_row_index, + _size_bucket, +) +from anonymizer.measurement.metrics.entities import ( + _GROUND_TRUTH_ENTITY_COLUMNS, + _entities_from_raw, + _entity_ground_truth_metrics, +) +from anonymizer.measurement.metrics.llm_calls import _validation_chunk_count, estimate_llm_calls_by_stage +from anonymizer.measurement.metrics.replacements import ( + _replacement_collision_metrics, + _replacement_coverage_metrics, + _replacement_map_metrics, +) +from anonymizer.measurement.metrics.rewrite import _original_value_leak_record_fields, _rewrite_record_fields +from anonymizer.measurement.session import current_collector + +if TYPE_CHECKING: + import pandas as pd + + from anonymizer.measurement.collector import MeasurementCollector + + +def record_record_metrics( + dataframe: pd.DataFrame, + *, + mode: str, + strategy: str, + text_column: str, + validation_max_entities_per_call: int, +) -> None: + """Record per-row count, length, and nominal-call metrics from a trace DataFrame.""" + collector = current_collector() + if collector is None or not collector.record_level: + return + + ground_truth_column = next((col for col in _GROUND_TRUTH_ENTITY_COLUMNS if col in dataframe.columns), None) + columns = set(dataframe.columns) + for row_index, row in dataframe.iterrows(): + final_entities = _entities_from_raw(row.get(COL_FINAL_ENTITIES)) + collector.record( + "record", + **_base_record_fields( + collector=collector, + row_index=row_index, + row=row, + text_column=text_column, + mode=mode, + strategy=strategy, + ), + **_entity_record_fields(row, final_entities=final_entities, ground_truth_column=ground_truth_column), + **_replacement_record_fields(row, columns=columns, final_entities=final_entities), + **_rewrite_record_fields(row, columns=columns), + **_original_value_leak_record_fields(row, columns=columns, final_entities=final_entities), + **_llm_record_fields( + row, + columns=columns, + mode=mode, + strategy=strategy, + final_entity_count=len(final_entities), + validation_max_entities_per_call=validation_max_entities_per_call, + ), + ) + + +def _base_record_fields( + *, + collector: MeasurementCollector, + row_index: object, + row: Any, + text_column: str, + mode: str, + strategy: str, +) -> dict[str, Any]: + text = str(row.get(text_column, "")) + text_length_tokens = _count_text_tokens(text) + return { + "mode": mode, + "strategy": strategy, + "row_index": _safe_row_index(row_index), + "record_hash": collector.record_hash(row_index=row_index, text=text), + "text_length_chars": len(text), + "text_length_chars_bucket": _size_bucket(len(text)), + "text_length_tokens": text_length_tokens, + "text_length_tokens_bucket": _size_bucket(text_length_tokens), + } + + +def _entity_record_fields( + row: Any, + *, + final_entities: list[dict[str, Any]], + ground_truth_column: str | None, +) -> dict[str, Any]: + ground_truth_entities = ( + _entities_from_raw(row.get(ground_truth_column)) if ground_truth_column is not None else None + ) + return { + "final_entity_count": len(final_entities), + "final_entity_label_counts": dict( + sorted(Counter(e.get("label", "") for e in final_entities if e.get("label")).items()) + ), + **_entity_ground_truth_metrics(final_entities, ground_truth_entities), + } + + +def _replacement_record_fields( + row: Any, + *, + columns: set[str], + final_entities: list[dict[str, Any]], +) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPLACEMENT_MAP + + if COL_REPLACEMENT_MAP not in columns: + return {} + raw_map = row.get(COL_REPLACEMENT_MAP) + return { + **_replacement_map_metrics(raw_map), + **_replacement_coverage_metrics(raw_map, final_entities), + **_replacement_collision_metrics(raw_map, final_entities), + } + + +def _llm_record_fields( + row: Any, + *, + columns: set[str], + mode: str, + strategy: str, + final_entity_count: int, + validation_max_entities_per_call: int, +) -> dict[str, Any]: + from anonymizer.engine.constants import COL_REPAIR_ITERATIONS + + detected_candidate_count = _detected_candidate_count(row, columns=columns) + validation_chunk_count = _validation_chunk_count( + detected_candidate_count, + validation_max_entities_per_call=validation_max_entities_per_call, + ) + grouped_entity_count = _grouped_entity_count(row, columns=columns, final_entity_count=final_entity_count) + repair_iterations = _coerce_int(row.get(COL_REPAIR_ITERATIONS, 0), default=0) + replace_map_generation_uses_llm = _replace_map_generation_uses_llm(row, columns=columns) + calls_by_stage = estimate_llm_calls_by_stage( + mode=mode, + strategy=strategy, + has_grouped_entities=grouped_entity_count > 0, + validation_chunk_count=validation_chunk_count, + repair_iterations=repair_iterations, + replace_map_generation_uses_llm=replace_map_generation_uses_llm, + ) + total_estimated = ( + sum(calls_by_stage.values()) if all(value is not None for value in calls_by_stage.values()) else None + ) + return { + "detected_candidate_count": detected_candidate_count, + "validation_chunk_count": validation_chunk_count, + "repair_iterations": repair_iterations if mode == "rewrite" else 0, + "llm_calls_estimated_by_stage": calls_by_stage, + "llm_calls_estimated_total": total_estimated, + } + + +def _replace_map_generation_uses_llm(row: Any, *, columns: set[str]) -> bool: + del row, columns + return True + + +def _detected_candidate_count(row: Any, *, columns: set[str]) -> int | None: + from anonymizer.engine.constants import COL_SEED_VALIDATION_CANDIDATES + + if COL_SEED_VALIDATION_CANDIDATES not in columns: + return None + return _count_items(row.get(COL_SEED_VALIDATION_CANDIDATES), primary_key="candidates", fallback_keys=("entities",)) + + +def _grouped_entity_count(row: Any, *, columns: set[str], final_entity_count: int) -> int: + from anonymizer.engine.constants import COL_ENTITIES_BY_VALUE + + if COL_ENTITIES_BY_VALUE not in columns: + return final_entity_count + return _count_items(row.get(COL_ENTITIES_BY_VALUE), primary_key="entities_by_value", fallback_keys=("entities",)) diff --git a/src/anonymizer/measurement/records/run.py b/src/anonymizer/measurement/records/run.py new file mode 100644 index 00000000..4a9e95b9 --- /dev/null +++ b/src/anonymizer/measurement/records/run.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import platform +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from anonymizer.measurement.constants import MEASUREMENT_SCHEMA_VERSION + + +def _detect_config_metadata(detect: Any | None) -> dict[str, Any]: + entity_labels = getattr(detect, "entity_labels", None) + if entity_labels is None: + from anonymizer.engine.constants import DEFAULT_ENTITY_LABELS + + entity_label_count = len(DEFAULT_ENTITY_LABELS) + else: + entity_label_count = len(entity_labels) + return { + "gliner_threshold": getattr(detect, "gliner_threshold", None), + "entity_label_source": "custom" if entity_labels is not None else "default", + "entity_label_count": entity_label_count, + "entity_labels": list(entity_labels) if entity_labels is not None else None, + "validation_max_entities_per_call": getattr(detect, "validation_max_entities_per_call", None), + "validation_excerpt_window_chars": getattr(detect, "validation_excerpt_window_chars", None), + } + + +def _source_metadata(source: str) -> dict[str, Any]: + parsed = urlparse(source) + if parsed.scheme in {"http", "https"}: + return { + "kind": "remote_file", + "scheme": parsed.scheme, + "suffix": Path(parsed.path).suffix.lower() or None, + } + if parsed.scheme == "file": + return { + "kind": "local_file", + "scheme": "file", + "suffix": Path(parsed.path).suffix.lower() or None, + } + return { + "kind": "local_file" if source else "unknown", + "scheme": None, + "suffix": Path(source).suffix.lower() or None, + } + + +def _replace_config_metadata(replace_config: Any | None) -> dict[str, Any] | None: + if replace_config is None: + return None + + metadata: dict[str, Any] = { + "strategy": type(replace_config).__name__, + "has_instructions": bool(getattr(replace_config, "instructions", None)), + } + for attr in ("normalize_label", "algorithm", "digest_length"): + if hasattr(replace_config, attr): + metadata[attr] = getattr(replace_config, attr) + if hasattr(replace_config, "format_template"): + metadata["has_format_template"] = True + return metadata + + +def _rewrite_config_metadata(rewrite_config: Any | None) -> dict[str, Any] | None: + if rewrite_config is None: + return None + return { + "risk_tolerance": _enum_value(getattr(rewrite_config, "risk_tolerance", None)), + "max_repair_iterations": getattr(rewrite_config, "max_repair_iterations", None), + "strict_entity_protection": getattr(rewrite_config, "strict_entity_protection", None), + "has_privacy_goal": bool(getattr(rewrite_config, "privacy_goal", None)), + "has_instructions": bool(getattr(rewrite_config, "instructions", None)), + } + + +def _model_config_metadata(model_config: Any) -> dict[str, Any]: + inference_parameters = getattr(model_config, "inference_parameters", None) + return { + "alias": getattr(model_config, "alias", None), + "model": getattr(model_config, "model", None), + "provider": _enum_value(getattr(model_config, "provider", None)), + "base_url": bool(getattr(model_config, "base_url", None)), + "max_parallel_requests": getattr(inference_parameters, "max_parallel_requests", None), + } + + +def _runtime_metadata() -> dict[str, Any]: + try: + anonymizer_version = version("nemo-anonymizer") + except PackageNotFoundError: + anonymizer_version = None + return { + "anonymizer_version": anonymizer_version, + "measurement_schema_version": MEASUREMENT_SCHEMA_VERSION, + "platform_machine": platform.machine(), + "platform_system": platform.system(), + "python_version": platform.python_version(), + } + + +def _enum_value(value: Any) -> Any: + return getattr(value, "value", value) diff --git a/src/anonymizer/measurement/session.py b/src/anonymizer/measurement/session.py new file mode 100644 index 00000000..1d985d4d --- /dev/null +++ b/src/anonymizer/measurement/session.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar + +from anonymizer.measurement.collector import MeasurementCollector +from anonymizer.measurement.config import MeasurementConfig +from anonymizer.measurement.sinks import _JsonlMeasurementSink + +logger = logging.getLogger("anonymizer.measurement") + +_ACTIVE_COLLECTOR: ContextVar[MeasurementCollector | None] = ContextVar( + "anonymizer_measurement_collector", + default=None, +) + + +@contextmanager +def measurement_session(collector: MeasurementCollector | None = None) -> Iterator[MeasurementCollector]: + """Activate a collector for code running in this context.""" + active = collector or MeasurementCollector() + token = _ACTIVE_COLLECTOR.set(active) + try: + yield active + finally: + _ACTIVE_COLLECTOR.reset(token) + + +@contextmanager +def configured_measurement_session(config: MeasurementConfig | None) -> Iterator[MeasurementCollector | None]: + """Activate and persist a collector when a measurement config is provided.""" + if config is None: + yield None + return + + sink = _JsonlMeasurementSink(config.output_path) if config.streaming else None + dd_trace_sink = None + if config.dd_trace != "none": + if config.dd_trace_path is None: + raise ValueError("dd_trace_path is required when dd_trace is enabled") + dd_trace_sink = _JsonlMeasurementSink(config.dd_trace_path) + dd_task_trace_sink = _JsonlMeasurementSink(config.dd_task_trace_path) if config.dd_task_trace_path else None + collector = MeasurementCollector( + run_id=config.run_id, + record_hash_key=config.record_hash_key, + record_level=config.record_level, + run_tags=config.run_tags, + record_sink=sink, + keep_records=config.keep_records, + dd_trace_mode=config.dd_trace, + dd_trace_sink=dd_trace_sink, + dd_task_trace_sink=dd_task_trace_sink, + fail_on_write_error=config.fail_on_write_error, + ) + with measurement_session(collector): + body_error: BaseException | None = None + try: + yield collector + except BaseException as exc: + body_error = exc + raise + finally: + if config.streaming: + _close_collector_safely(config=config, collector=collector, body_error=body_error) + else: + write_error: BaseException | None = None + try: + _write_collector_safely(config=config, collector=collector, body_error=body_error) + except BaseException as exc: + write_error = exc + raise + finally: + _close_collector_safely( + config=config, + collector=collector, + body_error=body_error or write_error, + ) + + +def current_collector() -> MeasurementCollector | None: + """Return the active collector, if measurement is enabled.""" + return _ACTIVE_COLLECTOR.get() + + +def _write_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + config.write_collector(collector) + except Exception as exc: + logger.warning("Failed to write Anonymizer measurement records (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise + + +def _close_collector_safely( + *, + config: MeasurementConfig, + collector: MeasurementCollector, + body_error: BaseException | None, +) -> None: + try: + collector.close() + except Exception as exc: + logger.warning("Failed to close Anonymizer measurement stream (%s)", type(exc).__name__) + if body_error is None and config.fail_on_write_error: + raise diff --git a/src/anonymizer/measurement/sinks.py b/src/anonymizer/measurement/sinks.py new file mode 100644 index 00000000..0465ddbf --- /dev/null +++ b/src/anonymizer/measurement/sinks.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Literal, Protocol + + +class _MeasurementWriter(Protocol): + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: ... + + +class _MeasurementSink(Protocol): + def write_record(self, record: dict[str, Any]) -> None: ... + + def close(self) -> None: ... + + +class _JsonlMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + for record in records: + f.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + +class _JsonMeasurementWriter: + def write(self, records: list[dict[str, Any]], path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(records, f, ensure_ascii=True, indent=2, sort_keys=True) + + +def _writer_for_format(output_format: Literal["jsonl", "json"]) -> _MeasurementWriter: + if output_format == "json": + return _JsonMeasurementWriter() + return _JsonlMeasurementWriter() + + +class _JsonlMeasurementSink: + def __init__(self, path: str | Path) -> None: + output_path = Path(path) + output_path.parent.mkdir(parents=True, exist_ok=True) + self._file = output_path.open("w", encoding="utf-8", buffering=1) + + def write_record(self, record: dict[str, Any]) -> None: + self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + + def close(self) -> None: + self._file.close() diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 3da61579..67b2e828 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -12,6 +12,56 @@ Use these tools when you need evidence about cost, latency, reliability, or anonymization quality. They are not product entry points and the benchmark-only strategy knobs are not public Anonymizer defaults. +## Quick export to DataFrames or CSV + +Start here when you have a `measurements.jsonl` file and want to analyze it in +pandas, Polars, a spreadsheet, or another local tool. + +```bash +uv run python tools/measurement/export_measurements.py \ + benchmark-runs/suite/measurements.jsonl \ + --output benchmark-runs/suite/tables \ + --overwrite +``` + +By default, the exporter writes one Parquet table per measurement record type +plus `manifest.json`: + +- `run.parquet` +- `stage.parquet` +- `record.parquet` +- `ndd_workflow.parquet` when DataDesigner adapter records are present +- `model_workflow.parquet` when direct model workflow records are present + +Use CSV or JSONL when those are easier to inspect: + +```bash +uv run python tools/measurement/export_measurements.py \ + benchmark-runs/suite/measurements.jsonl \ + --output benchmark-runs/suite/tables-csv \ + --format csv \ + --overwrite +``` + +Then load the tables directly: + +```python +import pandas as pd + +records = pd.read_parquet("benchmark-runs/suite/tables/record.parquet") +stages = pd.read_parquet("benchmark-runs/suite/tables/stage.parquet") +ndd = pd.read_parquet("benchmark-runs/suite/tables/ndd_workflow.parquet") +``` + +You can also read the raw log, but the exporter is the better default because +it splits records by `record_type` and normalizes nested fields into columns. + +```python +import pandas as pd + +raw = pd.read_json("benchmark-runs/suite/measurements.jsonl", lines=True) +``` + ## System overview The measurement system has three layers: @@ -65,22 +115,6 @@ This is intentionally composition-based. New analysis tools should declare their own row models and call the shared helpers rather than inheriting from a common analyzer base class. -```bash -uv run python tools/measurement/export_measurements.py measurements.jsonl --output tables -``` - -By default, `export_measurements.py` writes Parquet files plus -`manifest.json`: - -- `run.parquet` -- `stage.parquet` -- `record.parquet` -- `ndd_workflow.parquet` when DataDesigner adapter records are present -- `model_workflow.parquet` when direct model workflow records are present - -Use `--format csv` or `--format jsonl` for non-Parquet output, and -`--overwrite` to replace existing output files. - ## Benchmark runner `run_benchmarks.py` runs repeatable Anonymizer workloads and writes the same From fabdc78798b1162af77323124a59e8213df2c9b5 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Tue, 9 Jun 2026 22:15:55 +0000 Subject: [PATCH 20/26] Limit measurement PR to core benchmark tools Signed-off-by: Aaron Gonzales --- tests/tools/test_compare_strategy_pairs.py | 1473 -------------- tests/tools/test_dd_parser_compat.py | 93 - tests/tools/test_dd_trace_analysis.py | 160 -- tests/tools/test_detection_strategies.py | 1093 ----------- tests/tools/test_direct_detection_probe.py | 147 -- tests/tools/test_extract_signature_deltas.py | 190 -- tests/tools/test_measurement_tools.py | 524 +---- .../tools/test_screen_strategy_comparisons.py | 974 ---------- .../test_staged_detection_output_analysis.py | 182 -- tests/tools/test_staged_detection_probe.py | 579 ------ tools/measurement/README.md | 273 +-- tools/measurement/analyze_dd_traces.py | 359 ---- .../analyze_staged_detection_output.py | 392 ---- tools/measurement/compare_strategy_pairs.py | 1438 -------------- tools/measurement/dd_parser_compat.py | 108 -- tools/measurement/detection_strategies.py | 1695 ----------------- tools/measurement/direct_detection_probe.py | 555 ------ tools/measurement/extract_signature_deltas.py | 513 ----- tools/measurement/run_benchmarks.py | 320 +--- .../screen_strategy_comparisons.py | 1107 ----------- tools/measurement/staged_detection_probe.py | 1375 ------------- 21 files changed, 21 insertions(+), 13529 deletions(-) delete mode 100644 tests/tools/test_compare_strategy_pairs.py delete mode 100644 tests/tools/test_dd_parser_compat.py delete mode 100644 tests/tools/test_dd_trace_analysis.py delete mode 100644 tests/tools/test_detection_strategies.py delete mode 100644 tests/tools/test_direct_detection_probe.py delete mode 100644 tests/tools/test_extract_signature_deltas.py delete mode 100644 tests/tools/test_screen_strategy_comparisons.py delete mode 100644 tests/tools/test_staged_detection_output_analysis.py delete mode 100644 tests/tools/test_staged_detection_probe.py delete mode 100644 tools/measurement/analyze_dd_traces.py delete mode 100644 tools/measurement/analyze_staged_detection_output.py delete mode 100644 tools/measurement/compare_strategy_pairs.py delete mode 100644 tools/measurement/dd_parser_compat.py delete mode 100644 tools/measurement/detection_strategies.py delete mode 100644 tools/measurement/direct_detection_probe.py delete mode 100644 tools/measurement/extract_signature_deltas.py delete mode 100644 tools/measurement/screen_strategy_comparisons.py delete mode 100644 tools/measurement/staged_detection_probe.py diff --git a/tests/tools/test_compare_strategy_pairs.py b/tests/tools/test_compare_strategy_pairs.py deleted file mode 100644 index 4576fe81..00000000 --- a/tests/tools/test_compare_strategy_pairs.py +++ /dev/null @@ -1,1473 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -import pandas as pd -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -def test_compare_case_analysis_by_strategy_reports_safety_and_cost_deltas() -> None: - tool = load_tool("measurement_compare_strategy_pairs", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py") - table = pd.DataFrame( - [ - { - "workload_id": "shell-1", - "config_id": "shell-no-augment", - "experimental_detection_strategy": "no_augment", - "experimental_replacement_strategy": "default", - "case_id": "shell__base", - "pipeline_elapsed_sec": 4.5, - "observed_total_requests": 3, - "observed_total_tokens": 2875, - "observed_failed_requests": 1, - "final_entity_count": 4, - "seed_validation_candidate_count": 5, - "augmented_entity_count": 2, - "augmented_new_final_value_count": 1, - "artifact_final_detector_entity_count": 4, - }, - { - "workload_id": "shell-1", - "config_id": "shell-native", - "experimental_detection_strategy": "native_candidate_validate_no_augment", - "experimental_replacement_strategy": "custom_replacement_strategy", - "case_id": "shell__candidate", - "pipeline_elapsed_sec": 0.8, - "observed_total_requests": 1, - "observed_total_tokens": 101, - "observed_failed_requests": 0, - "final_entity_count": 8, - "seed_validation_candidate_count": 0, - "augmented_entity_count": 0, - "augmented_new_final_value_count": 0, - "artifact_final_augmenter_entity_count": 8, - }, - { - "workload_id": "legal-1", - "config_id": "legal-no-augment", - "experimental_detection_strategy": "no_augment", - "experimental_replacement_strategy": "default", - "case_id": "legal__base", - "pipeline_elapsed_sec": 21, - "observed_total_requests": 3, - "observed_total_tokens": 8847, - "observed_failed_requests": 1, - "final_entity_count": 26, - "seed_validation_candidate_count": 40, - "artifact_final_detector_entity_count": 26, - }, - { - "workload_id": "legal-1", - "config_id": "legal-native", - "experimental_detection_strategy": "native_candidate_validate_no_augment", - "experimental_replacement_strategy": "custom_replacement_strategy", - "case_id": "legal__candidate", - "pipeline_elapsed_sec": 20.9, - "observed_total_requests": 3, - "observed_total_tokens": 8847, - "observed_failed_requests": 1, - "final_entity_count": 26, - "seed_validation_candidate_count": 40, - "artifact_final_detector_entity_count": 26, - }, - ] - ) - - rows = tool.compare_case_analysis( - table, - baseline_strategy="no_augment", - candidate_strategy="native_candidate_validate_no_augment", - ) - - by_workload = {row.workload_id: row for row in rows} - shell = by_workload["shell-1"] - assert shell.baseline_replacement_strategy == "default" - assert shell.candidate_replacement_strategy == "custom_replacement_strategy" - assert shell.final_entity_count_delta == 4 - assert shell.observed_total_requests_delta == -2 - assert shell.observed_total_tokens_delta == -2774 - assert shell.seed_validation_candidate_count_delta == -5 - assert shell.baseline_augmented_entity_count == 2 - assert shell.candidate_augmented_entity_count == 0 - assert shell.augmented_entity_count_delta == -2 - assert shell.baseline_augmented_new_final_value_count == 1 - assert shell.candidate_augmented_new_final_value_count == 0 - assert shell.augmented_new_final_value_count_delta == -1 - assert shell.candidate_augmenter_entity_count == 8 - assert shell.safety_verdict == "review" - assert shell.performance_verdict == "improved" - assert shell.candidate_verdict == "review" - assert shell.flags == ["no_candidate_detector_entities"] - - legal = by_workload["legal-1"] - assert legal.baseline_replacement_strategy == "default" - assert legal.candidate_replacement_strategy == "custom_replacement_strategy" - assert legal.final_entity_count_delta == 0 - assert legal.observed_total_tokens_delta == 0 - assert legal.candidate_detector_entity_count == 26 - assert legal.safety_verdict == "pass" - assert legal.performance_verdict == "improved" - assert legal.candidate_verdict == "candidate_viable" - assert legal.flags == [] - - -def test_compare_case_analysis_rejects_ambiguous_strategy_selector() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_ambiguous", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-1", - "config_id": "base-a", - "experimental_detection_strategy": "no_augment", - "case_id": "a", - }, - { - "workload_id": "shell-1", - "config_id": "base-b", - "experimental_detection_strategy": "no_augment", - "case_id": "b", - }, - { - "workload_id": "shell-1", - "config_id": "candidate", - "experimental_detection_strategy": "detector_only", - "case_id": "c", - }, - ] - ) - - with pytest.raises(ValueError, match="baseline selector matched multiple configs"): - tool.compare_case_analysis(table, baseline_strategy="no_augment", candidate_strategy="detector_only") - - -def test_compare_case_analysis_rejects_candidate_synthetic_original_collisions() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_replacement_collisions", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-1", - "config_id": "baseline", - "experimental_detection_strategy": "default", - "case_id": "base", - "pipeline_elapsed_sec": 20, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 4, - "replacement_synthetic_original_collision_count": 0, - "replacement_synthetic_original_collision_value_count": 0, - }, - { - "workload_id": "legal-1", - "config_id": "candidate", - "experimental_detection_strategy": "candidate", - "case_id": "cand", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 2, - "observed_total_tokens": 1000, - "final_entity_count": 4, - "replacement_synthetic_original_collision_count": 1, - "replacement_synthetic_original_collision_value_count": 1, - "replacement_synthetic_original_collision_label_counts.date": 1, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.candidate_replacement_synthetic_original_collision_count == 1 - assert row.replacement_synthetic_original_collision_count_delta == 1 - assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} - assert "candidate_replacement_synthetic_original_collision" in row.flags - assert row.value_protection_verdict == "fail" - assert row.safety_verdict == "fail" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "reject" - - -def test_compare_case_analysis_rejects_candidate_missing_replacement_map_entries() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_missing_replacement_map_entries", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "config_id": "baseline", - "experimental_detection_strategy": "default", - "case_id": "base", - "pipeline_elapsed_sec": 20, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 4, - "replacement_missing_final_entity_count": 0, - "replacement_missing_final_value_count": 0, - }, - { - "workload_id": "structured-identifiers", - "config_id": "candidate", - "experimental_detection_strategy": "default", - "case_id": "cand", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 2, - "observed_total_tokens": 1000, - "final_entity_count": 4, - "replacement_missing_final_entity_count": 1, - "replacement_missing_final_value_count": 1, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.candidate_replacement_missing_final_entity_count == 1 - assert row.replacement_missing_final_entity_count_delta == 1 - assert "candidate_replacement_missing_final_entity" in row.flags - assert row.value_protection_verdict == "fail" - assert row.safety_verdict == "fail" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "reject" - - -def test_compare_case_tables_allows_candidate_from_separate_run() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_cross_run", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - baseline = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "legal-no-augment", - "experimental_detection_strategy": "no_augment", - "case_id": "legal__base", - "observed_total_tokens": 55790, - "final_entity_count": 193, - "artifact_final_detector_entity_count": 172, - } - ] - ) - candidate = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "legal-native-validate", - "experimental_detection_strategy": "detector_native_validate_no_augment", - "case_id": "legal__candidate", - "observed_total_tokens": 55805, - "final_entity_count": 193, - "artifact_final_detector_entity_count": 172, - } - ] - ) - - rows = tool.compare_case_tables( - baseline, - candidate, - baseline_strategy="no_augment", - candidate_strategy="detector_native_validate_no_augment", - ) - - assert len(rows) == 1 - assert rows[0].workload_id == "legal-5" - assert rows[0].baseline_config_id == "legal-no-augment" - assert rows[0].candidate_config_id == "legal-native-validate" - assert rows[0].observed_total_tokens_delta == 15 - assert rows[0].flags == ["token_increase"] - - -def test_compare_case_analysis_preserves_augmentation_contribution_deltas() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_augmentation", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-2", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "augmented_entity_count": 8, - "augmented_new_final_value_count": 3, - "artifact_final_augmenter_entity_count": 3, - "artifact_final_entity_signature_hashes": ["a", "b", "c"], - }, - { - "workload_id": "legal-2", - "config_id": "no-augment", - "experimental_detection_strategy": "no_augment", - "case_id": "no-augment-r0", - "augmented_entity_count": 0, - "augmented_new_final_value_count": 0, - "artifact_final_augmenter_entity_count": 0, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="no-augment") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_augmented_entity_count == 8 - assert row.candidate_augmented_entity_count == 0 - assert row.augmented_entity_count_delta == -8 - assert row.baseline_augmented_new_final_value_count == 3 - assert row.candidate_augmented_new_final_value_count == 0 - assert row.augmented_new_final_value_count_delta == -3 - assert row.baseline_augmenter_entity_count == 3 - assert row.candidate_augmenter_entity_count == 0 - - -def test_compare_case_analysis_review_gates_detector_only_candidate_shell_case() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_detector_only", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "bio-1", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "bio__default", - "pipeline_elapsed_sec": 20, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "bio-1", - "config_id": "detector-only", - "experimental_detection_strategy": "detector_only", - "case_id": "bio__detector", - "pipeline_elapsed_sec": 5, - "observed_total_requests": 1, - "observed_total_tokens": 1000, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_strategy="default", candidate_strategy="detector_only") - - assert len(rows) == 1 - row = rows[0] - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - assert row.flags == ["candidate_skips_llm_validation"] - - -def test_compare_case_analysis_review_gates_detector_only_candidates() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_detector_only", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-1", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "shell__default", - "pipeline_elapsed_sec": 8, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "shell-1", - "config_id": "detector-only", - "experimental_detection_strategy": "detector_only", - "case_id": "shell__candidate", - "pipeline_elapsed_sec": 1, - "observed_total_requests": 1, - "observed_total_tokens": 200, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis( - table, - baseline_strategy="default", - candidate_strategy="detector_only", - ) - - assert len(rows) == 1 - row = rows[0] - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - assert row.flags == ["candidate_skips_llm_validation"] - - -def test_compare_case_analysis_review_gates_non_detector_sources_when_signatures_match() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_non_detector_sources", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-1", - "config_id": "native-source-default", - "experimental_detection_strategy": "default", - "case_id": "shell__default", - "pipeline_elapsed_sec": 21.4, - "observed_total_requests": 8, - "observed_total_tokens": 9854, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, - }, - { - "workload_id": "shell-1", - "config_id": "native-source-candidate", - "experimental_detection_strategy": "native_single_pass", - "case_id": "shell__candidate", - "pipeline_elapsed_sec": 0.001, - "observed_total_requests": 0, - "observed_total_tokens": 0, - "final_entity_count": 2, - "artifact_final_augmenter_entity_count": 2, - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, - }, - ] - ) - - rows = tool.compare_case_analysis( - table, - baseline_config="native-source-default", - candidate_config="native-source-candidate", - ) - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_final_entity_signature_count == 0 - assert row.candidate_only_final_entity_signature_count == 0 - assert row.shared_final_entity_signature_label_counts == {"api_key": 1, "password": 1} - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "pass" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - assert row.flags == ["no_candidate_detector_entities"] - - -def test_compare_case_analysis_flags_signature_loss_even_when_counts_match() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_signature_loss", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "bio-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "bio__default", - "final_entity_count": 3, - "artifact_final_entity_signature_hashes": '["a","b","c"]', - "artifact_final_entity_signature_labels": '{"a":"first_name","b":"city","c":"first_name"}', - }, - { - "workload_id": "bio-5", - "config_id": "no-augment", - "experimental_detection_strategy": "no_augment", - "case_id": "bio__no_augment", - "final_entity_count": 3, - "artifact_final_entity_signature_hashes": ["a", "b", "d"], - "artifact_final_entity_signature_labels": {"a": "first_name", "b": "city", "d": "last_name"}, - }, - ] - ) - - rows = tool.compare_case_analysis( - table, - baseline_config="default", - candidate_config="no-augment", - ) - - assert len(rows) == 1 - row = rows[0] - assert row.final_entity_count_delta == 0 - assert row.baseline_only_final_entity_signature_count == 1 - assert row.candidate_only_final_entity_signature_count == 1 - assert row.shared_final_entity_signature_count == 2 - assert row.baseline_only_final_entity_signature_label_counts == {"first_name": 1} - assert row.candidate_only_final_entity_signature_label_counts == {"last_name": 1} - assert row.shared_final_entity_signature_label_counts == {"city": 1, "first_name": 1} - assert row.safety_verdict == "fail" - assert row.performance_verdict == "unknown" - assert row.candidate_verdict == "reject" - assert row.flags == ["entity_signature_loss"] - - -def test_compare_case_analysis_treats_baseline_subspan_as_candidate_covered() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_candidate_span_coverage", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 2, - "artifact_final_detector_entity_count": 2, - "artifact_final_entity_signature_hashes": ["api-token", "pin"], - "artifact_final_entity_signature_labels": {"api-token": "api_key", "pin": "pin"}, - "artifact_final_entity_signature_details": { - "api-token": { - "label": "api_key", - "source": "augmenter", - "row_index": 0, - "start_position": 42, - "end_position": 66, - "value_hash": "token-hash", - "value_length": 24, - }, - "pin": { - "label": "pin", - "source": "detector", - "row_index": 0, - "start_position": 90, - "end_position": 95, - "value_hash": "pin-hash", - "value_length": 5, - }, - }, - }, - { - "workload_id": "structured-identifiers", - "config_id": "native-local", - "experimental_detection_strategy": "native_single_pass", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 0.01, - "observed_total_requests": 0, - "observed_total_tokens": 0, - "final_entity_count": 2, - "artifact_final_augmenter_entity_count": 2, - "artifact_final_entity_signature_hashes": ["cookie", "pin"], - "artifact_final_entity_signature_labels": {"cookie": "http_cookie", "pin": "pin"}, - "artifact_final_entity_signature_details": { - "cookie": { - "label": "http_cookie", - "source": "native", - "row_index": 0, - "start_position": 30, - "end_position": 80, - "value_hash": "cookie-hash", - "value_length": 50, - }, - "pin": { - "label": "pin", - "source": "native", - "row_index": 0, - "start_position": 90, - "end_position": 95, - "value_hash": "pin-hash", - "value_length": 5, - }, - }, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_final_entity_signature_count == 1 - assert row.baseline_only_candidate_covered_signature_count == 1 - assert row.baseline_only_candidate_uncovered_signature_count == 0 - assert row.baseline_only_candidate_covered_signature_label_counts == {"api_key": 1} - assert row.baseline_only_candidate_uncovered_signature_label_counts == {} - assert "entity_signature_loss" not in row.flags - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_review_gates_covered_label_mismatch() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_label_mismatch", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-row", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 4, - "observed_total_tokens": 1000, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["dob"], - "artifact_final_entity_signature_labels": {"dob": "date_of_birth"}, - "artifact_final_entity_signature_details": { - "dob": { - "label": "date_of_birth", - "source": "detector", - "row_index": 0, - "start_position": 20, - "end_position": 35, - "value_hash": "date-hash", - "value_length": 15, - }, - }, - }, - { - "workload_id": "legal-row", - "config_id": "native-validation", - "experimental_detection_strategy": "detector_native_validate_no_augment", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 2, - "observed_total_requests": 2, - "observed_total_tokens": 300, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["date"], - "artifact_final_entity_signature_labels": {"date": "date"}, - "artifact_final_entity_signature_details": { - "date": { - "label": "date", - "source": "detector", - "row_index": 0, - "start_position": 20, - "end_position": 35, - "value_hash": "date-hash", - "value_length": 15, - }, - }, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-validation") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_candidate_covered_signature_count == 1 - assert row.baseline_only_candidate_uncovered_signature_count == 0 - assert row.baseline_only_candidate_label_mismatch_signature_count == 1 - assert row.baseline_only_candidate_label_mismatch_signature_label_counts == {"date_of_birth": 1} - assert "entity_signature_loss" not in row.flags - assert "covered_label_mismatch" in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_treats_high_overlap_candidate_span_as_covered() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_candidate_span_overlap", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["token-assignment"], - "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, - "artifact_final_entity_signature_details": { - "token-assignment": { - "label": "http_cookie", - "source": "augmenter", - "row_index": 0, - "start_position": 20, - "end_position": 68, - "value_hash": "token-assignment-hash", - "value_length": 48, - }, - }, - }, - { - "workload_id": "structured-identifiers", - "config_id": "native-local", - "experimental_detection_strategy": "native_single_pass", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 0.01, - "observed_total_requests": 0, - "observed_total_tokens": 0, - "final_entity_count": 1, - "artifact_final_augmenter_entity_count": 1, - "artifact_final_entity_signature_hashes": ["token-value"], - "artifact_final_entity_signature_labels": {"token-value": "api_key"}, - "artifact_final_entity_signature_details": { - "token-value": { - "label": "api_key", - "source": "native", - "row_index": 0, - "start_position": 26, - "end_position": 69, - "value_hash": "token-value-hash", - "value_length": 43, - }, - }, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_final_entity_signature_count == 1 - assert row.baseline_only_candidate_covered_signature_count == 1 - assert row.baseline_only_candidate_overlapping_signature_count == 1 - assert row.baseline_only_candidate_uncovered_signature_count == 0 - assert row.baseline_only_candidate_overlapping_signature_label_counts == {"http_cookie": 1} - assert row.baseline_only_candidate_uncovered_signature_label_counts == {} - assert "entity_signature_loss" not in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_treats_small_assignment_prefix_gap_as_boundary_overlap() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_candidate_span_boundary", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["login-assignment"], - "artifact_final_entity_signature_labels": {"login-assignment": "unique_id"}, - "artifact_final_entity_signature_details": { - "login-assignment": { - "label": "unique_id", - "source": "detector", - "row_index": 0, - "start_position": 20, - "end_position": 38, - "value_hash": "login-assignment-hash", - "value_length": 18, - }, - }, - }, - { - "workload_id": "structured-identifiers", - "config_id": "native-local", - "experimental_detection_strategy": "native_single_pass", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 0.01, - "observed_total_requests": 0, - "observed_total_tokens": 0, - "final_entity_count": 1, - "artifact_final_augmenter_entity_count": 1, - "artifact_final_entity_signature_hashes": ["login-value"], - "artifact_final_entity_signature_labels": {"login-value": "user_name"}, - "artifact_final_entity_signature_details": { - "login-value": { - "label": "user_name", - "source": "native", - "row_index": 0, - "start_position": 26, - "end_position": 38, - "value_hash": "login-value-hash", - "value_length": 12, - }, - }, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="native-local") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_candidate_covered_signature_count == 1 - assert row.baseline_only_candidate_overlapping_signature_count == 1 - assert row.baseline_only_candidate_uncovered_signature_count == 0 - assert "entity_signature_loss" not in row.flags - assert "span_boundary_mismatch" in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_flags_replacement_only_detection_instability() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_replacement_only_detection_instability", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "config_id": "dd-substitute", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 10, - "observed_total_requests": 4, - "observed_total_tokens": 4000, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["token-assignment"], - "artifact_final_entity_signature_labels": {"token-assignment": "http_cookie"}, - "artifact_final_entity_signature_details": { - "token-assignment": { - "label": "http_cookie", - "source": "detector", - "row_index": 0, - "start_position": 20, - "end_position": 70, - "value_hash": "token-assignment-hash", - "value_length": 50, - }, - }, - }, - { - "workload_id": "structured-identifiers", - "config_id": "local-substitute", - "experimental_detection_strategy": "default", - "experimental_replacement_strategy": "custom_replacement_strategy", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 7, - "observed_total_requests": 3, - "observed_total_tokens": 3000, - "final_entity_count": 1, - "artifact_final_detector_entity_count": 1, - "artifact_final_entity_signature_hashes": ["token-value"], - "artifact_final_entity_signature_labels": {"token-value": "api_key"}, - "artifact_final_entity_signature_details": { - "token-value": { - "label": "api_key", - "source": "detector", - "row_index": 0, - "start_position": 26, - "end_position": 70, - "value_hash": "token-value-hash", - "value_length": 44, - }, - }, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="dd-substitute", candidate_config="local-substitute") - - assert len(rows) == 1 - row = rows[0] - assert "covered_label_mismatch" in row.flags - assert "replacement_only_detection_instability" in row.flags - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_rejects_candidate_original_value_leaks() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_original_value_leak", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "structured-secrets", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "structured__default", - "pipeline_elapsed_sec": 10, - "observed_total_tokens": 1000, - "final_entity_count": 2, - "original_value_leak_count": 0, - "original_value_leak_record_count": 0, - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, - }, - { - "workload_id": "structured-secrets", - "config_id": "candidate", - "experimental_detection_strategy": "native_single_pass", - "case_id": "structured__candidate", - "pipeline_elapsed_sec": 1, - "observed_total_tokens": 0, - "final_entity_count": 2, - "original_value_leak_count": 2, - "original_value_leak_record_count": 1, - "original_value_leak_label_counts.api_key": 1, - "original_value_leak_label_counts.password": 1, - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "api_key", "b": "password"}, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.candidate_original_value_leak_count == 2 - assert row.candidate_original_value_leak_record_count == 1 - assert row.original_value_leak_count_delta == 2 - assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} - assert row.safety_verdict == "fail" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "reject" - assert row.flags == ["candidate_original_value_leak"] - - -def test_compare_case_analysis_marks_clean_but_expensive_candidate_for_review() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_verdicts", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "legal__default", - "pipeline_elapsed_sec": 20, - "observed_total_tokens": 1000, - "final_entity_count": 12, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "legal-5", - "config_id": "candidate", - "experimental_detection_strategy": "candidate", - "case_id": "legal__candidate", - "pipeline_elapsed_sec": 25, - "observed_total_tokens": 1200, - "final_entity_count": 12, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - assert rows[0].safety_verdict == "pass" - assert rows[0].performance_verdict == "regressed" - assert rows[0].candidate_verdict == "review" - assert rows[0].flags == ["token_increase"] - - -def test_compare_case_analysis_separates_bridge_fallbacks_from_provider_failures() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_trace_adjusted", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "bio-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "bio__default", - "pipeline_elapsed_sec": 20, - "observed_total_requests": 4, - "observed_failed_requests": 1, - "observed_bridge_fallback_requests": 1, - "observed_non_bridge_total_requests": 3, - "observed_non_bridge_failed_requests": 0, - "final_entity_count": 12, - "artifact_final_detector_entity_count": 12, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "bio-5", - "config_id": "windowed", - "experimental_detection_strategy": "windowed", - "case_id": "bio__windowed", - "pipeline_elapsed_sec": 15, - "observed_total_requests": 6, - "observed_failed_requests": 2, - "observed_bridge_fallback_requests": 2, - "observed_non_bridge_total_requests": 4, - "observed_non_bridge_failed_requests": 0, - "final_entity_count": 12, - "artifact_final_detector_entity_count": 12, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="windowed") - - assert len(rows) == 1 - row = rows[0] - assert row.observed_failed_requests_delta == 1 - assert row.observed_bridge_fallback_requests_delta == 1 - assert row.observed_non_bridge_failed_requests_delta == 0 - assert "bridge_fallback_increase" in row.flags - assert "failed_request_increase" not in row.flags - - -def test_compare_case_analysis_flags_non_bridge_provider_failure_increase() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_non_bridge_failure", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "legal__default", - "observed_total_requests": 5, - "observed_failed_requests": 1, - "observed_bridge_fallback_requests": 1, - "observed_non_bridge_total_requests": 4, - "observed_non_bridge_failed_requests": 0, - "final_entity_count": 10, - "artifact_final_detector_entity_count": 10, - "artifact_final_entity_signature_hashes": ["a"], - }, - { - "workload_id": "legal-5", - "config_id": "candidate", - "experimental_detection_strategy": "candidate", - "case_id": "legal__candidate", - "observed_total_requests": 5, - "observed_failed_requests": 2, - "observed_bridge_fallback_requests": 1, - "observed_non_bridge_total_requests": 4, - "observed_non_bridge_failed_requests": 1, - "final_entity_count": 10, - "artifact_final_detector_entity_count": 10, - "artifact_final_entity_signature_hashes": ["a"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.observed_bridge_fallback_requests_delta == 0 - assert row.observed_non_bridge_failed_requests_delta == 1 - assert "bridge_fallback_increase" not in row.flags - assert "failed_request_increase" in row.flags - assert row.safety_verdict == "review" - assert row.performance_verdict == "unchanged" - assert row.candidate_verdict == "review" - - -def test_compare_case_analysis_rejects_candidate_case_failures() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_case_failure", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 8, - "case_failed": False, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "shell-5", - "config_id": "default", - "experimental_detection_strategy": "default", - "case_id": "default-r1", - "pipeline_elapsed_sec": 8, - "case_failed": False, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "shell-5", - "config_id": "candidate", - "experimental_detection_strategy": "detector_only", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 2, - "case_failed": False, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "shell-5", - "config_id": "candidate", - "experimental_detection_strategy": "detector_only", - "case_id": "candidate-r1", - "pipeline_elapsed_sec": 0.2, - "case_failed": True, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_failed_case_count == 0 - assert row.candidate_failed_case_count == 1 - assert row.failed_case_count_delta == 1 - assert row.safety_verdict == "fail" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "reject" - assert "candidate_case_failures" in row.flags - - -def test_compare_case_analysis_counts_model_workflow_errors_as_case_failures() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_model_workflow_failure", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-5", - "config_id": "default", - "case_id": "default-r0", - "pipeline_elapsed_sec": 8, - "error_stage_count": 0, - "error_ndd_workflow_count": 0, - "error_model_workflow_count": 0, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - { - "workload_id": "shell-5", - "config_id": "candidate", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 2, - "error_stage_count": 0, - "error_ndd_workflow_count": 0, - "error_model_workflow_count": 1, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_failed_case_count == 0 - assert row.candidate_failed_case_count == 1 - assert row.safety_verdict == "fail" - assert row.candidate_verdict == "reject" - assert "candidate_case_failures" in row.flags - - -def test_compare_case_analysis_review_gates_baseline_case_failures() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_baseline_case_failure", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "bio-5", - "config_id": "baseline", - "case_id": "baseline-r0", - "pipeline_elapsed_sec": 30, - "case_failed": True, - "artifact_final_entity_signature_hashes": [], - }, - { - "workload_id": "bio-5", - "config_id": "candidate", - "case_id": "candidate-r0", - "pipeline_elapsed_sec": 20, - "case_failed": False, - "artifact_final_entity_signature_hashes": ["a", "b"], - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="baseline", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_failed_case_count == 1 - assert row.candidate_failed_case_count == 0 - assert row.failed_case_count_delta == -1 - assert row.safety_verdict == "review" - assert row.performance_verdict == "improved" - assert row.candidate_verdict == "review" - assert row.flags == ["baseline_case_failures"] - - -def test_compare_case_analysis_flags_repeated_signature_instability() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_stability", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "default", - "case_id": "default-r0", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - { - "workload_id": "legal-5", - "config_id": "default", - "case_id": "default-r1", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - { - "workload_id": "legal-5", - "config_id": "candidate", - "case_id": "candidate-r0", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - { - "workload_id": "legal-5", - "config_id": "candidate", - "case_id": "candidate-r1", - "artifact_final_entity_signature_hashes": ["a"], - "artifact_final_entity_signature_labels": {"a": "person"}, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_final_entity_signature_count == 0 - assert row.candidate_only_final_entity_signature_count == 0 - assert row.baseline_stable_final_entity_signature_count == 2 - assert row.candidate_stable_final_entity_signature_count == 1 - assert row.stable_final_entity_signature_count_delta == -1 - assert row.baseline_stable_candidate_unstable_final_entity_signature_count == 1 - assert row.baseline_stable_candidate_unstable_final_entity_signature_label_counts == {"date": 1} - assert row.safety_verdict == "fail" - assert row.candidate_verdict == "reject" - assert row.flags == ["stable_entity_signature_loss"] - - -def test_compare_case_analysis_does_not_infer_stability_loss_from_single_candidate_run() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_single_candidate_stability", - REPO_ROOT / "tools/measurement/compare_strategy_pairs.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-5", - "config_id": "default", - "case_id": "default-r0", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - { - "workload_id": "legal-5", - "config_id": "default", - "case_id": "default-r1", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - { - "workload_id": "legal-5", - "config_id": "candidate", - "case_id": "candidate-r0", - "artifact_final_entity_signature_hashes": ["a", "b"], - "artifact_final_entity_signature_labels": {"a": "person", "b": "date"}, - }, - ] - ) - - rows = tool.compare_case_analysis(table, baseline_config="default", candidate_config="candidate") - - assert len(rows) == 1 - row = rows[0] - assert row.baseline_only_final_entity_signature_count == 0 - assert row.candidate_only_final_entity_signature_count == 0 - assert row.baseline_stable_final_entity_signature_count is None - assert row.candidate_stable_final_entity_signature_count is None - assert row.stable_final_entity_signature_count_delta is None - assert row.baseline_stable_candidate_unstable_final_entity_signature_count is None - assert row.safety_verdict == "pass" - assert "stable_entity_signature_loss" not in row.flags - - -def test_compare_strategy_pairs_writes_csv(tmp_path: Path) -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_export", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" - ) - rows = [ - tool.ComparisonRow( - workload_id="shell-1", - baseline_config_id="base", - candidate_config_id="candidate", - baseline_replacement_strategy="default", - candidate_replacement_strategy="custom_replacement_strategy", - baseline_case_count=1, - candidate_case_count=1, - value_protection_verdict="review", - signature_parity_verdict="pass", - safety_verdict="review", - performance_verdict="improved", - candidate_verdict="review", - baseline_final_entity_count=4, - candidate_final_entity_count=8, - final_entity_count_delta=4, - flags=["candidate_skips_llm_validation"], - ) - ] - - output = tmp_path / "comparison.csv" - tool.write_comparisons(rows, output, tool.ExportFormat.csv) - - exported = pd.read_csv(output) - assert exported["workload_id"].tolist() == ["shell-1"] - assert exported["candidate_replacement_strategy"].tolist() == ["custom_replacement_strategy"] - assert exported["final_entity_count_delta"].tolist() == [4] - assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] - - -def test_compare_strategy_pairs_summarizes_candidate_verdicts() -> None: - tool = load_tool( - "measurement_compare_strategy_pairs_summary", REPO_ROOT / "tools/measurement/compare_strategy_pairs.py" - ) - rows = [ - tool.ComparisonRow( - workload_id="legal-1", - baseline_config_id="legal-default", - candidate_config_id="legal-no-augment", - baseline_case_count=1, - candidate_case_count=1, - value_protection_verdict="pass", - signature_parity_verdict="pass", - safety_verdict="pass", - performance_verdict="improved", - candidate_verdict="candidate_viable", - ), - tool.ComparisonRow( - workload_id="bio-1", - baseline_config_id="bio-default", - candidate_config_id="bio-no-augment", - baseline_case_count=1, - candidate_case_count=1, - value_protection_verdict="fail", - signature_parity_verdict="fail", - safety_verdict="fail", - performance_verdict="improved", - candidate_verdict="reject", - baseline_only_final_entity_signature_label_counts={"first_name": 2}, - ), - tool.ComparisonRow( - workload_id="shell-1", - baseline_config_id="shell-default", - candidate_config_id="shell-detector-only", - baseline_case_count=1, - candidate_case_count=1, - value_protection_verdict="review", - signature_parity_verdict="pass", - safety_verdict="review", - performance_verdict="improved", - candidate_verdict="review", - ), - ] - - summary = tool.summarize_comparisons(rows) - - assert summary.comparison_count == 3 - assert summary.value_protection_verdict_counts == {"fail": 1, "pass": 1, "review": 1} - assert summary.signature_parity_verdict_counts == {"fail": 1, "pass": 2, "review": 0} - assert summary.safety_verdict_counts == {"fail": 1, "pass": 1, "review": 1} - assert summary.performance_verdict_counts["improved"] == 3 - assert summary.candidate_verdict_counts == {"candidate_viable": 1, "reject": 1, "review": 1} - assert summary.candidate_viable_workloads == ["legal-1"] - assert summary.review_workloads == ["shell-1"] - assert summary.rejected_workloads == ["bio-1"] - - rendered = tool.render_result( - tool.ComparisonResult( - input_path="case_analysis.csv", - baseline_selector="strategy:default", - candidate_selector="strategy:no_augment", - summary=summary, - comparisons=rows, - ), - json_output=False, - ) - assert "Compared 3 workload(s): viable=1, review=1, reject=1" in rendered - assert ( - "- bio-1: verdict=reject (safety=fail, value_protection=fail, signature_parity=fail, performance=improved)" - ) in rendered - assert "elapsed unknown->unknown" in rendered - assert "lost_labels=first_name:2" in rendered diff --git a/tests/tools/test_dd_parser_compat.py b/tests/tools/test_dd_parser_compat.py deleted file mode 100644 index b9b0241c..00000000 --- a/tests/tools/test_dd_parser_compat.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -from data_designer.engine.models.recipes import response_recipes as recipes -from pydantic import BaseModel - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -class TinyPayload(BaseModel): - value: str - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -def test_raw_json_compat_accepts_pydantic_raw_json_and_restores_parser() -> None: - tool = load_tool("measurement_dd_parser_compat_pydantic", REPO_ROOT / "tools/measurement/dd_parser_compat.py") - original = recipes.PydanticResponseRecipe._build_parser_fn - - with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): - recipe = recipes.PydanticResponseRecipe(TinyPayload) - parsed = recipe._build_parser_fn()('{"value": "ok"}') - - assert parsed.value == "ok" - assert recipes.PydanticResponseRecipe._build_parser_fn is original - - -def test_raw_json_compat_accepts_pydantic_json_after_reasoning_prefix() -> None: - tool = load_tool( - "measurement_dd_parser_compat_pydantic_reasoning", REPO_ROOT / "tools/measurement/dd_parser_compat.py" - ) - - with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): - recipe = recipes.PydanticResponseRecipe(TinyPayload) - parsed = recipe.parse('reasoning text\n\n\n\n{"value": "ok"}') - - assert parsed.value == "ok" - - -def test_raw_json_compat_accepts_structured_raw_json_and_restores_parser() -> None: - tool = load_tool("measurement_dd_parser_compat_structured", REPO_ROOT / "tools/measurement/dd_parser_compat.py") - original = recipes.StructuredResponseRecipe._build_parser_fn - schema = { - "type": "object", - "properties": {"value": {"type": "string"}}, - "required": ["value"], - } - - with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): - recipe = recipes.StructuredResponseRecipe(schema) - parsed = recipe._build_parser_fn()('{"value": "ok"}') - - assert parsed == {"value": "ok"} - assert recipes.StructuredResponseRecipe._build_parser_fn is original - - -def test_raw_json_compat_uses_outermost_embedded_json_object() -> None: - tool = load_tool("measurement_dd_parser_compat_outer_json", REPO_ROOT / "tools/measurement/dd_parser_compat.py") - schema = { - "type": "object", - "properties": { - "items": { - "type": "array", - "items": { - "type": "object", - "properties": {"value": {"type": "string"}}, - "required": ["value"], - }, - } - }, - "required": ["items"], - } - - with tool.dd_parser_compat_context(tool.DDParserCompatMode.raw_json): - recipe = recipes.StructuredResponseRecipe(schema) - parsed = recipe.parse('text before {"items": [{"value": "first"}, {"value": "last"}]}') - - assert parsed == {"items": [{"value": "first"}, {"value": "last"}]} diff --git a/tests/tools/test_dd_trace_analysis.py b/tests/tools/test_dd_trace_analysis.py deleted file mode 100644 index 82bb5457..00000000 --- a/tests/tools/test_dd_trace_analysis.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import json -import sys -from pathlib import Path -from types import ModuleType - -import pandas as pd - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: - path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") - - -def test_analyze_dd_traces_classifies_response_shape_without_raw_content(tmp_path: Path) -> None: - tool = load_tool("measurement_dd_trace_analysis", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") - trace_path = tmp_path / "trace.jsonl" - _write_jsonl( - trace_path, - [ - { - "record_type": "dd_message_trace", - "run_id": "case-1", - "run_tags": { - "suite_id": "suite", - "workload_id": "legal", - "config_id": "default", - "case_id": "case-1", - }, - "workflow_name": "entity-detection", - "model_alias": "validator", - "model_name": "nvidia/nemotron-3-super", - "model_provider_name": "local-vllm", - "status": "completed", - "error_type": None, - "elapsed_sec": 3.0, - "messages": [{"role": "user", "content": [{"type": "text", "text": "secret prompt text"}]}], - "response": {"content": '{"decisions": []}', "reasoning_content": None, "tool_calls": []}, - "usage": {"input_tokens": 11, "output_tokens": 7, "total_tokens": 18}, - }, - { - "record_type": "dd_message_trace", - "run_id": "case-1", - "run_tags": { - "suite_id": "suite", - "workload_id": "legal", - "config_id": "default", - "case_id": "case-1", - }, - "workflow_name": "entity-detection", - "model_alias": "augmenter", - "model_name": "nvidia/nemotron-3-super", - "model_provider_name": "local-vllm", - "status": "completed", - "error_type": None, - "elapsed_sec": 4.0, - "messages": [{"role": "user", "content": [{"type": "text", "text": "another secret prompt"}]}], - "response": { - "content": 'reasoning\n\n```json\n{"entities": []}\n```', - "reasoning_content": None, - "tool_calls": [], - }, - "usage": {"input_tokens": 13, "output_tokens": 17, "total_tokens": 30}, - }, - { - "record_type": "other", - "response": {"content": "ignored raw text"}, - }, - ], - ) - - result = tool.analyze_trace_path(trace_path) - - assert result.trace_record_count == 2 - rows = {row.model_alias: row for row in result.rows} - assert rows["validator"].response_shape == "raw_json" - assert rows["validator"].response_has_embedded_json is True - assert rows["validator"].prompt_chars == len("secret prompt text") - assert rows["augmenter"].response_shape == "fenced_json" - assert rows["augmenter"].response_has_thinking is True - assert rows["augmenter"].response_chars > 0 - serialized = result.model_dump_json() - assert "secret prompt text" not in serialized - assert "another secret prompt" not in serialized - assert "decisions" not in serialized - assert "entities" not in serialized - - group = next(row for row in result.groups if row.response_shape == "fenced_json") - assert group.model_alias == "augmenter" - assert group.model_name == "nvidia/nemotron-3-super" - assert group.trace_record_count == 1 - assert group.sum_total_tokens == 30 - assert group.error_count == 0 - - -def test_analyze_dd_traces_reads_trace_directory_and_exports_tables(tmp_path: Path) -> None: - tool = load_tool("measurement_dd_trace_analysis_export", REPO_ROOT / "tools/measurement/analyze_dd_traces.py") - trace_dir = tmp_path / "traces" - trace_dir.mkdir() - _write_jsonl( - trace_dir / "a.jsonl", - [ - { - "record_type": "dd_message_trace", - "run_id": "case-a", - "workflow_name": "entity-detection", - "model_name": "nvidia/nemotron-3-super", - "status": "error", - "error_type": "ParserException", - "elapsed_sec": 1.0, - "response": None, - "usage": None, - } - ], - ) - _write_jsonl( - trace_dir / "b.jsonl", - [ - { - "record_type": "dd_message_trace", - "run_id": "case-b", - "workflow_name": "entity-detection", - "model_name": "nvidia/gliner-pii", - "status": "completed", - "elapsed_sec": 0.2, - "response": {"content": "plain text", "reasoning_content": None, "tool_calls": []}, - "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, - } - ], - ) - - result = tool.analyze_trace_path(trace_dir) - output_dir = tmp_path / "analysis" - tool.write_analysis_tables(result, output_dir, tool.ExportFormat.csv) - - assert result.trace_record_count == 2 - rows = {row.run_id: row for row in result.rows} - assert rows["case-a"].status == "error" - assert rows["case-a"].response_shape == "none" - assert rows["case-b"].response_shape == "text" - assert pd.read_csv(output_dir / "trace_analysis.csv")["run_id"].tolist() == ["case-a", "case-b"] - assert pd.read_csv(output_dir / "trace_group_analysis.csv")["trace_record_count"].sum() == 2 - assert (output_dir / "manifest.json").exists() diff --git a/tests/tools/test_detection_strategies.py b/tests/tools/test_detection_strategies.py deleted file mode 100644 index 5ef9affd..00000000 --- a/tests/tools/test_detection_strategies.py +++ /dev/null @@ -1,1093 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import json -import sys -import threading -import time -from pathlib import Path -from types import ModuleType, SimpleNamespace -from unittest.mock import Mock - -import pandas as pd - -from anonymizer.engine.constants import ( - COL_DETECTED_ENTITIES, - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_TAG_NOTATION, - COL_TAGGED_TEXT, - COL_TEXT, -) -from anonymizer.engine.detection.detection_workflow import EntityDetectionWorkflow -from anonymizer.engine.ndd.model_loader import load_default_model_selection -from anonymizer.engine.schemas import EntitiesSchema -from anonymizer.measurement import MeasurementCollector, measurement_session - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def test_native_candidate_validate_no_augment_strategy_skips_data_designer_and_augmentation() -> None: - tool = load_tool( - "measurement_detection_strategies_native_candidate_validate", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SequencedClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - self.outputs = [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - ] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content=self.outputs.pop(0), - elapsed_sec=0.1, - usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, - ) - - adapter = Mock() - client = SequencedClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=False, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "direct_seed"), - ] - assert len(client.prompts) == 2 - assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-native-candidate-validate-no-augment" - assert record["observed_total_requests"] == 2 - assert record["observed_total_tokens"] == 28 - - -def test_detector_native_validate_no_augment_strategy_reuses_detector_seed_and_direct_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_native_validate", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice works at NVIDIA." - seed_row = { - COL_TEXT: text, - COL_RAW_DETECTED: "", - COL_SEED_ENTITIES: EntitiesSchema( - entities=[ - { - "id": "first_name_0_5", - "value": "Alice", - "label": "first_name", - "start_position": 0, - "end_position": 5, - "score": 0.9, - "source": "detector", - } - ] - ).model_dump(mode="json"), - COL_TAG_NOTATION: "xml", - } - tool.prepare_validation_inputs(seed_row) - - class ValidationClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) - client = ValidationClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_called_once() - assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( - "entity-detection-detector-native-validate-no-augment-seed" - ) - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "detector"), - ] - assert len(client.prompts) == 1 - assert all("Find additional sensitive entities" not in prompt for prompt in client.prompts) - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - - -def test_detector_native_validate_no_augment_ignores_invalid_reclass_labels() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_native_validate_invalid_label", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice works at NVIDIA." - seed_row = { - COL_TEXT: text, - COL_RAW_DETECTED: "", - COL_SEED_ENTITIES: EntitiesSchema( - entities=[ - { - "id": "first_name_0_5", - "value": "Alice", - "label": "first_name", - "start_position": 0, - "end_position": 5, - "score": 0.9, - "source": "detector", - } - ] - ).model_dump(mode="json"), - COL_TAG_NOTATION: "xml", - } - tool.prepare_validation_inputs(seed_row) - - class ValidationClient: - def complete(self, request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content=( - '{"decisions": [' - '{"id": "first_name_0_5", "decision": "reclass", ' - '"proposed_label": "drop", "reason": "invalid label"}' - "]}" - ), - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, - native_client=ValidationClient(), - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "detector"), - ] - - -def test_detector_native_validate_native_augment_uses_direct_validation_and_augmentation() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_native_augment", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice works at NVIDIA." - seed_row = { - COL_TEXT: text, - COL_RAW_DETECTED: "", - COL_SEED_ENTITIES: EntitiesSchema( - entities=[ - { - "id": "first_name_0_5", - "value": "Alice", - "label": "first_name", - "start_position": 0, - "end_position": 5, - "score": 0.9, - "source": "detector", - } - ] - ).model_dump(mode="json"), - COL_TAG_NOTATION: "xml", - } - tool.prepare_validation_inputs(seed_row) - - class DirectClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - if "Find additional sensitive entities" in request.prompt: - return SimpleNamespace( - content='{"entities": [{"value": "NVIDIA", "label": "organization_name"}]}', - elapsed_sec=0.2, - usage={"prompt_tokens": 50, "completion_tokens": 6, "total_tokens": 56}, - ) - return SimpleNamespace( - content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - adapter.run_workflow.return_value = SimpleNamespace(dataframe=pd.DataFrame([seed_row]), failed_records=[]) - client = DirectClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_called_once() - assert adapter.run_workflow.call_args.kwargs["workflow_name"] == ( - "entity-detection-detector-native-validate-native-augment-seed" - ) - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "detector"), - ("organization_name", "NVIDIA", "augmenter"), - ] - assert len(client.prompts) == 2 - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - assert records[0]["workflow_name"] == "entity-detection-detector-native-validate-native-augment" - assert records[0]["observed_total_requests"] == 2 - assert records[0]["observed_total_tokens"] == 104 - - -def test_detector_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_native_validate_parallel", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - def seed_row(text: str, value: str) -> dict[str, object]: - row = { - COL_TEXT: text, - COL_RAW_DETECTED: "", - COL_SEED_ENTITIES: EntitiesSchema( - entities=[ - { - "id": f"first_name_0_{len(value)}", - "value": value, - "label": "first_name", - "start_position": 0, - "end_position": len(value), - "score": 0.9, - "source": "detector", - } - ] - ).model_dump(mode="json"), - COL_TAG_NOTATION: "xml", - } - tool.prepare_validation_inputs(row) - return row - - class ValidationClient: - def __init__(self) -> None: - self.lock = threading.Lock() - self.active_count = 0 - self.max_active_count = 0 - - def complete(self, request): # type: ignore[no-untyped-def] - with self.lock: - self.active_count += 1 - self.max_active_count = max(self.max_active_count, self.active_count) - time.sleep(0.05) - with self.lock: - self.active_count -= 1 - candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" - return SimpleNamespace( - content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - seed_df = pd.DataFrame( - [ - seed_row("Alice works at NVIDIA.", "Alice"), - seed_row("Bob works at NVIDIA.", "Bob"), - ], - index=[10, 4], - ) - adapter = Mock() - adapter.run_workflow.return_value = SimpleNamespace(dataframe=seed_df, failed_records=[]) - client = ValidationClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - assert client.max_active_count > 1 - assert list(result.dataframe.index) == [10, 4] - entities_by_row = [ - [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] - for raw in result.dataframe[COL_DETECTED_ENTITIES] - ] - assert entities_by_row == [ - [("first_name", "Alice", "detector")], - [("first_name", "Bob", "detector")], - ] - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert [record["observed_total_requests"] for record in records] == [1, 1] - assert [record["observed_total_tokens"] for record in records] == [48, 48] - record = records[0] - assert record["workflow_name"] == "entity-detection-detector-native-validate-no-augment" - assert record["observed_total_requests"] == 1 - assert record["observed_total_tokens"] == 48 - - -def test_gliner_native_validate_no_augment_strategy_bypasses_data_designer() -> None: - tool = load_tool( - "measurement_detection_strategies_gliner_native_validate", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice works at NVIDIA." - - class GlinerSeedClient: - def detect(self, request): # type: ignore[no-untyped-def] - assert request.text == text - return SimpleNamespace( - content=json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "first_name", - "start": 0, - "end": 5, - "score": 0.99, - } - ] - } - ), - elapsed_sec=0.2, - usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, - ) - - class ValidationClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - seed_client = GlinerSeedClient() - validation_client = ValidationClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - native_client=validation_client, - gliner_seed_client=seed_client, - native_runtime=tool.NativeDetectionRuntime( - model="test/native", - provider="test-native-provider", - gliner_model="test/gliner", - gliner_provider="test-gliner-provider", - ), - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "detector"), - ] - assert len(validation_client.prompts) == 1 - assert all("Find additional sensitive entities" not in prompt for prompt in validation_client.prompts) - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-gliner-native-validate-no-augment" - assert record["observed_total_requests"] == 2 - assert record["observed_total_tokens"] == 73 - assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] - assert record["model_usage"]["gliner-direct"]["model_name"] == "test/gliner" - assert record["model_usage"]["gliner-direct"]["model_provider_name"] == "test-gliner-provider" - assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 - assert record["model_usage"]["native-direct"]["model_name"] == "test/native" - assert record["model_usage"]["native-direct"]["model_provider_name"] == "test-native-provider" - assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 48 - - -def test_gliner_native_validate_no_augment_parallel_rows_preserve_order_and_measurements() -> None: - tool = load_tool( - "measurement_detection_strategies_gliner_native_parallel", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class GlinerSeedClient: - def detect(self, request): # type: ignore[no-untyped-def] - value = str(request.text).split()[0] - return SimpleNamespace( - content=json.dumps( - { - "entities": [ - { - "text": value, - "label": "first_name", - "start": 0, - "end": len(value), - "score": 0.99, - } - ] - } - ), - elapsed_sec=0.2, - usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, - ) - - class ValidationClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - candidate_id = "first_name_0_3" if '"value":"Bob"' in request.prompt else "first_name_0_5" - return SimpleNamespace( - content=json.dumps({"decisions": [{"id": candidate_id, "decision": "keep", "reason": "real name"}]}), - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - collector = MeasurementCollector(record_hash_key="test-key") - dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Bob works at NVIDIA."]}, index=[10, 4]) - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - native_client=ValidationClient(), - gliner_seed_client=GlinerSeedClient(), - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - dataframe, - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - assert list(result.dataframe.index) == [10, 4] - entities_by_row = [ - [(entity.label, entity.value, entity.source) for entity in EntitiesSchema.from_raw(raw).entities] - for raw in result.dataframe[COL_DETECTED_ENTITIES] - ] - assert entities_by_row == [ - [("first_name", "Alice", "detector")], - [("first_name", "Bob", "detector")], - ] - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert [record["observed_total_requests"] for record in records] == [2, 2] - assert [record["observed_total_tokens"] for record in records] == [73, 73] - - -def test_gliner_native_validate_no_augment_parallel_rows_keep_failed_records() -> None: - tool = load_tool( - "measurement_detection_strategies_gliner_native_parallel_failure", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class GlinerSeedClient: - def detect(self, request): # type: ignore[no-untyped-def] - if str(request.text).startswith("Broken"): - raise RuntimeError("seed unavailable") - return SimpleNamespace( - content=json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "first_name", - "start": 0, - "end": 5, - "score": 0.99, - } - ] - } - ), - elapsed_sec=0.2, - usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, - ) - - class ValidationClient: - def complete(self, _request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content='{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - collector = MeasurementCollector(record_hash_key="test-key") - dataframe = pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA.", "Broken row"]}, index=[0, 1]) - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - native_client=ValidationClient(), - gliner_seed_client=GlinerSeedClient(), - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - dataframe, - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name"], - ) - - assert list(result.dataframe.index) == [0] - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("first_name", "Alice")] - assert [(failed.record_id, failed.step) for failed in result.failed_records] == [ - ("1", "entity-detection-gliner-native-validate-no-augment") - ] - assert "seed unavailable" in result.failed_records[0].reason - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - assert records[0]["observed_total_requests"] == 2 - - -def test_gliner_native_validate_native_augment_strategy_bypasses_data_designer() -> None: - tool = load_tool( - "measurement_detection_strategies_gliner_native_augment", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice works at NVIDIA." - - class GlinerSeedClient: - def detect(self, request): # type: ignore[no-untyped-def] - assert request.text == text - return SimpleNamespace( - content=json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "first_name", - "start": 0, - "end": 5, - "score": 0.99, - } - ] - } - ), - elapsed_sec=0.2, - usage={"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, - ) - - class NativeClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - self.outputs = [ - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content=self.outputs.pop(0), - elapsed_sec=0.1, - usage={"prompt_tokens": 40, "completion_tokens": 8, "total_tokens": 48}, - ) - - adapter = Mock() - seed_client = GlinerSeedClient() - native_client = NativeClient() - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment, - native_client=native_client, - gliner_seed_client=seed_client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - validation_single_chunk_full_text=True, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "detector"), - ("organization_name", "NVIDIA", "augmenter"), - ] - assert any("Find additional sensitive entities" in prompt for prompt in native_client.prompts) - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-gliner-native-validate-native-augment" - assert record["observed_total_requests"] == 3 - assert record["observed_total_tokens"] == 121 - assert sorted(record["model_usage"]) == ["gliner-direct", "native-direct"] - assert record["model_usage"]["gliner-direct"]["token_usage"]["total_tokens"] == 25 - assert record["model_usage"]["native-direct"]["token_usage"]["total_tokens"] == 96 - - -def test_native_single_pass_strategy_runs_one_direct_call_without_data_designer() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SinglePassClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content=json.dumps( - { - "entities": [ - {"value": "Alice", "label": "first_name", "start": 0, "end": 5}, - {"value": "NVIDIA", "label": "organization_name", "start": 15, "end": 21}, - ] - } - ), - elapsed_sec=0.1, - usage={}, - ) - - adapter = Mock() - client = SinglePassClient() - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name", "organization_name"], - ) - - adapter.run_workflow.assert_not_called() - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("first_name", "Alice", "direct_single_pass"), - ("organization_name", "NVIDIA", "direct_single_pass"), - ] - assert len(client.prompts) == 1 - assert '"start"' in client.prompts[0] - assert '"end"' in client.prompts[0] - - -def test_native_single_pass_recall_strategy_uses_label_examples() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass_recall", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SinglePassClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', - elapsed_sec=0.1, - usage={}, - ) - - client = SinglePassClient() - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass_recall, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["person", "email"], - ) - - assert len(client.prompts) == 1 - assert "- person" in client.prompts[0] - assert "- email:" in client.prompts[0] - assert "Bias toward high recall" in client.prompts[0] - - -def test_native_single_pass_values_strategy_uses_value_only_prompt() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass_values", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SinglePassClient: - def __init__(self) -> None: - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - return SimpleNamespace( - content='{"entities": [{"value": "Alice", "label": "person"}]}', - elapsed_sec=0.1, - usage={}, - ) - - client = SinglePassClient() - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass_values, - native_client=client, - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice met Alice."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.value, entity.start_position, entity.end_position) for entity in entities] == [ - ("Alice", 0, 5), - ("Alice", 10, 15), - ] - assert len(client.prompts) == 1 - assert '"start"' not in client.prompts[0] - assert '"end"' not in client.prompts[0] - assert '{"entities": [{"value": "exact substring", "label": "one_allowed_label"' in client.prompts[0] - - -def test_native_single_pass_strategy_records_direct_model_usage() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass_usage", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class SinglePassClient: - def complete(self, _request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content='{"entities": [{"value": "Alice", "label": "first_name", "start": 0, "end": 5}]}', - elapsed_sec=0.1, - usage={"prompt_tokens": 20, "completion_tokens": 7, "total_tokens": 27}, - ) - - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass, - native_client=SinglePassClient(), - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name"], - ) - - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-native-single-pass" - assert record["observed_total_requests"] == 1 - assert record["observed_successful_requests"] == 1 - assert record["observed_failed_requests"] == 0 - assert record["observed_input_tokens"] == 20 - assert record["observed_output_tokens"] == 7 - assert record["observed_total_tokens"] == 27 - - -def test_native_single_pass_strategy_uses_only_native_spans() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass_native_spans", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - text = "Alice logged in.\nPassword: SuperSecret123!\n" - - class SinglePassClient: - def complete(self, _request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content='{"entities": [{"value": "Alice", "label": "person", "start": 0, "end": 5}]}', - elapsed_sec=0.1, - usage={}, - ) - - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass, - native_client=SinglePassClient(), - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["person", "password"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value, entity.source) for entity in entities] == [ - ("person", "Alice", "direct_single_pass"), - ] - - -def test_native_single_pass_strategy_records_parser_errors_as_failures() -> None: - tool = load_tool( - "measurement_detection_strategies_native_single_pass_parser_error", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - - class InvalidJsonClient: - def complete(self, _request): # type: ignore[no-untyped-def] - return SimpleNamespace( - content="not json", - elapsed_sec=0.1, - usage={"prompt_tokens": 9, "completion_tokens": 2, "total_tokens": 11}, - ) - - collector = MeasurementCollector(record_hash_key="test-key") - - with measurement_session(collector): - with tool.experimental_detection_strategy_context( - tool.ExperimentalDetectionStrategy.native_single_pass, - native_client=InvalidJsonClient(), - ): - workflow = EntityDetectionWorkflow(adapter=Mock()) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at NVIDIA."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name"], - ) - - assert result.dataframe.empty - assert len(result.failed_records) == 1 - assert result.failed_records[0].step == "entity-detection-native-single-pass" - records = [record for record in collector.records if record["record_type"] == "model_workflow"] - assert len(records) == 1 - record = records[0] - assert record["workflow_name"] == "entity-detection-native-single-pass" - assert record["status"] == "error" - assert record["failed_record_count"] == 1 - assert record["observed_successful_requests"] == 1 - assert record["observed_failed_requests"] == 0 - - -def test_detector_only_strategy_finalizes_gliner_spans_without_validation_or_augmentation() -> None: - tool = load_tool( - "measurement_detection_strategies_detector_only", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - adapter = Mock() - text = "Alice met Alice at the lab." - start = text.index("Alice") - - def fake_run_workflow(dataframe: pd.DataFrame, *, columns: list, **kwargs: object) -> SimpleNamespace: - assert [column.name for column in columns] == [ - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_DETECTED_ENTITIES, - ] - assert kwargs["workflow_name"] == "entity-detection-detector-only" - row = { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_RAW_DETECTED: json.dumps( - { - "entities": [ - { - "text": "Alice", - "label": "person", - "start": start, - "end": start + len("Alice"), - "score": 0.99, - } - ] - } - ), - } - row = columns[1].generator_function(row) - row = columns[2].generator_function(row) - seed_entities = json.loads(row[COL_SEED_ENTITIES_JSON]) - assert [(entity["label"], entity["value"]) for entity in seed_entities] == [("person", "Alice")] - row = columns[3].generator_function(row) - return SimpleNamespace(dataframe=pd.DataFrame([row]), failed_records=[]) - - adapter.run_workflow.side_effect = fake_run_workflow - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.detector_only): - workflow = EntityDetectionWorkflow(adapter=adapter) - result = workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: [text]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["person"], - ) - - entities = EntitiesSchema.from_raw(result.dataframe[COL_DETECTED_ENTITIES].iloc[0]).entities - assert [(entity.label, entity.value) for entity in entities] == [("person", "Alice"), ("person", "Alice")] - assert "Alice" in result.dataframe[COL_TAGGED_TEXT].iloc[0] - - -def test_compact_validation_strategy_disables_full_text_single_chunk_validation() -> None: - tool = load_tool( - "measurement_detection_strategies_compact_validation", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - original = EntityDetectionWorkflow.detect_and_validate_entities - calls = [] - - def fake_original( - self: EntityDetectionWorkflow, - dataframe: pd.DataFrame, - **kwargs: object, - ) -> object: - calls.append(kwargs) - return SimpleNamespace( - dataframe=pd.DataFrame( - [ - { - COL_TEXT: dataframe[COL_TEXT].iloc[0], - COL_DETECTED_ENTITIES: EntitiesSchema(entities=[]).model_dump(mode="json"), - } - ] - ), - failed_records=[], - ) - - EntityDetectionWorkflow.detect_and_validate_entities = fake_original # type: ignore[method-assign] - try: - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.compact_validation): - workflow = EntityDetectionWorkflow(adapter=Mock()) - workflow.detect_and_validate_entities( - pd.DataFrame({COL_TEXT: ["Alice works at Acme."]}), - model_configs=[], - selected_models=load_default_model_selection().detection, - gliner_detection_threshold=0.3, - entity_labels=["first_name"], - ) - finally: - EntityDetectionWorkflow.detect_and_validate_entities = original # type: ignore[method-assign] - - assert EntityDetectionWorkflow.detect_and_validate_entities is original - assert calls[0]["validation_single_chunk_full_text"] is False - - -def test_prose_augment_focus_extends_and_restores_augment_prompt() -> None: - tool = load_tool( - "measurement_detection_strategies_prose_augment_focus", - REPO_ROOT / "tools/measurement/detection_strategies.py", - ) - before = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) - - with tool.experimental_detection_strategy_context(tool.ExperimentalDetectionStrategy.prose_augment_focus): - inside = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) - - after = tool.dw._get_augment_prompt(data_summary=None, labels=["organization_name"], strict_labels=True) - assert "Contextual prose recall focus" not in before - assert "Contextual prose recall focus" in inside - assert "organization and institution names" in inside - assert after == before diff --git a/tests/tools/test_direct_detection_probe.py b/tests/tools/test_direct_detection_probe.py deleted file mode 100644 index c37960d6..00000000 --- a/tests/tools/test_direct_detection_probe.py +++ /dev/null @@ -1,147 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -class FakeClient: - def complete(self, request): # type: ignore[no-untyped-def] - module = sys.modules["measurement_direct_detection_probe_case"] - return module.DirectCompletion( - content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - elapsed_sec=1.25, - usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, - ) - - -def test_direct_detection_case_uses_direct_client_and_canonical_artifacts() -> None: - tool = load_tool( - "measurement_direct_detection_probe_case", - REPO_ROOT / "tools/measurement/direct_detection_probe.py", - ) - - result = tool.run_direct_detection_case( - tool.DirectDetectionRequest( - case_id="case-1", - text="Alice met Alice.", - labels=["first_name"], - row_index=0, - prompt_mode=tool.PromptMode.compact, - ), - client=FakeClient(), - ) - - assert result.status == tool.CaseStatus.completed - assert result.elapsed_sec == 1.25 - assert result.usage == {"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14} - assert result.allowed_suggestion_count == 1 - assert result.final_entity_count == 2 - assert result.final_entity_signature_count == 2 - assert result.final_label_counts == {"first_name": 2} - assert result.artifact.final_source_counts == {"direct_llm": 2} - - -def test_finalize_suggestions_filters_labels_and_deduplicates_signature_hashes() -> None: - tool = load_tool( - "measurement_direct_detection_probe_finalize", - REPO_ROOT / "tools/measurement/direct_detection_probe.py", - ) - - artifact = tool.finalize_direct_suggestions( - text="Alice works at NVIDIA.", - suggestions=[ - {"value": "Alice", "label": "first_name"}, - {"value": "NVIDIA", "label": "organization_name"}, - {"value": "works", "label": "unsupported"}, - ], - labels=["first_name", "organization_name"], - row_index=0, - workflow_name="direct-detection", - ) - - assert artifact.final_entity_count == 2 - assert artifact.final_label_counts == {"first_name": 1, "organization_name": 1} - assert artifact.weak_api_key_shape_count == 0 - assert set(artifact.final_entity_signature_labels.values()) == {"first_name", "organization_name"} - - -def test_signature_comparison_counts_shared_baseline_only_and_direct_only_labels() -> None: - tool = load_tool( - "measurement_direct_detection_probe_comparison", - REPO_ROOT / "tools/measurement/direct_detection_probe.py", - ) - - comparison = tool.compare_signature_sets( - baseline_hashes={"shared", "baseline-only"}, - baseline_labels={"shared": "person", "baseline-only": "date"}, - direct_hashes={"shared", "direct-only"}, - direct_labels={"shared": "person", "direct-only": "city"}, - ) - - assert comparison.shared_final_entity_signature_count == 1 - assert comparison.baseline_only_final_entity_signature_count == 1 - assert comparison.direct_only_final_entity_signature_count == 1 - assert comparison.baseline_only_label_counts == {"date": 1} - assert comparison.direct_only_label_counts == {"city": 1} - - -def test_baseline_comparison_skips_rows_without_signature_hashes() -> None: - tool = load_tool( - "measurement_direct_detection_probe_missing_signatures", - REPO_ROOT / "tools/measurement/direct_detection_probe.py", - ) - - class LocalFakeClient: - def complete(self, request): # type: ignore[no-untyped-def] - return tool.DirectCompletion( - content='{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - elapsed_sec=1.25, - usage={"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}, - ) - - case = tool.run_direct_detection_case( - tool.DirectDetectionRequest( - case_id="case-1", - text="Alice met Alice.", - labels=["first_name"], - row_index=0, - prompt_mode=tool.PromptMode.compact, - ), - client=LocalFakeClient(), - ) - - compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 2}) - - assert compared.comparison is None - - -def test_baseline_artifact_reader_rejects_duplicate_row_indexes(tmp_path: Path) -> None: - tool = load_tool( - "measurement_direct_detection_probe_duplicate_baseline", - REPO_ROOT / "tools/measurement/direct_detection_probe.py", - ) - baseline_path = tmp_path / "baseline.jsonl" - baseline_path.write_text('{"row_index": 0}\n{"row_index": 0}\n', encoding="utf-8") - - with pytest.raises(ValueError, match="multiple rows for row_index=0"): - tool._read_baseline_artifacts(baseline_path) diff --git a/tests/tools/test_extract_signature_deltas.py b/tests/tools/test_extract_signature_deltas.py deleted file mode 100644 index 922d73eb..00000000 --- a/tests/tools/test_extract_signature_deltas.py +++ /dev/null @@ -1,190 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -import pandas as pd - -from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def _entity(value: str, label: str, start: int, end: int, source: str = "detector") -> dict[str, object]: - return {"value": value, "label": label, "start_position": start, "end_position": end, "source": source} - - -def _write_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: - parquet_file = root / "entity-detection" / "parquet-files" / "batch_00000.parquet" - parquet_file.parent.mkdir(parents=True) - pd.DataFrame([{COL_TEXT: text, COL_DETECTED_ENTITIES: {"entities": entities}}]).to_parquet(parquet_file) - artifact_path = root / "detection-artifacts.jsonl" - row = tool.build_detection_artifact_row_from_entities( - workflow_name="entity-detection", - batch_file="entity-detection/parquet-files/batch_00000.parquet", - row_index=0, - seed_entities=[], - seed_validation_candidate_count=0, - merged_validation_candidate_count=0, - augmented_entities=[], - final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], - ).model_dump() - pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) - return artifact_path - - -def _write_seed_only_artifact_case(root: Path, tool: ModuleType, entities: list[dict[str, object]], text: str) -> Path: - parquet_file = root / "entity-detection-seed" / "parquet-files" / "batch_00000.parquet" - parquet_file.parent.mkdir(parents=True) - pd.DataFrame([{COL_TEXT: text}]).to_parquet(parquet_file) - artifact_path = root / "detection-artifacts.jsonl" - row = tool.build_detection_artifact_row_from_entities( - workflow_name="entity-detection-seed", - batch_file="entity-detection-seed/parquet-files/batch_00000.parquet", - row_index=0, - seed_entities=[], - seed_validation_candidate_count=0, - merged_validation_candidate_count=0, - augmented_entities=[], - final_entities=[tool.EntitySchema.model_validate(entity) for entity in entities], - ).model_dump() - pd.json_normalize([{**_case_metadata(), **row}], sep=".").to_json(artifact_path, orient="records", lines=True) - return artifact_path - - -def _case_metadata() -> dict[str, object]: - return { - "suite_id": "suite", - "workload_id": "bio", - "config_id": "config", - "case_id": "bio__config__r000", - "run_id": "bio__config__r000", - "repetition": 0, - } - - -def test_extract_signature_deltas_masks_candidate_only_context(tmp_path: Path) -> None: - analyzer = load_tool( - "measurement_detection_artifact_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" - ) - tool = load_tool( - "measurement_extract_signature_deltas", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" - ) - baseline_root = tmp_path / "baseline" - candidate_root = tmp_path / "candidate" - baseline = _write_artifact_case(baseline_root, analyzer, [_entity("Alice", "person", 0, 5)], "Alice met NASA") - candidate = _write_artifact_case( - candidate_root, - analyzer, - [_entity("Alice", "person", 0, 5), _entity("NASA", "organization_name", 10, 14, "augmenter")], - "Alice met NASA", - ) - - result = tool.extract_signature_deltas( - baseline, - candidate, - baseline_artifact_root=baseline_root, - candidate_artifact_root=candidate_root, - ) - - assert result.delta_count == 1 - row = result.rows[0] - assert row.side == "candidate_only" - assert row.label == "organization_name" - assert row.source == "augmenter" - assert row.resolution == "parquet" - assert "NASA" not in row.masked_context - assert "[ORGANIZATION_NAME:" in row.masked_context - - -def test_extract_signature_deltas_recovers_artifact_detail_context(tmp_path: Path) -> None: - analyzer = load_tool( - "measurement_detection_artifact_context_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" - ) - tool = load_tool( - "measurement_extract_signature_deltas_context", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" - ) - baseline_root = tmp_path / "baseline" - candidate_root = tmp_path / "candidate" - baseline = _write_artifact_case(baseline_root, analyzer, [], "The applicant was born in 1990.") - candidate = _write_artifact_case(candidate_root, analyzer, [], "The applicant was born in 1990.") - detail_entity = analyzer.EntitySchema( - value="1990", label="date_of_birth", start_position=26, end_position=30, source="rule" - ) - detail_row = analyzer.build_detection_artifact_row_from_entities( - workflow_name="entity-detection", - batch_file="entity-detection/parquet-files/batch_00000.parquet", - row_index=0, - seed_entities=[], - seed_validation_candidate_count=0, - merged_validation_candidate_count=0, - augmented_entities=[], - final_entities=[detail_entity], - ).model_dump() - pd.json_normalize([{**_case_metadata(), **detail_row}], sep=".").to_json(candidate, orient="records", lines=True) - - result = tool.extract_signature_deltas( - baseline, - candidate, - baseline_artifact_root=baseline_root, - candidate_artifact_root=candidate_root, - ) - - assert result.delta_count == 1 - row = result.rows[0] - assert row.label == "date_of_birth" - assert row.source == "rule" - assert row.resolution == "artifact_details" - assert "1990" not in row.masked_context - assert "[DATE_OF_BIRTH:" in row.masked_context - - -def test_extract_signature_deltas_uses_signature_details_when_final_parquet_is_unavailable(tmp_path: Path) -> None: - analyzer = load_tool( - "measurement_detection_artifact_detail_builder", REPO_ROOT / "tools/measurement/analyze_detection_artifacts.py" - ) - tool = load_tool( - "measurement_extract_signature_deltas_details", REPO_ROOT / "tools/measurement/extract_signature_deltas.py" - ) - baseline_root = tmp_path / "baseline" - candidate_root = tmp_path / "candidate" - baseline = _write_seed_only_artifact_case(baseline_root, analyzer, [], "Alice met NASA") - candidate = _write_seed_only_artifact_case( - candidate_root, - analyzer, - [_entity("NASA", "organization_name", 10, 14, "detector")], - "Alice met NASA", - ) - - result = tool.extract_signature_deltas( - baseline, - candidate, - baseline_artifact_root=baseline_root, - candidate_artifact_root=candidate_root, - ) - - assert result.delta_count == 1 - row = result.rows[0] - assert row.label == "organization_name" - assert row.source == "detector" - assert row.resolution == "artifact_details" - assert row.start_position == 10 - assert row.end_position == 14 - assert "NASA" not in row.masked_context - assert "[ORGANIZATION_NAME:" in row.masked_context diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index f27078df..7566974f 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -4,7 +4,6 @@ from __future__ import annotations import importlib.util -import json import sys from collections.abc import Iterator from contextlib import contextmanager @@ -18,7 +17,6 @@ from pydantic import ValidationError from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT -from anonymizer.engine.constants import COL_FINAL_ENTITIES REPO_ROOT = Path(__file__).resolve().parents[2] @@ -44,7 +42,6 @@ def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[ "trace_dir": tmp_path / "traces", "dd_task_trace": False, "task_trace_dir": tmp_path / "task-traces", - "dd_parser_compat": spec.dd_parser_compat, "artifact_path": tmp_path / "artifacts", } @@ -210,128 +207,6 @@ def test_benchmark_case_detection_artifact_analysis_adds_case_metadata(tmp_path: assert "sk-test" not in output_path.read_text(encoding="utf-8") -def _stale_detection_artifact_payload() -> dict[str, Any]: - return { - "workflow_name": "entity-detection", - "batch_file": "entity-detection/parquet-files/batch_00000.parquet", - "row_index": 0, - "seed_entity_count": 1, - "seed_validation_candidate_count": 1, - "merged_validation_candidate_count": 1, - "augmented_entity_count": 0, - "final_entity_count": 1, - "augmented_duplicate_seed_value_count": 0, - "augmented_new_value_count": 0, - "augmented_new_final_value_count": 0, - "weak_api_key_shape_count": 0, - "final_entity_signature_count": 1, - "final_entity_signature_hashes": ["stale"], - "final_entity_signature_labels": {"stale": "person"}, - "weak_api_key_shape_label_counts": {}, - "final_label_counts": {"person": 1}, - "final_source_counts": {"detector": 1}, - } - - -def _final_trace_dataframe_with_rule_entity() -> pd.DataFrame: - return pd.DataFrame( - { - COL_FINAL_ENTITIES: [ - { - "entities": [ - { - "value": "Alice", - "label": "person", - "start_position": 0, - "end_position": 5, - "source": "detector", - }, - { - "value": "1990", - "label": "date_of_birth", - "start_position": 25, - "end_position": 29, - "source": "rule", - }, - ] - } - ] - } - ) - - -def test_benchmark_patches_detection_artifacts_from_final_trace_dataframe(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_patch_result_artifacts", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - output_path = tmp_path / "raw" / "case.detection-artifacts.jsonl" - tool.write_detection_artifact_payloads([_stale_detection_artifact_payload()], output_path) - - result = tool.patch_case_detection_artifacts_from_trace_dataframe( - output_path, - _final_trace_dataframe_with_rule_entity(), - ) - - assert result == output_path - text = output_path.read_text(encoding="utf-8") - assert "Alice" not in text - assert "1990" not in text - row = json.loads(text) - assert row["final_entity_count"] == 2 - assert row["final_entity_signature_count"] == 2 - assert row["final_label_counts.person"] == 1 - assert row["final_label_counts.date_of_birth"] == 1 - assert row["final_source_counts.detector"] == 1 - assert row["final_source_counts.rule"] == 1 - assert any(key.startswith("final_entity_signature_details.") for key in row) - assert "final_entity_signature_labels.stale" not in row - - -def test_benchmark_no_export_disables_trace_detection_artifact_sidecar(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_no_export_trace_artifact", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - spec = tool.BenchmarkSpec( - suite_id="no-export-suite", - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[ - tool.ConfigSpec( - id="native-single-pass-redact", - replace="redact", - experimental_detection_strategy="native_single_pass", - ) - ], - ) - case = tool.BenchmarkCase( - suite_id="no-export-suite", - workload_id="input", - config_id="native-single-pass-redact", - repetition=0, - case_id="input__native-single-pass-redact__r000", - ) - contexts = _minimal_case_contexts(tool, spec, tmp_path) - paths = tool._case_run_paths(case, contexts=contexts, export_detection_artifacts=False) - input_path = tmp_path / "input.csv" - pd.DataFrame({"text": ["Alice"]}).to_csv(input_path, index=False) - execution = tool._CaseExecution( - input_data=tool.AnonymizerInput(source=str(input_path)), - trace_dataframe=_final_trace_dataframe_with_rule_entity(), - ) - - result = tool._case_detection_artifact_path( - contexts, - paths, - case=case, - config=spec.configs[0], - execution=execution, - ) - - assert result is None - assert not paths.artifact_output_path.exists() - - def test_run_suite_records_detection_artifact_analysis_path( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -462,15 +337,12 @@ def test_benchmark_case_retries_transient_errors_and_records_attempts( ) pd.DataFrame({"text": ["Alice"]}).to_csv(tmp_path / "input.csv", index=False) - def fake_execute_case(*_args: Any, raw_path: Path, **_kwargs: Any) -> Any: + def fake_execute_case(*_args: Any, raw_path: Path, **_kwargs: Any) -> None: attempts.append(raw_path) if len(attempts) == 1: raise RuntimeError("transient provider health check failure") raw_path.parent.mkdir(parents=True, exist_ok=True) raw_path.write_text('{"record_type":"run"}\n', encoding="utf-8") - return tool._CaseExecution( - input_data=tool.AnonymizerInput(source=str(tmp_path / "input.csv"), text_column="text") - ) monkeypatch.setattr(tool, "_execute_case", fake_execute_case) @@ -923,186 +795,6 @@ def test_benchmark_preflight_accepts_provider_config_path(tmp_path: Path) -> Non tool.preflight_suite(spec, spec_path=spec_path) -def test_benchmark_preflight_rejects_native_strategy_without_runtime( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_native_runtime_required", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) - monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) - _copy_biography_data(tmp_path) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: native-runtime-suite -workloads: - - id: input - source: input.csv - text_column: biography -configs: - - id: native-single-pass - experimental_detection_strategy: native_single_pass - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="native_runtime.runtime_id"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_native_runtime_config_override_merges_suite_defaults() -> None: - tool = load_tool( - "measurement_benchmark_tool_native_runtime_merge", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - config = tool.ConfigSpec( - id="native-single-pass", - experimental_detection_strategy="native_single_pass", - native_runtime=tool.NativeRuntimeSpec(model="config-model", max_workers=2), - replace="redact", - ) - spec = tool.BenchmarkSpec( - suite_id="native-runtime-suite", - native_runtime=tool.NativeRuntimeSpec( - runtime_id="suite-runtime", - endpoint="http://suite-endpoint/v1", - model="suite-model", - provider="suite-provider", - max_workers=4, - ), - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[config], - ) - - runtime = tool._native_detection_runtime(spec, config) - - assert runtime.endpoint == "http://suite-endpoint/v1" - assert runtime.model == "config-model" - assert runtime.provider == "suite-provider" - assert runtime.max_workers == 2 - - -def test_benchmark_preflight_rejects_native_strategy_without_endpoint_or_model( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_native_runtime_endpoint_required", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", raising=False) - monkeypatch.delenv("ANONYMIZER_BENCH_NATIVE_MODEL", raising=False) - _copy_biography_data(tmp_path) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: native-runtime-suite -native_runtime: - runtime_id: native-runtime -workloads: - - id: input - source: input.csv - text_column: biography -configs: - - id: native-single-pass - experimental_detection_strategy: native_single_pass - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - with pytest.raises(ValueError, match="native_runtime.endpoint"): - tool.preflight_suite(spec, spec_path=spec_path) - - -def test_benchmark_native_runtime_resolves_endpoint_and_model_from_env( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_native_runtime_env", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_ENDPOINT", "http://runtime-from-env/v1") - monkeypatch.setenv("ANONYMIZER_BENCH_NATIVE_MODEL", "env-model") - _copy_biography_data(tmp_path) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: native-runtime-suite -native_runtime: - runtime_id: env-runtime -workloads: - - id: input - source: input.csv - text_column: biography -configs: - - id: native-single-pass - experimental_detection_strategy: native_single_pass - replace: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - runtime = tool._native_detection_runtime(spec, spec.configs[0]) - tags = tool._run_tags( - tool.BenchmarkCase( - suite_id="native-runtime-suite", - workload_id="input", - config_id="native-single-pass", - repetition=0, - case_id="input__native-single-pass__r000", - ), - spec, - ) - - assert runtime.endpoint == "http://runtime-from-env/v1" - assert runtime.model == "env-model" - assert tags["native_runtime_id"] == "env-runtime" - assert "native_endpoint" not in tags - assert tags["native_endpoint_env"] == "ANONYMIZER_BENCH_NATIVE_ENDPOINT" - assert tags["native_model"] == "env-model" - - -def test_benchmark_preflight_skips_inactive_native_configs(tmp_path: Path) -> None: - tool = load_tool( - "measurement_benchmark_tool_inactive_native_runtime", - REPO_ROOT / "tools/measurement/run_benchmarks.py", - ) - _copy_biography_data(tmp_path) - spec_path = tmp_path / "suite.yaml" - spec_path.write_text( - """ -suite_id: inactive-native-suite -workloads: - - id: input - source: input.csv - text_column: biography -configs: - - id: redact - replace: redact - - id: inactive-native - experimental_detection_strategy: native_single_pass - replace: redact -matrix: - - workload: input - config: redact -""", - encoding="utf-8", - ) - spec = tool.load_spec(spec_path) - - tool.preflight_suite(spec, spec_path=spec_path) - - def test_benchmark_case_passes_dd_trace_config_to_measurement_session( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -1149,7 +841,6 @@ def run(self, *, config: Any, data: Any) -> None: spec=spec, base_dir=tmp_path, dd_trace=tool.DDTraceMode.all_messages, - dd_parser_compat=tool.DDParserCompatMode.none, ) assert len(captured) == 1 @@ -1158,216 +849,3 @@ def run(self, *, config: Any, data: Any) -> None: assert captured[0].dd_task_trace_path == task_trace_path assert captured[0].streaming is True assert captured[0].keep_records is False - - -def test_benchmark_config_accepts_experimental_detection_strategy() -> None: - tool = load_tool( - "measurement_benchmark_tool_detection_strategy_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - - detector_only = tool.ConfigSpec( - id="detector-only", - replace="redact", - experimental_detection_strategy="detector_only", - ) - - assert detector_only.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.detector_only - anonymizer_config = tool.build_anonymizer_config(detector_only) - assert not hasattr(anonymizer_config.detect, "experimental_detection_strategy") - - native_candidate_validate = tool.ConfigSpec( - id="native-candidate-validate", - replace="redact", - experimental_detection_strategy="native_candidate_validate_no_augment", - ) - - assert ( - native_candidate_validate.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.native_candidate_validate_no_augment - ) - - detector_native_validate = tool.ConfigSpec( - id="detector-native-validate", - replace="redact", - experimental_detection_strategy="detector_native_validate_no_augment", - ) - - assert ( - detector_native_validate.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.detector_native_validate_no_augment - ) - - detector_native_augment = tool.ConfigSpec( - id="detector-native-augment", - replace="redact", - experimental_detection_strategy="detector_native_validate_native_augment", - ) - - assert ( - detector_native_augment.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.detector_native_validate_native_augment - ) - - gliner_native_validate = tool.ConfigSpec( - id="gliner-native-validate", - replace="redact", - experimental_detection_strategy="gliner_native_validate_no_augment", - ) - - assert ( - gliner_native_validate.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.gliner_native_validate_no_augment - ) - - gliner_native_augment = tool.ConfigSpec( - id="gliner-native-augment", - replace="redact", - experimental_detection_strategy="gliner_native_validate_native_augment", - ) - - assert ( - gliner_native_augment.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.gliner_native_validate_native_augment - ) - - native_single_pass = tool.ConfigSpec( - id="native-single-pass", - replace="redact", - experimental_detection_strategy="native_single_pass", - ) - - assert native_single_pass.experimental_detection_strategy == tool.ExperimentalDetectionStrategy.native_single_pass - - native_single_pass_recall = tool.ConfigSpec( - id="native-single-pass-recall", - replace="redact", - experimental_detection_strategy="native_single_pass_recall", - ) - - assert ( - native_single_pass_recall.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.native_single_pass_recall - ) - - native_single_pass_values = tool.ConfigSpec( - id="native-single-pass-values", - replace="redact", - experimental_detection_strategy="native_single_pass_values", - ) - - assert ( - native_single_pass_values.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.native_single_pass_values - ) - - native_single_pass_values_recall = tool.ConfigSpec( - id="native-single-pass-values-recall", - replace="redact", - experimental_detection_strategy="native_single_pass_values_recall", - ) - - assert ( - native_single_pass_values_recall.experimental_detection_strategy - == tool.ExperimentalDetectionStrategy.native_single_pass_values_recall - ) - - -def test_benchmark_spec_accepts_dd_parser_compat() -> None: - tool = load_tool( - "measurement_benchmark_tool_dd_parser_compat_config", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - - spec = tool.BenchmarkSpec( - suite_id="raw-json-suite", - dd_parser_compat="raw_json", - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) - - assert spec.dd_parser_compat == tool.DDParserCompatMode.raw_json - - -def test_benchmark_case_enters_experimental_detection_strategy_context( - monkeypatch: pytest.MonkeyPatch, - tmp_path: Path, -) -> None: - tool = load_tool( - "measurement_benchmark_tool_detection_strategy_case", REPO_ROOT / "tools/measurement/run_benchmarks.py" - ) - captured_measurements: list[Any] = [] - captured_parser_compat: list[Any] = [] - captured_strategies: list[Any] = [] - captured_context_kwargs: list[dict[str, Any]] = [] - - @contextmanager - def fake_measurement_session(config: Any) -> Iterator[None]: - captured_measurements.append(config) - yield None - - @contextmanager - def fake_detection_strategy_context(strategy: Any, **kwargs: Any) -> Iterator[None]: - captured_strategies.append(strategy) - captured_context_kwargs.append(kwargs) - yield None - - @contextmanager - def fake_dd_parser_compat_context(mode: Any) -> Iterator[None]: - captured_parser_compat.append(mode) - yield None - - class FakeAnonymizer: - def run(self, *, config: Any, data: Any) -> None: - assert config.replace is not None - assert data.text_column == "text" - - monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) - monkeypatch.setattr(tool, "dd_parser_compat_context", fake_dd_parser_compat_context) - monkeypatch.setattr(tool, "experimental_detection_strategy_context", fake_detection_strategy_context) - - spec = tool.BenchmarkSpec( - suite_id="native-suite", - dd_parser_compat="raw_json", - native_runtime=tool.NativeRuntimeSpec( - runtime_id="native-test", - endpoint="http://runtime.example/v1", - model="test-model", - ), - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[ - tool.ConfigSpec( - id="native-single-pass-redact", - replace="redact", - experimental_detection_strategy="native_single_pass", - ) - ], - ) - pd.DataFrame({"text": ["token=sk-test-AAAAAAAAAAAAAAAAAAAAAAAA"]}).to_csv(tmp_path / "input.csv", index=False) - case = tool.BenchmarkCase( - suite_id="native-suite", - workload_id="input", - config_id="native-single-pass-redact", - repetition=0, - case_id="input__native-single-pass-redact__r000", - ) - - tool._execute_case( - FakeAnonymizer(), - spec.workloads[0], - spec.configs[0], - raw_path=tmp_path / "raw" / "input__native-single-pass-redact__r000.jsonl", - trace_path=None, - task_trace_path=None, - case=case, - spec=spec, - base_dir=tmp_path, - dd_trace=tool.DDTraceMode.none, - dd_parser_compat=spec.dd_parser_compat, - ) - - assert captured_parser_compat == [tool.DDParserCompatMode.raw_json] - assert captured_strategies == [tool.ExperimentalDetectionStrategy.native_single_pass] - assert captured_context_kwargs[0]["native_runtime"].endpoint == "http://runtime.example/v1" - assert captured_context_kwargs[0]["native_runtime"].model == "test-model" - assert captured_measurements[0].run_tags["experimental_detection_strategy"] == "native_single_pass" - assert captured_measurements[0].run_tags["native_runtime_id"] == "native-test" - assert captured_measurements[0].run_tags["dd_parser_compat"] == "raw_json" diff --git a/tests/tools/test_screen_strategy_comparisons.py b/tests/tools/test_screen_strategy_comparisons.py deleted file mode 100644 index ee690cac..00000000 --- a/tests/tools/test_screen_strategy_comparisons.py +++ /dev/null @@ -1,974 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -import pandas as pd - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -def test_screen_strategy_comparisons_reads_comparison_csvs_only(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - analysis_dir = tmp_path / "analysis" - analysis_dir.mkdir() - pd.DataFrame( - [ - { - "workload_id": "shell-3", - "baseline_config_id": "default", - "candidate_config_id": "detector-only", - "baseline_strategy": "default", - "candidate_strategy": "detector_only", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "baseline_case_count": 3, - "candidate_case_count": 3, - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "pipeline_elapsed_sec_delta_pct": -99.0, - "observed_total_requests_delta": -12, - "observed_total_tokens_delta": -11000, - "augmented_entity_count_delta": -3, - "augmented_new_final_value_count_delta": 0, - "baseline_only_final_entity_signature_count": 0, - "baseline_stable_candidate_unstable_final_entity_signature_count": 0, - "baseline_stable_final_entity_signature_count": 12, - "candidate_stable_final_entity_signature_count": 12, - "shared_stable_final_entity_signature_count": 12, - "flags": '["candidate_skips_llm_validation"]', - }, - { - "workload_id": "legal-2", - "baseline_config_id": "default", - "candidate_config_id": "no-augment", - "baseline_strategy": "default", - "candidate_strategy": "no_augment", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "default", - "safety_verdict": "fail", - "performance_verdict": "improved", - "candidate_verdict": "reject", - "pipeline_elapsed_sec_delta_pct": -10.0, - "observed_total_requests_delta": -2, - "observed_total_tokens_delta": -10000, - "augmented_entity_count_delta": -5.5, - "augmented_new_final_value_count_delta": -2, - "baseline_only_final_entity_signature_count": 2, - "baseline_stable_candidate_unstable_final_entity_signature_count": 1, - "baseline_only_final_entity_signature_label_counts.first_name": 2, - "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, - "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', - }, - ] - ).to_csv(analysis_dir / "default-vs-candidates.csv", index=False) - pd.DataFrame([{"workload_id": "shell-3", "config_id": "default"}]).to_csv( - analysis_dir / "group_analysis.csv", - index=False, - ) - pd.DataFrame( - [ - { - "source_path": "analysis/default-vs-candidates.csv", - "workload_id": "shell-3", - "baseline_config_id": "default", - "candidate_config_id": "detector-only", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - } - ] - ).to_csv(analysis_dir / "strategy-screen.csv", index=False) - (analysis_dir / "empty.csv").write_text("", encoding="utf-8") - - result = tool.screen_comparison_paths([analysis_dir]) - - assert result.scanned_file_count == 4 - assert result.comparison_file_count == 1 - assert result.row_count == 2 - assert result.summary.candidate_verdict_counts == {"reject": 1, "review": 1} - assert result.summary.review_count == 1 - assert result.summary.reject_count == 1 - assert result.summary.viable_count == 0 - legal = next(row for row in result.rows if row.workload_id == "legal-2") - assert legal.workload_family == "legal" - assert legal.flags == ["entity_signature_loss", "stable_entity_signature_loss"] - assert legal.baseline_only_label_counts == {"first_name": 2} - assert legal.stable_lost_label_counts == {"date": 1} - assert legal.augmented_new_final_value_count_delta == -2 - shell = next(row for row in result.rows if row.workload_id == "shell-3") - assert shell.baseline_replacement_strategy == "default" - assert shell.candidate_replacement_strategy == "custom_replacement_strategy" - assert shell.baseline_case_count == 3 - assert shell.candidate_case_count == 3 - assert shell.shared_stable_final_entity_signature_count == 12 - detector_local = next( - group - for group in result.groups - if group.group_key == "strategy:detector_only|replacement:custom_replacement_strategy" - ) - assert detector_local.candidate_replacement_strategy == "custom_replacement_strategy" - assert detector_local.row_count == 1 - no_augment = next(group for group in result.groups if group.group_key == "strategy:no_augment") - assert no_augment.row_count == 1 - assert no_augment.reject_count == 1 - assert no_augment.has_conflicting_verdicts is False - - -def test_screen_strategy_comparisons_writes_csv(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_export", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - rows = [ - tool.ScreenRow( - source_path="analysis/default-vs-detector-only.csv", - workload_id="shell", - baseline_config_id="default", - candidate_config_id="detector-only", - baseline_replacement_strategy="default", - candidate_replacement_strategy="custom_replacement_strategy", - safety_verdict="review", - performance_verdict="improved", - candidate_verdict="review", - flags=["candidate_skips_llm_validation"], - ) - ] - - output = tmp_path / "screen.csv" - tool.write_rows(rows, output, tool.ExportFormat.csv) - - exported = pd.read_csv(output) - assert exported["workload_id"].tolist() == ["shell"] - assert exported["candidate_replacement_strategy"].tolist() == ["custom_replacement_strategy"] - assert exported["flags"].tolist() == ['["candidate_skips_llm_validation"]'] - - -def test_screen_strategy_comparisons_dedupes_exact_rows(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_dedupe", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - comparison = pd.DataFrame( - [ - { - "workload_id": "bio", - "baseline_config_id": "default", - "candidate_config_id": "candidate", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - "pipeline_elapsed_sec_delta_pct": -5.0, - } - ] - ) - comparison.to_csv(tmp_path / "a.csv", index=False) - comparison.to_csv(tmp_path / "b.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - assert result.scanned_file_count == 2 - assert result.comparison_file_count == 2 - assert result.row_count == 1 - assert result.duplicate_row_count == 1 - assert result.summary.viable_count == 1 - - -def test_screen_strategy_comparisons_surfaces_evidence_level_counts(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_evidence_level", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "baseline_config_id": "default", - "candidate_config_id": "local-substitute", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "value_protection_verdict": "pass", - "signature_parity_verdict": "review", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - }, - { - "workload_id": "structured-identifiers", - "baseline_config_id": "default", - "candidate_config_id": "local-substitute-legacy", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - "shared_stable_final_entity_signature_count": 17, - }, - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - assert result.summary.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} - assert {row.evidence_level for row in result.rows} == {"split_verdicts", "stable_signatures"} - group = result.groups[0] - assert group.evidence_level_counts == {"split_verdicts": 1, "stable_signatures": 1} - assert group.split_verdict_candidate_verdict_counts == {"review": 1} - rendered = tool.render_result(result, json_output=False, limit=10) - assert "evidence=split_verdicts" in rendered - assert "evidence_counts=split_verdicts:1,stable_signatures:1" in rendered - - -def test_screen_strategy_comparisons_surfaces_candidate_original_value_leaks(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_original_value_leaks", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-secrets", - "baseline_config_id": "default", - "candidate_config_id": "native-single-pass", - "candidate_strategy": "native_single_pass", - "safety_verdict": "fail", - "performance_verdict": "improved", - "candidate_verdict": "reject", - "candidate_original_value_leak_count": 2, - "candidate_original_value_leak_record_count": 1, - "original_value_leak_count_delta": 2, - "candidate_original_value_leak_label_counts.api_key": 1, - "candidate_original_value_leak_label_counts.password": 1, - "candidate_replacement_synthetic_original_collision_count": 1, - "candidate_replacement_synthetic_original_collision_value_count": 1, - "replacement_synthetic_original_collision_count_delta": 1, - "candidate_replacement_synthetic_original_collision_label_counts.date": 1, - "flags": ('["candidate_original_value_leak", "candidate_replacement_synthetic_original_collision"]'), - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - assert result.summary.reject_count == 1 - row = result.rows[0] - assert row.candidate_original_value_leak_count == 2 - assert row.candidate_original_value_leak_record_count == 1 - assert row.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} - assert row.candidate_replacement_synthetic_original_collision_count == 1 - assert row.candidate_replacement_synthetic_original_collision_value_count == 1 - assert row.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} - group = result.groups[0] - assert group.sum_candidate_original_value_leak_count == 2 - assert group.sum_candidate_original_value_leak_record_count == 1 - assert group.candidate_original_value_leak_label_counts == {"api_key": 1, "password": 1} - assert group.sum_candidate_replacement_synthetic_original_collision_count == 1 - assert group.sum_candidate_replacement_synthetic_original_collision_value_count == 1 - assert group.candidate_replacement_synthetic_original_collision_label_counts == {"date": 1} - rendered = tool.render_result(result, json_output=False, limit=10) - assert "candidate_original_value_leaks=2.0" in rendered - assert "candidate_replacement_collisions=1.0" in rendered - assert "leak_labels=api_key:1,password:1" in rendered - assert "collision_labels=date:1" in rendered - - -def test_screen_strategy_comparisons_surfaces_label_policy_review(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_label_policy_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "legal-r1", - "baseline_config_id": "legal-default", - "candidate_config_id": "legal-detector-native-validate", - "candidate_strategy": "detector_native_validate_no_augment", - "value_protection_verdict": "pass", - "signature_parity_verdict": "review", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "pipeline_elapsed_sec_delta_pct": -26.5, - "observed_total_requests_delta": -1, - "observed_total_tokens_delta": -5366, - "flags": '["covered_label_mismatch"]', - "baseline_only_candidate_label_mismatch_signature_label_counts.date_of_birth": 1, - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - assert result.summary.value_protection_verdict_counts == {"pass": 1} - assert result.summary.signature_parity_verdict_counts == {"review": 1} - row = result.rows[0] - assert row.value_protection_verdict == "pass" - assert row.signature_parity_verdict == "review" - assert row.label_mismatch_label_counts == {"date_of_birth": 1} - group = result.groups[0] - assert group.value_protection_verdict_counts == {"pass": 1} - assert group.signature_parity_verdict_counts == {"review": 1} - assert group.label_mismatch_label_counts == {"date_of_birth": 1} - assert group.recommendation == "label_policy_review" - rendered = tool.render_result(result, json_output=False, limit=10) - assert "value_protection=pass" in rendered - assert "signature_parity=review" in rendered - assert "label_mismatch=date_of_birth:1" in rendered - - -def test_screen_strategy_comparisons_surfaces_reliability_review(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_reliability_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "baseline_config_id": "default-substitute", - "candidate_config_id": "local-substitute", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "value_protection_verdict": "pass", - "signature_parity_verdict": "pass", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "flags": '["failed_request_increase"]', - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - group = result.groups[0] - assert group.flag_counts == {"failed_request_increase": 1} - assert group.recommendation == "reliability_review" - - -def test_screen_strategy_comparisons_surfaces_replacement_replay_review(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_replacement_replay_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "structured-identifiers", - "baseline_config_id": "default-substitute", - "candidate_config_id": "local-substitute", - "baseline_strategy": "default", - "candidate_strategy": "default", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "value_protection_verdict": "pass", - "signature_parity_verdict": "review", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "flags": '["covered_label_mismatch", "replacement_only_detection_instability"]', - "baseline_only_candidate_label_mismatch_signature_label_counts.api_key": 1, - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - group = result.groups[0] - assert group.flag_counts == { - "covered_label_mismatch": 1, - "replacement_only_detection_instability": 1, - } - assert group.recommendation == "replacement_replay_review" - - -def test_screen_strategy_comparisons_surfaces_baseline_defect_improvement_review(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_baseline_defect_improvement", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - pd.DataFrame( - [ - { - "workload_id": "biography-supported-structured", - "baseline_config_id": "dd-substitute", - "candidate_config_id": "local-substitute", - "baseline_strategy": "default", - "candidate_strategy": "default", - "baseline_replacement_strategy": "default", - "candidate_replacement_strategy": "custom_replacement_strategy", - "value_protection_verdict": "pass", - "signature_parity_verdict": "review", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "baseline_replacement_missing_final_entity_count": 2, - "candidate_replacement_missing_final_entity_count": 0, - "baseline_original_value_leak_count": 2, - "candidate_original_value_leak_count": 0, - "baseline_duplicate_synthetic_replacement_count": 1, - "candidate_duplicate_synthetic_replacement_count": 0, - "baseline_replacement_missing_final_entity_label_counts.organization_name": 1, - "baseline_replacement_missing_final_entity_label_counts.religious_belief": 1, - "baseline_original_value_leak_label_counts.organization_name": 2, - "flags": ( - '["baseline_replacement_missing_final_entity", "baseline_original_value_leak", ' - '"candidate_covers_baseline_replacement_missing_final_entity", ' - '"candidate_covers_baseline_original_value_leak", "replacement_count_delta"]' - ), - } - ] - ).to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - row = result.rows[0] - assert row.baseline_replacement_missing_final_entity_count == 2 - assert row.candidate_replacement_missing_final_entity_count == 0 - assert row.baseline_original_value_leak_count == 2 - assert row.candidate_original_value_leak_count == 0 - assert row.baseline_duplicate_synthetic_replacement_count == 1 - assert row.candidate_duplicate_synthetic_replacement_count == 0 - assert row.baseline_replacement_missing_final_entity_label_counts == { - "organization_name": 1, - "religious_belief": 1, - } - assert row.baseline_original_value_leak_label_counts == {"organization_name": 2} - group = result.groups[0] - assert group.baseline_defect_improvement_count == 1 - assert group.sum_baseline_replacement_missing_final_entity_count == 2 - assert group.sum_candidate_replacement_missing_final_entity_count == 0 - assert group.sum_baseline_original_value_leak_count == 2 - assert group.sum_candidate_original_value_leak_count == 0 - assert group.sum_baseline_duplicate_synthetic_replacement_count == 1 - assert group.sum_candidate_duplicate_synthetic_replacement_count == 0 - assert group.baseline_replacement_missing_final_entity_label_counts == { - "organization_name": 1, - "religious_belief": 1, - } - assert group.baseline_original_value_leak_label_counts == {"organization_name": 2} - assert group.recommendation == "candidate_covers_baseline_defects" - rendered = tool.render_result(result, json_output=False, limit=10) - assert "baseline_defect_improvements=1" in rendered - assert "baseline_missing_replacements=2.0" in rendered - assert "candidate_missing_replacements=0.0" in rendered - assert "baseline_missing_labels=organization_name:1,religious_belief:1" in rendered - - -def test_screen_strategy_comparisons_groups_default_detection_by_replacement_strategy() -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_replacement_group", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - row = tool.ScreenRow( - source_path="comparison.csv", - workload_id="structured-identifiers", - baseline_config_id="substitute-dd", - candidate_config_id="substitute-local", - candidate_strategy="default", - baseline_replacement_strategy="default", - candidate_replacement_strategy="custom_replacement_strategy", - safety_verdict="pass", - performance_verdict="improved", - candidate_verdict="candidate_viable", - ) - - assert tool.group_base_for_row(row, config_aliases={}) == "replacement:custom_replacement_strategy" - - -def test_screen_strategy_comparisons_keeps_generic_review_without_leak_metrics() -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_generic_review", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - group = tool.ScreenGroup( - group_key="strategy:detector_only", - candidate_strategy="detector_only", - row_count=1, - review_count=1, - performance_verdict_counts={"improved": 1}, - flag_counts={"candidate_skips_llm_validation": 1}, - ) - - assert tool.group_recommendation(group) == "review_only" - - -def test_screen_strategy_comparisons_filters_source_paths(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_source_filters", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - current_dir = tmp_path / "analysis-current-csv" - current_dir.mkdir() - stale_dir = tmp_path / "analysis" - stale_dir.mkdir() - pd.DataFrame( - [ - { - "workload_id": "bio-current", - "baseline_config_id": "default", - "candidate_config_id": "candidate", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - } - ] - ).to_csv(current_dir / "current-comparison.csv", index=False) - pd.DataFrame( - [ - { - "workload_id": "bio-stale", - "baseline_config_id": "default", - "candidate_config_id": "candidate", - "safety_verdict": "fail", - "performance_verdict": "regressed", - "candidate_verdict": "reject", - } - ] - ).to_csv(stale_dir / "stale-comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path], source_includes=["analysis-current-csv"]) - - assert result.scanned_file_count == 1 - assert result.comparison_file_count == 1 - assert [row.workload_id for row in result.rows] == ["bio-current"] - - result = tool.screen_comparison_paths( - [tmp_path], - source_includes=["analysis"], - source_excludes=["analysis-current-csv"], - ) - - assert result.scanned_file_count == 1 - assert result.comparison_file_count == 1 - assert [row.workload_id for row in result.rows] == ["bio-stale"] - - -def test_screen_strategy_comparisons_groups_candidate_strategy_conflicts(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_groups", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "legal-small", - "baseline_config_id": "default", - "candidate_config_id": "no-augment-small", - "candidate_strategy": "no_augment", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - "pipeline_elapsed_sec_delta_pct": -7.0, - "observed_total_requests_delta": -2, - "observed_total_tokens_delta": -8000, - "baseline_case_count": 3, - "candidate_case_count": 2, - "shared_stable_final_entity_signature_count": 20, - "flags": "[]", - }, - { - "workload_id": "legal-offset", - "baseline_config_id": "default", - "candidate_config_id": "no-augment-offset", - "candidate_strategy": "no_augment", - "safety_verdict": "fail", - "performance_verdict": "improved", - "candidate_verdict": "reject", - "pipeline_elapsed_sec_delta_pct": -10.0, - "observed_total_requests_delta": -1, - "observed_total_tokens_delta": -10000, - "baseline_case_count": 2, - "candidate_case_count": 2, - "shared_stable_final_entity_signature_count": 14, - "baseline_only_final_entity_signature_label_counts.first_name": 2, - "baseline_stable_candidate_unstable_final_entity_signature_label_counts.date": 1, - "flags": '["entity_signature_loss", "stable_entity_signature_loss"]', - }, - { - "workload_id": "shell", - "baseline_config_id": "default", - "candidate_config_id": "detector-only", - "candidate_strategy": "detector_only", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "pipeline_elapsed_sec_delta_pct": -99.9, - "observed_total_tokens_delta": -11000, - "flags": '["candidate_skips_llm_validation"]', - }, - ] - ) - table.to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path]) - - groups = {group.group_key: group for group in result.groups} - assert list(groups) == ["strategy:detector_only", "strategy:no_augment"] - no_augment = groups["strategy:no_augment"] - assert no_augment.row_count == 2 - assert no_augment.viable_count == 1 - assert no_augment.reject_count == 1 - assert no_augment.has_conflicting_verdicts is True - assert no_augment.recommendation == "conflicting_evidence" - assert no_augment.best_pipeline_elapsed_sec_delta_pct == -10.0 - assert no_augment.best_observed_total_tokens_delta == -10000 - assert no_augment.best_observed_total_requests_delta == -2 - assert no_augment.worst_pipeline_elapsed_sec_delta_pct == -7.0 - assert no_augment.worst_observed_total_tokens_delta == -8000 - assert no_augment.worst_observed_total_requests_delta == -1 - assert no_augment.min_baseline_case_count == 2 - assert no_augment.min_candidate_case_count == 2 - assert no_augment.min_shared_stable_final_entity_signature_count == 14 - assert no_augment.flag_counts == {"entity_signature_loss": 1, "stable_entity_signature_loss": 1} - assert no_augment.baseline_only_label_counts == {"first_name": 2} - assert no_augment.stable_lost_label_counts == {"date": 1} - - -def test_screen_strategy_comparisons_groups_default_strategy_by_config(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_default_config_groups", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "biographies-r5", - "baseline_config_id": "biography-default", - "candidate_config_id": "biography-augment-temp07", - "candidate_strategy": "default", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - "pipeline_elapsed_sec_delta_pct": -6.0, - "observed_total_tokens_delta": -325, - "flags": "[]", - }, - { - "workload_id": "biographies-r5-offset5", - "baseline_config_id": "biography-default", - "candidate_config_id": "biography-hybrid", - "candidate_strategy": "default", - "safety_verdict": "fail", - "performance_verdict": "regressed", - "candidate_verdict": "reject", - "pipeline_elapsed_sec_delta_pct": 10.0, - "observed_total_tokens_delta": 100, - "baseline_only_final_entity_signature_label_counts.university": 1, - "flags": '["entity_signature_loss"]', - }, - ] - ) - table.to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) - - groups = {group.group_key: group for group in result.groups} - assert list(groups) == [ - "config:biography-augment-temp07|family:biographies", - "config:biography-hybrid|family:biographies", - ] - assert groups["config:biography-augment-temp07|family:biographies"].recommendation == "single_slice_viable" - assert groups["config:biography-hybrid|family:biographies"].recommendation == "reject" - - -def test_screen_strategy_comparisons_groups_default_strategy_by_config_alias(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_config_alias_groups", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "biographies-r2", - "baseline_config_id": "biography-default", - "candidate_config_id": "biography-hybrid-augment-temp07", - "candidate_strategy": "default", - "safety_verdict": "pass", - "performance_verdict": "improved", - "candidate_verdict": "candidate_viable", - "pipeline_elapsed_sec_delta_pct": -10.4, - "baseline_case_count": 2, - "candidate_case_count": 2, - "shared_stable_final_entity_signature_count": 48, - "flags": "[]", - }, - { - "workload_id": "biographies-r5-offset5", - "baseline_config_id": "biography-default", - "candidate_config_id": "biography-augment-temp07", - "candidate_strategy": "default", - "safety_verdict": "fail", - "performance_verdict": "mixed", - "candidate_verdict": "reject", - "pipeline_elapsed_sec_delta_pct": 16.0, - "baseline_case_count": 2, - "candidate_case_count": 2, - "shared_stable_final_entity_signature_count": 116, - "baseline_stable_candidate_unstable_final_entity_signature_label_counts.university": 1, - "flags": '["stable_entity_signature_loss"]', - }, - ] - ) - table.to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths( - [tmp_path], - group_by=tool.GroupBy.strategy_workload_family, - config_aliases={ - "biography-hybrid-augment-temp07": "biography-temp07-routing", - "biography-augment-temp07": "biography-temp07-routing", - }, - ) - - groups = {group.group_key: group for group in result.groups} - group = groups["alias:biography-temp07-routing|family:biographies"] - assert group.row_count == 2 - assert group.candidate_config_ids == ["biography-augment-temp07", "biography-hybrid-augment-temp07"] - assert group.viable_count == 1 - assert group.reject_count == 1 - assert group.recommendation == "conflicting_evidence" - assert group.min_baseline_case_count == 2 - assert group.min_candidate_case_count == 2 - assert group.min_shared_stable_final_entity_signature_count == 48 - assert group.stable_lost_label_counts == {"university": 1} - - -def test_screen_strategy_comparisons_can_group_by_strategy_and_workload_family(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_family_groups", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - table = pd.DataFrame( - [ - { - "workload_id": "shell-secrets-3", - "baseline_config_id": "default", - "candidate_config_id": "detector-only-shell", - "candidate_strategy": "detector_only", - "safety_verdict": "review", - "performance_verdict": "improved", - "candidate_verdict": "review", - "pipeline_elapsed_sec_delta_pct": -99.9, - "observed_total_tokens_delta": -11000, - "flags": '["candidate_skips_llm_validation"]', - }, - { - "workload_id": "biographies-r5-offset5", - "baseline_config_id": "default", - "candidate_config_id": "detector-only-bio", - "candidate_strategy": "detector_only", - "safety_verdict": "fail", - "performance_verdict": "improved", - "candidate_verdict": "reject", - "pipeline_elapsed_sec_delta_pct": -90.0, - "observed_total_tokens_delta": -17000, - "baseline_only_final_entity_signature_label_counts.first_name": 4, - "flags": '["entity_signature_loss"]', - }, - ] - ) - table.to_csv(tmp_path / "comparison.csv", index=False) - - result = tool.screen_comparison_paths([tmp_path], group_by=tool.GroupBy.strategy_workload_family) - - groups = {group.group_key: group for group in result.groups} - assert list(groups) == ["strategy:detector_only|family:shell-secrets", "strategy:detector_only|family:biographies"] - assert groups["strategy:detector_only|family:shell-secrets"].recommendation == "review_only" - assert groups["strategy:detector_only|family:shell-secrets"].workload_families == ["shell-secrets"] - assert groups["strategy:detector_only|family:biographies"].recommendation == "reject" - assert groups["strategy:detector_only|family:biographies"].baseline_only_label_counts == {"first_name": 4} - - -def test_workload_family_normalizes_slice_and_offset_suffixes() -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_workload_family", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - - assert tool.workload_family("legal-r2-offset1") == "legal" - assert tool.workload_family("legal-slice-2") == "legal" - assert tool.workload_family("biographies-r5-offset5") == "biographies" - assert tool.workload_family("shell-secrets-3") == "shell-secrets" - assert tool.workload_family("support-ticket") == "support-ticket" - - -def test_screen_strategy_comparisons_writes_group_csv(tmp_path: Path) -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_group_export", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - groups = [ - tool.ScreenGroup( - group_key="strategy:no_augment", - candidate_strategy="no_augment", - row_count=2, - viable_count=1, - reject_count=1, - has_conflicting_verdicts=True, - performance_verdict_counts={"improved": 1, "regressed": 1}, - flag_counts={"entity_signature_loss": 1}, - ) - ] - - output = tmp_path / "groups.csv" - tool.write_groups(groups, output, tool.ExportFormat.csv) - - exported = pd.read_csv(output) - assert exported["group_key"].tolist() == ["strategy:no_augment"] - assert exported["has_conflicting_verdicts"].tolist() == [True] - assert exported["recommendation"].tolist() == ["conflicting_evidence"] - assert exported["performance_verdict_counts"].tolist() == ['{"improved": 1, "regressed": 1}'] - assert exported["flag_counts"].tolist() == ['{"entity_signature_loss": 1}'] - - -def test_screen_strategy_group_recommendations() -> None: - tool = load_tool( - "measurement_screen_strategy_comparisons_recommendations", - REPO_ROOT / "tools/measurement/screen_strategy_comparisons.py", - ) - - assert ( - tool.group_recommendation( - tool.ScreenGroup(group_key="viable", row_count=1, viable_count=1), - ) - == "single_slice_viable" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup(group_key="repeated-viable", row_count=2, viable_count=2), - ) - == "candidate_family_viable" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup(group_key="promising", row_count=2, viable_count=1, review_count=1), - ) - == "promising_needs_review" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="weak-promising", - row_count=2, - viable_count=1, - review_count=1, - evidence_level_counts={"signature_counts": 1, "stable_signatures": 1}, - ), - ) - == "needs_split_verdict_rerun" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="partial-split-review", - row_count=2, - review_count=2, - evidence_level_counts={"split_verdicts": 1, "stable_signatures": 1}, - split_verdict_candidate_verdict_counts={"review": 1}, - ), - ) - == "needs_split_verdict_rerun" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="split-review-with-legacy-viable", - row_count=2, - viable_count=1, - review_count=1, - evidence_level_counts={"signature_counts": 1, "split_verdicts": 1}, - split_verdict_candidate_verdict_counts={"review": 1}, - ), - ) - == "needs_viable_split_verdict" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="split-viable-with-review", - row_count=2, - viable_count=1, - review_count=1, - evidence_level_counts={"split_verdicts": 2}, - split_verdict_candidate_verdict_counts={"candidate_viable": 1, "review": 1}, - ), - ) - == "promising_needs_review" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="review", - row_count=2, - review_count=2, - performance_verdict_counts={"improved": 2}, - ), - ) - == "review_only" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="review-mixed", - row_count=2, - review_count=2, - performance_verdict_counts={"mixed": 2}, - ), - ) - == "review_mixed_performance" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="review-improved-and-mixed", - row_count=2, - review_count=2, - performance_verdict_counts={"improved": 1, "mixed": 1}, - ), - ) - == "review_mixed_performance" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup( - group_key="review-regressed", - row_count=2, - review_count=2, - performance_verdict_counts={"regressed": 2}, - ), - ) - == "no_performance_win" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup(group_key="reject", row_count=2, reject_count=2), - ) - == "reject" - ) - assert ( - tool.group_recommendation( - tool.ScreenGroup(group_key="conflict", row_count=2, viable_count=1, reject_count=1), - ) - == "conflicting_evidence" - ) diff --git a/tests/tools/test_staged_detection_output_analysis.py b/tests/tools/test_staged_detection_output_analysis.py deleted file mode 100644 index 4c791501..00000000 --- a/tests/tools/test_staged_detection_output_analysis.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import json -import sys -from pathlib import Path -from types import ModuleType - -import pandas as pd -import pytest - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -def _write_jsonl(path: Path, rows: list[dict[str, object]]) -> None: - path.write_text("".join(json.dumps(row) + "\n" for row in rows), encoding="utf-8") - - -def test_analyze_staged_detection_output_summarizes_native_detection_probe(tmp_path: Path) -> None: - tool = load_tool( - "measurement_staged_detection_output_analysis", - REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", - ) - output_dir = tmp_path / "staged" - output_dir.mkdir() - _write_jsonl( - output_dir / "staged-detection-cases.jsonl", - [ - { - "record_type": "staged_detection_case", - "case_id": "shell-row-0", - "row_index": 0, - "seed_source": "gliner", - "status": "completed", - "elapsed_sec": 0.002, - "model_elapsed_sec": 0.0, - "model_phase_count": 0, - "model_request_count": 0, - "final_entity_count": 5, - "final_entity_signature_count": 5, - "final_label_counts": {"api_key": 2, "email": 1, "password": 1, "url": 1}, - "total_usage": {}, - "comparison": { - "baseline_final_entity_signature_count": 5, - "shared_final_entity_signature_count": 5, - "baseline_only_final_entity_signature_count": 0, - "direct_only_final_entity_signature_count": 0, - "baseline_only_label_counts": {}, - "direct_only_label_counts": {}, - }, - }, - { - "record_type": "staged_detection_case", - "case_id": "bio-row-0", - "row_index": 0, - "seed_source": "direct_llm", - "status": "completed", - "elapsed_sec": 10.0, - "model_elapsed_sec": 9.5, - "model_phase_count": 3, - "model_request_count": 3, - "final_entity_count": 3, - "final_entity_signature_count": 3, - "final_label_counts": {"person": 2, "api_key": 1}, - "total_usage": {"prompt_tokens": 100, "completion_tokens": 20, "total_tokens": 120}, - "comparison": { - "baseline_final_entity_signature_count": 4, - "shared_final_entity_signature_count": 2, - "baseline_only_final_entity_signature_count": 2, - "direct_only_final_entity_signature_count": 1, - "baseline_only_label_counts": {"city": 1, "person": 1}, - "direct_only_label_counts": {"api_key": 1}, - }, - }, - { - "record_type": "staged_detection_case", - "case_id": "bio-row-1", - "row_index": 1, - "seed_source": "direct_llm", - "status": "error", - "elapsed_sec": 1.0, - "model_elapsed_sec": 0.8, - "model_phase_count": 1, - "model_request_count": 1, - "final_entity_count": 0, - "final_entity_signature_count": 0, - "total_usage": {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}, - "comparison": None, - "error": "provider failed", - }, - ], - ) - - result = tool.analyze_staged_detection_output(output_dir) - - assert result.case_count == 3 - assert result.group_count == 2 - groups = {row.seed_source: row for row in result.groups} - assert groups["gliner"].case_count == 1 - assert groups["gliner"].completed_case_count == 1 - assert groups["gliner"].model_elapsed_sec_sum == 0.0 - assert groups["gliner"].model_request_count_sum == 0 - assert groups["gliner"].baseline_shared_signature_rate == 1.0 - assert groups["direct_llm"].case_count == 2 - assert groups["direct_llm"].completed_case_count == 1 - assert groups["direct_llm"].error_case_count == 1 - assert groups["direct_llm"].elapsed_sec_sum == pytest.approx(11.0) - assert groups["direct_llm"].model_elapsed_sec_sum == pytest.approx(10.3) - assert groups["direct_llm"].model_request_count_sum == 4 - assert groups["direct_llm"].total_tokens_sum == 132 - assert groups["direct_llm"].baseline_final_entity_signature_count_sum == 4 - assert groups["direct_llm"].shared_final_entity_signature_count_sum == 2 - assert groups["direct_llm"].baseline_only_final_entity_signature_count_sum == 2 - assert groups["direct_llm"].direct_only_final_entity_signature_count_sum == 1 - assert groups["direct_llm"].baseline_shared_signature_rate == pytest.approx(0.5) - - label_deltas = {(row.seed_source, row.delta_type, row.label): row.count for row in result.label_deltas} - assert label_deltas == { - ("direct_llm", "baseline_only", "city"): 1, - ("direct_llm", "baseline_only", "person"): 1, - ("direct_llm", "direct_only", "api_key"): 1, - } - - -def test_staged_detection_output_analysis_writes_csv_tables(tmp_path: Path) -> None: - tool = load_tool( - "measurement_staged_detection_output_analysis_export", - REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", - ) - cases_path = tmp_path / "cases.jsonl" - _write_jsonl( - cases_path, - [ - { - "case_id": "case-0", - "row_index": 0, - "seed_source": "gliner", - "status": "completed", - "elapsed_sec": 0.01, - "model_elapsed_sec": 0.0, - "model_request_count": 0, - "total_usage": {}, - } - ], - ) - - result = tool.analyze_staged_detection_output(cases_path) - export = tool.write_analysis_tables(result, tmp_path / "analysis", tool.ExportFormat.csv) - - assert Path(export.manifest_path).exists() - assert [table.table for table in export.tables] == [ - "case_analysis", - "group_analysis", - "label_delta_analysis", - ] - case_table = pd.read_csv(tmp_path / "analysis" / "case_analysis.csv") - assert case_table.loc[0, "case_id"] == "case-0" - assert case_table.loc[0, "model_request_count"] == 0 - - -def test_staged_detection_output_analysis_rejects_missing_case_file(tmp_path: Path) -> None: - tool = load_tool( - "measurement_staged_detection_output_analysis_missing_input", - REPO_ROOT / "tools/measurement/analyze_staged_detection_output.py", - ) - - with pytest.raises(ValueError, match="input path does not exist"): - tool.analyze_staged_detection_output(tmp_path / "missing") diff --git a/tests/tools/test_staged_detection_probe.py b/tests/tools/test_staged_detection_probe.py deleted file mode 100644 index 29987f90..00000000 --- a/tests/tools/test_staged_detection_probe.py +++ /dev/null @@ -1,579 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import importlib.util -import sys -from pathlib import Path -from types import ModuleType - -from anonymizer.engine.schemas import ValidationCandidateSchema - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def load_tool(module_name: str, path: Path) -> ModuleType: - spec = importlib.util.spec_from_file_location(module_name, path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - sys.path.insert(0, str(path.parent)) - spec.loader.exec_module(module) - return module - - -class SequencedClient: - def __init__(self, tool: ModuleType, outputs: list[str]) -> None: - self._tool = tool - self._outputs = list(outputs) - self.prompts: list[str] = [] - - def complete(self, request): # type: ignore[no-untyped-def] - self.prompts.append(request.prompt) - content = self._outputs.pop(0) - return self._tool.DirectCompletion( - content=content, - elapsed_sec=0.5, - usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - ) - - -class StaticSeedClient: - def __init__(self, tool: ModuleType, content: str) -> None: - self._tool = tool - self.content = content - self.requests = [] - - def detect(self, request): # type: ignore[no-untyped-def] - self.requests.append(request) - return self._tool.DirectCompletion( - content=self.content, - elapsed_sec=0.25, - usage={"prompt_tokens": 1, "completion_tokens": 3, "total_tokens": 4}, - ) - - -def test_staged_detection_reuses_validation_and_augmentation_flow() -> None: - tool = load_tool( - "measurement_staged_detection_probe_case", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works at NVIDIA.", - labels=["first_name", "organization_name"], - row_index=0, - ), - client=client, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_suggestion_count == 1 - assert result.seed_entity_count == 1 - assert result.validation_decision_count == 1 - assert result.augmented_suggestion_count == 1 - assert result.final_entity_count == 2 - assert result.final_label_counts == {"first_name": 1, "organization_name": 1} - assert result.artifact.final_source_counts == {"direct_seed": 1, "augmenter": 1} - assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=True) - assert result.phase_skip_reasons == tool.PhaseSkipReasons() - assert result.model_phase_count == 3 - assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=1) - assert result.model_request_count == 3 - assert result.total_usage == {"prompt_tokens": 30, "completion_tokens": 15, "total_tokens": 45} - assert len(client.prompts) == 3 - - -def test_staged_detection_can_disable_augmentation_phase() -> None: - tool = load_tool( - "measurement_staged_detection_probe_disable_augmentation", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works at NVIDIA.", - labels=["first_name", "organization_name"], - row_index=0, - ), - client=client, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_label_counts == {"first_name": 1} - assert result.phase_model_work == tool.PhaseModelWork(seed=True, validation=True, augmentation=False) - assert result.phase_skip_reasons == tool.PhaseSkipReasons(augmentation="disabled") - assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=1, augmentation=0) - assert result.model_request_count == 2 - assert len(client.prompts) == 2 - - -def test_staged_detection_execution_exposes_output_row_without_serializing_it() -> None: - tool = load_tool( - "measurement_staged_detection_probe_execution", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": []}', - ], - ) - - execution = tool.execute_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works remotely.", - labels=["first_name"], - row_index=0, - ), - client=client, - ) - - assert execution.case.status == tool.CaseStatus.completed - assert execution.row[tool.COL_DETECTED_ENTITIES]["entities"][0]["value"] == "Alice" - assert "row" not in execution.case.model_dump() - - -def test_staged_detection_validation_drop_removes_seed_entity() -> None: - tool = load_tool( - "measurement_staged_detection_probe_drop", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "name", "label": "first_name", "reason": "placeholder"}]}', - '{"decisions": [{"id": "first_name_3_7", "decision": "drop", "reason": "placeholder"}]}', - '{"entities": []}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="my name is hidden", - labels=["first_name"], - row_index=0, - ), - client=client, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_entity_count == 1 - assert result.validation_decision_count == 1 - assert result.final_entity_count == 0 - assert result.final_label_counts == {} - - -def test_staged_detection_invalid_reclass_label_keeps_seed_label() -> None: - tool = load_tool( - "measurement_staged_detection_probe_invalid_reclass_label", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - ( - '{"decisions": [' - '{"id": "first_name_0_5", "decision": "reclass", ' - '"proposed_label": "drop", "reason": "invalid label"}' - "]}" - ), - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works remotely.", - labels=["first_name"], - row_index=0, - ), - client=client, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_entity_count == 1 - assert result.final_label_counts == {"first_name": 1} - - -def test_staged_detection_discards_invalid_augmentation_labels() -> None: - tool = load_tool( - "measurement_staged_detection_probe_invalid_augmentation_label", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "drop", "reason": "invalid label"}]}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works at NVIDIA.", - labels=["first_name", "organization_name"], - row_index=0, - ), - client=client, - ) - - assert result.status == tool.CaseStatus.completed - assert result.augmented_suggestion_count == 0 - assert result.final_entity_count == 1 - assert result.final_label_counts == {"first_name": 1} - - -def test_staged_detection_preserves_more_specific_seed_label_on_native_reclass() -> None: - tool = load_tool( - "measurement_staged_detection_probe_specific_label_reclass", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "birth date"}]}', - ( - '{"decisions": [' - '{"id": "date_of_birth_26_41", "decision": "reclass", ' - '"proposed_label": "date", "reason": "date expression"}' - "]}" - ), - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="The applicant was born on 23 October 1992.", - labels=["date", "date_of_birth"], - row_index=0, - ), - client=client, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_entity_count == 1 - assert result.final_label_counts == {"date_of_birth": 1} - - -def test_staged_detection_allows_date_of_birth_reclass_without_birth_context() -> None: - tool = load_tool( - "measurement_staged_detection_probe_generic_date_reclass", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "23 October 1992", "label": "date_of_birth", "reason": "date"}]}', - ( - '{"decisions": [' - '{"id": "date_of_birth_3_18", "decision": "reclass", ' - '"proposed_label": "date", "reason": "filing date"}' - "]}" - ), - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="On 23 October 1992 the applicant filed an action.", - labels=["date", "date_of_birth"], - row_index=0, - ), - client=client, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_entity_count == 1 - assert result.final_label_counts == {"date": 1} - - -def test_staged_detection_demotes_native_birth_date_without_birth_context() -> None: - tool = load_tool( - "measurement_staged_detection_probe_generic_date_birth_label", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - '{"entities": [{"value": "23 October 1992", "label": "date", "reason": "date"}]}', - ( - '{"decisions": [' - '{"id": "date_3_18", "decision": "reclass", ' - '"proposed_label": "date_of_birth", "reason": "ambiguous date"}' - "]}" - ), - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="On 23 October 1992 the applicant filed an action.", - labels=["date", "date_of_birth"], - row_index=0, - ), - client=client, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_entity_count == 1 - assert result.final_label_counts == {"date": 1} - - -def test_staged_detection_can_seed_from_direct_gliner_payload_without_llm_seed_prompt() -> None: - tool = load_tool( - "measurement_staged_detection_probe_gliner_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - seed_client = StaticSeedClient( - tool, - '{"entities": [{"text": "Alice", "label": "first_name", "start": 0, "end": 5, "score": 0.99}]}', - ) - llm_client = SequencedClient( - tool, - [ - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"entities": [{"value": "NVIDIA", "label": "organization_name", "reason": "employer"}]}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works at NVIDIA.", - labels=["first_name", "organization_name"], - row_index=0, - ), - client=llm_client, - seed_client=seed_client, - seed_source=tool.SeedSource.gliner, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.gliner - assert result.seed_entity_count == 1 - assert result.final_label_counts == {"first_name": 1, "organization_name": 1} - assert result.total_usage == {"prompt_tokens": 21, "completion_tokens": 13, "total_tokens": 34} - assert len(seed_client.requests) == 1 - assert len(llm_client.prompts) == 2 - - -def test_staged_detection_promotes_gliner_date_seed_in_birth_context() -> None: - tool = load_tool( - "measurement_staged_detection_probe_gliner_birth_date_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - seed_client = StaticSeedClient( - tool, - ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 26, "end": 41, "score": 0.99}]}'), - ) - llm_client = SequencedClient( - tool, - ['{"decisions": [{"id": "date_of_birth_26_41", "decision": "keep", "reason": "birth date"}]}'], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="The applicant was born on 23 October 1992.", - labels=["date", "date_of_birth"], - row_index=0, - ), - client=llm_client, - seed_client=seed_client, - seed_source=tool.SeedSource.gliner, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.gliner - assert result.seed_entity_count == 1 - assert result.final_label_counts == {"date_of_birth": 1} - - -def test_staged_detection_keeps_gliner_date_seed_without_birth_context() -> None: - tool = load_tool( - "measurement_staged_detection_probe_gliner_generic_date_seed", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - seed_client = StaticSeedClient( - tool, - ('{"entities": [{"text": "23 October 1992", "label": "date", "start": 3, "end": 18, "score": 0.99}]}'), - ) - llm_client = SequencedClient( - tool, - ['{"decisions": [{"id": "date_3_18", "decision": "keep", "reason": "filing date"}]}'], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="On 23 October 1992 the applicant filed an action.", - labels=["date", "date_of_birth"], - row_index=0, - ), - client=llm_client, - seed_client=seed_client, - seed_source=tool.SeedSource.gliner, - skip_augmentation=True, - ) - - assert result.status == tool.CaseStatus.completed - assert result.seed_source == tool.SeedSource.gliner - assert result.seed_entity_count == 1 - assert result.final_label_counts == {"date": 1} - - -def test_staged_detection_validation_prompt_preserves_degree_label_guidance() -> None: - tool = load_tool( - "measurement_staged_detection_probe_degree_guidance", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - request = tool.StagedDetectionRequest( - case_id="case-1", - text="He earned his Bachelor of Science in physics.", - labels=["degree", "education_level", "field_of_study"], - ) - candidates = tool.ValidationCandidatesSchema( - candidates=[ - ValidationCandidateSchema( - id="education_level_14_33", - value="Bachelor of Science", - label="education_level", - context_before="He earned his ", - context_after=" in physics.", - ) - ] - ) - - prompt = tool._validation_prompt(request, candidates) - - assert "degree" in prompt - assert "education_level" in prompt - assert "Bachelor of Science" in prompt - assert "Prefer degree for credential names" in prompt - - -def test_staged_detection_augmentation_prompt_discourages_grouped_person_and_surname_spans() -> None: - tool = load_tool( - "measurement_staged_detection_probe_augmentation_guidance", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - request = tool.StagedDetectionRequest( - case_id="case-1", - text="Her parents, Mark and Linda, live near West Baker Drive and Baker's grocery.", - labels=["first_name", "last_name", "organization_name", "place_name", "company_name"], - ) - - prompt = tool._augmentation_prompt(request, {tool.COL_SEED_ENTITIES_JSON: "[]"}) - - assert "split personal names connected by 'and'" in prompt - assert "Do not label a list of people as organization_name" in prompt - assert "also return the surname substring as last_name" in prompt - - -def test_staged_detection_baseline_comparison_skips_rows_without_signature_hashes() -> None: - tool = load_tool( - "measurement_staged_detection_probe_missing_baseline_signatures", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - llm_client = SequencedClient( - tool, - [ - '{"entities": [{"value": "Alice", "label": "first_name", "reason": "person name"}]}', - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "person name"}]}', - '{"entities": []}', - ], - ) - case = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works remotely.", - labels=["first_name"], - row_index=0, - ), - client=llm_client, - ) - - compared = tool._case_with_comparison(case, {"row_index": 0, "final_entity_count": 1}) - - assert compared.comparison is None - - -def test_staged_detection_can_chunk_validation_into_local_excerpts() -> None: - tool = load_tool( - "measurement_staged_detection_probe_chunked_validation", - REPO_ROOT / "tools/measurement/staged_detection_probe.py", - ) - client = SequencedClient( - tool, - [ - ( - '{"entities": [' - '{"value": "Alice", "label": "first_name", "reason": "name"},' - '{"value": "Paris", "label": "city", "reason": "city"}' - "]}" - ), - '{"decisions": [{"id": "first_name_0_5", "decision": "keep", "reason": "real name"}]}', - '{"decisions": [{"id": "city_61_66", "decision": "keep", "reason": "city"}]}', - '{"entities": []}', - ], - ) - - result = tool.run_staged_detection_case( - tool.StagedDetectionRequest( - case_id="case-1", - text="Alice works in a very long remote biography before moving to Paris.", - labels=["first_name", "city"], - row_index=0, - ), - client=client, - validation_prompt_mode=tool.ValidationPromptMode.chunked_excerpt, - validation_max_entities_per_call=1, - validation_excerpt_window_chars=8, - ) - - assert result.status == tool.CaseStatus.completed - assert result.final_label_counts == {"city": 1, "first_name": 1} - assert result.phase_model_requests == tool.PhaseModelRequests(seed=1, validation=2, augmentation=1) - assert result.model_phase_count == 3 - assert result.model_request_count == 4 - assert len(client.prompts) == 4 - assert "Alice" in client.prompts[1] - assert "Paris" not in client.prompts[1] - assert "Paris" in client.prompts[2] diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 67b2e828..4ae539ba 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -3,14 +3,12 @@ # Measurement tools -This directory contains developer tools for measuring Anonymizer runs, -exporting measurement JSONL to tables, and comparing benchmark strategies. -Run the tools inside the project environment, either with an activated venv or -through `uv run`. +This directory contains developer tools for measuring Anonymizer runs and +exporting measurement JSONL to tables. Run the tools inside the project +environment, either with an activated venv or through `uv run`. Use these tools when you need evidence about cost, latency, reliability, or -anonymization quality. They are not product entry points and the benchmark-only -strategy knobs are not public Anonymizer defaults. +anonymization quality. They are not product entry points. ## Quick export to DataFrames or CSV @@ -68,11 +66,9 @@ The measurement system has three layers: - Instrumentation in Anonymizer emits JSONL records for runs, stages, DataDesigner workflows, direct model workflows, and per-record safety metrics. -- Benchmark runners and probes create repeatable workloads and write those JSONL - records plus optional sidecars such as detection artifacts and DataDesigner - traces. -- Analysis tools convert raw run artifacts into case, group, model, and - comparison tables. +- Benchmark runners create repeatable workloads and write those JSONL records + plus optional sidecars such as detection artifacts and DataDesigner traces. +- Analysis tools convert raw run artifacts into case, group, and model tables. External/distributed execution is a separate boundary. Detection export APIs are responsible for building DataDesigner configs that an external runtime can @@ -86,19 +82,12 @@ orchestration or distributed DataDesigner execution. | --- | --- | | Export raw measurement JSONL to tables | `export_measurements.py` | | Run repeatable Anonymizer suites | `run_benchmarks.py` | -| Inspect DataDesigner traces | `analyze_dd_traces.py` | | Inspect detection artifact sidecars | `analyze_detection_artifacts.py` | -| Probe one direct detection prompt | `direct_detection_probe.py` | -| Probe staged seed/validate/augment paths | `staged_detection_probe.py` | -| Analyze staged probe outputs | `analyze_staged_detection_output.py` | | Analyze benchmark output directories | `analyze_benchmark_output.py` | -| Compare a candidate against a baseline | `compare_strategy_pairs.py` | -| Screen many comparison files | `screen_strategy_comparisons.py` | -| Extract exact-signature deltas | `extract_signature_deltas.py` | -Most workflows start with `run_benchmarks.py`, then -`analyze_benchmark_output.py`, then either `compare_strategy_pairs.py` or -`screen_strategy_comparisons.py`. +Most workflows start with `run_benchmarks.py`, then either export the raw +measurement log with `export_measurements.py` or summarize the benchmark output +directory with `analyze_benchmark_output.py`. ## Implementation shape @@ -187,7 +176,6 @@ Example: suite_id: biography-smoke model_configs: ./model-configs.yaml model_providers: ./providers.yaml -dd_parser_compat: none case_retries: 1 case_retry_backoff_sec: 10 workloads: @@ -236,9 +224,9 @@ pass `--no-export` when only raw measurement JSONL is needed. Before starting a real run, the benchmark runner performs cheap preflight checks: suite/config parsing, local dataset existence, CSV/Parquet text-column -metadata, provider YAML shape, native runtime requirements, and active -model-alias references. `--dry-run` runs those same checks, expands the planned -matrix, and skips output-dir writes and model work. +metadata, provider YAML shape, and active model-alias references. `--dry-run` +runs those same checks, expands the planned matrix, and skips output-dir writes +and model work. Use `case_retries` and `case_retry_backoff_sec` for long-running suites on shared model endpoints. Retries are disabled by default. When enabled, a failed @@ -246,94 +234,6 @@ case is retried with the same `case_id` and output paths; the final case still records `attempt_count` and `attempt_errors` in `summary.json`. `--fail-fast` remains fail-fast and bypasses retries. -## Benchmark-only detection strategies - -Configs may set `experimental_detection_strategy` for benchmark-only pipeline -probes. These values are not public `Detect` config fields, and they should not -be treated as safe defaults across arbitrary data. - -```yaml -configs: - - id: native-single-pass - experimental_detection_strategy: native_single_pass - replace: redact -``` - -Supported values: - -- `default`: run the normal Anonymizer detection pipeline. -- `no_augment`: run GLiNER detection and validation, but skip LLM augmentation. -- `detector_only`: run only GLiNER detection and local finalization. This skips - LLM validation and LLM augmentation. -- `native_candidate_validate_no_augment`: run a benchmark-only native staged - detector without DataDesigner using direct OpenAI-compatible calls for seed - extraction and validation, then skip augmentation. -- `detector_native_validate_no_augment`: run the normal GLiNER detector seed - through Anonymizer/DataDesigner, then bypass DataDesigner validation and - augmentation with direct OpenAI-compatible validation calls. -- `detector_native_validate_native_augment`: run the normal GLiNER detector seed - through Anonymizer/DataDesigner, then bypass DataDesigner validation and - augmentation with direct OpenAI-compatible validation and augmentation calls. -- `gliner_native_validate_no_augment`: run a direct hosted-GLiNER seed without - DataDesigner, validate those detector candidates with direct - OpenAI-compatible calls, and skip augmentation. -- `gliner_native_validate_native_augment`: run a direct hosted-GLiNER seed - without DataDesigner, validate those detector candidates with direct - OpenAI-compatible calls, then run direct native augmentation. -- `native_single_pass`: run a benchmark-only native detector without - DataDesigner using one direct OpenAI-compatible provider call per row. The - model must return exact values plus `start`/`end` offsets; local code - validates offsets, resolves overlaps, and records parser/runtime failures as - `model_workflow` errors. -- `native_single_pass_recall`: the same one-call native detector with a - recall-oriented prompt that includes Anonymizer's label examples and stronger - high-recall guidance. -- `native_single_pass_values`: the same one-call native detector, but with the - value-only prompt shape from `direct_detection_probe.py`. The model returns - exact values and labels only; local code resolves every occurrence of each - returned value into spans. -- `native_single_pass_values_recall`: the value-only one-call detector with the - recall-oriented prompt from `direct_detection_probe.py`. - -Native benchmark strategies require an explicit runtime. Set top-level -`native_runtime.endpoint` and `native_runtime.model`, set the standard -`ANONYMIZER_BENCH_NATIVE_ENDPOINT` and `ANONYMIZER_BENCH_NATIVE_MODEL` -environment variables, or override runtime fields per config with -`configs[].native_runtime`. GLiNER-seeded native strategies also require -`native_runtime.gliner_endpoint` and `native_runtime.gliner_model`, or the -standard `ANONYMIZER_BENCH_GLINER_ENDPOINT` and -`ANONYMIZER_BENCH_GLINER_MODEL` environment variables. The runner records -runtime id, alias, provider, model, and env-variable names as run tags; raw -endpoint URLs are not emitted into measurement tables. - -```yaml -native_runtime: - runtime_id: local-vllm-json - endpoint_env: ANONYMIZER_BENCH_NATIVE_ENDPOINT - model_env: ANONYMIZER_BENCH_NATIVE_MODEL - provider: local-vllm - alias: native-direct -configs: - - id: native-single-pass - experimental_detection_strategy: native_single_pass - replace: redact -``` - -Use `detector_only` only as a lower-bound ablation. It skips the LLM validation -pass that drops false positives and reclassifies ambiguous spans. A faster run -that loses baseline signatures is a rejection. - -Use staged native strategies when the question is "can direct provider calls -replace part of DataDesigner orchestration?" They still need repeated signature, -leak, label-mismatch, parser, and reliability gates before any workload-specific -promotion. - -Use one-call native strategies for the more aggressive "collapse detection to -one call" experiment. They are often faster when the prompt works, but they are -more parser- and recall-sensitive. Any malformed JSON response becomes a failed -case in analysis, and any missed baseline signature should be treated as a -rejection rather than a latency win. - ## DataDesigner traces For debugging DataDesigner calls, pass `--dd-trace last-message` or @@ -355,21 +255,6 @@ side effects. That covers `LLMTextColumnConfig` and runs can detect this gap. Full custom-column message tracing would need a DataDesigner hook; it is a good candidate for an upstream DataDesigner PR. -Summarize traced calls without copying raw prompts or responses into analysis -output: - -```bash -uv run python tools/measurement/analyze_dd_traces.py \ - benchmark-runs/suite-id/traces \ - --output benchmark-runs/suite-id/trace-analysis \ - --format csv -``` - -This writes `trace_analysis.*` and `trace_group_analysis.*`. The row table -captures run tags, workflow/model metadata, status, elapsed time, prompt and -response lengths, token counts, and response-shape flags such as `raw_json`, -`fenced_json`, `embedded_json`, `text`, and `none`. - ## DataDesigner Scheduler Task Traces Pass `--dd-task-trace` to collect sanitized DataDesigner async scheduler task @@ -390,73 +275,6 @@ uv run python tools/measurement/run_benchmarks.py \ --dd-task-trace ``` -## Direct probes - -`direct_detection_probe.py` calls a local OpenAI-compatible endpoint directly -for a small slice of records. It is useful for prompt experiments before adding -a benchmark strategy. - -```bash -uv run python tools/measurement/direct_detection_probe.py \ - docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column text \ - --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ - --model nvidia/nemotron-3-super \ - --labels person,email,api_key,password \ - --row-limit 5 \ - --output /tmp/direct-detection-probe -``` - -`staged_detection_probe.py` runs a no-DataDesigner staged detector outside the -main benchmark harness. It can compare seed extraction, validation, and -augmentation boundaries before integrating a strategy into `run_benchmarks.py`. - -```bash -uv run python tools/measurement/staged_detection_probe.py \ - docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column text \ - --endpoint http://gpu-dev-pod-serve-svc:8000/v1 \ - --model nvidia/nemotron-3-super \ - --labels person,email,api_key,password \ - --row-limit 5 \ - --output /tmp/staged-detection-probe -``` - -Useful staged options: - -- `--seed-source direct-llm`: use direct LLM seed extraction. -- `--seed-source gliner`: use direct hosted GLiNER seeding. -- `--skip-augmentation`: disable augmentation for any seed source. This is an - ablation for measuring how much recall the augmentation phase carries. -- `--validation-prompt-mode chunked-excerpt`: split seed validation candidates - into chunks of `--validation-max-entities-per-call` and send each chunk with a - tagged local excerpt bounded by `--validation-excerpt-window-chars`. - -The staged tool writes `staged-detection-cases.jsonl`, -`staged-detection-artifacts.jsonl`, and `summary.json`. Case rows include -per-phase usage for seed extraction, validation, and augmentation, true case -wall time in `elapsed_sec`, model-call time in `model_elapsed_sec`, -`phase_model_work`, `phase_skip_reasons`, `phase_model_requests`, -`model_phase_count`, `model_request_count`, total usage, and optional baseline -signature deltas. Treat `summary.json` as a sensitive debug artifact because it -records the resolved endpoint/model runtime used for the probe. - -Summarize staged probe outputs: - -```bash -uv run python tools/measurement/analyze_staged_detection_output.py \ - /tmp/staged-detection-probe \ - --output /tmp/staged-detection-probe/analysis \ - --format csv -``` - -The analyzer accepts either the staged output directory or the -`staged-detection-cases.jsonl` file directly. It writes per-case, per-seed -source group, and label-delta tables. Use `group_analysis.csv` for latency, -token, request, and signature-overlap totals; use `label_delta_analysis.csv` to -see which labels account for baseline-only misses or direct-only additions. The -analysis tables omit raw text and raw entity values. - ## Benchmark analysis `analyze_benchmark_output.py` joins `measurements.jsonl`, optional @@ -481,39 +299,6 @@ Use `--detection-artifacts` to provide an explicit detection artifact JSONL sidecar. Otherwise, the analyzer reads `detection-artifacts.jsonl` in the benchmark directory when present. -`compare_strategy_pairs.py` compares baseline/candidate case rows: - -```bash -uv run python tools/measurement/compare_strategy_pairs.py \ - benchmark-runs/suite-id/analysis/case_analysis.csv \ - --baseline-config default \ - --candidate-config native-single-pass \ - --output benchmark-runs/suite-id/analysis/default-vs-native-single-pass.csv -``` - -When one CSV does not contain both arms, pass `--candidate-case-analysis`: - -```bash -uv run python tools/measurement/compare_strategy_pairs.py \ - baseline/analysis/case_analysis.csv \ - --candidate-case-analysis candidate/analysis/case_analysis.csv \ - --baseline-strategy default \ - --candidate-strategy detector_native_validate_no_augment \ - --output comparison.csv -``` - -`screen_strategy_comparisons.py` screens many comparison CSVs: - -```bash -uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ \ - --output benchmark-runs/strategy-screen.csv -``` - -Use `--group-by strategy_workload_family` when the same candidate behaves -differently across workload families. Use `--config-aliases aliases.json` to -group related config IDs, such as temperature or validation-window variants of -the same strategy. - ## Pandas patterns Analysis tables are regular CSV/Parquet files. A typical local workflow: @@ -527,7 +312,6 @@ groups = pd.read_parquet("benchmark-runs/suite/analysis/group_analysis.parquet") cols = [ "workload_id", "config_id", - "experimental_detection_strategy", "median_pipeline_elapsed_sec", "median_observed_total_requests", "median_observed_total_tokens", @@ -543,26 +327,6 @@ failures = cases[ print(failures[["case_id", "config_id", "observed_failed_requests", "dd_trace_error_count"]]) ``` -Compare a candidate against a baseline: - -```python -comparison = pd.read_csv("benchmark-runs/suite/analysis/default-vs-native.csv") -candidate_rows = comparison[ - ["workload_id", "candidate_verdict", "safety_verdict", "performance_verdict", "flags"] -] -print(candidate_rows) -``` - -Find candidate-specific misses: - -```python -loss_cols = [ - column for column in comparison.columns - if column.startswith("baseline_only_final_entity_signature_label_counts.") -] -print(comparison[["workload_id", *loss_cols]].fillna(0)) -``` - ## Metric interpretation Use metrics as signals, not as a single score. @@ -615,14 +379,3 @@ Safety and replacement: - `replacement_synthetic_original_collision_count`: final entity occurrences whose original value was reused as a synthetic replacement value elsewhere in the same record. -- `baseline_only_candidate_covered_signature_count`, - `baseline_only_candidate_overlapping_signature_count`, and - `baseline_only_candidate_uncovered_signature_count`: comparison-only fields - from `compare_strategy_pairs.py`. These split exact signature deltas into - covered, boundary-overlap, and uncovered losses. -- `candidate_verdict`: `candidate_viable`, `review`, or `reject`. - -Treat `candidate_viable` as a promotion candidate, not as an automatic default. -It means the sampled comparison passed the current gates and improved at least -one performance metric without regressing another. Re-run candidates on the -target workload family, with repetitions, before changing production defaults. diff --git a/tools/measurement/analyze_dd_traces.py b/tools/measurement/analyze_dd_traces.py deleted file mode 100644 index b0fcae55..00000000 --- a/tools/measurement/analyze_dd_traces.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Analyze DataDesigner message traces without emitting raw prompt or response text. - -Usage: - uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces - uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --output analysis --format csv - uv run python tools/measurement/analyze_dd_traces.py benchmark-runs/suite-id/traces --json -""" - -from __future__ import annotations - -import json -import logging -import re -import sys -from pathlib import Path -from typing import Annotated, Any - -import cyclopts -import pandas as pd -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from measurement_tools.stats import median_or_none as _median_or_none -from measurement_tools.stats import none_if_nan as _none_if_nan -from measurement_tools.stats import sum_int_or_zero as _sum_int_or_zero -from measurement_tools.tables import AnalysisExportResult, ExportFormat, ModelTableSpec -from measurement_tools.tables import write_analysis_tables as _write_analysis_table_specs -from pydantic import BaseModel, Field, computed_field - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.dd_traces") - -JSON_FENCE_RE = re.compile(r"```\s*json\b", re.IGNORECASE) - - -class TraceAnalysisRow(BaseModel): - trace_file: str - line_number: int - suite_id: str | None = None - workload_id: str | None = None - config_id: str | None = None - case_id: str | None = None - run_id: str - workflow_name: str | None = None - model_alias: str | None = None - model_name: str | None = None - model_provider_name: str | None = None - status: str | None = None - error_type: str | None = None - elapsed_sec: float | None = None - message_count: int = 0 - prompt_chars: int = 0 - response_shape: str - response_chars: int = 0 - response_has_thinking: bool = False - response_has_json_fence: bool = False - response_has_embedded_json: bool = False - input_tokens: int | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class TraceGroupAnalysisRow(BaseModel): - suite_id: str | None = None - workload_id: str | None = None - config_id: str | None = None - workflow_name: str | None = None - model_alias: str | None = None - model_name: str | None = None - model_provider_name: str | None = None - status: str | None = None - error_type: str | None = None - response_shape: str - trace_record_count: int - error_count: int = 0 - thinking_count: int = 0 - json_fence_count: int = 0 - embedded_json_count: int = 0 - median_elapsed_sec: float | None = None - median_prompt_chars: float | None = None - median_response_chars: float | None = None - sum_input_tokens: int = 0 - sum_output_tokens: int = 0 - sum_total_tokens: int = 0 - - -class TraceAnalysis(BaseModel): - trace_path: str - rows: list[TraceAnalysisRow] = Field(default_factory=list) - groups: list[TraceGroupAnalysisRow] = Field(default_factory=list) - - @computed_field - @property - def trace_record_count(self) -> int: - return len(self.rows) - - @computed_field - @property - def group_count(self) -> int: - return len(self.groups) - - -def analyze_trace_path(trace_path: Path) -> TraceAnalysis: - rows = [ - _trace_row(record, trace_file=path, line_number=line_number) - for path in _iter_trace_files(trace_path) - for line_number, record in _iter_trace_records(path) - ] - return TraceAnalysis( - trace_path=str(trace_path), - rows=rows, - groups=build_trace_group_rows(rows), - ) - - -def _iter_trace_files(trace_path: Path) -> list[Path]: - if not trace_path.exists(): - raise ValueError(f"trace path does not exist: {trace_path}") - if trace_path.is_file(): - return [trace_path] - if not trace_path.is_dir(): - raise ValueError(f"trace path is not a file or directory: {trace_path}") - return sorted(trace_path.rglob("*.jsonl")) - - -def _iter_trace_records(path: Path) -> list[tuple[int, dict[str, Any]]]: - records: list[tuple[int, dict[str, Any]]] = [] - for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): - if not line.strip(): - continue - payload = json.loads(line) - if isinstance(payload, dict) and payload.get("record_type") == "dd_message_trace": - records.append((line_number, payload)) - return records - - -def _trace_row(record: dict[str, Any], *, trace_file: Path, line_number: int) -> TraceAnalysisRow: - response = record.get("response") if isinstance(record.get("response"), dict) else None - response_content = _string_or_none(response.get("content")) if response is not None else None - reasoning_content = _string_or_none(response.get("reasoning_content")) if response is not None else None - run_tags = record.get("run_tags") if isinstance(record.get("run_tags"), dict) else {} - usage = record.get("usage") if isinstance(record.get("usage"), dict) else {} - return TraceAnalysisRow( - trace_file=str(trace_file), - line_number=line_number, - suite_id=_string_or_none(run_tags.get("suite_id")), - workload_id=_string_or_none(run_tags.get("workload_id")), - config_id=_string_or_none(run_tags.get("config_id")), - case_id=_string_or_none(run_tags.get("case_id")), - run_id=str(record.get("run_id") or ""), - workflow_name=_string_or_none(record.get("workflow_name")), - model_alias=_string_or_none(record.get("model_alias")), - model_name=_string_or_none(record.get("model_name")), - model_provider_name=_string_or_none(record.get("model_provider_name")), - status=_string_or_none(record.get("status")), - error_type=_string_or_none(record.get("error_type")), - elapsed_sec=_float_or_none(record.get("elapsed_sec")), - message_count=_message_count(record.get("messages")), - prompt_chars=_message_text_chars(record.get("messages")), - response_shape=_response_shape(response_content), - response_chars=len(response_content or ""), - response_has_thinking=_has_thinking(response_content, reasoning_content), - response_has_json_fence=_has_json_fence(response_content), - response_has_embedded_json=_has_embedded_json(response_content), - input_tokens=_int_or_none(usage.get("input_tokens")), - output_tokens=_int_or_none(usage.get("output_tokens")), - total_tokens=_int_or_none(usage.get("total_tokens")), - ) - - -def build_trace_group_rows(rows: list[TraceAnalysisRow]) -> list[TraceGroupAnalysisRow]: - if not rows: - return [] - table = pd.DataFrame([row.model_dump() for row in rows]) - group_columns = [ - "suite_id", - "workload_id", - "config_id", - "workflow_name", - "model_alias", - "model_name", - "model_provider_name", - "status", - "error_type", - "response_shape", - ] - return [ - _build_trace_group_row(keys, group) for keys, group in table.groupby(group_columns, dropna=False, sort=True) - ] - - -def _build_trace_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> TraceGroupAnalysisRow: - ( - suite_id, - workload_id, - config_id, - workflow_name, - model_alias, - model_name, - provider_name, - status, - error_type, - response_shape, - ) = keys - return TraceGroupAnalysisRow( - suite_id=_none_if_nan(suite_id), - workload_id=_none_if_nan(workload_id), - config_id=_none_if_nan(config_id), - workflow_name=_none_if_nan(workflow_name), - model_alias=_none_if_nan(model_alias), - model_name=_none_if_nan(model_name), - model_provider_name=_none_if_nan(provider_name), - status=_none_if_nan(status), - error_type=_none_if_nan(error_type), - response_shape=str(response_shape), - trace_record_count=len(group), - error_count=int((group["status"] == "error").sum()), - thinking_count=int(group["response_has_thinking"].sum()), - json_fence_count=int(group["response_has_json_fence"].sum()), - embedded_json_count=int(group["response_has_embedded_json"].sum()), - median_elapsed_sec=_median_or_none(group, "elapsed_sec"), - median_prompt_chars=_median_or_none(group, "prompt_chars"), - median_response_chars=_median_or_none(group, "response_chars"), - sum_input_tokens=_sum_int_or_zero(group, "input_tokens"), - sum_output_tokens=_sum_int_or_zero(group, "output_tokens"), - sum_total_tokens=_sum_int_or_zero(group, "total_tokens"), - ) - - -def _response_shape(content: str | None) -> str: - if not content: - return "none" - if _has_json_fence(content): - return "fenced_json" - stripped = content.strip() - try: - json.loads(stripped) - except json.JSONDecodeError: - return "embedded_json" if _has_embedded_json(content) else "text" - return "raw_json" - - -def _has_thinking(content: str | None, reasoning_content: str | None) -> bool: - return bool(reasoning_content) or "" in (content or "") - - -def _has_json_fence(content: str | None) -> bool: - return bool(content and JSON_FENCE_RE.search(content)) - - -def _has_embedded_json(content: str | None) -> bool: - if not content: - return False - decoder = json.JSONDecoder() - for start, char in enumerate(content): - if char not in "{[": - continue - try: - parsed, _ = decoder.raw_decode(content, start) - except json.JSONDecodeError: - continue - if isinstance(parsed, dict | list): - return True - return False - - -def _message_count(messages: object) -> int: - return len(messages) if isinstance(messages, list) else 0 - - -def _message_text_chars(messages: object) -> int: - if not isinstance(messages, list): - return 0 - return sum(_message_content_chars(message) for message in messages) - - -def _message_content_chars(message: object) -> int: - if not isinstance(message, dict): - return 0 - return _content_chars(message.get("content")) - - -def _content_chars(content: object) -> int: - if isinstance(content, str): - return len(content) - if isinstance(content, list): - return sum(_content_item_chars(item) for item in content) - return 0 - - -def _content_item_chars(item: object) -> int: - if isinstance(item, str): - return len(item) - if isinstance(item, dict): - return len(str(item.get("text") or "")) - return 0 - - -def _string_or_none(value: object) -> str | None: - return str(value) if value is not None and not pd.isna(value) else None - - -def _float_or_none(value: object) -> float | None: - return float(value) if value is not None and not pd.isna(value) else None - - -def _int_or_none(value: object) -> int | None: - return int(value) if value is not None and not pd.isna(value) else None - - -def write_analysis_tables(result: TraceAnalysis, output_dir: Path, export_format: ExportFormat) -> AnalysisExportResult: - return _write_analysis_table_specs( - output_dir, - export_format, - [ - ModelTableSpec("trace_analysis", result.rows, TraceAnalysisRow), - ModelTableSpec("trace_group_analysis", result.groups, TraceGroupAnalysisRow), - ], - ) - - -def render_result(result: TraceAnalysis, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - lines = [f"Analyzed {result.trace_record_count} trace record(s) across {result.group_count} group(s)"] - for group in result.groups: - model_label = group.model_alias or group.model_name - label = f"{group.workload_id}/{group.config_id}/{group.workflow_name}/{model_label}/{group.response_shape}" - lines.append( - f"- {label}: traces={group.trace_record_count}, errors={group.error_count}, " - f"thinking={group.thinking_count}, json_fence={group.json_fence_count}, " - f"tokens={group.sum_total_tokens}" - ) - return "\n".join(lines) - - -@app.default -def main( - trace_path: Path, - *, - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - result = analyze_trace_path(trace_path) - if output is not None: - write_analysis_tables(result, output, format) - except (OSError, ValueError, json.JSONDecodeError) as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/analyze_staged_detection_output.py b/tools/measurement/analyze_staged_detection_output.py deleted file mode 100644 index d74e0f32..00000000 --- a/tools/measurement/analyze_staged_detection_output.py +++ /dev/null @@ -1,392 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Analyze benchmark-only DD-free staged detection probe outputs. - -Usage: - uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe - uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe/staged-detection-cases.jsonl - uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --output analysis --format csv - uv run python tools/measurement/analyze_staged_detection_output.py /tmp/staged-probe --json -""" - -from __future__ import annotations - -import json -import logging -import sys -from collections import Counter, defaultdict -from pathlib import Path -from typing import Annotated, Any - -import cyclopts -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from measurement_tools.tables import AnalysisExportResult, ExportFormat, ModelTableSpec -from measurement_tools.tables import write_analysis_tables as _write_analysis_table_specs -from pydantic import BaseModel, Field, computed_field - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.staged_detection_output") - - -class StagedCaseAnalysisRow(BaseModel): - source_path: str - case_id: str - row_index: int | None = None - seed_source: str | None = None - status: str | None = None - case_failed: bool = False - elapsed_sec: float | None = None - model_elapsed_sec: float | None = None - model_phase_count: int = 0 - model_request_count: int = 0 - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - seed_entity_count: int = 0 - validation_candidate_count: int = 0 - validation_decision_count: int = 0 - augmented_suggestion_count: int = 0 - final_entity_count: int = 0 - final_entity_signature_count: int = 0 - final_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_final_entity_signature_count: int | None = None - shared_final_entity_signature_count: int | None = None - baseline_only_final_entity_signature_count: int | None = None - direct_only_final_entity_signature_count: int | None = None - baseline_shared_signature_rate: float | None = None - baseline_loss_signature_rate: float | None = None - baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) - direct_only_label_counts: dict[str, int] = Field(default_factory=dict) - error: str | None = None - - -class StagedGroupAnalysisRow(BaseModel): - seed_source: str | None = None - case_count: int - completed_case_count: int = 0 - error_case_count: int = 0 - failed_case_rate: float | None = None - elapsed_sec_sum: float | None = None - elapsed_sec_mean: float | None = None - model_elapsed_sec_sum: float | None = None - model_elapsed_sec_mean: float | None = None - model_phase_count_sum: int = 0 - model_request_count_sum: int = 0 - prompt_tokens_sum: int = 0 - completion_tokens_sum: int = 0 - total_tokens_sum: int = 0 - final_entity_count_sum: int = 0 - final_entity_signature_count_sum: int = 0 - baseline_final_entity_signature_count_sum: int = 0 - shared_final_entity_signature_count_sum: int = 0 - baseline_only_final_entity_signature_count_sum: int = 0 - direct_only_final_entity_signature_count_sum: int = 0 - baseline_shared_signature_rate: float | None = None - baseline_loss_signature_rate: float | None = None - - -class LabelDeltaAnalysisRow(BaseModel): - seed_source: str | None = None - delta_type: str - label: str - count: int - - -class StagedDetectionOutputAnalysis(BaseModel): - source_path: str - cases: list[StagedCaseAnalysisRow] = Field(default_factory=list) - groups: list[StagedGroupAnalysisRow] = Field(default_factory=list) - label_deltas: list[LabelDeltaAnalysisRow] = Field(default_factory=list) - - @computed_field - @property - def case_count(self) -> int: - return len(self.cases) - - @computed_field - @property - def group_count(self) -> int: - return len(self.groups) - - @computed_field - @property - def label_delta_count(self) -> int: - return len(self.label_deltas) - - -def analyze_staged_detection_output(input_path: Path) -> StagedDetectionOutputAnalysis: - case_path = resolve_case_path(input_path) - case_rows = [_build_case_row(row, source_path=case_path) for row in read_case_records(case_path)] - return StagedDetectionOutputAnalysis( - source_path=str(case_path), - cases=case_rows, - groups=build_group_rows(case_rows), - label_deltas=build_label_delta_rows(case_rows), - ) - - -def resolve_case_path(input_path: Path) -> Path: - if not input_path.exists(): - raise ValueError(f"input path does not exist: {input_path}") - if input_path.is_dir(): - input_path = input_path / "staged-detection-cases.jsonl" - if not input_path.exists(): - raise ValueError(f"staged detection case file does not exist: {input_path}") - if input_path.is_dir(): - raise ValueError(f"input path is a directory: {input_path}") - return input_path - - -def read_case_records(case_path: Path) -> list[dict[str, Any]]: - records: list[dict[str, Any]] = [] - for line in case_path.read_text(encoding="utf-8").splitlines(): - if not line.strip(): - continue - record = json.loads(line) - if not isinstance(record, dict): - raise ValueError(f"JSONL row is not an object in {case_path}") - if record.get("record_type") in (None, "staged_detection_case"): - records.append(record) - return records - - -def _build_case_row(record: dict[str, Any], *, source_path: Path) -> StagedCaseAnalysisRow: - comparison = _dict_value(record.get("comparison")) - baseline_count = _optional_int(comparison.get("baseline_final_entity_signature_count")) - shared_count = _optional_int(comparison.get("shared_final_entity_signature_count")) - baseline_only_count = _optional_int(comparison.get("baseline_only_final_entity_signature_count")) - return StagedCaseAnalysisRow( - source_path=str(source_path), - case_id=str(record.get("case_id") or ""), - row_index=_optional_int(record.get("row_index")), - seed_source=_optional_str(record.get("seed_source")), - status=_optional_str(record.get("status")), - case_failed=str(record.get("status")).lower() == "error", - elapsed_sec=_optional_float(record.get("elapsed_sec")), - model_elapsed_sec=_optional_float(record.get("model_elapsed_sec")), - model_phase_count=_int_value(record.get("model_phase_count")), - model_request_count=_int_value(record.get("model_request_count")), - **_usage_fields(_dict_value(record.get("total_usage"))), - **_entity_count_fields(record), - baseline_final_entity_signature_count=baseline_count, - shared_final_entity_signature_count=shared_count, - baseline_only_final_entity_signature_count=baseline_only_count, - direct_only_final_entity_signature_count=_optional_int( - comparison.get("direct_only_final_entity_signature_count") - ), - baseline_shared_signature_rate=_rate(shared_count, baseline_count), - baseline_loss_signature_rate=_rate(baseline_only_count, baseline_count), - baseline_only_label_counts=_counter_dict(comparison.get("baseline_only_label_counts")), - direct_only_label_counts=_counter_dict(comparison.get("direct_only_label_counts")), - error=_optional_str(record.get("error")), - ) - - -def _usage_fields(usage: dict[str, Any]) -> dict[str, int]: - return { - "prompt_tokens": _int_value(usage.get("prompt_tokens")), - "completion_tokens": _int_value(usage.get("completion_tokens")), - "total_tokens": _int_value(usage.get("total_tokens")), - } - - -def _entity_count_fields(record: dict[str, Any]) -> dict[str, int | dict[str, int]]: - return { - "seed_entity_count": _int_value(record.get("seed_entity_count")), - "validation_candidate_count": _int_value(record.get("validation_candidate_count")), - "validation_decision_count": _int_value(record.get("validation_decision_count")), - "augmented_suggestion_count": _int_value(record.get("augmented_suggestion_count")), - "final_entity_count": _int_value(record.get("final_entity_count")), - "final_entity_signature_count": _int_value(record.get("final_entity_signature_count")), - "final_label_counts": _counter_dict(record.get("final_label_counts")), - } - - -def build_group_rows(cases: list[StagedCaseAnalysisRow]) -> list[StagedGroupAnalysisRow]: - groups: defaultdict[str | None, list[StagedCaseAnalysisRow]] = defaultdict(list) - for case in cases: - groups[case.seed_source].append(case) - return [_build_group_row(seed_source, rows) for seed_source, rows in sorted(groups.items(), key=_group_sort_key)] - - -def _build_group_row(seed_source: str | None, rows: list[StagedCaseAnalysisRow]) -> StagedGroupAnalysisRow: - case_count = len(rows) - error_count = sum(1 for row in rows if row.case_failed) - baseline_total = _sum_optional_int(rows, "baseline_final_entity_signature_count") - shared_total = _sum_optional_int(rows, "shared_final_entity_signature_count") - baseline_only_total = _sum_optional_int(rows, "baseline_only_final_entity_signature_count") - model_request_count = sum(row.model_request_count for row in rows) - return StagedGroupAnalysisRow( - seed_source=seed_source, - case_count=case_count, - completed_case_count=case_count - error_count, - error_case_count=error_count, - failed_case_rate=_rate(error_count, case_count), - elapsed_sec_sum=_sum_optional_float(rows, "elapsed_sec"), - elapsed_sec_mean=_mean_optional_float(rows, "elapsed_sec"), - model_elapsed_sec_sum=_sum_optional_float(rows, "model_elapsed_sec"), - model_elapsed_sec_mean=_mean_optional_float(rows, "model_elapsed_sec"), - model_phase_count_sum=sum(row.model_phase_count for row in rows), - model_request_count_sum=model_request_count, - prompt_tokens_sum=sum(row.prompt_tokens for row in rows), - completion_tokens_sum=sum(row.completion_tokens for row in rows), - total_tokens_sum=sum(row.total_tokens for row in rows), - final_entity_count_sum=sum(row.final_entity_count for row in rows), - final_entity_signature_count_sum=sum(row.final_entity_signature_count for row in rows), - baseline_final_entity_signature_count_sum=baseline_total, - shared_final_entity_signature_count_sum=shared_total, - baseline_only_final_entity_signature_count_sum=baseline_only_total, - direct_only_final_entity_signature_count_sum=_sum_optional_int( - rows, "direct_only_final_entity_signature_count" - ), - baseline_shared_signature_rate=_rate(shared_total, baseline_total), - baseline_loss_signature_rate=_rate(baseline_only_total, baseline_total), - ) - - -def _group_sort_key(item: tuple[str | None, list[StagedCaseAnalysisRow]]) -> str: - return item[0] or "" - - -def build_label_delta_rows(cases: list[StagedCaseAnalysisRow]) -> list[LabelDeltaAnalysisRow]: - counts: Counter[tuple[str | None, str, str]] = Counter() - for case in cases: - for label, count in case.baseline_only_label_counts.items(): - counts[(case.seed_source, "baseline_only", label)] += count - for label, count in case.direct_only_label_counts.items(): - counts[(case.seed_source, "direct_only", label)] += count - return [ - LabelDeltaAnalysisRow(seed_source=seed_source, delta_type=delta_type, label=label, count=count) - for (seed_source, delta_type, label), count in sorted(counts.items(), key=_label_delta_sort_key) - ] - - -def _label_delta_sort_key(item: tuple[tuple[str | None, str, str], int]) -> tuple[str, str, str]: - (seed_source, delta_type, label), _ = item - return (seed_source or "", delta_type, label) - - -def _sum_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: - values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] - return sum(values) if values else None - - -def _mean_optional_float(rows: list[StagedCaseAnalysisRow], field_name: str) -> float | None: - values = [value for value in (_optional_float(getattr(row, field_name)) for row in rows) if value is not None] - return sum(values) / len(values) if values else None - - -def _sum_optional_int(rows: list[StagedCaseAnalysisRow], field_name: str) -> int: - return sum(value for value in (_optional_int(getattr(row, field_name)) for row in rows) if value is not None) - - -def _rate(numerator: object, denominator: object) -> float | None: - numerator_value = _optional_float(numerator) - denominator_value = _optional_float(denominator) - if numerator_value is None or denominator_value is None or denominator_value <= 0: - return None - return numerator_value / denominator_value - - -def _dict_value(raw: object) -> dict[str, Any]: - return raw if isinstance(raw, dict) else {} - - -def _counter_dict(raw: object) -> dict[str, int]: - return {str(key): _int_value(value) for key, value in _dict_value(raw).items()} - - -def _optional_str(raw: object) -> str | None: - return str(raw) if raw is not None else None - - -def _optional_float(raw: object) -> float | None: - if raw is None: - return None - return float(raw) - - -def _optional_int(raw: object) -> int | None: - if raw is None: - return None - return int(float(raw)) - - -def _int_value(raw: object) -> int: - return _optional_int(raw) or 0 - - -def write_analysis_tables( - result: StagedDetectionOutputAnalysis, - output_dir: Path, - export_format: ExportFormat, -) -> AnalysisExportResult: - return _write_analysis_table_specs( - output_dir, - export_format, - [ - ModelTableSpec("case_analysis", result.cases, StagedCaseAnalysisRow), - ModelTableSpec("group_analysis", result.groups, StagedGroupAnalysisRow), - ModelTableSpec("label_delta_analysis", result.label_deltas, LabelDeltaAnalysisRow), - ], - ) - - -def render_result(result: StagedDetectionOutputAnalysis, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - lines = [f"Analyzed {result.case_count} staged detection case(s) across {result.group_count} group(s)"] - lines.extend(_render_group_line(group, result.label_deltas) for group in result.groups) - return "\n".join(lines) - - -def _render_group_line(group: StagedGroupAnalysisRow, label_deltas: list[LabelDeltaAnalysisRow]) -> str: - label = group.seed_source or "" - lost = _top_labels(label_deltas, seed_source=group.seed_source, delta_type="baseline_only") - return ( - f"- {label}: cases={group.case_count}, errors={group.error_case_count}, " - f"elapsed_sum={_fmt_float(group.elapsed_sec_sum)}s, " - f"model_elapsed_sum={_fmt_float(group.model_elapsed_sec_sum)}s, " - f"requests={group.model_request_count_sum}, tokens={group.total_tokens_sum}, " - f"shared={group.shared_final_entity_signature_count_sum}/" - f"{group.baseline_final_entity_signature_count_sum}, " - f"baseline_only={group.baseline_only_final_entity_signature_count_sum}, " - f"direct_only={group.direct_only_final_entity_signature_count_sum}, lost_labels={lost}" - ) - - -def _top_labels(label_deltas: list[LabelDeltaAnalysisRow], *, seed_source: str | None, delta_type: str) -> str: - matches = [delta for delta in label_deltas if delta.seed_source == seed_source and delta.delta_type == delta_type] - if not matches: - return "{}" - items = sorted(matches, key=lambda delta: (-delta.count, delta.label))[:8] - return "{" + ", ".join(f"{item.label}:{item.count}" for item in items) + "}" - - -def _fmt_float(value: float | None) -> str: - return "n/a" if value is None else f"{value:.3f}" - - -@app.default -def main( - input_path: Path, - *, - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.parquet, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - result = analyze_staged_detection_output(input_path) - if output is not None: - write_analysis_tables(result, output, format) - except ValueError as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/compare_strategy_pairs.py b/tools/measurement/compare_strategy_pairs.py deleted file mode 100644 index e9075e2c..00000000 --- a/tools/measurement/compare_strategy_pairs.py +++ /dev/null @@ -1,1438 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Compare benchmark case-analysis rows for baseline/candidate strategy pairs. - -Usage: - uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.csv \ - --baseline-strategy default --candidate-strategy detector_native_validate_no_augment - uv run python tools/measurement/compare_strategy_pairs.py analysis/case_analysis.parquet \ - --baseline-config default --candidate-config no-augment --output comparisons.csv - uv run python tools/measurement/compare_strategy_pairs.py baseline/case_analysis.csv \ - --candidate-case-analysis candidate/case_analysis.csv \ - --baseline-strategy default --candidate-strategy native_single_pass -""" - -from __future__ import annotations - -import ast -import json -import logging -import sys -from enum import StrEnum -from pathlib import Path -from typing import Annotated - -import cyclopts -import pandas as pd -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from measurement_tools.tables import ExportFormat -from pydantic import BaseModel, Field - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.strategy_pairs") - -_SIGNATURE_DETAIL_FIELDS = { - "label", - "source", - "row_index", - "start_position", - "end_position", - "value_length", -} - - -class SafetyVerdict(StrEnum): - passed = "pass" - review = "review" - fail = "fail" - - -class PerformanceVerdict(StrEnum): - improved = "improved" - mixed = "mixed" - regressed = "regressed" - unchanged = "unchanged" - unknown = "unknown" - - -class CandidateVerdict(StrEnum): - candidate_viable = "candidate_viable" - review = "review" - reject = "reject" - - -_MIN_CANDIDATE_OVERLAP_RATIO = 0.8 -_MAX_CANDIDATE_BOUNDARY_GAP_CHARS = 8 -_MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS = 8 -_BOUNDARY_NORMALIZED_BASELINE_LABELS = {"api_key", "http_cookie", "unique_id"} - - -class ComparisonSummary(BaseModel): - comparison_count: int = 0 - value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) - signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) - safety_verdict_counts: dict[str, int] = Field(default_factory=dict) - performance_verdict_counts: dict[str, int] = Field(default_factory=dict) - candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) - candidate_viable_workloads: list[str] = Field(default_factory=list) - review_workloads: list[str] = Field(default_factory=list) - rejected_workloads: list[str] = Field(default_factory=list) - - -class ComparisonRow(BaseModel): - workload_id: str - baseline_config_id: str - candidate_config_id: str - baseline_strategy: str | None = None - candidate_strategy: str | None = None - baseline_replacement_strategy: str | None = None - candidate_replacement_strategy: str | None = None - baseline_case_count: int - candidate_case_count: int - baseline_failed_case_count: float | None = None - candidate_failed_case_count: float | None = None - failed_case_count_delta: float | None = None - baseline_failed_case_rate: float | None = None - candidate_failed_case_rate: float | None = None - failed_case_rate_delta: float | None = None - baseline_pipeline_elapsed_sec: float | None = None - candidate_pipeline_elapsed_sec: float | None = None - pipeline_elapsed_sec_delta: float | None = None - pipeline_elapsed_sec_delta_pct: float | None = None - baseline_observed_total_requests: float | None = None - candidate_observed_total_requests: float | None = None - observed_total_requests_delta: float | None = None - baseline_observed_total_tokens: float | None = None - candidate_observed_total_tokens: float | None = None - observed_total_tokens_delta: float | None = None - baseline_observed_failed_requests: float | None = None - candidate_observed_failed_requests: float | None = None - observed_failed_requests_delta: float | None = None - baseline_observed_bridge_fallback_requests: float | None = None - candidate_observed_bridge_fallback_requests: float | None = None - observed_bridge_fallback_requests_delta: float | None = None - baseline_observed_non_bridge_total_requests: float | None = None - candidate_observed_non_bridge_total_requests: float | None = None - observed_non_bridge_total_requests_delta: float | None = None - baseline_observed_non_bridge_failed_requests: float | None = None - candidate_observed_non_bridge_failed_requests: float | None = None - observed_non_bridge_failed_requests_delta: float | None = None - baseline_original_value_leak_count: float | None = None - candidate_original_value_leak_count: float | None = None - original_value_leak_count_delta: float | None = None - baseline_original_value_leak_record_count: float | None = None - candidate_original_value_leak_record_count: float | None = None - original_value_leak_record_count_delta: float | None = None - baseline_replacement_missing_final_entity_count: float | None = None - candidate_replacement_missing_final_entity_count: float | None = None - replacement_missing_final_entity_count_delta: float | None = None - baseline_replacement_missing_final_value_count: float | None = None - candidate_replacement_missing_final_value_count: float | None = None - replacement_missing_final_value_count_delta: float | None = None - baseline_replacement_synthetic_original_collision_count: float | None = None - candidate_replacement_synthetic_original_collision_count: float | None = None - replacement_synthetic_original_collision_count_delta: float | None = None - baseline_replacement_synthetic_original_collision_value_count: float | None = None - candidate_replacement_synthetic_original_collision_value_count: float | None = None - replacement_synthetic_original_collision_value_count_delta: float | None = None - value_protection_verdict: SafetyVerdict | None = None - signature_parity_verdict: SafetyVerdict | None = None - safety_verdict: SafetyVerdict | None = None - performance_verdict: PerformanceVerdict | None = None - candidate_verdict: CandidateVerdict | None = None - baseline_final_entity_count: float | None = None - candidate_final_entity_count: float | None = None - final_entity_count_delta: float | None = None - baseline_seed_validation_candidate_count: float | None = None - candidate_seed_validation_candidate_count: float | None = None - seed_validation_candidate_count_delta: float | None = None - baseline_augmented_entity_count: float | None = None - candidate_augmented_entity_count: float | None = None - augmented_entity_count_delta: float | None = None - baseline_augmented_new_final_value_count: float | None = None - candidate_augmented_new_final_value_count: float | None = None - augmented_new_final_value_count_delta: float | None = None - baseline_detector_entity_count: float | None = None - candidate_detector_entity_count: float | None = None - baseline_augmenter_entity_count: float | None = None - candidate_augmenter_entity_count: float | None = None - baseline_only_final_entity_signature_count: int | None = None - candidate_only_final_entity_signature_count: int | None = None - shared_final_entity_signature_count: int | None = None - baseline_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_only_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - shared_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_only_candidate_covered_signature_count: int | None = None - baseline_only_candidate_overlapping_signature_count: int | None = None - baseline_only_candidate_uncovered_signature_count: int | None = None - baseline_only_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_only_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_only_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_only_candidate_label_mismatch_signature_count: int | None = None - baseline_only_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_stable_final_entity_signature_count: int | None = None - candidate_stable_final_entity_signature_count: int | None = None - stable_final_entity_signature_count_delta: int | None = None - baseline_stable_candidate_unstable_final_entity_signature_count: int | None = None - candidate_stable_baseline_unstable_final_entity_signature_count: int | None = None - shared_stable_final_entity_signature_count: int | None = None - baseline_stable_candidate_covered_signature_count: int | None = None - baseline_stable_candidate_overlapping_signature_count: int | None = None - baseline_stable_candidate_uncovered_signature_count: int | None = None - baseline_stable_candidate_covered_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_stable_candidate_overlapping_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_stable_candidate_uncovered_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_stable_candidate_label_mismatch_signature_count: int | None = None - baseline_stable_candidate_label_mismatch_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_stable_candidate_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_stable_baseline_unstable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - shared_stable_final_entity_signature_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - flags: list[str] = Field(default_factory=list) - - -class ComparisonResult(BaseModel): - input_path: str - candidate_input_path: str | None = None - baseline_selector: str - candidate_selector: str - summary: ComparisonSummary = Field(default_factory=ComparisonSummary) - comparisons: list[ComparisonRow] = Field(default_factory=list) - - @property - def comparison_count(self) -> int: - return len(self.comparisons) - - -def read_case_analysis(path: Path) -> pd.DataFrame: - if not path.exists() or path.is_dir(): - raise ValueError(f"case analysis path is not a file: {path}") - if path.suffix == ".parquet": - table = pd.read_parquet(path) - elif path.suffix == ".csv": - table = pd.read_csv(path) - elif path.suffix == ".jsonl": - table = pd.read_json(path, lines=True) - else: - raise ValueError(f"unsupported case analysis format: {path.suffix}") - _validate_case_analysis_columns(table) - return table - - -def _validate_case_analysis_columns(table: pd.DataFrame) -> None: - required = {"workload_id", "config_id", "case_id"} - missing = sorted(required - set(table.columns)) - if missing: - raise ValueError(f"case analysis is missing required column(s): {', '.join(missing)}") - - -def compare_case_analysis( - table: pd.DataFrame, - *, - baseline_config: str | None = None, - candidate_config: str | None = None, - baseline_strategy: str | None = None, - candidate_strategy: str | None = None, -) -> list[ComparisonRow]: - return compare_case_tables( - table, - table, - baseline_config=baseline_config, - candidate_config=candidate_config, - baseline_strategy=baseline_strategy, - candidate_strategy=candidate_strategy, - ) - - -def compare_case_tables( - baseline_table: pd.DataFrame, - candidate_table: pd.DataFrame, - *, - baseline_config: str | None = None, - candidate_config: str | None = None, - baseline_strategy: str | None = None, - candidate_strategy: str | None = None, -) -> list[ComparisonRow]: - baseline = _select_rows( - baseline_table, config_id=baseline_config, strategy=baseline_strategy, selector_name="baseline" - ) - candidate = _select_rows( - candidate_table, config_id=candidate_config, strategy=candidate_strategy, selector_name="candidate" - ) - return [ - _compare_workload(workload_id, baseline, candidate) for workload_id in _common_workloads(baseline, candidate) - ] - - -def summarize_comparisons(rows: list[ComparisonRow]) -> ComparisonSummary: - return ComparisonSummary( - comparison_count=len(rows), - value_protection_verdict_counts=_verdict_counts(rows, "value_protection_verdict", list(SafetyVerdict)), - signature_parity_verdict_counts=_verdict_counts(rows, "signature_parity_verdict", list(SafetyVerdict)), - safety_verdict_counts=_verdict_counts(rows, "safety_verdict", list(SafetyVerdict)), - performance_verdict_counts=_verdict_counts(rows, "performance_verdict", list(PerformanceVerdict)), - candidate_verdict_counts=_verdict_counts(rows, "candidate_verdict", list(CandidateVerdict)), - candidate_viable_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.candidate_viable), - review_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.review), - rejected_workloads=_workloads_by_candidate_verdict(rows, CandidateVerdict.reject), - ) - - -def _verdict_counts(rows: list[ComparisonRow], field: str, values: list[StrEnum]) -> dict[str, int]: - counts = {value.value: 0 for value in values} - for row in rows: - verdict = _verdict_value(getattr(row, field)) - if verdict is not None: - counts[verdict] = counts.get(verdict, 0) + 1 - return counts - - -def _workloads_by_candidate_verdict(rows: list[ComparisonRow], verdict: CandidateVerdict) -> list[str]: - return sorted(row.workload_id for row in rows if _verdict_value(row.candidate_verdict) == verdict.value) - - -def _verdict_value(value: object) -> str | None: - if value is None: - return None - if isinstance(value, StrEnum): - return value.value - return str(value) - - -def _select_rows( - table: pd.DataFrame, - *, - config_id: str | None, - strategy: str | None, - selector_name: str, -) -> pd.DataFrame: - if (config_id is None) == (strategy is None): - raise ValueError(f"{selector_name} selector must specify exactly one of config or strategy") - column, value = ("config_id", config_id) if config_id is not None else ("experimental_detection_strategy", strategy) - if column not in table.columns: - raise ValueError(f"case analysis is missing selector column: {column}") - selected = table[table[column].astype(str) == str(value)] - if selected.empty: - raise ValueError(f"{selector_name} selector matched no rows: {column}={value}") - _validate_unique_config_per_workload(selected, selector_name=selector_name) - return selected - - -def _validate_unique_config_per_workload(rows: pd.DataFrame, *, selector_name: str) -> None: - counts = rows.groupby("workload_id")["config_id"].nunique() - ambiguous = sorted(str(index) for index, count in counts.items() if count > 1) - if ambiguous: - raise ValueError(f"{selector_name} selector matched multiple configs for workload(s): {', '.join(ambiguous)}") - - -def _common_workloads(baseline: pd.DataFrame, candidate: pd.DataFrame) -> list[str]: - baseline_workloads = set(str(value) for value in baseline["workload_id"].dropna()) - candidate_workloads = set(str(value) for value in candidate["workload_id"].dropna()) - common = sorted(baseline_workloads & candidate_workloads) - if not common: - raise ValueError("baseline and candidate selectors have no workloads in common") - return common - - -def _compare_workload(workload_id: str, baseline: pd.DataFrame, candidate: pd.DataFrame) -> ComparisonRow: - baseline_summary = _summarize_selector(baseline[baseline["workload_id"].astype(str) == workload_id]) - candidate_summary = _summarize_selector(candidate[candidate["workload_id"].astype(str) == workload_id]) - return ComparisonRow( - workload_id=workload_id, - baseline_config_id=str(baseline_summary["config_id"]), - candidate_config_id=str(candidate_summary["config_id"]), - baseline_strategy=_optional_string(baseline_summary.get("experimental_detection_strategy")), - candidate_strategy=_optional_string(candidate_summary.get("experimental_detection_strategy")), - baseline_replacement_strategy=_optional_string(baseline_summary.get("experimental_replacement_strategy")), - candidate_replacement_strategy=_optional_string(candidate_summary.get("experimental_replacement_strategy")), - baseline_case_count=int(baseline_summary["case_count"]), - candidate_case_count=int(candidate_summary["case_count"]), - **_comparison_metrics(baseline_summary, candidate_summary), - ) - - -def _summarize_selector(rows: pd.DataFrame) -> dict[str, object]: - if rows.empty: - raise ValueError("selector has no rows for workload") - case_count = int(rows["case_id"].nunique()) - summary: dict[str, object] = { - "config_id": _single_string(rows, "config_id"), - "experimental_detection_strategy": _single_string(rows, "experimental_detection_strategy"), - "experimental_replacement_strategy": _single_string(rows, "experimental_replacement_strategy"), - "case_count": case_count, - } - for column in _NUMERIC_COLUMNS: - summary[column] = _median_or_none(rows, column) - summary["replacement_missing_final_entity_count"] = _sum_or_none(rows, "replacement_missing_final_entity_count") - summary["replacement_missing_final_value_count"] = _sum_or_none(rows, "replacement_missing_final_value_count") - summary["replacement_synthetic_original_collision_count"] = _sum_or_none( - rows, - "replacement_synthetic_original_collision_count", - ) - summary["replacement_synthetic_original_collision_value_count"] = _sum_or_none( - rows, - "replacement_synthetic_original_collision_value_count", - ) - summary["original_value_leak_count"] = _sum_or_none(rows, "original_value_leak_count") - summary["original_value_leak_record_count"] = _sum_or_none(rows, "original_value_leak_record_count") - summary["original_value_leak_label_counts"] = _sum_label_counts(rows, "original_value_leak_label_counts") - summary["replacement_synthetic_original_collision_label_counts"] = _sum_label_counts( - rows, - "replacement_synthetic_original_collision_label_counts", - ) - summary["failed_case_count"] = _failed_case_count(rows) - summary["failed_case_rate"] = _rate(summary["failed_case_count"], case_count) - summary["artifact_final_entity_signature_hashes"] = _signature_hashes(rows) - summary["stable_artifact_final_entity_signature_hashes"] = _stable_signature_hashes(rows) if case_count > 1 else [] - summary["artifact_final_entity_signature_labels"] = _signature_labels(rows) - summary["artifact_final_entity_signature_details"] = _signature_details(rows) - return summary - - -def _single_string(rows: pd.DataFrame, column: str) -> str | None: - if column not in rows.columns: - return None - values = sorted({str(value) for value in rows[column].dropna()}) - return values[0] if values else None - - -_NUMERIC_COLUMNS = [ - "pipeline_elapsed_sec", - "observed_total_requests", - "observed_total_tokens", - "observed_failed_requests", - "observed_bridge_fallback_requests", - "observed_non_bridge_total_requests", - "observed_non_bridge_failed_requests", - "replacement_missing_final_entity_count", - "replacement_missing_final_value_count", - "replacement_synthetic_original_collision_count", - "replacement_synthetic_original_collision_value_count", - "final_entity_count", - "seed_validation_candidate_count", - "augmented_entity_count", - "augmented_new_final_value_count", - "artifact_final_detector_entity_count", - "artifact_final_augmenter_entity_count", -] - - -def _median_or_none(rows: pd.DataFrame, column: str) -> float | None: - if column not in rows.columns: - return None - values = pd.to_numeric(rows[column], errors="coerce").dropna() - return float(values.median()) if not values.empty else None - - -def _sum_or_none(rows: pd.DataFrame, column: str) -> float | None: - if column not in rows.columns: - return None - values = pd.to_numeric(rows[column], errors="coerce").dropna() - return float(values.sum()) if not values.empty else None - - -def _comparison_metrics(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: - metrics = _named_metric_deltas(baseline, candidate) - metrics.update(_source_counts(baseline, candidate)) - metrics.update(_entity_signature_deltas(baseline, candidate)) - metrics["flags"] = _comparison_flags( - metrics, - baseline_strategy=_optional_string(baseline.get("experimental_detection_strategy")), - candidate_strategy=_optional_string(candidate.get("experimental_detection_strategy")), - baseline_replacement_strategy=_optional_string(baseline.get("experimental_replacement_strategy")), - candidate_replacement_strategy=_optional_string(candidate.get("experimental_replacement_strategy")), - ) - metrics.update(_comparison_verdicts(metrics)) - return metrics - - -def _named_metric_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: - return { - **_metric_delta("failed_case_count", baseline, candidate), - **_metric_delta("failed_case_rate", baseline, candidate), - **_metric_delta("pipeline_elapsed_sec", baseline, candidate, include_pct=True), - **_metric_delta("observed_total_requests", baseline, candidate), - **_metric_delta("observed_total_tokens", baseline, candidate), - **_metric_delta("observed_failed_requests", baseline, candidate), - **_metric_delta("observed_bridge_fallback_requests", baseline, candidate), - **_metric_delta("observed_non_bridge_total_requests", baseline, candidate), - **_metric_delta("observed_non_bridge_failed_requests", baseline, candidate), - **_metric_delta("original_value_leak_count", baseline, candidate), - **_metric_delta("original_value_leak_record_count", baseline, candidate), - **_metric_delta("replacement_missing_final_entity_count", baseline, candidate), - **_metric_delta("replacement_missing_final_value_count", baseline, candidate), - **_metric_delta("replacement_synthetic_original_collision_count", baseline, candidate), - **_metric_delta("replacement_synthetic_original_collision_value_count", baseline, candidate), - **_metric_delta("final_entity_count", baseline, candidate), - **_metric_delta("seed_validation_candidate_count", baseline, candidate), - **_metric_delta("augmented_entity_count", baseline, candidate), - **_metric_delta("augmented_new_final_value_count", baseline, candidate), - } - - -def _metric_delta( - name: str, - baseline: dict[str, object], - candidate: dict[str, object], - *, - include_pct: bool = False, -) -> dict[str, float | None]: - base = _optional_float(baseline.get(name)) - cand = _optional_float(candidate.get(name)) - values = {f"baseline_{name}": base, f"candidate_{name}": cand, f"{name}_delta": _delta(base, cand)} - if include_pct: - values[f"{name}_delta_pct"] = _delta_pct(base, cand) - return values - - -def _source_counts(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: - return { - "baseline_detector_entity_count": _optional_float(baseline.get("artifact_final_detector_entity_count")), - "candidate_detector_entity_count": _optional_float(candidate.get("artifact_final_detector_entity_count")), - "baseline_augmenter_entity_count": _optional_float(baseline.get("artifact_final_augmenter_entity_count")), - "candidate_augmenter_entity_count": _optional_float(candidate.get("artifact_final_augmenter_entity_count")), - "baseline_original_value_leak_label_counts": _coerce_count_map( - baseline.get("original_value_leak_label_counts") - ), - "candidate_original_value_leak_label_counts": _coerce_count_map( - candidate.get("original_value_leak_label_counts") - ), - "baseline_replacement_synthetic_original_collision_label_counts": _coerce_count_map( - baseline.get("replacement_synthetic_original_collision_label_counts") - ), - "candidate_replacement_synthetic_original_collision_label_counts": _coerce_count_map( - candidate.get("replacement_synthetic_original_collision_label_counts") - ), - } - - -def _comparison_flags( - metrics: dict[str, object], - *, - baseline_strategy: str | None, - candidate_strategy: str | None, - baseline_replacement_strategy: str | None, - candidate_replacement_strategy: str | None, -) -> list[str]: - flags: list[str] = [] - _append_if_positive(flags, metrics, "baseline_failed_case_count", "baseline_case_failures") - _append_if_positive(flags, metrics, "candidate_failed_case_count", "candidate_case_failures") - _append_if_positive(flags, metrics, "baseline_original_value_leak_count", "baseline_original_value_leak") - _append_if_positive(flags, metrics, "candidate_original_value_leak_count", "candidate_original_value_leak") - _append_if_positive( - flags, - metrics, - "baseline_replacement_synthetic_original_collision_count", - "baseline_replacement_synthetic_original_collision", - ) - _append_if_positive( - flags, - metrics, - "candidate_replacement_synthetic_original_collision_count", - "candidate_replacement_synthetic_original_collision", - ) - _append_if_positive( - flags, - metrics, - "baseline_replacement_missing_final_entity_count", - "baseline_replacement_missing_final_entity", - ) - _append_if_positive( - flags, - metrics, - "candidate_replacement_missing_final_entity_count", - "candidate_replacement_missing_final_entity", - ) - _append_if_negative(flags, metrics, "final_entity_count_delta", "entity_count_loss") - _append_if_positive(flags, metrics, _signature_loss_metric(metrics), "entity_signature_loss") - _append_if_positive( - flags, - metrics, - "baseline_only_candidate_overlapping_signature_count", - "span_boundary_mismatch", - ) - _append_if_positive( - flags, - metrics, - "baseline_only_candidate_label_mismatch_signature_count", - "covered_label_mismatch", - ) - _append_if_positive( - flags, - metrics, - _stable_signature_loss_metric(metrics), - "stable_entity_signature_loss", - ) - failed_request_delta = ( - "observed_non_bridge_failed_requests_delta" - if _has_metric_pair(metrics, "observed_non_bridge_failed_requests") - else "observed_failed_requests_delta" - ) - _append_if_positive(flags, metrics, failed_request_delta, "failed_request_increase") - _append_if_positive(flags, metrics, "observed_bridge_fallback_requests_delta", "bridge_fallback_increase") - _append_if_positive(flags, metrics, "observed_total_tokens_delta", "token_increase") - _append_if_positive(flags, metrics, "observed_total_requests_delta", "request_increase") - if _candidate_lacks_detector_entities(metrics): - flags.append("no_candidate_detector_entities") - if candidate_strategy in _SKIPS_LLM_VALIDATION_STRATEGIES: - flags.append("candidate_skips_llm_validation") - if _replacement_only_detection_instability( - flags, - baseline_strategy=baseline_strategy, - candidate_strategy=candidate_strategy, - baseline_replacement_strategy=baseline_replacement_strategy, - candidate_replacement_strategy=candidate_replacement_strategy, - ): - flags.append("replacement_only_detection_instability") - return flags - - -_DETECTION_INSTABILITY_FLAGS = { - "covered_label_mismatch", - "entity_count_loss", - "entity_signature_loss", - "span_boundary_mismatch", - "stable_entity_signature_loss", -} - - -def _replacement_only_detection_instability( - flags: list[str], - *, - baseline_strategy: str | None, - candidate_strategy: str | None, - baseline_replacement_strategy: str | None, - candidate_replacement_strategy: str | None, -) -> bool: - if baseline_strategy != candidate_strategy: - return False - if not baseline_replacement_strategy or not candidate_replacement_strategy: - return False - if baseline_replacement_strategy == candidate_replacement_strategy: - return False - return bool(set(flags) & _DETECTION_INSTABILITY_FLAGS) - - -def _signature_loss_metric(metrics: dict[str, object]) -> str: - if metrics.get("baseline_only_candidate_uncovered_signature_count") is not None: - return "baseline_only_candidate_uncovered_signature_count" - return "baseline_only_final_entity_signature_count" - - -def _stable_signature_loss_metric(metrics: dict[str, object]) -> str: - if metrics.get("baseline_stable_candidate_uncovered_signature_count") is not None: - return "baseline_stable_candidate_uncovered_signature_count" - return "baseline_stable_candidate_unstable_final_entity_signature_count" - - -_SKIPS_LLM_VALIDATION_STRATEGIES = {"detector_only"} - - -def _has_metric_pair(metrics: dict[str, object], name: str) -> bool: - return ( - _optional_float(metrics.get(f"baseline_{name}")) is not None - and _optional_float(metrics.get(f"candidate_{name}")) is not None - ) - - -def _comparison_verdicts(metrics: dict[str, object]) -> dict[str, str]: - value_protection_verdict = _value_protection_verdict(metrics) - signature_parity_verdict = _signature_parity_verdict(metrics) - safety_verdict = _safety_verdict(metrics) - performance_verdict = _performance_verdict(metrics) - return { - "value_protection_verdict": value_protection_verdict.value, - "signature_parity_verdict": signature_parity_verdict.value, - "safety_verdict": safety_verdict.value, - "performance_verdict": performance_verdict.value, - "candidate_verdict": _candidate_verdict(safety_verdict, performance_verdict).value, - } - - -def _value_protection_verdict(metrics: dict[str, object]) -> SafetyVerdict: - flags = set(_coerce_flag_list(metrics.get("flags"))) - if flags & { - "candidate_case_failures", - "candidate_original_value_leak", - "candidate_replacement_missing_final_entity", - "candidate_replacement_synthetic_original_collision", - "entity_signature_loss", - "stable_entity_signature_loss", - }: - return SafetyVerdict.fail - if _entity_count_loss_without_signature_artifacts(flags, metrics): - return SafetyVerdict.fail - if flags & { - "baseline_case_failures", - "baseline_original_value_leak", - "baseline_replacement_missing_final_entity", - "baseline_replacement_synthetic_original_collision", - }: - return SafetyVerdict.review - return SafetyVerdict.passed - - -def _signature_parity_verdict(metrics: dict[str, object]) -> SafetyVerdict: - flags = set(_coerce_flag_list(metrics.get("flags"))) - if flags & {"candidate_case_failures", "entity_signature_loss", "stable_entity_signature_loss"}: - return SafetyVerdict.fail - if _entity_count_loss_without_signature_artifacts(flags, metrics): - return SafetyVerdict.fail - if flags & { - "span_boundary_mismatch", - "covered_label_mismatch", - "entity_count_loss", - "baseline_case_failures", - }: - return SafetyVerdict.review - return SafetyVerdict.passed - - -def _safety_verdict(metrics: dict[str, object]) -> SafetyVerdict: - flags = set(_coerce_flag_list(metrics.get("flags"))) - if flags & { - "candidate_case_failures", - "candidate_original_value_leak", - "candidate_replacement_missing_final_entity", - "candidate_replacement_synthetic_original_collision", - "entity_signature_loss", - "stable_entity_signature_loss", - }: - return SafetyVerdict.fail - if _entity_count_loss_without_signature_artifacts(flags, metrics): - return SafetyVerdict.fail - if flags & { - "no_candidate_detector_entities", - "candidate_skips_llm_validation", - "failed_request_increase", - "bridge_fallback_increase", - "span_boundary_mismatch", - "covered_label_mismatch", - }: - return SafetyVerdict.review - if flags & { - "baseline_case_failures", - "baseline_original_value_leak", - "baseline_replacement_missing_final_entity", - "baseline_replacement_synthetic_original_collision", - "entity_count_loss", - }: - return SafetyVerdict.review - return SafetyVerdict.passed - - -def _entity_count_loss_without_signature_artifacts(flags: set[str], metrics: dict[str, object]) -> bool: - return "entity_count_loss" in flags and metrics.get("baseline_only_final_entity_signature_count") is None - - -def _performance_verdict(metrics: dict[str, object]) -> PerformanceVerdict: - deltas = [ - _optional_float(metrics.get("pipeline_elapsed_sec_delta")), - _optional_float(metrics.get("observed_total_requests_delta")), - _optional_float(metrics.get("observed_total_tokens_delta")), - ] - known = [value for value in deltas if value is not None] - if not known: - return PerformanceVerdict.unknown - improved = any(value < 0 for value in known) - regressed = any(value > 0 for value in known) - if improved and regressed: - return PerformanceVerdict.mixed - if improved: - return PerformanceVerdict.improved - if regressed: - return PerformanceVerdict.regressed - return PerformanceVerdict.unchanged - - -def _candidate_verdict( - safety_verdict: SafetyVerdict, - performance_verdict: PerformanceVerdict, -) -> CandidateVerdict: - if safety_verdict == SafetyVerdict.fail: - return CandidateVerdict.reject - if safety_verdict == SafetyVerdict.passed and performance_verdict == PerformanceVerdict.improved: - return CandidateVerdict.candidate_viable - return CandidateVerdict.review - - -def _coerce_flag_list(value: object) -> list[str]: - if isinstance(value, list): - return [str(item) for item in value] - return [] - - -def _candidate_lacks_detector_entities(metrics: dict[str, object]) -> bool: - final_count = _optional_float(metrics.get("candidate_final_entity_count")) - if final_count is None or final_count <= 0: - return False - detector_count = _optional_float(metrics.get("candidate_detector_entity_count")) - if detector_count is not None: - return detector_count == 0 - non_detector_count = _known_non_detector_candidate_count(metrics) - return non_detector_count is not None and non_detector_count >= final_count - - -def _known_non_detector_candidate_count(metrics: dict[str, object]) -> float | None: - known_counts = [ - _optional_float(metrics.get("candidate_augmenter_entity_count")), - ] - if all(value is None for value in known_counts): - return None - return sum(value or 0.0 for value in known_counts) - - -def _entity_signature_deltas(baseline: dict[str, object], candidate: dict[str, object]) -> dict[str, object]: - baseline_signatures = set(_coerce_signature_list(baseline.get("artifact_final_entity_signature_hashes"))) - candidate_signatures = set(_coerce_signature_list(candidate.get("artifact_final_entity_signature_hashes"))) - baseline_labels = _coerce_signature_labels(baseline.get("artifact_final_entity_signature_labels")) - candidate_labels = _coerce_signature_labels(candidate.get("artifact_final_entity_signature_labels")) - baseline_details = _coerce_signature_details(baseline.get("artifact_final_entity_signature_details")) - candidate_details = _coerce_signature_details(candidate.get("artifact_final_entity_signature_details")) - baseline_signatures.update(baseline_labels) - candidate_signatures.update(candidate_labels) - if not baseline_signatures and not candidate_signatures: - return { - "baseline_only_final_entity_signature_count": None, - "candidate_only_final_entity_signature_count": None, - "shared_final_entity_signature_count": None, - "baseline_only_final_entity_signature_label_counts": {}, - "candidate_only_final_entity_signature_label_counts": {}, - "shared_final_entity_signature_label_counts": {}, - "baseline_only_candidate_covered_signature_count": None, - "baseline_only_candidate_overlapping_signature_count": None, - "baseline_only_candidate_uncovered_signature_count": None, - "baseline_only_candidate_covered_signature_label_counts": {}, - "baseline_only_candidate_overlapping_signature_label_counts": {}, - "baseline_only_candidate_uncovered_signature_label_counts": {}, - "baseline_only_candidate_label_mismatch_signature_count": None, - "baseline_only_candidate_label_mismatch_signature_label_counts": {}, - } - baseline_only = baseline_signatures - candidate_signatures - candidate_only = candidate_signatures - baseline_signatures - shared = baseline_signatures & candidate_signatures - coverage = _candidate_span_coverage( - baseline_only, - baseline_details=baseline_details, - candidate_details=candidate_details, - baseline_labels=baseline_labels, - ) - return { - "baseline_only_final_entity_signature_count": len(baseline_only), - "candidate_only_final_entity_signature_count": len(candidate_only), - "shared_final_entity_signature_count": len(shared), - "baseline_only_final_entity_signature_label_counts": _signature_label_counts( - baseline_only, - baseline_labels, - ), - "candidate_only_final_entity_signature_label_counts": _signature_label_counts( - candidate_only, - candidate_labels, - ), - "shared_final_entity_signature_label_counts": _signature_label_counts(shared, baseline_labels), - **coverage, - **_stable_entity_signature_deltas( - baseline, - candidate, - baseline_labels, - candidate_labels, - baseline_details, - candidate_details, - ), - } - - -def _stable_entity_signature_deltas( - baseline: dict[str, object], - candidate: dict[str, object], - baseline_labels: dict[str, str], - candidate_labels: dict[str, str], - baseline_details: dict[str, dict[str, object]], - candidate_details: dict[str, dict[str, object]], -) -> dict[str, object]: - baseline_case_count = _optional_float(baseline.get("case_count")) - candidate_case_count = _optional_float(candidate.get("case_count")) - if ( - baseline_case_count is None - or candidate_case_count is None - or baseline_case_count < 2 - or candidate_case_count < 2 - ): - return _empty_stable_signature_deltas() - baseline_stable = set(_coerce_signature_list(baseline.get("stable_artifact_final_entity_signature_hashes"))) - candidate_stable = set(_coerce_signature_list(candidate.get("stable_artifact_final_entity_signature_hashes"))) - if not baseline_stable and not candidate_stable: - return _empty_stable_signature_deltas() - baseline_stable_candidate_unstable = baseline_stable - candidate_stable - candidate_stable_baseline_unstable = candidate_stable - baseline_stable - shared_stable = baseline_stable & candidate_stable - stable_candidate_details = { - signature: detail for signature, detail in candidate_details.items() if signature in candidate_stable - } - coverage = _candidate_span_coverage( - baseline_stable_candidate_unstable, - baseline_details=baseline_details, - candidate_details=stable_candidate_details, - baseline_labels=baseline_labels, - prefix="baseline_stable_candidate", - ) - return { - "baseline_stable_final_entity_signature_count": len(baseline_stable), - "candidate_stable_final_entity_signature_count": len(candidate_stable), - "stable_final_entity_signature_count_delta": len(candidate_stable) - len(baseline_stable), - "baseline_stable_candidate_unstable_final_entity_signature_count": len(baseline_stable_candidate_unstable), - "candidate_stable_baseline_unstable_final_entity_signature_count": len(candidate_stable_baseline_unstable), - "shared_stable_final_entity_signature_count": len(shared_stable), - "baseline_stable_candidate_unstable_final_entity_signature_label_counts": _signature_label_counts( - baseline_stable_candidate_unstable, - baseline_labels, - ), - "candidate_stable_baseline_unstable_final_entity_signature_label_counts": _signature_label_counts( - candidate_stable_baseline_unstable, - candidate_labels, - ), - "shared_stable_final_entity_signature_label_counts": _signature_label_counts(shared_stable, baseline_labels), - **coverage, - } - - -def _candidate_span_coverage( - baseline_signatures: set[str], - *, - baseline_details: dict[str, dict[str, object]], - candidate_details: dict[str, dict[str, object]], - baseline_labels: dict[str, str], - prefix: str = "baseline_only_candidate", -) -> dict[str, object]: - contained: set[str] = set() - overlapping: set[str] = set() - label_mismatch: set[str] = set() - for signature in baseline_signatures: - match, labels_mismatch = _candidate_span_match_kind(baseline_details.get(signature), candidate_details.values()) - if match == "contained": - contained.add(signature) - elif match == "overlapping": - overlapping.add(signature) - if labels_mismatch: - label_mismatch.add(signature) - covered = contained | overlapping - uncovered = baseline_signatures - covered - return { - f"{prefix}_covered_signature_count": len(covered), - f"{prefix}_overlapping_signature_count": len(overlapping), - f"{prefix}_uncovered_signature_count": len(uncovered), - f"{prefix}_covered_signature_label_counts": _signature_label_counts(covered, baseline_labels), - f"{prefix}_overlapping_signature_label_counts": _signature_label_counts(overlapping, baseline_labels), - f"{prefix}_uncovered_signature_label_counts": _signature_label_counts(uncovered, baseline_labels), - f"{prefix}_label_mismatch_signature_count": len(label_mismatch), - f"{prefix}_label_mismatch_signature_label_counts": _signature_label_counts(label_mismatch, baseline_labels), - } - - -def _candidate_span_match_kind( - baseline_detail: dict[str, object] | None, - candidate_details: object, -) -> tuple[str | None, bool]: - baseline_row = _optional_int(_detail_value(baseline_detail, "row_index")) - baseline_start = _optional_int(_detail_value(baseline_detail, "start_position")) - baseline_end = _optional_int(_detail_value(baseline_detail, "end_position")) - baseline_label = _optional_string(_detail_value(baseline_detail, "label")) - if baseline_row is None or baseline_start is None or baseline_end is None: - return None, False - baseline_length = baseline_end - baseline_start - if baseline_length <= 0: - return None, False - first_mismatched_match: tuple[str, bool] | None = None - for candidate_detail in candidate_details: - candidate_row = _optional_int(_detail_value(candidate_detail, "row_index")) - candidate_start = _optional_int(_detail_value(candidate_detail, "start_position")) - candidate_end = _optional_int(_detail_value(candidate_detail, "end_position")) - if candidate_row != baseline_row or candidate_start is None or candidate_end is None: - continue - candidate_label = _optional_string(_detail_value(candidate_detail, "label")) - labels_mismatch = _labels_mismatch(baseline_label, candidate_label) - if candidate_start <= baseline_start and candidate_end >= baseline_end: - if not labels_mismatch: - return "contained", False - first_mismatched_match = first_mismatched_match or ("contained", True) - continue - overlap = max(0, min(baseline_end, candidate_end) - max(baseline_start, candidate_start)) - if overlap / baseline_length >= _MIN_CANDIDATE_OVERLAP_RATIO: - if not labels_mismatch: - return "overlapping", False - first_mismatched_match = first_mismatched_match or ("overlapping", True) - continue - if _is_small_boundary_gap( - baseline_start=baseline_start, - baseline_end=baseline_end, - baseline_label=baseline_label, - candidate_start=candidate_start, - candidate_end=candidate_end, - overlap=overlap, - ): - if not labels_mismatch: - return "overlapping", False - first_mismatched_match = first_mismatched_match or ("overlapping", True) - return first_mismatched_match or (None, False) - - -def _labels_mismatch(baseline_label: str | None, candidate_label: str | None) -> bool: - return baseline_label is not None and candidate_label is not None and baseline_label != candidate_label - - -def _is_small_boundary_gap( - *, - baseline_start: int, - baseline_end: int, - baseline_label: str | None, - candidate_start: int, - candidate_end: int, - overlap: int, -) -> bool: - if baseline_label not in _BOUNDARY_NORMALIZED_BASELINE_LABELS: - return False - if overlap < _MIN_CANDIDATE_BOUNDARY_OVERLAP_CHARS: - return False - omitted_left = max(0, candidate_start - baseline_start) - omitted_right = max(0, baseline_end - candidate_end) - return omitted_left + omitted_right <= _MAX_CANDIDATE_BOUNDARY_GAP_CHARS - - -def _detail_value(detail: object, key: str) -> object: - if not isinstance(detail, dict): - return None - return detail.get(key) - - -def _empty_stable_signature_deltas() -> dict[str, object]: - return { - "baseline_stable_final_entity_signature_count": None, - "candidate_stable_final_entity_signature_count": None, - "stable_final_entity_signature_count_delta": None, - "baseline_stable_candidate_unstable_final_entity_signature_count": None, - "candidate_stable_baseline_unstable_final_entity_signature_count": None, - "shared_stable_final_entity_signature_count": None, - "baseline_stable_candidate_covered_signature_count": None, - "baseline_stable_candidate_overlapping_signature_count": None, - "baseline_stable_candidate_uncovered_signature_count": None, - "baseline_stable_candidate_covered_signature_label_counts": {}, - "baseline_stable_candidate_overlapping_signature_label_counts": {}, - "baseline_stable_candidate_uncovered_signature_label_counts": {}, - "baseline_stable_candidate_label_mismatch_signature_count": None, - "baseline_stable_candidate_label_mismatch_signature_label_counts": {}, - "baseline_stable_candidate_unstable_final_entity_signature_label_counts": {}, - "candidate_stable_baseline_unstable_final_entity_signature_label_counts": {}, - "shared_stable_final_entity_signature_label_counts": {}, - } - - -def _signature_hashes(rows: pd.DataFrame) -> list[str]: - hashes: set[str] = set() - for signature_set in _signature_hash_sets(rows): - hashes.update(signature_set) - return sorted(hashes) - - -def _stable_signature_hashes(rows: pd.DataFrame) -> list[str]: - signature_sets = _signature_hash_sets(rows) - if not signature_sets: - return [] - return sorted(set.intersection(*signature_sets)) - - -def _signature_hash_sets(rows: pd.DataFrame) -> list[set[str]]: - if "artifact_final_entity_signature_hashes" not in rows.columns: - return [] - return [set(_coerce_signature_list(value)) for value in rows["artifact_final_entity_signature_hashes"].tolist()] - - -def _signature_labels(rows: pd.DataFrame) -> dict[str, str]: - labels: dict[str, str] = {} - if "artifact_final_entity_signature_labels" in rows.columns: - for value in rows["artifact_final_entity_signature_labels"].tolist(): - labels.update(_coerce_signature_labels(value)) - prefix = "artifact_final_entity_signature_labels." - for column in rows.columns: - if not column.startswith(prefix): - continue - signature_hash = column.removeprefix(prefix) - for value in rows[column].tolist(): - if not _is_missing_cell(value): - labels[signature_hash] = str(value) - return dict(sorted(labels.items())) - - -def _signature_details(rows: pd.DataFrame) -> dict[str, dict[str, object]]: - details: dict[str, dict[str, object]] = {} - if "artifact_final_entity_signature_details" in rows.columns: - for value in rows["artifact_final_entity_signature_details"].tolist(): - details.update(_coerce_signature_details(value)) - prefix = "artifact_final_entity_signature_details." - for column in rows.columns: - if not column.startswith(prefix): - continue - remainder = column.removeprefix(prefix) - signature_hash, _, field = remainder.partition(".") - if not signature_hash or not field: - continue - if field not in _SIGNATURE_DETAIL_FIELDS: - continue - for value in rows[column].tolist(): - if not _is_missing_cell(value): - details.setdefault(signature_hash, {})[field] = _json_scalar(value) - return dict(sorted(details.items())) - - -def _sum_label_counts(rows: pd.DataFrame, field: str) -> dict[str, int]: - counts: dict[str, int] = {} - if field in rows.columns: - for value in rows[field].tolist(): - _merge_count_map(counts, _coerce_count_map(value)) - prefix = f"{field}." - for column in rows.columns: - if not column.startswith(prefix): - continue - total = _sum_or_none(rows, column) - if total: - counts[column.removeprefix(prefix)] = counts.get(column.removeprefix(prefix), 0) + int(total) - return dict(sorted(counts.items())) - - -def _failed_case_count(rows: pd.DataFrame) -> int: - if "case_failed" in rows.columns: - return int(rows["case_failed"].map(_coerce_bool).sum()) - error_columns = [ - column - for column in ("error_stage_count", "error_ndd_workflow_count", "error_model_workflow_count") - if column in rows.columns - ] - if not error_columns: - return 0 - error_counts = rows[error_columns].apply(pd.to_numeric, errors="coerce").fillna(0).sum(axis=1) - return int((error_counts > 0).sum()) - - -def _coerce_bool(value: object) -> bool: - if _is_missing_cell(value): - return False - if isinstance(value, bool): - return value - if isinstance(value, int | float): - return value != 0 - text = str(value).strip().lower() - return text in {"1", "true", "t", "yes", "y"} - - -def _rate(count: object, total: object) -> float | None: - count_value = _optional_float(count) - total_value = _optional_float(total) - if count_value is None or total_value is None or total_value <= 0: - return None - return count_value / total_value - - -def _signature_label_counts(signatures: set[str], labels: dict[str, str]) -> dict[str, int]: - counts: dict[str, int] = {} - for signature_hash in signatures: - label = labels.get(signature_hash, "unknown") - counts[label] = counts.get(label, 0) + 1 - return dict(sorted(counts.items())) - - -def _coerce_signature_list(value: object) -> list[str]: - if _is_missing_cell(value): - return [] - if hasattr(value, "tolist"): - value = value.tolist() - if isinstance(value, list | tuple | set): - return [str(item) for item in value if not _is_missing_cell(item)] - if not isinstance(value, str): - return [] - text = value.strip() - if not text: - return [] - try: - parsed = json.loads(text) - except json.JSONDecodeError: - try: - parsed = ast.literal_eval(text) - except (SyntaxError, ValueError): - return [] - return _coerce_signature_list(parsed) - - -def _coerce_signature_labels(value: object) -> dict[str, str]: - if _is_missing_cell(value): - return {} - if hasattr(value, "to_dict"): - value = value.to_dict() - if isinstance(value, dict): - return {str(key): str(item) for key, item in value.items() if not _is_missing_cell(item)} - if not isinstance(value, str): - return {} - text = value.strip() - if not text: - return {} - try: - parsed = json.loads(text) - except json.JSONDecodeError: - try: - parsed = ast.literal_eval(text) - except (SyntaxError, ValueError): - return {} - return _coerce_signature_labels(parsed) - - -def _coerce_signature_details(value: object) -> dict[str, dict[str, object]]: - if _is_missing_cell(value): - return {} - if hasattr(value, "to_dict"): - value = value.to_dict() - if isinstance(value, dict): - details: dict[str, dict[str, object]] = {} - for signature_hash, raw_detail in value.items(): - if not isinstance(raw_detail, dict): - continue - details[str(signature_hash)] = { - str(key): _json_scalar(item) - for key, item in raw_detail.items() - if str(key) in _SIGNATURE_DETAIL_FIELDS and not _is_missing_cell(item) - } - return dict(sorted(details.items())) - if not isinstance(value, str): - return {} - text = value.strip() - if not text: - return {} - try: - parsed = json.loads(text) - except json.JSONDecodeError: - try: - parsed = ast.literal_eval(text) - except (SyntaxError, ValueError): - return {} - return _coerce_signature_details(parsed) - - -def _json_scalar(value: object) -> object: - if hasattr(value, "item"): - try: - return value.item() - except ValueError: - return value - return value - - -def _coerce_count_map(value: object) -> dict[str, int]: - if _is_missing_cell(value): - return {} - if hasattr(value, "to_dict"): - value = value.to_dict() - if isinstance(value, dict): - counts: dict[str, int] = {} - for key, item in value.items(): - if _is_missing_cell(item): - continue - count = _optional_float(item) - if count: - counts[str(key)] = int(count) - return dict(sorted(counts.items())) - if not isinstance(value, str): - return {} - text = value.strip() - if not text: - return {} - try: - parsed = json.loads(text) - except json.JSONDecodeError: - try: - parsed = ast.literal_eval(text) - except (SyntaxError, ValueError): - return {} - return _coerce_count_map(parsed) - - -def _merge_count_map(target: dict[str, int], update: dict[str, int]) -> None: - for key, value in update.items(): - target[key] = target.get(key, 0) + value - - -def _is_missing_cell(value: object) -> bool: - return value is None or (isinstance(value, float) and pd.isna(value)) - - -def _append_if_negative(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: - value = _optional_float(metrics.get(field)) - if value is not None and value < 0: - flags.append(flag) - - -def _append_if_positive(flags: list[str], metrics: dict[str, object], field: str, flag: str) -> None: - value = _optional_float(metrics.get(field)) - if value is not None and value > 0: - flags.append(flag) - - -def _optional_float(value: object) -> float | None: - if value is None or pd.isna(value): - return None - return float(value) - - -def _optional_int(value: object) -> int | None: - number = _optional_float(value) - return int(number) if number is not None else None - - -def _optional_string(value: object) -> str | None: - if value is None or pd.isna(value): - return None - return str(value) - - -def _delta(baseline: float | None, candidate: float | None) -> float | None: - return candidate - baseline if baseline is not None and candidate is not None else None - - -def _delta_pct(baseline: float | None, candidate: float | None) -> float | None: - if baseline is None or candidate is None or baseline == 0: - return None - return ((candidate - baseline) / baseline) * 100 - - -def write_comparisons(rows: list[ComparisonRow], output_path: Path, export_format: ExportFormat) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - table = pd.json_normalize([row.model_dump() for row in rows], sep=".") - table = _normalize_table_cells(table) - if export_format == ExportFormat.parquet: - table.to_parquet(output_path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(output_path, index=False) - else: - table.to_json(output_path, orient="records", lines=True) - - -def _normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: - normalized = table.copy() - for column in normalized.columns: - if normalized[column].map(_is_nested_cell).any(): - normalized[column] = normalized[column].map(_json_cell) - return normalized - - -def _is_nested_cell(value: object) -> bool: - return isinstance(value, dict | list) - - -def _json_cell(value: object) -> object: - if not _is_nested_cell(value): - return value - return json.dumps(value, ensure_ascii=True, sort_keys=True) - - -def render_result(result: ComparisonResult, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - lines = [ - f"Compared {result.comparison_count} workload(s): " - f"viable={result.summary.candidate_verdict_counts.get(CandidateVerdict.candidate_viable.value, 0)}, " - f"review={result.summary.candidate_verdict_counts.get(CandidateVerdict.review.value, 0)}, " - f"reject={result.summary.candidate_verdict_counts.get(CandidateVerdict.reject.value, 0)}" - ] - for row in result.comparisons: - lines.append( - f"- {row.workload_id}: verdict={row.candidate_verdict or 'unknown'} " - f"(safety={row.safety_verdict or 'unknown'}, " - f"value_protection={row.value_protection_verdict or 'unknown'}, " - f"signature_parity={row.signature_parity_verdict or 'unknown'}, " - f"performance={row.performance_verdict or 'unknown'}), " - f"replacement={row.baseline_replacement_strategy or 'unknown'}" - f"->{row.candidate_replacement_strategy or 'unknown'}, " - f"elapsed {_number_pair(row.baseline_pipeline_elapsed_sec, row.candidate_pipeline_elapsed_sec, suffix='s')}, " - f"entities {row.baseline_final_entity_count}->{row.candidate_final_entity_count}, " - f"requests {_number_pair(row.baseline_observed_total_requests, row.candidate_observed_total_requests)}, " - f"tokens {_number_pair(row.baseline_observed_total_tokens, row.candidate_observed_total_tokens)}, " - f"original_value_leaks " - f"{_number_pair(row.baseline_original_value_leak_count, row.candidate_original_value_leak_count)}, " - f"lost_labels={_label_count_summary(row.baseline_only_final_entity_signature_label_counts)}, " - "covered_label_mismatch_labels=" - f"{_label_count_summary(row.baseline_only_candidate_label_mismatch_signature_label_counts)}, " - f"unstable_lost_labels={_label_count_summary(row.baseline_stable_candidate_unstable_final_entity_signature_label_counts)}, " - f"leak_labels={_label_count_summary(row.candidate_original_value_leak_label_counts)}, " - f"flags={','.join(row.flags) if row.flags else 'none'}" - ) - return "\n".join(lines) - - -def _number_pair(baseline: float | None, candidate: float | None, *, suffix: str = "") -> str: - return f"{_format_number(baseline, suffix=suffix)}->{_format_number(candidate, suffix=suffix)}" - - -def _format_number(value: float | None, *, suffix: str = "") -> str: - if value is None: - return "unknown" - return f"{value:.1f}{suffix}" - - -def _label_count_summary(counts: dict[str, int]) -> str: - if not counts: - return "none" - return ",".join(f"{label}:{count}" for label, count in sorted(counts.items())) - - -def _selector_label(*, config: str | None, strategy: str | None) -> str: - if config is not None: - return f"config:{config}" - if strategy is not None: - return f"strategy:{strategy}" - return "" - - -@app.default -def main( - case_analysis: Path, - *, - candidate_case_analysis: Annotated[Path | None, cyclopts.Parameter("--candidate-case-analysis")] = None, - baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, - candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, - baseline_strategy: Annotated[str | None, cyclopts.Parameter("--baseline-strategy")] = None, - candidate_strategy: Annotated[str | None, cyclopts.Parameter("--candidate-strategy")] = None, - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - baseline_table = read_case_analysis(case_analysis) - candidate_table = read_case_analysis(candidate_case_analysis) if candidate_case_analysis else baseline_table - comparisons = compare_case_tables( - baseline_table, - candidate_table, - baseline_config=baseline_config, - candidate_config=candidate_config, - baseline_strategy=baseline_strategy, - candidate_strategy=candidate_strategy, - ) - except ValueError as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - result = ComparisonResult( - input_path=str(case_analysis), - candidate_input_path=str(candidate_case_analysis) if candidate_case_analysis else None, - baseline_selector=_selector_label(config=baseline_config, strategy=baseline_strategy), - candidate_selector=_selector_label(config=candidate_config, strategy=candidate_strategy), - summary=summarize_comparisons(comparisons), - comparisons=comparisons, - ) - if output is not None: - write_comparisons(comparisons, output, format) - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/dd_parser_compat.py b/tools/measurement/dd_parser_compat.py deleted file mode 100644 index d3aa441c..00000000 --- a/tools/measurement/dd_parser_compat.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Benchmark-only DataDesigner structured-output parser compatibility modes.""" - -from __future__ import annotations - -import json -from collections.abc import Callable, Iterator -from contextlib import contextmanager -from enum import StrEnum -from typing import Any - -from data_designer.engine.models.parsers.errors import ParserException -from data_designer.engine.models.recipes import response_recipes as recipes -from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate - - -class DDParserCompatMode(StrEnum): - none = "none" - raw_json = "raw_json" - - -@contextmanager -def dd_parser_compat_context(mode: DDParserCompatMode) -> Iterator[None]: - """Temporarily patch DataDesigner parsers for benchmark-only compatibility.""" - if mode == DDParserCompatMode.none: - yield - return - if mode != DDParserCompatMode.raw_json: - raise ValueError(f"unsupported DataDesigner parser compatibility mode: {mode}") - - original_pydantic = recipes.PydanticResponseRecipe._build_parser_fn - original_structured = recipes.StructuredResponseRecipe._build_parser_fn - recipes.PydanticResponseRecipe._build_parser_fn = _tolerant_pydantic_builder(original_pydantic) # type: ignore[method-assign] - recipes.StructuredResponseRecipe._build_parser_fn = _tolerant_structured_builder(original_structured) # type: ignore[method-assign] - try: - yield - finally: - recipes.PydanticResponseRecipe._build_parser_fn = original_pydantic # type: ignore[method-assign] - recipes.StructuredResponseRecipe._build_parser_fn = original_structured # type: ignore[method-assign] - - -def _tolerant_pydantic_builder(original: Callable[..., Callable[[str], Any]]) -> Callable[..., Callable[[str], Any]]: - def build_parser(self: Any) -> Callable[[str], Any]: - base_parse = original(self) - - def parse(response: str) -> Any: - try: - return base_parse(response) - except ParserException as exc: - try: - return self.data_type.model_validate(_load_embedded_json(response)) - except Exception: - raise exc - - return parse - - return build_parser - - -def _tolerant_structured_builder( - original: Callable[..., Callable[[str], dict]], -) -> Callable[..., Callable[[str], dict]]: - def build_parser(self: Any) -> Callable[[str], dict]: - base_parse = original(self) - - def parse(response: str) -> dict: - try: - return base_parse(response) - except ParserException as exc: - try: - return validate(_load_embedded_json(response), **self._validate_args) - except (json.JSONDecodeError, JSONSchemaValidationError, TypeError, ValueError): - raise exc - - return parse - - return build_parser - - -def _load_embedded_json(response: str) -> Any: - """Return the largest JSON object/array embedded in a model response.""" - decoder = json.JSONDecoder() - stripped = response.strip() - try: - return json.loads(stripped) - except json.JSONDecodeError: - pass - - best: tuple[int, int, Any] | None = None - for start, char in enumerate(response): - if char not in "{[": - continue - try: - parsed, end = decoder.raw_decode(response, start) - except json.JSONDecodeError: - continue - if not isinstance(parsed, dict | list): - continue - # Prefer the candidate that consumes the most response text. That - # selects the outer response object instead of nested item objects. - if best is None or end > best[1] or (end == best[1] and start < best[0]): - best = (start, end, parsed) - - if best is None: - raise json.JSONDecodeError("No embedded JSON object or array found", response, 0) - return best[2] diff --git a/tools/measurement/detection_strategies.py b/tools/measurement/detection_strategies.py deleted file mode 100644 index 8c45c6ab..00000000 --- a/tools/measurement/detection_strategies.py +++ /dev/null @@ -1,1695 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Experimental detection strategies for benchmark-only performance probes.""" - -from __future__ import annotations - -import json -import time -from collections import Counter -from collections.abc import Callable, Iterator -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager -from dataclasses import dataclass -from enum import StrEnum -from typing import Any - -import pandas as pd -from data_designer.config import custom_column_generator -from data_designer.config.column_configs import CustomColumnConfig, LLMTextColumnConfig -from data_designer.config.models import ModelConfig -from dd_parser_compat import _load_embedded_json -from direct_detection_probe import DirectDetectionRequest, DirectGenerationRequest, PromptMode, build_direct_prompt -from staged_detection_probe import ( - CaseStatus as StagedCaseStatus, -) -from staged_detection_probe import ( - DirectCompletion, - DirectDetectionClient, - GlinerSeedClient, - HttpxDirectDetectionClient, - SeedSource, - StagedDetectionCase, - StagedDetectionRequest, - StagedExecutionConfig, - ValidationPromptMode, - _run_augmentation_phase, - _run_validation_phase, - execute_staged_detection_case, - normalize_birth_context_final_entities, -) - -from anonymizer.config.models import DetectionModelSelection -from anonymizer.engine.constants import ( - COL_AUGMENTED_ENTITIES, - COL_DETECTED_ENTITIES, - COL_MERGED_ENTITIES, - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_SEED_VALIDATION_CANDIDATES, - COL_TAGGED_TEXT, - COL_TEXT, - COL_VALIDATED_ENTITIES, - COL_VALIDATION_DECISIONS, - _jinja, -) -from anonymizer.engine.detection import detection_workflow as dw -from anonymizer.engine.detection.chunked_validation import ChunkedValidationParams, make_chunked_validation_generator -from anonymizer.engine.detection.custom_columns import ( - apply_validation_and_finalize, - apply_validation_to_seed_entities, - enrich_validation_decisions, - merge_and_build_candidates, - parse_detected_entities, - prepare_validation_inputs, -) -from anonymizer.engine.detection.postprocess import ( - EntitySpan, - build_tagged_text, - expand_entity_occurrences, - resolve_overlaps, -) -from anonymizer.engine.ndd.adapter import FailedRecord -from anonymizer.engine.ndd.model_loader import resolve_model_alias, resolve_model_aliases -from anonymizer.engine.schemas import AugmentedEntitiesSchema, EntitiesSchema, ValidationCandidatesSchema -from anonymizer.measurement import record_model_workflow - -_NATIVE_DIRECT_MODEL_ALIAS = "native-direct" -_GLINER_DIRECT_MODEL_ALIAS = "gliner-direct" -_DIRECT_MODEL_NAME = "configured-native-model" -_DIRECT_MODEL_PROVIDER = "configured-native-provider" -_DIRECT_MAX_TOKENS = 4096 -_DIRECT_TIMEOUT_SEC = 180.0 -_DIRECT_MAX_WORKERS = 4 - - -class ExperimentalDetectionStrategy(StrEnum): - default = "default" - prose_augment_focus = "prose_augment_focus" - compact_validation = "compact_validation" - no_augment = "no_augment" - detector_only = "detector_only" - native_candidate_validate_no_augment = "native_candidate_validate_no_augment" - detector_native_validate_no_augment = "detector_native_validate_no_augment" - detector_native_validate_native_augment = "detector_native_validate_native_augment" - gliner_native_validate_no_augment = "gliner_native_validate_no_augment" - gliner_native_validate_native_augment = "gliner_native_validate_native_augment" - native_single_pass = "native_single_pass" - native_single_pass_recall = "native_single_pass_recall" - native_single_pass_values = "native_single_pass_values" - native_single_pass_values_recall = "native_single_pass_values_recall" - - -_DetectAndValidate = Callable[..., dw.EntityDetectionResult] -_AugmentPrompt = Callable[..., str] -PROSE_AUGMENT_FOCUS_TEXT = """\ -Contextual prose recall focus: -- Re-scan untagged narrative prose for organization and institution names, named facilities, labs, research centers, street or place names, self-described beliefs, occupations, titles, family member names, and other quasi-identifiers that combine with already tagged entities. -- Prefer allowed labels when present, especially organization_name, company_name, place_name, religious_belief, political_view, occupation, first_name, last_name, city, state, university, degree, field_of_study, language, race_ethnicity, and age. -- Do not add generic common nouns, section headings, or labels outside the allowed list when strict labels are required. -""" - - -@dataclass(frozen=True) -class NativeDetectionRuntime: - """Runtime settings for benchmark-only native detection strategies.""" - - endpoint: str | None = None - model: str = _DIRECT_MODEL_NAME - provider: str = _DIRECT_MODEL_PROVIDER - alias: str = _NATIVE_DIRECT_MODEL_ALIAS - max_tokens: int = _DIRECT_MAX_TOKENS - timeout_sec: float = _DIRECT_TIMEOUT_SEC - gliner_endpoint: str | None = None - gliner_model: str = _DIRECT_MODEL_NAME - gliner_provider: str = _DIRECT_MODEL_PROVIDER - gliner_alias: str = _GLINER_DIRECT_MODEL_ALIAS - gliner_api_key_env: str = "NVIDIA_API_KEY" - gliner_threshold: float = 0.3 - max_workers: int = _DIRECT_MAX_WORKERS - - -@dataclass(frozen=True) -class _NativeStagedTask: - ordinal: int - index: Any - row: pd.Series - - -@dataclass(frozen=True) -class _NativeStagedRunParams: - labels: list[str] - client: DirectDetectionClient - gliner_seed_client: GlinerSeedClient | None - runtime: NativeDetectionRuntime - data_summary: str | None - validation_prompt_mode: ValidationPromptMode - validation_max_entities_per_call: int - validation_excerpt_window_chars: int - seed_source: SeedSource - workflow_name: str - skip_augmentation: bool - - -@dataclass(frozen=True) -class _NativeStagedRowResult: - ordinal: int - index: Any - output_row: dict[str, Any] | None - failed_record: FailedRecord | None - case: StagedDetectionCase | None - - -@dataclass(frozen=True) -class _DetectorNativeValidationParams: - labels: list[str] - client: DirectDetectionClient - runtime: NativeDetectionRuntime - data_summary: str | None - validation_prompt_mode: ValidationPromptMode - validation_max_entities_per_call: int - validation_excerpt_window_chars: int - workflow_name: str - skip_augmentation: bool - - -@dataclass(frozen=True) -class _DetectorNativeValidationRowResult: - ordinal: int - index: Any - workflow_name: str - runtime: NativeDetectionRuntime - output_row: dict[str, Any] | None - failed_record: FailedRecord | None - completion: DirectCompletion | None - elapsed_sec: float - request_count: int - - -@contextmanager -def experimental_detection_strategy_context( - strategy: ExperimentalDetectionStrategy, - *, - native_client: DirectDetectionClient | None = None, - gliner_seed_client: GlinerSeedClient | None = None, - native_runtime: NativeDetectionRuntime | None = None, -) -> Iterator[None]: - """Temporarily apply a benchmark-only detection strategy.""" - if strategy == ExperimentalDetectionStrategy.default: - yield - return - - original_method = dw.EntityDetectionWorkflow.detect_and_validate_entities - original_augment_prompt = dw._get_augment_prompt - if strategy == ExperimentalDetectionStrategy.prose_augment_focus: - dw._get_augment_prompt = _make_prose_augment_prompt(original_augment_prompt) # type: ignore[assignment] - else: - dw.EntityDetectionWorkflow.detect_and_validate_entities = _method_for_strategy( # type: ignore[method-assign] - strategy, - original=original_method, - native_client=native_client, - gliner_seed_client=gliner_seed_client, - native_runtime=native_runtime or NativeDetectionRuntime(), - ) - try: - yield - finally: - dw.EntityDetectionWorkflow.detect_and_validate_entities = original_method # type: ignore[method-assign] - dw._get_augment_prompt = original_augment_prompt # type: ignore[assignment] - - -def _make_prose_augment_prompt(original: _AugmentPrompt) -> _AugmentPrompt: - def get_augment_prompt(*, data_summary: str | None, labels: list[str], strict_labels: bool = False) -> str: - prompt = original(data_summary=data_summary, labels=labels, strict_labels=strict_labels) - return prompt.replace("Rules:\n", f"{PROSE_AUGMENT_FOCUS_TEXT}\nRules:\n", 1) - - return get_augment_prompt - - -def _method_for_strategy( - strategy: ExperimentalDetectionStrategy, - *, - original: _DetectAndValidate | None = None, - native_client: DirectDetectionClient | None = None, - gliner_seed_client: GlinerSeedClient | None = None, - native_runtime: NativeDetectionRuntime | None = None, -) -> _DetectAndValidate: - runtime = native_runtime or NativeDetectionRuntime() - if strategy == ExperimentalDetectionStrategy.compact_validation: - if original is None: - raise ValueError("compact_validation requires the original detection method") - return _make_default_compact_validation_method(original) - if strategy == ExperimentalDetectionStrategy.no_augment: - return _make_validated_no_augment_method() - if strategy == ExperimentalDetectionStrategy.detector_only: - return _detect_with_detector_only - if strategy == ExperimentalDetectionStrategy.native_candidate_validate_no_augment: - return _make_native_candidate_validate_no_augment_method(native_client=native_client, native_runtime=runtime) - if strategy == ExperimentalDetectionStrategy.detector_native_validate_no_augment: - return _make_detector_native_validate_no_augment_method(native_client=native_client, native_runtime=runtime) - if strategy == ExperimentalDetectionStrategy.detector_native_validate_native_augment: - return _make_detector_native_validate_native_augment_method(native_client=native_client, native_runtime=runtime) - if strategy == ExperimentalDetectionStrategy.gliner_native_validate_no_augment: - return _make_gliner_native_validate_no_augment_method( - native_client=native_client, - gliner_seed_client=gliner_seed_client, - native_runtime=runtime, - ) - if strategy == ExperimentalDetectionStrategy.gliner_native_validate_native_augment: - return _make_gliner_native_validate_native_augment_method( - native_client=native_client, - gliner_seed_client=gliner_seed_client, - native_runtime=runtime, - ) - if strategy == ExperimentalDetectionStrategy.native_single_pass: - return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime) - if strategy == ExperimentalDetectionStrategy.native_single_pass_recall: - return _make_native_single_pass_method(native_client=native_client, native_runtime=runtime, recall_prompt=True) - if strategy == ExperimentalDetectionStrategy.native_single_pass_values: - return _make_native_single_pass_method( - native_client=native_client, - native_runtime=runtime, - value_only_prompt=True, - ) - if strategy == ExperimentalDetectionStrategy.native_single_pass_values_recall: - return _make_native_single_pass_method( - native_client=native_client, - native_runtime=runtime, - recall_prompt=True, - value_only_prompt=True, - ) - raise ValueError(f"unsupported experimental detection strategy: {strategy}") - - -def _make_default_compact_validation_method(original: _DetectAndValidate) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - return original( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=False, - entity_labels=entity_labels, - data_summary=data_summary, - preview_num_records=preview_num_records, - ) - - return detect_and_validate_entities - - -def _make_validated_no_augment_method() -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - return _run_validated_no_augment_detection( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - preview_num_records=preview_num_records, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - entity_labels=entity_labels, - data_summary=data_summary, - ) - - return detect_and_validate_entities - - -def _run_validated_no_augment_detection( - workflow: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - preview_num_records: int | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - entity_labels: list[str] | None, - data_summary: str | None, -) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - workflow_model_configs = workflow._inject_detector_params( - model_configs=model_configs, - selected_models=selected_models, - labels=labels, - gliner_detection_threshold=gliner_detection_threshold, - ) - detection_result = workflow._adapter.run_workflow( - dataframe, - model_configs=workflow_model_configs, - columns=_validated_no_augment_columns( - selected_models=selected_models, - labels=labels, - data_summary=data_summary, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - ), - workflow_name="entity-detection-no-augment", - preview_num_records=preview_num_records, - ) - return dw.EntityDetectionResult( - dataframe=detection_result.dataframe.copy(), - failed_records=detection_result.failed_records, - ) - - -def _validated_no_augment_columns( - *, - selected_models: DetectionModelSelection, - labels: list[str], - data_summary: str | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, -) -> list[LLMTextColumnConfig | CustomColumnConfig]: - validator_params = _validator_params( - selected_models=selected_models, - labels=labels, - data_summary=data_summary, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - ) - return [ - LLMTextColumnConfig( - name=COL_RAW_DETECTED, prompt=_jinja(COL_TEXT), model_alias=_detector_alias(selected_models) - ), - CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), - CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), - _validation_decisions_column(selected_models, validator_params), - CustomColumnConfig(name=COL_VALIDATED_ENTITIES, generator_function=enrich_validation_decisions), - CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=apply_validation_to_seed_entities), - CustomColumnConfig(name=COL_AUGMENTED_ENTITIES, generator_function=_empty_augmentation), - CustomColumnConfig(name=COL_MERGED_ENTITIES, generator_function=merge_and_build_candidates), - CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=apply_validation_and_finalize), - ] - - -def _validation_decisions_column( - selected_models: DetectionModelSelection, - validator_params: ChunkedValidationParams, -) -> CustomColumnConfig: - return CustomColumnConfig( - name=COL_VALIDATION_DECISIONS, - generator_function=make_chunked_validation_generator( - resolve_model_aliases("entity_validator", selected_models) - ), - generator_params=validator_params, - drop=True, - ) - - -def _validator_params( - *, - selected_models: DetectionModelSelection, - labels: list[str], - data_summary: str | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, -) -> ChunkedValidationParams: - validator_aliases = resolve_model_aliases("entity_validator", selected_models) - return ChunkedValidationParams( - pool=list(validator_aliases), - max_entities_per_call=validation_max_entities_per_call, - excerpt_window_chars=validation_excerpt_window_chars, - prompt_template=dw._get_validation_prompt(data_summary=data_summary, labels=labels), - ) - - -def _detector_alias(selected_models: DetectionModelSelection) -> str: - return resolve_model_alias("entity_detector", selected_models) - - -def _detect_with_detector_only( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, -) -> dw.EntityDetectionResult: - return _run_detector_only_detection( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - gliner_detection_threshold=gliner_detection_threshold, - entity_labels=entity_labels, - preview_num_records=preview_num_records, - workflow_name="entity-detection-detector-only", - ) - - -def _run_detector_only_detection( - workflow: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - entity_labels: list[str] | None, - preview_num_records: int | None, - workflow_name: str, -) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - workflow_model_configs = workflow._inject_detector_params( - model_configs=model_configs, - selected_models=selected_models, - labels=labels, - gliner_detection_threshold=gliner_detection_threshold, - ) - detection_result = workflow._adapter.run_workflow( - dataframe, - model_configs=workflow_model_configs, - columns=_detector_only_columns(selected_models), - workflow_name=workflow_name, - preview_num_records=preview_num_records, - ) - return dw.EntityDetectionResult( - dataframe=detection_result.dataframe.copy(), - failed_records=detection_result.failed_records, - ) - - -def _detector_only_columns(selected_models: DetectionModelSelection) -> list[LLMTextColumnConfig | CustomColumnConfig]: - return [ - LLMTextColumnConfig( - name=COL_RAW_DETECTED, - prompt=_jinja(COL_TEXT), - model_alias=_detector_alias(selected_models), - ), - CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), - CustomColumnConfig(name=COL_SEED_ENTITIES_JSON, generator_function=_copy_seed_entities_json), - CustomColumnConfig(name=COL_DETECTED_ENTITIES, generator_function=_finalize_detector_only), - ] - - -@custom_column_generator(required_columns=[COL_SEED_ENTITIES]) -def _copy_seed_entities_json(row: dict[str, Any]) -> dict[str, Any]: - row[COL_SEED_ENTITIES_JSON] = json.dumps( - [span.as_dict() for span in _entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))] - ) - return row - - -@custom_column_generator( - required_columns=[COL_TEXT, COL_SEED_ENTITIES], - side_effect_columns=[COL_TAGGED_TEXT], -) -def _finalize_detector_only(row: dict[str, Any]) -> dict[str, Any]: - text = str(row.get(COL_TEXT, "")) - spans = expand_entity_occurrences(text=text, entities=_entity_spans_from_payload(row.get(COL_SEED_ENTITIES, {}))) - row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump(mode="json") - row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) - return row - - -@custom_column_generator(required_columns=[COL_TEXT]) -def _empty_augmentation(row: dict[str, Any]) -> dict[str, Any]: - row[COL_AUGMENTED_ENTITIES] = AugmentedEntitiesSchema().model_dump(mode="json") - return row - - -def _entity_spans_from_payload(raw_payload: object) -> list[EntitySpan]: - return [ - EntitySpan( - entity_id=entity.id, - value=entity.value, - label=entity.label, - start_position=entity.start_position, - end_position=entity.end_position, - score=entity.score, - source=entity.source, - ) - for entity in EntitiesSchema.from_raw(raw_payload).entities - ] - - -def _make_native_single_pass_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, - recall_prompt: bool = False, - value_only_prompt: bool = False, -) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - client = _native_client_or_default(native_client, native_runtime) - return _run_native_single_pass_detection( - dataframe, - labels=labels, - client=client, - runtime=native_runtime, - data_summary=data_summary, - preview_num_records=preview_num_records, - recall_prompt=recall_prompt, - value_only_prompt=value_only_prompt, - ) - - return detect_and_validate_entities - - -def _run_native_single_pass_detection( - dataframe: pd.DataFrame, - *, - labels: list[str], - client: DirectDetectionClient, - runtime: NativeDetectionRuntime, - data_summary: str | None, - preview_num_records: int | None, - recall_prompt: bool, - value_only_prompt: bool, -) -> dw.EntityDetectionResult: - source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() - output_rows: list[dict[str, Any]] = [] - output_indices: list[Any] = [] - failed_records: list[FailedRecord] = [] - - for index, row in source_df.iterrows(): - output_row, failed_record = _execute_native_single_pass_row( - row, - index=index, - labels=labels, - client=client, - runtime=runtime, - data_summary=data_summary, - recall_prompt=recall_prompt, - value_only_prompt=value_only_prompt, - ) - if failed_record is not None: - failed_records.append(failed_record) - continue - if output_row is None: - failed_records.append(_native_single_pass_failed_record(index, error="native single-pass produced no row")) - continue - output_rows.append(output_row) - output_indices.append(index) - - return dw.EntityDetectionResult( - dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), - failed_records=failed_records, - ) - - -def _execute_native_single_pass_row( - row: pd.Series, - *, - index: object, - labels: list[str], - client: DirectDetectionClient, - runtime: NativeDetectionRuntime, - data_summary: str | None, - recall_prompt: bool, - value_only_prompt: bool, -) -> tuple[dict[str, Any] | None, FailedRecord | None]: - text = str(row.get(COL_TEXT, "")) - started = time.perf_counter() - try: - completion = _complete_native_single_pass( - text=text, - labels=labels, - client=client, - runtime=runtime, - data_summary=data_summary, - recall_prompt=recall_prompt, - value_only_prompt=value_only_prompt, - ) - except Exception as exc: # noqa: BLE001 - benchmark experiment records case-local failures - _record_native_single_pass_request_error(elapsed_sec=time.perf_counter() - started, runtime=runtime) - return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") - try: - spans = _native_single_pass_spans(text, completion.content, labels=labels) - except Exception as exc: # noqa: BLE001 - parser fragility is a benchmark failure mode - _record_native_single_pass_completion( - completion, - status="error", - output_row_count=0, - failed_record_count=1, - runtime=runtime, - ) - return None, _native_single_pass_failed_record(index, error=f"{type(exc).__name__}: {exc}") - _record_native_single_pass_completion( - completion, - status="completed", - output_row_count=1, - failed_record_count=0, - runtime=runtime, - ) - return _native_single_pass_result_row(row, spans=spans), None - - -def _complete_native_single_pass( - *, - text: str, - labels: list[str], - client: DirectDetectionClient, - runtime: NativeDetectionRuntime, - data_summary: str | None, - recall_prompt: bool, - value_only_prompt: bool, -) -> Any: - return client.complete( - DirectGenerationRequest( - endpoint=runtime.endpoint or "", - model=runtime.model, - prompt=_native_single_pass_prompt( - text=text, - labels=labels, - data_summary=data_summary, - recall_prompt=recall_prompt, - value_only_prompt=value_only_prompt, - ), - max_tokens=runtime.max_tokens, - timeout_sec=runtime.timeout_sec, - ) - ) - - -def _native_single_pass_prompt( - *, - text: str, - labels: list[str], - data_summary: str | None, - recall_prompt: bool = False, - value_only_prompt: bool = False, -) -> str: - if value_only_prompt: - return build_direct_prompt( - DirectDetectionRequest( - case_id="native-single-pass-values", - text=text, - labels=labels, - prompt_mode=PromptMode.recall if recall_prompt else PromptMode.compact, - data_summary=data_summary, - ) - ) - - label_text = dw._format_label_examples(labels) if recall_prompt else ", ".join(labels) - recall_block = _native_single_pass_recall_block() if recall_prompt else "" - summary = f"\nData context: {data_summary}\n" if data_summary else "" - return f"""Extract privacy-sensitive entities from the input text in one pass. -{summary} -Use only these labels: -{label_text} - -Rules: -- Return exact substrings from the input text. -- Do not invent values. -- `start` and `end` must be zero-based Python character offsets, where text[start:end] equals `value`. -- Missing a sensitive value is worse than returning one extra plausible value. -- Skip generic nouns, syntax, and non-sensitive filler. -{recall_block}- Return only a JSON object with this shape: - {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "start": 0, "end": 0, "reason": "short reason"}}]}} - -Input text: ---- -{text} ----""" - - -def _native_single_pass_recall_block() -> str: - return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. -- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. -""" - - -def _native_single_pass_spans(text: str, content: str, *, labels: list[str]) -> list[EntitySpan]: - payload = _load_embedded_json(content) - if not isinstance(payload, dict): - raise ValueError("native single-pass response must be a JSON object") - entities = payload.get("entities") - if not isinstance(entities, list): - raise ValueError("native single-pass response must contain an entities list") - spans: list[EntitySpan] = [] - allowed = set(labels) - for item in entities: - if not isinstance(item, dict): - continue - value = str(item.get("value", "")).strip() - label = str(item.get("label", "")).strip() - if not value or label not in allowed: - continue - offset_span = _native_single_pass_offset_span(text, item, value=value, label=label) - if offset_span is not None: - spans.append(offset_span) - continue - spans.extend(_native_single_pass_value_spans(text, value=value, label=label)) - return resolve_overlaps(spans) - - -def _native_single_pass_offset_span( - text: str, - item: dict[str, Any], - *, - value: str, - label: str, -) -> EntitySpan | None: - start = _coerce_native_offset(item.get("start")) - end = _coerce_native_offset(item.get("end")) - if start is None or end is None or start < 0 or end <= start or end > len(text): - return None - if text[start:end] != value: - return None - return _native_single_pass_span(value=value, label=label, start=start, end=end) - - -def _coerce_native_offset(value: object) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return value - if isinstance(value, str) and value.strip().isdigit(): - return int(value.strip()) - return None - - -def _native_single_pass_value_spans(text: str, *, value: str, label: str) -> list[EntitySpan]: - spans: list[EntitySpan] = [] - start = 0 - while True: - match_start = text.find(value, start) - if match_start < 0: - return spans - match_end = match_start + len(value) - spans.append(_native_single_pass_span(value=value, label=label, start=match_start, end=match_end)) - start = match_end - - -def _native_single_pass_span(*, value: str, label: str, start: int, end: int) -> EntitySpan: - return EntitySpan( - entity_id=f"{label}_{start}_{end}", - value=value, - label=label, - start_position=start, - end_position=end, - score=1.0, - source="direct_single_pass", - ) - - -def _native_single_pass_result_row(row: pd.Series, *, spans: list[EntitySpan]) -> dict[str, Any]: - text = str(row.get(COL_TEXT, "")) - output_row = row.to_dict() - output_row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=[span.as_dict() for span in spans]).model_dump( - mode="json" - ) - output_row[COL_TAGGED_TEXT] = build_tagged_text(text=text, entities=spans) - return output_row - - -def _native_single_pass_failed_record(index: object, *, error: str | None) -> FailedRecord: - return FailedRecord( - record_id=str(index), - step="entity-detection-native-single-pass", - reason=error or "native single-pass detection failed", - ) - - -def _record_native_single_pass_completion( - completion: Any, - *, - status: str, - output_row_count: int, - failed_record_count: int, - runtime: NativeDetectionRuntime, -) -> None: - record_model_workflow( - workflow_name="entity-detection-native-single-pass", - model_aliases=[runtime.alias], - input_row_count=1, - output_row_count=output_row_count, - failed_record_count=failed_record_count, - elapsed_sec=float(getattr(completion, "elapsed_sec", 0.0) or 0.0), - status=status, - model_usage=_native_single_pass_model_usage( - successful_requests=1, - failed_requests=0, - usage=dict(getattr(completion, "usage", {}) or {}), - runtime=runtime, - ), - ) - - -def _record_native_single_pass_request_error(*, elapsed_sec: float, runtime: NativeDetectionRuntime) -> None: - record_model_workflow( - workflow_name="entity-detection-native-single-pass", - model_aliases=[runtime.alias], - input_row_count=1, - output_row_count=0, - failed_record_count=1, - elapsed_sec=elapsed_sec, - status="error", - model_usage=_native_single_pass_model_usage( - successful_requests=0, - failed_requests=1, - usage={}, - runtime=runtime, - ), - ) - - -def _native_single_pass_model_usage( - *, - successful_requests: int, - failed_requests: int, - usage: dict[str, int], - runtime: NativeDetectionRuntime, -) -> dict[str, dict[str, Any]]: - total_requests = successful_requests + failed_requests - return { - runtime.alias: { - "model_alias": runtime.alias, - "model_name": runtime.model, - "model_provider_name": runtime.provider, - "request_usage": { - "successful_requests": successful_requests, - "failed_requests": failed_requests, - "total_requests": total_requests, - }, - "token_usage": _native_token_usage(usage), - } - } - - -def _native_client_or_default( - native_client: DirectDetectionClient | None, - runtime: NativeDetectionRuntime, -) -> DirectDetectionClient: - if native_client is not None: - return native_client - _require_native_endpoint(runtime) - return HttpxDirectDetectionClient() - - -def _require_native_endpoint(runtime: NativeDetectionRuntime) -> None: - if not runtime.endpoint or not runtime.model: - raise ValueError( - "native detection strategies require configured native endpoint and model; " - "set native_runtime.endpoint and native_runtime.model in the benchmark suite" - ) - - -def _make_native_candidate_validate_no_augment_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_native_staged_method( - native_client=native_client, - gliner_seed_client=None, - native_runtime=native_runtime, - seed_source=SeedSource.direct_llm, - workflow_name="entity-detection-native-candidate-validate-no-augment", - skip_augmentation=True, - ) - - -def _make_gliner_native_validate_no_augment_method( - *, - native_client: DirectDetectionClient | None, - gliner_seed_client: GlinerSeedClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_native_staged_method( - native_client=native_client, - gliner_seed_client=gliner_seed_client, - native_runtime=native_runtime, - seed_source=SeedSource.gliner, - workflow_name="entity-detection-gliner-native-validate-no-augment", - skip_augmentation=True, - ) - - -def _make_gliner_native_validate_native_augment_method( - *, - native_client: DirectDetectionClient | None, - gliner_seed_client: GlinerSeedClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_native_staged_method( - native_client=native_client, - gliner_seed_client=gliner_seed_client, - native_runtime=native_runtime, - seed_source=SeedSource.gliner, - workflow_name="entity-detection-gliner-native-validate-native-augment", - skip_augmentation=False, - ) - - -def _make_detector_native_validate_no_augment_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_detector_native_validate_method( - native_client=native_client, - native_runtime=native_runtime, - workflow_name="entity-detection-detector-native-validate-no-augment", - seed_workflow_name="entity-detection-detector-native-validate-no-augment-seed", - skip_augmentation=True, - ) - - -def _make_detector_native_validate_native_augment_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, -) -> _DetectAndValidate: - return _make_detector_native_validate_method( - native_client=native_client, - native_runtime=native_runtime, - workflow_name="entity-detection-detector-native-validate-native-augment", - seed_workflow_name="entity-detection-detector-native-validate-native-augment-seed", - skip_augmentation=False, - ) - - -def _make_detector_native_validate_method( - *, - native_client: DirectDetectionClient | None, - native_runtime: NativeDetectionRuntime, - workflow_name: str, - seed_workflow_name: str, - skip_augmentation: bool, -) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - labels = dw._resolve_detection_labels(entity_labels) - client = _native_client_or_default(native_client, native_runtime) - return _run_detector_native_validate_detection( - self, - dataframe, - model_configs=model_configs, - selected_models=selected_models, - labels=labels, - gliner_detection_threshold=gliner_detection_threshold, - preview_num_records=preview_num_records, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - client=client, - runtime=native_runtime, - data_summary=data_summary, - workflow_name=workflow_name, - seed_workflow_name=seed_workflow_name, - skip_augmentation=skip_augmentation, - ) - - return detect_and_validate_entities - - -def _run_detector_native_validate_detection( - workflow: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - labels: list[str], - gliner_detection_threshold: float, - preview_num_records: int | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - validation_single_chunk_full_text: bool, - client: DirectDetectionClient, - runtime: NativeDetectionRuntime, - data_summary: str | None, - workflow_name: str, - seed_workflow_name: str, - skip_augmentation: bool, -) -> dw.EntityDetectionResult: - workflow_model_configs = workflow._inject_detector_params( - model_configs=model_configs, - selected_models=selected_models, - labels=labels, - gliner_detection_threshold=gliner_detection_threshold, - ) - seed_result = workflow._adapter.run_workflow( - dataframe, - model_configs=workflow_model_configs, - columns=_detector_native_validate_seed_columns(selected_models), - workflow_name=seed_workflow_name, - preview_num_records=preview_num_records, - ) - output_rows: list[dict[str, Any]] = [] - output_indices: list[Any] = [] - failed_records = list(seed_result.failed_records) - validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) - tasks = [ - _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) - for ordinal, (index, row) in enumerate(seed_result.dataframe.iterrows()) - ] - params = _DetectorNativeValidationParams( - labels=labels, - client=client, - runtime=runtime, - data_summary=data_summary, - validation_prompt_mode=validation_prompt_mode, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - workflow_name=workflow_name, - skip_augmentation=skip_augmentation, - ) - - for result in _execute_detector_native_validate_tasks(tasks, params=params): - _record_detector_native_validation_result(result) - if result.failed_record is not None: - failed_records.append(result.failed_record) - continue - if result.output_row is None: - failed_records.append( - _native_failed_record( - result.index, - workflow_name=workflow_name, - error="native detector-seed validation produced no row", - ) - ) - continue - output_rows.append(result.output_row) - output_indices.append(result.index) - - return dw.EntityDetectionResult( - dataframe=_native_output_dataframe( - seed_result.dataframe, output_rows=output_rows, output_indices=output_indices - ), - failed_records=failed_records, - ) - - -def _execute_detector_native_validate_tasks( - tasks: list[_NativeStagedTask], - *, - params: _DetectorNativeValidationParams, -) -> list[_DetectorNativeValidationRowResult]: - if not tasks: - return [] - worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) - if worker_count == 1: - return [_execute_detector_native_validate_task(task, params=params) for task in tasks] - with ThreadPoolExecutor(max_workers=worker_count) as executor: - return list( - executor.map( - lambda task: _execute_detector_native_validate_task(task, params=params), - tasks, - ) - ) - - -def _execute_detector_native_validate_task( - task: _NativeStagedTask, - *, - params: _DetectorNativeValidationParams, -) -> _DetectorNativeValidationRowResult: - return _execute_detector_native_validate_row( - task.row, - index=task.index, - ordinal=task.ordinal, - labels=params.labels, - client=params.client, - runtime=params.runtime, - data_summary=params.data_summary, - validation_prompt_mode=params.validation_prompt_mode, - validation_max_entities_per_call=params.validation_max_entities_per_call, - validation_excerpt_window_chars=params.validation_excerpt_window_chars, - workflow_name=params.workflow_name, - skip_augmentation=params.skip_augmentation, - ) - - -def _detector_native_validate_seed_columns( - selected_models: DetectionModelSelection, -) -> list[LLMTextColumnConfig | CustomColumnConfig]: - return [ - LLMTextColumnConfig( - name=COL_RAW_DETECTED, - prompt=_jinja(COL_TEXT), - model_alias=_detector_alias(selected_models), - ), - CustomColumnConfig(name=COL_SEED_ENTITIES, generator_function=parse_detected_entities), - CustomColumnConfig(name=COL_SEED_VALIDATION_CANDIDATES, generator_function=prepare_validation_inputs), - ] - - -def _execute_detector_native_validate_row( - row: pd.Series, - *, - index: object, - ordinal: int, - labels: list[str], - client: DirectDetectionClient, - runtime: NativeDetectionRuntime, - data_summary: str | None, - validation_prompt_mode: ValidationPromptMode, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - workflow_name: str, - skip_augmentation: bool, -) -> _DetectorNativeValidationRowResult: - output_row = row.to_dict() - request = _native_staged_request(row, index=index, ordinal=ordinal, labels=labels, data_summary=data_summary) - config = StagedExecutionConfig( - endpoint=runtime.endpoint or "", - model=runtime.model, - max_tokens=runtime.max_tokens, - timeout_sec=runtime.timeout_sec, - validation_prompt_mode=validation_prompt_mode, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - skip_augmentation=skip_augmentation, - ) - request_count = _native_validation_request_count( - output_row, - validation_prompt_mode=validation_prompt_mode, - validation_max_entities_per_call=validation_max_entities_per_call, - ) - if not skip_augmentation: - request_count += 1 - started = time.perf_counter() - try: - validation_completion = _run_validation_phase(output_row, request, client, config) - augmentation_completion = _run_augmentation_phase(output_row, request, client, config) - completion = _combine_detector_native_completions([validation_completion, augmentation_completion]) - merge_and_build_candidates(output_row) - apply_validation_and_finalize(output_row) - normalize_birth_context_final_entities(output_row, allowed_labels=labels) - except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures - return _DetectorNativeValidationRowResult( - ordinal=ordinal, - index=index, - workflow_name=workflow_name, - runtime=runtime, - output_row=None, - failed_record=_native_failed_record( - index, - workflow_name=workflow_name, - error=f"{type(exc).__name__}: {exc}", - ), - completion=None, - elapsed_sec=time.perf_counter() - started, - request_count=request_count, - ) - return _DetectorNativeValidationRowResult( - ordinal=ordinal, - index=index, - workflow_name=workflow_name, - runtime=runtime, - output_row=output_row, - failed_record=None, - completion=completion, - elapsed_sec=time.perf_counter() - started, - request_count=request_count, - ) - - -def _native_validation_request_count( - row: dict[str, Any], - *, - validation_prompt_mode: ValidationPromptMode, - validation_max_entities_per_call: int, -) -> int: - candidate_count = len(ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})).candidates) - if candidate_count == 0: - return 0 - if validation_prompt_mode == ValidationPromptMode.full_text: - return 1 - return (candidate_count + validation_max_entities_per_call - 1) // validation_max_entities_per_call - - -def _combine_detector_native_completions(completions: list[DirectCompletion]) -> DirectCompletion: - return DirectCompletion( - content=completions[-1].content if completions else "", - elapsed_sec=sum(float(completion.elapsed_sec or 0.0) for completion in completions), - usage=_sum_usage_dicts([dict(completion.usage or {}) for completion in completions]), - ) - - -def _record_detector_native_validation_completion( - completion: DirectCompletion, - *, - request_count: int, - workflow_name: str, - runtime: NativeDetectionRuntime, -) -> None: - record_model_workflow( - workflow_name=workflow_name, - model_aliases=[runtime.alias], - input_row_count=1, - output_row_count=1, - failed_record_count=0, - elapsed_sec=float(completion.elapsed_sec or 0.0), - model_usage=_native_single_pass_model_usage( - successful_requests=request_count, - failed_requests=0, - usage=dict(completion.usage or {}), - runtime=runtime, - ), - ) - - -def _record_detector_native_validation_result(result: _DetectorNativeValidationRowResult) -> None: - if result.completion is not None: - _record_detector_native_validation_completion( - result.completion, - request_count=result.request_count, - workflow_name=result.workflow_name, - runtime=result.runtime, - ) - return - if result.failed_record is not None: - _record_detector_native_validation_error( - elapsed_sec=result.elapsed_sec, - request_count=result.request_count, - workflow_name=result.workflow_name, - runtime=result.runtime, - ) - - -def _record_detector_native_validation_error( - *, - elapsed_sec: float, - request_count: int, - workflow_name: str, - runtime: NativeDetectionRuntime, -) -> None: - record_model_workflow( - workflow_name=workflow_name, - model_aliases=[runtime.alias], - input_row_count=1, - output_row_count=0, - failed_record_count=1, - elapsed_sec=elapsed_sec, - status="error", - model_usage=_native_single_pass_model_usage( - successful_requests=0, - failed_requests=max(request_count, 1), - usage={}, - runtime=runtime, - ), - ) - - -def _make_native_staged_method( - *, - native_client: DirectDetectionClient | None, - gliner_seed_client: GlinerSeedClient | None, - native_runtime: NativeDetectionRuntime, - seed_source: SeedSource, - workflow_name: str, - skip_augmentation: bool, -) -> _DetectAndValidate: - def detect_and_validate_entities( - self: dw.EntityDetectionWorkflow, - dataframe: pd.DataFrame, - *, - model_configs: list[ModelConfig], - selected_models: DetectionModelSelection, - gliner_detection_threshold: float, - validation_max_entities_per_call: int = dw._DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, - validation_excerpt_window_chars: int = dw._DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, - validation_single_chunk_full_text: bool = True, - entity_labels: list[str] | None = None, - data_summary: str | None = None, - preview_num_records: int | None = None, - ) -> dw.EntityDetectionResult: - _ = self, model_configs, selected_models, gliner_detection_threshold - labels = dw._resolve_detection_labels(entity_labels) - client = _native_client_or_default(native_client, native_runtime) - return _run_native_staged_detection( - dataframe, - labels=labels, - client=client, - gliner_seed_client=gliner_seed_client, - runtime=native_runtime, - data_summary=data_summary, - preview_num_records=preview_num_records, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - validation_single_chunk_full_text=validation_single_chunk_full_text, - seed_source=seed_source, - workflow_name=workflow_name, - skip_augmentation=skip_augmentation, - ) - - return detect_and_validate_entities - - -def _run_native_staged_detection( - dataframe: pd.DataFrame, - *, - labels: list[str], - client: DirectDetectionClient, - gliner_seed_client: GlinerSeedClient | None, - runtime: NativeDetectionRuntime, - data_summary: str | None, - preview_num_records: int | None, - validation_max_entities_per_call: int, - validation_excerpt_window_chars: int, - validation_single_chunk_full_text: bool, - seed_source: SeedSource, - workflow_name: str, - skip_augmentation: bool, -) -> dw.EntityDetectionResult: - source_df = dataframe.iloc[:preview_num_records].copy() if preview_num_records is not None else dataframe.copy() - output_rows: list[dict[str, Any]] = [] - output_indices: list[Any] = [] - failed_records: list[FailedRecord] = [] - validation_prompt_mode = _native_validation_prompt_mode(validation_single_chunk_full_text) - params = _NativeStagedRunParams( - labels=labels, - client=client, - gliner_seed_client=gliner_seed_client, - runtime=runtime, - data_summary=data_summary, - validation_prompt_mode=validation_prompt_mode, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - seed_source=seed_source, - workflow_name=workflow_name, - skip_augmentation=skip_augmentation, - ) - tasks = [ - _NativeStagedTask(ordinal=ordinal, index=index, row=row.copy(deep=True)) - for ordinal, (index, row) in enumerate(source_df.iterrows()) - ] - - for result in _execute_native_staged_tasks(tasks, params=params): - if result.failed_record is not None: - failed_records.append(result.failed_record) - continue - if result.output_row is None: - failed_records.append( - _native_failed_record( - result.index, - workflow_name=workflow_name, - error="native staged detection produced no row", - ) - ) - continue - if result.case is not None: - _record_native_direct_usage(result.case, workflow_name=workflow_name, runtime=runtime) - output_rows.append(result.output_row) - output_indices.append(result.index) - - return dw.EntityDetectionResult( - dataframe=_native_output_dataframe(source_df, output_rows=output_rows, output_indices=output_indices), - failed_records=failed_records, - ) - - -def _execute_native_staged_tasks( - tasks: list[_NativeStagedTask], - *, - params: _NativeStagedRunParams, -) -> list[_NativeStagedRowResult]: - if not tasks: - return [] - worker_count = _native_staged_worker_count(len(tasks), runtime=params.runtime) - if worker_count == 1: - return [_execute_native_staged_task(task, params=params) for task in tasks] - with ThreadPoolExecutor(max_workers=worker_count) as executor: - return list(executor.map(lambda task: _execute_native_staged_task(task, params=params), tasks)) - - -def _native_staged_worker_count(task_count: int, *, runtime: NativeDetectionRuntime) -> int: - return max(1, min(task_count, runtime.max_workers)) - - -def _execute_native_staged_task( - task: _NativeStagedTask, - *, - params: _NativeStagedRunParams, -) -> _NativeStagedRowResult: - try: - request = _native_staged_request( - task.row, - index=task.index, - ordinal=task.ordinal, - labels=params.labels, - data_summary=params.data_summary, - ) - execution = execute_staged_detection_case( - request, - client=params.client, - seed_client=params.gliner_seed_client, - seed_source=params.seed_source, - endpoint=params.runtime.endpoint or "", - model=params.runtime.model, - gliner_endpoint=params.runtime.gliner_endpoint or "", - gliner_model=params.runtime.gliner_model, - gliner_api_key_env=params.runtime.gliner_api_key_env, - gliner_threshold=params.runtime.gliner_threshold, - max_tokens=params.runtime.max_tokens, - timeout_sec=params.runtime.timeout_sec, - skip_augmentation=params.skip_augmentation, - validation_prompt_mode=params.validation_prompt_mode, - validation_max_entities_per_call=params.validation_max_entities_per_call, - validation_excerpt_window_chars=params.validation_excerpt_window_chars, - ) - if execution.case.status != StagedCaseStatus.completed: - return _NativeStagedRowResult( - ordinal=task.ordinal, - index=task.index, - output_row=None, - failed_record=_native_failed_record( - task.index, - workflow_name=params.workflow_name, - error=execution.case.error, - ), - case=execution.case, - ) - return _NativeStagedRowResult( - ordinal=task.ordinal, - index=task.index, - output_row=_native_detection_result_row(task.row, execution_row=execution.row), - failed_record=None, - case=execution.case, - ) - except Exception as exc: # noqa: BLE001 - benchmark experiment records per-row failures - return _NativeStagedRowResult( - ordinal=task.ordinal, - index=task.index, - output_row=None, - failed_record=_native_failed_record( - task.index, - workflow_name=params.workflow_name, - error=f"{type(exc).__name__}: {exc}", - ), - case=None, - ) - - -def _native_validation_prompt_mode(validation_single_chunk_full_text: bool) -> ValidationPromptMode: - return ValidationPromptMode.full_text if validation_single_chunk_full_text else ValidationPromptMode.chunked_excerpt - - -def _native_failed_record(index: object, *, workflow_name: str, error: str | None) -> FailedRecord: - return FailedRecord( - record_id=str(index), - step=workflow_name, - reason=error or "native staged detection failed", - ) - - -def _record_native_direct_usage( - case: StagedDetectionCase, - *, - workflow_name: str, - runtime: NativeDetectionRuntime, -) -> None: - model_usage = _native_staged_model_usage(case, runtime=runtime) - if not model_usage: - return - record_model_workflow( - workflow_name=workflow_name, - model_aliases=sorted(model_usage), - input_row_count=1, - output_row_count=1 if case.status == StagedCaseStatus.completed else 0, - failed_record_count=0 if case.status == StagedCaseStatus.completed else 1, - elapsed_sec=case.model_elapsed_sec or case.elapsed_sec or 0.0, - model_usage=model_usage, - ) - - -def _native_staged_model_usage( - case: StagedDetectionCase, - *, - runtime: NativeDetectionRuntime, -) -> dict[str, dict[str, Any]]: - usage: dict[str, dict[str, Any]] = {} - native_requests = 0 - native_usage: list[dict[str, int]] = [] - - if case.phase_model_work.seed: - if case.seed_source == SeedSource.gliner: - usage[runtime.gliner_alias] = _direct_model_usage_entry( - alias=runtime.gliner_alias, - model_name=runtime.gliner_model, - provider_name=runtime.gliner_provider, - successful_requests=case.phase_model_requests.seed, - usage=case.phase_usage.seed, - ) - else: - native_requests += case.phase_model_requests.seed - native_usage.append(case.phase_usage.seed) - if case.phase_model_work.validation: - native_requests += case.phase_model_requests.validation - native_usage.append(case.phase_usage.validation) - if case.phase_model_work.augmentation: - native_requests += case.phase_model_requests.augmentation - native_usage.append(case.phase_usage.augmentation) - if native_requests: - usage[runtime.alias] = _direct_model_usage_entry( - alias=runtime.alias, - model_name=runtime.model, - provider_name=runtime.provider, - successful_requests=native_requests, - usage=_sum_usage_dicts(native_usage), - ) - return usage - - -def _direct_model_usage_entry( - *, - alias: str, - model_name: str, - provider_name: str, - successful_requests: int, - usage: dict[str, int], -) -> dict[str, Any]: - return { - "model_alias": alias, - "model_name": model_name, - "model_provider_name": provider_name, - "request_usage": { - "successful_requests": successful_requests, - "failed_requests": 0, - "total_requests": successful_requests, - }, - "token_usage": _native_token_usage(usage), - } - - -def _sum_usage_dicts(usages: list[dict[str, int]]) -> dict[str, int]: - totals: Counter[str] = Counter() - for usage in usages: - for key, value in usage.items(): - if isinstance(value, int): - totals[key] += value - return dict(sorted(totals.items())) - - -def _native_token_usage(usage: dict[str, int]) -> dict[str, int]: - input_tokens = usage.get("input_tokens", usage.get("prompt_tokens", 0)) - output_tokens = usage.get("output_tokens", usage.get("completion_tokens", 0)) - total_tokens = usage.get("total_tokens", input_tokens + output_tokens) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": total_tokens, - } - - -def _native_staged_request( - row: pd.Series, - *, - index: object, - ordinal: int, - labels: list[str], - data_summary: str | None, -) -> StagedDetectionRequest: - return StagedDetectionRequest( - case_id=f"native-staged-{ordinal}", - text=str(row.get(COL_TEXT, "")), - labels=labels, - row_index=_safe_row_index(index, fallback=ordinal), - data_summary=data_summary, - ) - - -def _native_detection_result_row(row: pd.Series, *, execution_row: dict[str, Any]) -> dict[str, Any]: - output_row = row.to_dict() - output_row[COL_DETECTED_ENTITIES] = execution_row.get( - COL_DETECTED_ENTITIES, - EntitiesSchema().model_dump(mode="json"), - ) - output_row[COL_TAGGED_TEXT] = execution_row.get(COL_TAGGED_TEXT, str(row.get(COL_TEXT, ""))) - return output_row - - -def _safe_row_index(index: object, *, fallback: int) -> int: - try: - return int(index) # type: ignore[arg-type] - except (TypeError, ValueError): - return fallback - - -def _native_output_dataframe( - source_df: pd.DataFrame, - *, - output_rows: list[dict[str, Any]], - output_indices: list[Any], -) -> pd.DataFrame: - if output_rows: - return pd.DataFrame(output_rows, index=output_indices) - output = source_df.iloc[0:0].copy() - output[COL_DETECTED_ENTITIES] = pd.Series(dtype="object") - output[COL_TAGGED_TEXT] = pd.Series(dtype="object") - return output diff --git a/tools/measurement/direct_detection_probe.py b/tools/measurement/direct_detection_probe.py deleted file mode 100644 index 05f93e0d..00000000 --- a/tools/measurement/direct_detection_probe.py +++ /dev/null @@ -1,555 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Run benchmark-only DD-free direct entity extraction probes. - -Usage: - uv run python tools/measurement/direct_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column biography --labels age,city,first_name,last_name,occupation \ - --output /tmp/direct-probe --overwrite -""" - -from __future__ import annotations - -import json -import logging -import os -import shutil -import sys -from collections import Counter -from enum import StrEnum -from pathlib import Path -from typing import Annotated, Any, Protocol - -import cyclopts -import httpx -import pandas as pd -from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities -from dd_parser_compat import _load_embedded_json -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from pydantic import BaseModel, Field, ValidationError, model_validator - -from anonymizer.engine.detection.detection_workflow import _format_label_examples -from anonymizer.engine.detection.postprocess import EntitySpan, apply_augmented_entities, expand_entity_occurrences -from anonymizer.engine.schemas import EntitySchema - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.direct_detection_probe") - -_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" -_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" -_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" -_UNCONFIGURED_MODEL = "configured-native-model" - - -class CaseStatus(StrEnum): - completed = "completed" - error = "error" - - -class PromptMode(StrEnum): - compact = "compact" - recall = "recall" - - -class DirectCompletion(BaseModel): - content: str - elapsed_sec: float - usage: dict[str, Any] = Field(default_factory=dict) - - -class DirectDetectionRequest(BaseModel): - case_id: str - text: str - labels: list[str] = Field(min_length=1) - row_index: int = 0 - prompt_mode: PromptMode = PromptMode.compact - data_summary: str | None = None - - @model_validator(mode="after") - def normalize_labels(self) -> DirectDetectionRequest: - self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) - if not self.labels: - raise ValueError("labels must contain at least one non-empty label") - return self - - -class DirectGenerationRequest(BaseModel): - endpoint: str - model: str - prompt: str - max_tokens: int = Field(gt=0) - temperature: float = 0.0 - top_p: float = 1.0 - timeout_sec: float = Field(gt=0) - json_mode: bool = True - disable_thinking: bool = True - - -class SignatureComparison(BaseModel): - baseline_final_entity_signature_count: int - shared_final_entity_signature_count: int - baseline_only_final_entity_signature_count: int - direct_only_final_entity_signature_count: int - baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) - direct_only_label_counts: dict[str, int] = Field(default_factory=dict) - - -class DirectDetectionCase(BaseModel): - case_id: str - row_index: int - status: CaseStatus - elapsed_sec: float | None = None - usage: dict[str, Any] = Field(default_factory=dict) - raw_suggestion_count: int = 0 - allowed_suggestion_count: int = 0 - final_entity_count: int = 0 - final_entity_signature_count: int = 0 - final_entity_signature_hashes: list[str] = Field(default_factory=list) - final_label_counts: dict[str, int] = Field(default_factory=dict) - comparison: SignatureComparison | None = None - artifact: DetectionArtifactRow | None = None - error: str | None = None - - -class DirectDetectionRun(BaseModel): - input_path: str - text_column: str - endpoint: str - model: str - prompt_mode: PromptMode - labels: list[str] - rows: list[DirectDetectionCase] = Field(default_factory=list) - - @property - def error_count(self) -> int: - return sum(1 for row in self.rows if row.status == CaseStatus.error) - - -class DirectDetectionClient(Protocol): - def complete(self, request: DirectGenerationRequest) -> DirectCompletion: ... - - -class HttpxDirectDetectionClient: - def complete(self, request: DirectGenerationRequest) -> DirectCompletion: - payload = build_chat_payload(request) - response = httpx.post( - f"{request.endpoint.rstrip('/')}/chat/completions", - json=payload, - timeout=request.timeout_sec, - ) - response.raise_for_status() - data = response.json() - return DirectCompletion( - content=str(data["choices"][0]["message"].get("content") or ""), - elapsed_sec=float(response.elapsed.total_seconds()), - usage=data.get("usage") or {}, - ) - - -def build_chat_payload(request: DirectGenerationRequest) -> dict[str, Any]: - payload: dict[str, Any] = { - "model": request.model, - "messages": [ - {"role": "system", "content": "You are a precise entity extraction engine. Return JSON only."}, - {"role": "user", "content": request.prompt}, - ], - "temperature": request.temperature, - "top_p": request.top_p, - "max_tokens": request.max_tokens, - } - if request.json_mode: - payload["response_format"] = {"type": "json_object"} - if request.disable_thinking: - payload["chat_template_kwargs"] = {"enable_thinking": False} - return payload - - -def build_direct_prompt(request: DirectDetectionRequest) -> str: - label_text = ( - _format_label_examples(request.labels) - if request.prompt_mode == PromptMode.recall - else ", ".join(request.labels) - ) - recall_block = _recall_block() if request.prompt_mode == PromptMode.recall else "" - summary = f"\nData context: {request.data_summary}\n" if request.data_summary else "" - return f"""Extract privacy-sensitive entities from the input text. -{summary} -Use only these labels: -{label_text} - -Rules: -- Return exact substrings from the input text. -- Do not invent values. -- Prefer specific labels from the allowed list. -- Skip generic nouns, syntax, and non-sensitive filler. -{recall_block}- Return only a JSON object with this shape: - {{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} - -Input text: ---- -{request.text} ----""" - - -def _recall_block() -> str: - return """- Bias toward high recall. Missing a sensitive value is worse than returning one extra plausible value. -- Include family members, colleagues, employers, schools, institutions, locations, dates, demographics, beliefs, and identifiers when allowed. -""" - - -def run_direct_detection_case( - request: DirectDetectionRequest, - *, - client: DirectDetectionClient, - endpoint: str | None = None, - model: str | None = None, - max_tokens: int = 4096, - timeout_sec: float = 180.0, -) -> DirectDetectionCase: - endpoint = endpoint or _UNCONFIGURED_ENDPOINT - model = model or _UNCONFIGURED_MODEL - try: - completion = client.complete( - DirectGenerationRequest( - endpoint=endpoint, - model=model, - prompt=build_direct_prompt(request), - max_tokens=max_tokens, - timeout_sec=timeout_sec, - ) - ) - suggestions = _extract_suggestions(completion.content) - allowed_suggestions = filter_direct_suggestions(suggestions, request.labels) - artifact = finalize_direct_suggestions( - text=request.text, - suggestions=allowed_suggestions, - labels=request.labels, - row_index=request.row_index, - workflow_name="direct-detection", - ) - return _completed_case(request, completion, suggestions, allowed_suggestions, artifact) - except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures - return DirectDetectionCase( - case_id=request.case_id, - row_index=request.row_index, - status=CaseStatus.error, - error=f"{type(exc).__name__}: {exc}", - ) - - -def _extract_suggestions(content: str) -> list[dict[str, Any]]: - payload = _load_embedded_json(content) - if not isinstance(payload, dict): - return [] - suggestions = payload.get("entities") - return suggestions if isinstance(suggestions, list) else [] - - -def finalize_direct_suggestions( - *, - text: str, - suggestions: list[dict[str, Any]], - labels: list[str], - row_index: int, - workflow_name: str, -) -> DetectionArtifactRow: - cleaned = filter_direct_suggestions(suggestions, labels) - direct_spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) - entities = [ - _span_to_entity_schema(span, source="direct_llm") for span in expand_entity_occurrences(text, direct_spans) - ] - return build_detection_artifact_row_from_entities( - workflow_name=workflow_name, - batch_file="direct-detection", - row_index=row_index, - seed_entities=[], - seed_validation_candidate_count=0, - merged_validation_candidate_count=0, - augmented_entities=entities, - final_entities=entities, - ) - - -def filter_direct_suggestions(suggestions: list[dict[str, Any]], labels: list[str]) -> list[dict[str, str]]: - allowed = set(labels) - cleaned = [ - {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} - for item in suggestions - if isinstance(item, dict) - ] - return [item for item in cleaned if item["value"] and item["label"] in allowed] - - -def _span_to_entity_schema(span: EntitySpan, *, source: str) -> EntitySchema: - span_source = source if span.source == "augmenter" else span.source - return EntitySchema( - value=span.value, - label=span.label, - start_position=span.start_position, - end_position=span.end_position, - score=span.score, - source=span_source, - ) - - -def _completed_case( - request: DirectDetectionRequest, - completion: DirectCompletion, - suggestions: list[dict[str, Any]], - allowed_suggestions: list[dict[str, str]], - artifact: DetectionArtifactRow, -) -> DirectDetectionCase: - return DirectDetectionCase( - case_id=request.case_id, - row_index=request.row_index, - status=CaseStatus.completed, - elapsed_sec=completion.elapsed_sec, - usage=completion.usage, - raw_suggestion_count=len(suggestions), - allowed_suggestion_count=len(allowed_suggestions), - final_entity_count=artifact.final_entity_count, - final_entity_signature_count=artifact.final_entity_signature_count, - final_entity_signature_hashes=artifact.final_entity_signature_hashes, - final_label_counts=artifact.final_label_counts, - artifact=artifact, - ) - - -def compare_signature_sets( - *, - baseline_hashes: set[str], - baseline_labels: dict[str, str], - direct_hashes: set[str], - direct_labels: dict[str, str], -) -> SignatureComparison: - baseline_only = baseline_hashes - direct_hashes - direct_only = direct_hashes - baseline_hashes - return SignatureComparison( - baseline_final_entity_signature_count=len(baseline_hashes), - shared_final_entity_signature_count=len(baseline_hashes & direct_hashes), - baseline_only_final_entity_signature_count=len(baseline_only), - direct_only_final_entity_signature_count=len(direct_only), - baseline_only_label_counts=_label_counts(baseline_only, baseline_labels), - direct_only_label_counts=_label_counts(direct_only, direct_labels), - ) - - -def _label_counts(hashes: set[str], labels: dict[str, str]) -> dict[str, int]: - return dict(sorted(Counter(labels.get(item, "unknown") for item in hashes).items())) - - -def apply_baseline_comparisons( - cases: list[DirectDetectionCase], - baseline_artifacts: Path, -) -> list[DirectDetectionCase]: - baseline = _read_baseline_artifacts(baseline_artifacts) - compared: list[DirectDetectionCase] = [] - for case in cases: - if case.status != CaseStatus.completed or case.artifact is None: - compared.append(case) - continue - baseline_row = baseline.get(case.row_index) - if baseline_row is None: - compared.append(case) - continue - compared.append(_case_with_comparison(case, baseline_row)) - return compared - - -def _case_with_comparison(case: DirectDetectionCase, baseline_row: dict[str, Any]) -> DirectDetectionCase: - baseline_hashes = _baseline_signature_hashes(baseline_row) - if baseline_hashes is None: - return case - comparison = compare_signature_sets( - baseline_hashes=baseline_hashes, - baseline_labels=_signature_labels(baseline_row), - direct_hashes=set(case.artifact.final_entity_signature_hashes if case.artifact else []), - direct_labels=case.artifact.final_entity_signature_labels if case.artifact else {}, - ) - return case.model_copy(update={"comparison": comparison}) - - -def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: - hashes = row.get("final_entity_signature_hashes") - if not isinstance(hashes, list): - return None - return {str(item) for item in hashes} - - -def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: - baseline: dict[int, dict[str, Any]] = {} - with path.open(encoding="utf-8") as source: - for line in source: - if not line.strip(): - continue - row = json.loads(line) - row_index = int(row.get("row_index", 0)) - if row_index in baseline: - raise ValueError( - f"baseline artifacts has multiple rows for row_index={row_index}; " - "pass a per-case sidecar or pre-filter the artifact file" - ) - baseline[row_index] = row - return baseline - - -def _signature_labels(row: dict[str, Any]) -> dict[str, str]: - return { - key.removeprefix("final_entity_signature_labels."): str(value) - for key, value in row.items() - if key.startswith("final_entity_signature_labels.") and value is not None - } - - -def run_probe( - input_path: Path, - *, - text_column: str, - labels: list[str], - output: Path | None = None, - overwrite: bool = False, - endpoint: str | None = None, - model: str | None = None, - limit: int = 1, - offset: int = 0, - prompt_mode: PromptMode = PromptMode.compact, - baseline_artifacts: Path | None = None, -) -> DirectDetectionRun: - endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) - model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) - requests = _load_requests( - input_path, text_column=text_column, labels=labels, limit=limit, offset=offset, prompt_mode=prompt_mode - ) - client = HttpxDirectDetectionClient() - cases = [run_direct_detection_case(request, client=client, endpoint=endpoint, model=model) for request in requests] - if baseline_artifacts is not None: - cases = apply_baseline_comparisons(cases, baseline_artifacts) - result = DirectDetectionRun( - input_path=str(input_path), - text_column=text_column, - endpoint=endpoint, - model=model, - prompt_mode=prompt_mode, - labels=labels, - rows=cases, - ) - if output is not None: - write_outputs(result, output, overwrite=overwrite) - return result - - -def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: - value = explicit or os.environ.get(env_var) - if not value: - raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") - return value - - -def _load_requests( - input_path: Path, - *, - text_column: str, - labels: list[str], - limit: int, - offset: int, - prompt_mode: PromptMode, -) -> list[DirectDetectionRequest]: - dataframe = pd.read_csv(input_path) - if text_column not in dataframe.columns: - raise ValueError(f"text column {text_column!r} not found in {input_path}") - selected = dataframe.iloc[offset : offset + limit] - return [ - DirectDetectionRequest( - case_id=f"{input_path.stem}-row-{int(index)}", - text=str(row[text_column]), - labels=labels, - row_index=int(index), - prompt_mode=prompt_mode, - ) - for index, row in selected.iterrows() - ] - - -def write_outputs(result: DirectDetectionRun, output_dir: Path, *, overwrite: bool) -> None: - if output_dir.exists(): - if not overwrite: - raise ValueError(f"output directory already exists: {output_dir}") - shutil.rmtree(output_dir) - output_dir.mkdir(parents=True) - _write_jsonl(output_dir / "direct-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) - _write_jsonl(output_dir / "direct-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) - (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") - - -def _case_payload(case: DirectDetectionCase) -> dict[str, Any]: - payload = case.model_dump(exclude={"artifact"}) - payload["record_type"] = "direct_detection_case" - return payload - - -def _artifact_payload(case: DirectDetectionCase) -> dict[str, Any]: - payload = case.artifact.model_dump() if case.artifact is not None else {} - payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "direct_detection_artifact"}) - return payload - - -def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: - with path.open("w", encoding="utf-8") as target: - for row in rows: - target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") - - -def render_result(result: DirectDetectionRun, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - completed = len(result.rows) - result.error_count - return f"Ran {completed}/{len(result.rows)} direct detection case(s); errors={result.error_count}" - - -def parse_labels(raw: str) -> list[str]: - return [label.strip() for label in raw.split(",") if label.strip()] - - -@app.default -def main( - input_path: Path, - *, - text_column: Annotated[str, cyclopts.Parameter("--text-column")], - labels: Annotated[str, cyclopts.Parameter("--labels")], - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, - endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, - model: Annotated[str | None, cyclopts.Parameter("--model")] = None, - limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, - offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, - prompt_mode: Annotated[PromptMode, cyclopts.Parameter("--prompt-mode")] = PromptMode.compact, - baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - result = run_probe( - input_path, - text_column=text_column, - labels=parse_labels(labels), - output=output, - overwrite=overwrite, - endpoint=endpoint, - model=model, - limit=limit, - offset=offset, - prompt_mode=prompt_mode, - baseline_artifacts=baseline_artifacts, - ) - except (ValueError, ValidationError, httpx.HTTPError) as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - if result.error_count: - raise SystemExit(1) - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/extract_signature_deltas.py b/tools/measurement/extract_signature_deltas.py deleted file mode 100644 index a676255f..00000000 --- a/tools/measurement/extract_signature_deltas.py +++ /dev/null @@ -1,513 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Extract masked context for entity signature differences between benchmark artifacts. - -Usage: - uv run python tools/measurement/extract_signature_deltas.py \ - baseline/detection-artifacts.jsonl candidate/detection-artifacts.jsonl \ - --baseline-artifact-root baseline/artifacts --candidate-artifact-root candidate/artifacts \ - --baseline-config legal-default --candidate-config legal-rules-guardrail \ - --workload legal-r2 --output deltas.csv --format csv -""" - -from __future__ import annotations - -import json -import logging -import sys -from enum import StrEnum -from pathlib import Path -from typing import Annotated - -import cyclopts -import pandas as pd -from analyze_detection_artifacts import _entity_signature_hash -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from measurement_tools.tables import ExportFormat -from pydantic import BaseModel, Field, ValidationError - -from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_TEXT -from anonymizer.engine.schemas import EntitiesSchema, EntitySchema - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.signature_deltas") - - -class DeltaSide(StrEnum): - baseline_only = "baseline_only" - candidate_only = "candidate_only" - - -class ContextResolution(StrEnum): - parquet = "parquet" - artifact_details = "artifact_details" - metadata_only = "metadata_only" - - -_SIGNATURE_DETAIL_FIELDS = { - "label", - "source", - "row_index", - "start_position", - "end_position", - "value_length", -} - - -class SignatureDeltaRow(BaseModel): - workload_id: str - row_index: int - side: DeltaSide - config_id: str - signature_hash: str - label: str | None = None - source: str | None = None - start_position: int | None = None - end_position: int | None = None - value_length: int | None = None - masked_context: str | None = None - resolution: ContextResolution = ContextResolution.metadata_only - batch_file: str | None = None - - -class SignatureDeltaResult(BaseModel): - baseline_artifacts: str - candidate_artifacts: str - workload_id: str | None = None - baseline_config: str | None = None - candidate_config: str | None = None - delta_count: int - rows: list[SignatureDeltaRow] = Field(default_factory=list) - - -class _ArtifactSide(BaseModel): - artifacts_path: str - artifact_root: str - config_id: str | None = None - - -def extract_signature_deltas( - baseline_artifacts: Path, - candidate_artifacts: Path, - *, - baseline_artifact_root: Path, - candidate_artifact_root: Path, - baseline_config: str | None = None, - candidate_config: str | None = None, - workload: str | None = None, - context_window: int = 40, -) -> SignatureDeltaResult: - baseline = _select_artifact_rows( - _read_artifact_rows(baseline_artifacts), config_id=baseline_config, workload=workload - ) - candidate = _select_artifact_rows( - _read_artifact_rows(candidate_artifacts), - config_id=candidate_config, - workload=workload, - ) - rows = _compare_artifact_rows( - baseline, - candidate, - baseline_side=_artifact_side(baseline_artifacts, baseline_artifact_root, baseline_config), - candidate_side=_artifact_side(candidate_artifacts, candidate_artifact_root, candidate_config), - context_window=context_window, - ) - return SignatureDeltaResult( - baseline_artifacts=str(baseline_artifacts), - candidate_artifacts=str(candidate_artifacts), - workload_id=workload, - baseline_config=baseline_config, - candidate_config=candidate_config, - delta_count=len(rows), - rows=rows, - ) - - -def _artifact_side(artifacts_path: Path, artifact_root: Path, config_id: str | None) -> _ArtifactSide: - return _ArtifactSide( - artifacts_path=str(artifacts_path), - artifact_root=str(artifact_root), - config_id=config_id, - ) - - -def _read_artifact_rows(path: Path) -> list[dict[str, object]]: - if not path.exists() or path.is_dir(): - raise ValueError(f"detection artifact path is not a file: {path}") - with path.open(encoding="utf-8") as source: - return [json.loads(line) for line in source if line.strip()] - - -def _select_artifact_rows( - rows: list[dict[str, object]], - *, - config_id: str | None, - workload: str | None, -) -> dict[tuple[str, int], dict[str, object]]: - selected = [_row for _row in rows if _matches(_row, config_id=config_id, workload=workload)] - if not selected: - raise ValueError("artifact selector matched no rows") - return {_artifact_key(row): row for row in selected} - - -def _matches(row: dict[str, object], *, config_id: str | None, workload: str | None) -> bool: - if config_id is not None and str(row.get("config_id")) != config_id: - return False - if workload is not None and str(row.get("workload_id")) != workload: - return False - return True - - -def _artifact_key(row: dict[str, object]) -> tuple[str, int]: - return str(row.get("workload_id")), int(row.get("row_index", 0)) - - -def _compare_artifact_rows( - baseline: dict[tuple[str, int], dict[str, object]], - candidate: dict[tuple[str, int], dict[str, object]], - *, - baseline_side: _ArtifactSide, - candidate_side: _ArtifactSide, - context_window: int, -) -> list[SignatureDeltaRow]: - rows: list[SignatureDeltaRow] = [] - for key in sorted(set(baseline) & set(candidate)): - rows.extend( - _row_signature_deltas(key, baseline[key], candidate[key], baseline_side, candidate_side, context_window) - ) - return rows - - -def _row_signature_deltas( - key: tuple[str, int], - baseline_row: dict[str, object], - candidate_row: dict[str, object], - baseline_side: _ArtifactSide, - candidate_side: _ArtifactSide, - context_window: int, -) -> list[SignatureDeltaRow]: - baseline_signatures = _signature_set(baseline_row) - candidate_signatures = _signature_set(candidate_row) - return [ - *_delta_rows( - key, - baseline_row, - baseline_signatures - candidate_signatures, - DeltaSide.baseline_only, - baseline_side, - context_window, - ), - *_delta_rows( - key, - candidate_row, - candidate_signatures - baseline_signatures, - DeltaSide.candidate_only, - candidate_side, - context_window, - ), - ] - - -def _signature_set(row: dict[str, object]) -> set[str]: - raw = row.get("final_entity_signature_hashes", []) - if isinstance(raw, str): - raw = json.loads(raw) - return {str(item) for item in raw} if isinstance(raw, list) else set() - - -def _delta_rows( - key: tuple[str, int], - artifact_row: dict[str, object], - signatures: set[str], - side: DeltaSide, - artifact_side: _ArtifactSide, - context_window: int, -) -> list[SignatureDeltaRow]: - labels = _signature_labels(artifact_row) - return [ - _contextual_delta_row(key, artifact_row, signature, labels.get(signature), side, artifact_side, context_window) - for signature in sorted(signatures) - ] - - -def _signature_labels(row: dict[str, object]) -> dict[str, str]: - labels = row.get("final_entity_signature_labels", {}) - if isinstance(labels, str): - labels = json.loads(labels) - flattened = { - key.removeprefix("final_entity_signature_labels."): str(value) - for key, value in row.items() - if key.startswith("final_entity_signature_labels.") - } - return {**(labels if isinstance(labels, dict) else {}), **flattened} - - -def _contextual_delta_row( - key: tuple[str, int], - artifact_row: dict[str, object], - signature: str, - label: str | None, - side: DeltaSide, - artifact_side: _ArtifactSide, - context_window: int, -) -> SignatureDeltaRow: - context = _resolve_signature_context( - artifact_row, signature, label, Path(artifact_side.artifact_root), context_window - ) - return SignatureDeltaRow( - workload_id=key[0], - row_index=key[1], - side=side, - config_id=str(artifact_row.get("config_id") or artifact_side.config_id or ""), - signature_hash=signature, - label=label, - batch_file=_optional_string(artifact_row.get("batch_file")), - **context, - ) - - -def _resolve_signature_context( - artifact_row: dict[str, object], - signature: str, - label: str | None, - artifact_root: Path, - context_window: int, -) -> dict[str, object]: - parquet_context = _parquet_entity_context(artifact_row, signature, artifact_root, context_window) - if parquet_context is not None: - return parquet_context - detail_context = _artifact_detail_context(artifact_row, signature, label, artifact_root, context_window) - return detail_context or {"resolution": ContextResolution.metadata_only} - - -def _parquet_entity_context( - artifact_row: dict[str, object], - signature: str, - artifact_root: Path, - context_window: int, -) -> dict[str, object] | None: - record = _artifact_record(artifact_row, artifact_root) - if record is None: - return None - text, row_index, row = record - for entity in EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES)).entities: - if _entity_signature_hash(entity, row_index=row_index) == signature: - return _entity_context(entity, text, signature, context_window, ContextResolution.parquet) - return None - - -def _artifact_detail_context( - artifact_row: dict[str, object], - signature: str, - label: str | None, - artifact_root: Path, - context_window: int, -) -> dict[str, object] | None: - details = _signature_details(artifact_row).get(signature) - if details is None: - return None - start_position = _optional_int(details.get("start_position")) - end_position = _optional_int(details.get("end_position")) - resolved_label = _optional_string(details.get("label")) or label - if start_position is None or end_position is None or resolved_label is None: - return _metadata_context_from_details(details) - record = _artifact_record(artifact_row, artifact_root) - masked_context = None - if record is not None: - text, _row_index, _row = record - masked_context = _masked_context_from_details( - text, - label=resolved_label, - signature_hash=signature, - start_position=start_position, - end_position=end_position, - window=context_window, - ) - return { - "source": _optional_string(details.get("source")), - "start_position": start_position, - "end_position": end_position, - "value_length": _optional_int(details.get("value_length")), - "masked_context": masked_context, - "resolution": ContextResolution.artifact_details - if masked_context is not None - else ContextResolution.metadata_only, - } - - -def _signature_details(row: dict[str, object]) -> dict[str, dict[str, object]]: - details = _coerce_detail_map(row.get("final_entity_signature_details", {})) - prefix = "final_entity_signature_details." - for key, value in row.items(): - if not key.startswith(prefix): - continue - remainder = key.removeprefix(prefix) - signature_hash, _, field = remainder.partition(".") - if not signature_hash or not field: - continue - if field not in _SIGNATURE_DETAIL_FIELDS: - continue - details.setdefault(signature_hash, {})[field] = value - return details - - -def _coerce_detail_map(raw: object) -> dict[str, dict[str, object]]: - if isinstance(raw, str): - try: - raw = json.loads(raw) - except json.JSONDecodeError: - return {} - if not isinstance(raw, dict): - return {} - return { - str(key): {str(field): item for field, item in value.items() if str(field) in _SIGNATURE_DETAIL_FIELDS} - for key, value in raw.items() - if isinstance(value, dict) - } - - -def _metadata_context_from_details(details: dict[str, object]) -> dict[str, object]: - return { - "source": _optional_string(details.get("source")), - "start_position": _optional_int(details.get("start_position")), - "end_position": _optional_int(details.get("end_position")), - "value_length": _optional_int(details.get("value_length")), - "resolution": ContextResolution.metadata_only, - } - - -def _artifact_record(artifact_row: dict[str, object], artifact_root: Path) -> tuple[str, int, pd.Series] | None: - batch_file = _optional_string(artifact_row.get("batch_file")) - if batch_file is None: - return None - parquet_file = artifact_root / batch_file - if not parquet_file.exists(): - return None - row_index = int(artifact_row.get("row_index", 0)) - row = pd.read_parquet(parquet_file).iloc[row_index] - return str(row.get(COL_TEXT, "")), row_index, row - - -def _entity_context( - entity: EntitySchema, - text: str, - signature_hash: str, - context_window: int, - resolution: ContextResolution, -) -> dict[str, object]: - return { - "source": entity.source, - "start_position": entity.start_position, - "end_position": entity.end_position, - "value_length": len(entity.value), - "masked_context": _masked_context(text, entity, signature_hash, context_window), - "resolution": resolution, - } - - -def _masked_context(text: str, entity: EntitySchema, signature_hash: str, window: int) -> str: - before = text[max(0, entity.start_position - window) : entity.start_position] - after = text[entity.end_position : entity.end_position + window] - placeholder = f"[{entity.label.upper()}:{signature_hash}]" - return (before + placeholder + after).replace("\n", " ") - - -def _masked_context_from_details( - text: str, - *, - label: str, - signature_hash: str, - start_position: int, - end_position: int, - window: int, -) -> str: - before = text[max(0, start_position - window) : start_position] - after = text[end_position : end_position + window] - placeholder = f"[{label.upper()}:{signature_hash}]" - return (before + placeholder + after).replace("\n", " ") - - -def _optional_int(value: object) -> int | None: - try: - if pd.isna(value): - return None - except (TypeError, ValueError): - pass - try: - return int(value) - except (TypeError, ValueError): - return None - - -def _optional_string(value: object) -> str | None: - if value is None: - return None - try: - if pd.isna(value): - return None - except (TypeError, ValueError): - pass - return str(value) - - -def write_rows(rows: list[SignatureDeltaRow], output_path: Path, export_format: ExportFormat) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - table = pd.json_normalize([row.model_dump() for row in rows], sep=".") - if export_format == ExportFormat.parquet: - table.to_parquet(output_path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(output_path, index=False) - else: - table.to_json(output_path, orient="records", lines=True) - - -def render_result(result: SignatureDeltaResult, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - counts = pd.Series([row.side.value for row in result.rows]).value_counts().to_dict() if result.rows else {} - return ( - f"Extracted {result.delta_count} signature delta(s)" - f"; baseline_only={counts.get('baseline_only', 0)}" - f"; candidate_only={counts.get('candidate_only', 0)}" - ) - - -@app.default -def main( - baseline_artifacts: Path, - candidate_artifacts: Path, - *, - baseline_artifact_root: Annotated[Path, cyclopts.Parameter("--baseline-artifact-root")], - candidate_artifact_root: Annotated[Path, cyclopts.Parameter("--candidate-artifact-root")], - baseline_config: Annotated[str | None, cyclopts.Parameter("--baseline-config")] = None, - candidate_config: Annotated[str | None, cyclopts.Parameter("--candidate-config")] = None, - workload: Annotated[str | None, cyclopts.Parameter("--workload")] = None, - context_window: Annotated[int, cyclopts.Parameter("--context-window")] = 40, - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.jsonl, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - result = extract_signature_deltas( - baseline_artifacts, - candidate_artifacts, - baseline_artifact_root=baseline_artifact_root, - candidate_artifact_root=candidate_artifact_root, - baseline_config=baseline_config, - candidate_config=candidate_config, - workload=workload, - context_window=context_window, - ) - except (ValueError, ValidationError) as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - if output is not None: - write_rows(result.rows, output, format) - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 23c586af..2533dc09 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -8,9 +8,7 @@ uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json """ -import json import logging -import os import shutil import sys import time @@ -25,17 +23,10 @@ import yaml from analyze_detection_artifacts import ( analyze_artifacts, - build_detection_artifact_row_from_entities, iter_detection_parquet_files, ) from data_designer.config.models import ModelProvider from data_designer.config.utils.io_helpers import load_config_file -from dd_parser_compat import DDParserCompatMode, dd_parser_compat_context -from detection_strategies import ( - ExperimentalDetectionStrategy, - NativeDetectionRuntime, - experimental_detection_strategy_context, -) from export_measurements import export_tables, read_measurements from measurement_tools.cli import LogFormat, configure_logging, log_bad_input from measurement_tools.tables import ExportFormat @@ -51,10 +42,8 @@ ) from anonymizer.config.replace_strategies import Annotate, Hash, Redact, Substitute from anonymizer.config.rewrite import DEFAULT_PRESERVE_TEXT, DEFAULT_PROTECT_TEXT, PrivacyGoal, RiskTolerance -from anonymizer.engine.constants import COL_DETECTED_ENTITIES, COL_FINAL_ENTITIES from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references -from anonymizer.engine.schemas import EntitiesSchema from anonymizer.interface.anonymizer import Anonymizer from anonymizer.measurement import MeasurementConfig, configured_measurement_session @@ -74,35 +63,6 @@ class DDTraceMode(StrEnum): all_messages = "all_messages" -_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" -_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" -_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" -_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" - - -class NativeRuntimeSpec(BaseModel): - model_config = ConfigDict(extra="forbid") - - runtime_id: str | None = None - endpoint: str | None = None - endpoint_env: str | None = _NATIVE_ENDPOINT_ENV - model: str | None = None - model_env: str | None = _NATIVE_MODEL_ENV - provider: str = "native" - alias: str = "native-direct" - max_tokens: int = Field(default=4096, gt=0) - timeout_sec: float = Field(default=180.0, gt=0) - gliner_endpoint: str | None = None - gliner_endpoint_env: str | None = _GLINER_ENDPOINT_ENV - gliner_model: str | None = None - gliner_model_env: str | None = _GLINER_MODEL_ENV - gliner_provider: str = "gliner" - gliner_alias: str = "gliner-direct" - gliner_api_key_env: str = "NVIDIA_API_KEY" - gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) - max_workers: int = Field(default=4, ge=1) - - class ReplaceKind(StrEnum): redact = "redact" hash = "hash" @@ -152,8 +112,6 @@ class ConfigSpec(BaseModel): replace: str | ReplaceSpec | None = None rewrite: RewriteSpec | None = None emit_telemetry: bool = False - experimental_detection_strategy: ExperimentalDetectionStrategy = ExperimentalDetectionStrategy.default - native_runtime: NativeRuntimeSpec | None = None @model_validator(mode="after") def validate_mode(self) -> "ConfigSpec": @@ -189,8 +147,6 @@ class BenchmarkSpec(BaseModel): model_configs: str | None = None model_providers: str | None = None artifact_path: str | None = None - dd_parser_compat: DDParserCompatMode = DDParserCompatMode.none - native_runtime: NativeRuntimeSpec | None = None case_retries: int = Field(default=0, ge=0) case_retry_backoff_sec: float = Field(default=0.0, ge=0.0) workloads: list[WorkloadSpec] = Field(min_length=1) @@ -260,54 +216,6 @@ class _CaseRunPaths: export_detection_artifacts: bool -@dataclass(frozen=True) -class _CaseExecution: - input_data: AnonymizerInput - trace_dataframe: pd.DataFrame | None = None - - -_TRACE_FINAL_ARTIFACT_STRATEGIES = { - ExperimentalDetectionStrategy.native_candidate_validate_no_augment, - ExperimentalDetectionStrategy.detector_native_validate_no_augment, - ExperimentalDetectionStrategy.detector_native_validate_native_augment, - ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - ExperimentalDetectionStrategy.gliner_native_validate_native_augment, - ExperimentalDetectionStrategy.native_single_pass, - ExperimentalDetectionStrategy.native_single_pass_recall, - ExperimentalDetectionStrategy.native_single_pass_values, - ExperimentalDetectionStrategy.native_single_pass_values_recall, -} -_NATIVE_RUNTIME_STRATEGIES = { - ExperimentalDetectionStrategy.native_candidate_validate_no_augment, - ExperimentalDetectionStrategy.detector_native_validate_no_augment, - ExperimentalDetectionStrategy.detector_native_validate_native_augment, - ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - ExperimentalDetectionStrategy.gliner_native_validate_native_augment, - ExperimentalDetectionStrategy.native_single_pass, - ExperimentalDetectionStrategy.native_single_pass_recall, - ExperimentalDetectionStrategy.native_single_pass_values, - ExperimentalDetectionStrategy.native_single_pass_values_recall, -} -_GLINER_NATIVE_RUNTIME_STRATEGIES = { - ExperimentalDetectionStrategy.gliner_native_validate_no_augment, - ExperimentalDetectionStrategy.gliner_native_validate_native_augment, -} - -_FINAL_ARTIFACT_KEYS = { - "final_entity_count", - "weak_api_key_shape_count", - "final_entity_signature_count", - "final_entity_signature_hashes", - "final_entity_signature_labels", - "final_entity_signature_details", - "weak_api_key_shape_label_counts", - "final_label_counts", - "final_source_counts", -} - -_FINAL_ARTIFACT_PREFIXES = tuple(f"{key}." for key in _FINAL_ARTIFACT_KEYS if key != "final_entity_signature_hashes") - - def load_spec(path: Path) -> BenchmarkSpec: if not path.exists() or path.is_dir(): raise ValueError(f"spec path is not a file: {path}") @@ -419,10 +327,6 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) except Exception as exc: errors.append(f"config '{config.id}' invalid: {exc}") continue - try: - _preflight_native_runtime(config, spec=spec) - except Exception as exc: - errors.append(f"config '{config.id}' native_runtime invalid: {exc}") if parsed_models is None: continue try: @@ -444,65 +348,6 @@ def _active_config_ids(spec: BenchmarkSpec) -> set[str]: return {entry.config for entry in spec.matrix} -def _preflight_native_runtime(config: ConfigSpec, *, spec: BenchmarkSpec) -> None: - strategy = config.experimental_detection_strategy - if strategy not in _NATIVE_RUNTIME_STRATEGIES: - return - runtime = _resolve_native_runtime_spec(spec, config) - if not runtime.runtime_id: - raise ValueError("native strategies require native_runtime.runtime_id") - if not runtime.endpoint or not runtime.model: - raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") - if strategy in _GLINER_NATIVE_RUNTIME_STRATEGIES: - if not runtime.gliner_endpoint or not runtime.gliner_model: - raise ValueError("GLiNER-native strategies require native_runtime.gliner_endpoint and gliner_model") - if not os.environ.get(runtime.gliner_api_key_env): - raise ValueError(f"{runtime.gliner_api_key_env} is not set for GLiNER-native strategy") - - -def _resolve_native_runtime_spec(spec: BenchmarkSpec, config: ConfigSpec) -> NativeRuntimeSpec: - runtime = spec.native_runtime or NativeRuntimeSpec() - if config.native_runtime is not None: - runtime = runtime.model_copy(update=config.native_runtime.model_dump(exclude_unset=True)) - return runtime.model_copy( - update={ - "endpoint": _resolve_runtime_value(runtime.endpoint, runtime.endpoint_env), - "model": _resolve_runtime_value(runtime.model, runtime.model_env), - "gliner_endpoint": _resolve_runtime_value(runtime.gliner_endpoint, runtime.gliner_endpoint_env), - "gliner_model": _resolve_runtime_value(runtime.gliner_model, runtime.gliner_model_env), - } - ) - - -def _resolve_runtime_value(explicit: str | None, env_var: str | None) -> str | None: - if explicit: - return explicit - return os.environ.get(env_var) if env_var else None - - -def _native_detection_runtime(spec: BenchmarkSpec, config: ConfigSpec) -> NativeDetectionRuntime: - runtime = _resolve_native_runtime_spec(spec, config) - if not runtime.runtime_id: - raise ValueError("native strategies require native_runtime.runtime_id") - if not runtime.endpoint or not runtime.model: - raise ValueError("native strategies require native_runtime.endpoint and native_runtime.model") - return NativeDetectionRuntime( - endpoint=runtime.endpoint, - model=runtime.model, - provider=runtime.provider, - alias=runtime.alias, - max_tokens=runtime.max_tokens, - timeout_sec=runtime.timeout_sec, - gliner_endpoint=runtime.gliner_endpoint, - gliner_model=runtime.gliner_model or "", - gliner_provider=runtime.gliner_provider, - gliner_alias=runtime.gliner_alias, - gliner_api_key_env=runtime.gliner_api_key_env, - gliner_threshold=runtime.gliner_threshold, - max_workers=runtime.max_workers, - ) - - def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: raw = _resolve_config_source(spec.model_providers, base_dir) if raw is None: @@ -683,7 +528,6 @@ def _build_contexts( "trace_dir": trace_dir or output_dir / "traces", "dd_task_trace": dd_task_trace, "task_trace_dir": task_trace_dir or output_dir / "task-traces", - "dd_parser_compat": spec.dd_parser_compat, "artifact_path": artifact_path, "anonymizer_kwargs": { "model_configs": _resolve_config_source(spec.model_configs, base_dir), @@ -786,7 +630,7 @@ def _run_case_success( ) -> BenchmarkCase: workload = _get_item(contexts["workloads"], case.workload_id, "workload") config = _get_item(contexts["configs"], case.config_id, "config") - execution = _execute_case( + _execute_case( anonymizer, workload, config, @@ -797,14 +641,11 @@ def _run_case_success( spec=spec, base_dir=contexts["base_dir"], dd_trace=contexts["dd_trace"], - dd_parser_compat=contexts["dd_parser_compat"], ) detection_artifact_path = _case_detection_artifact_path( contexts, paths, case=case, - config=config, - execution=execution, ) return _case_with_result( case, @@ -854,8 +695,6 @@ def _case_detection_artifact_path( paths: _CaseRunPaths, *, case: BenchmarkCase, - config: ConfigSpec, - execution: _CaseExecution, ) -> Path | None: detection_artifact_path = _export_case_detection_artifacts_if_requested( contexts, @@ -863,117 +702,11 @@ def _case_detection_artifact_path( case=case, artifact_snapshot=paths.artifact_snapshot, ) - if paths.export_detection_artifacts: - detection_artifact_path = _trace_final_artifact_path_if_requested( - config, - detection_artifact_path, - paths.artifact_output_path, - case=case, - trace_dataframe=execution.trace_dataframe, - ) if detection_artifact_path is not None or paths.artifact_snapshot is None: return detection_artifact_path return None -def _trace_final_artifact_path_if_requested( - config: ConfigSpec, - detection_artifact_path: Path | None, - output_path: Path, - *, - case: BenchmarkCase, - trace_dataframe: pd.DataFrame | None, -) -> Path | None: - if config.experimental_detection_strategy not in _TRACE_FINAL_ARTIFACT_STRATEGIES: - return detection_artifact_path - if trace_dataframe is None: - return detection_artifact_path - return patch_case_detection_artifacts_from_trace_dataframe( - detection_artifact_path or output_path, - trace_dataframe, - case=case, - ) - - -def patch_case_detection_artifacts_from_trace_dataframe( - output_path: Path, - trace_dataframe: pd.DataFrame, - *, - case: BenchmarkCase | None = None, -) -> Path | None: - final_rows = _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe) - if not final_rows: - return None - rows = _read_detection_artifact_payloads(output_path) if output_path.exists() else [] - patched = _merge_final_entity_artifact_rows(rows, final_rows) - if case is not None: - patched = [_with_case_metadata(row, case=case) for row in patched] - write_detection_artifact_payloads(patched, output_path) - return output_path - - -def _final_entity_artifact_rows_from_trace_dataframe(trace_dataframe: pd.DataFrame) -> list[dict[str, Any]]: - entities_column = _trace_final_entities_column(trace_dataframe) - if entities_column is None: - return [] - return [ - _final_entity_artifact_row(raw_entities, row_index=row_index) - for row_index, raw_entities in enumerate(trace_dataframe[entities_column]) - ] - - -def _trace_final_entities_column(trace_dataframe: pd.DataFrame) -> str | None: - if COL_FINAL_ENTITIES in trace_dataframe.columns: - return COL_FINAL_ENTITIES - if COL_DETECTED_ENTITIES in trace_dataframe.columns: - return COL_DETECTED_ENTITIES - return None - - -def _final_entity_artifact_row(raw_entities: object, *, row_index: int) -> dict[str, Any]: - entities = EntitiesSchema.from_raw(raw_entities).entities - return build_detection_artifact_row_from_entities( - workflow_name="entity-detection-final-trace", - batch_file="trace_dataframe", - row_index=row_index, - seed_entities=[], - seed_validation_candidate_count=0, - merged_validation_candidate_count=0, - augmented_entities=[], - final_entities=entities, - ).model_dump() - - -def _read_detection_artifact_payloads(output_path: Path) -> list[dict[str, Any]]: - with output_path.open(encoding="utf-8") as source: - return [json.loads(line) for line in source if line.strip()] - - -def _merge_final_entity_artifact_rows( - rows: list[dict[str, Any]], - final_rows: list[dict[str, Any]], -) -> list[dict[str, Any]]: - patched = [_patch_final_entity_artifact_row(row, final_row) for row, final_row in zip(rows, final_rows)] - return patched + rows[len(final_rows) :] + final_rows[len(rows) :] - - -def _patch_final_entity_artifact_row(row: dict[str, Any], final_row: dict[str, Any]) -> dict[str, Any]: - clean_row = _without_final_entity_artifact_fields(row) - return {**clean_row, **_final_entity_artifact_fields(final_row)} - - -def _without_final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: - return { - key: value - for key, value in row.items() - if key not in _FINAL_ARTIFACT_KEYS and not key.startswith(_FINAL_ARTIFACT_PREFIXES) - } - - -def _final_entity_artifact_fields(row: dict[str, Any]) -> dict[str, Any]: - return {key: row[key] for key in _FINAL_ARTIFACT_KEYS} - - def _case_with_result( case: BenchmarkCase, *, @@ -1043,8 +776,7 @@ def _execute_case( spec: BenchmarkSpec, base_dir: Path, dd_trace: DDTraceMode, - dd_parser_compat: DDParserCompatMode, -) -> _CaseExecution: +) -> None: anonymizer_config = build_anonymizer_config(config) input_data = build_input( workload, @@ -1064,19 +796,10 @@ def _execute_case( fail_on_write_error=True, ) with configured_measurement_session(measurement): - with dd_parser_compat_context(dd_parser_compat): - detection_context_kwargs: dict[str, Any] = {} - if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: - detection_context_kwargs["native_runtime"] = _native_detection_runtime(spec, config) - with experimental_detection_strategy_context( - config.experimental_detection_strategy, - **detection_context_kwargs, - ): - result = anonymizer.run( - config=anonymizer_config, - data=input_data, - ) - return _CaseExecution(input_data=input_data, trace_dataframe=getattr(result, "trace_dataframe", None)) + anonymizer.run( + config=anonymizer_config, + data=input_data, + ) def build_input( @@ -1339,42 +1062,13 @@ def render_result(result: BenchmarkResult, *, json_output: bool) -> str: def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: - config = next(item for item in spec.configs if item.id == case.config_id) - tags = { + return { "suite_id": spec.suite_id, "workload_id": case.workload_id, "config_id": case.config_id, "repetition": case.repetition, "case_id": case.case_id, - "experimental_detection_strategy": config.experimental_detection_strategy.value, - "dd_parser_compat": spec.dd_parser_compat.value, } - if config.experimental_detection_strategy in _NATIVE_RUNTIME_STRATEGIES: - tags.update(_native_runtime_tags(_resolve_native_runtime_spec(spec, config))) - return tags - - -def _native_runtime_tags(runtime: NativeRuntimeSpec) -> dict[str, Any]: - return _present( - { - "native_runtime_id": runtime.runtime_id, - "native_endpoint_env": runtime.endpoint_env, - "native_model": runtime.model, - "native_model_env": runtime.model_env, - "native_provider": runtime.provider, - "native_alias": runtime.alias, - "native_max_tokens": runtime.max_tokens, - "native_timeout_sec": runtime.timeout_sec, - "native_max_workers": runtime.max_workers, - "gliner_endpoint_env": runtime.gliner_endpoint_env, - "gliner_model": runtime.gliner_model, - "gliner_model_env": runtime.gliner_model_env, - "gliner_provider": runtime.gliner_provider, - "gliner_alias": runtime.gliner_alias, - "gliner_api_key_env": runtime.gliner_api_key_env, - "gliner_threshold": runtime.gliner_threshold, - } - ) def _present(values: dict[str, Any]) -> dict[str, Any]: diff --git a/tools/measurement/screen_strategy_comparisons.py b/tools/measurement/screen_strategy_comparisons.py deleted file mode 100644 index 69542205..00000000 --- a/tools/measurement/screen_strategy_comparisons.py +++ /dev/null @@ -1,1107 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Screen strategy-comparison CSVs across benchmark analysis directories. - -Usage: - uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ - uv run python tools/measurement/screen_strategy_comparisons.py benchmark-runs/ --output strategy-screen.csv - uv run python tools/measurement/screen_strategy_comparisons.py run-a/analysis run-b/analysis --json -""" - -from __future__ import annotations - -import ast -import json -import logging -import sys -from enum import StrEnum -from pathlib import Path -from typing import Annotated - -import cyclopts -import pandas as pd -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from measurement_tools.tables import ExportFormat -from pydantic import BaseModel, Field, model_validator - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.strategy_screen") - -COMPARISON_COLUMNS = { - "workload_id", - "baseline_config_id", - "candidate_config_id", - "safety_verdict", - "performance_verdict", - "candidate_verdict", -} - - -class GroupBy(StrEnum): - strategy = "strategy" - strategy_workload_family = "strategy_workload_family" - strategy_workload = "strategy_workload" - - -class ScreenRow(BaseModel): - source_path: str - workload_id: str - workload_family: str | None = None - baseline_config_id: str - candidate_config_id: str - baseline_strategy: str | None = None - candidate_strategy: str | None = None - baseline_replacement_strategy: str | None = None - candidate_replacement_strategy: str | None = None - baseline_case_count: int | None = None - candidate_case_count: int | None = None - value_protection_verdict: str | None = None - signature_parity_verdict: str | None = None - safety_verdict: str - performance_verdict: str - candidate_verdict: str - evidence_level: str = "legacy" - pipeline_elapsed_sec_delta_pct: float | None = None - observed_total_requests_delta: float | None = None - observed_total_tokens_delta: float | None = None - final_entity_count_delta: float | None = None - augmented_entity_count_delta: float | None = None - augmented_new_final_value_count_delta: float | None = None - baseline_original_value_leak_count: float | None = None - candidate_original_value_leak_count: float | None = None - original_value_leak_count_delta: float | None = None - baseline_original_value_leak_record_count: float | None = None - candidate_original_value_leak_record_count: float | None = None - original_value_leak_record_count_delta: float | None = None - baseline_replacement_missing_final_entity_count: float | None = None - candidate_replacement_missing_final_entity_count: float | None = None - replacement_missing_final_entity_count_delta: float | None = None - baseline_replacement_synthetic_original_collision_count: float | None = None - candidate_replacement_synthetic_original_collision_count: float | None = None - replacement_synthetic_original_collision_count_delta: float | None = None - baseline_replacement_synthetic_original_collision_value_count: float | None = None - candidate_replacement_synthetic_original_collision_value_count: float | None = None - replacement_synthetic_original_collision_value_count_delta: float | None = None - baseline_duplicate_synthetic_replacement_count: float | None = None - candidate_duplicate_synthetic_replacement_count: float | None = None - duplicate_synthetic_replacement_count_delta: float | None = None - baseline_only_final_entity_signature_count: float | None = None - candidate_only_final_entity_signature_count: float | None = None - shared_final_entity_signature_count: float | None = None - baseline_stable_final_entity_signature_count: int | None = None - candidate_stable_final_entity_signature_count: int | None = None - baseline_stable_candidate_unstable_final_entity_signature_count: float | None = None - candidate_stable_baseline_unstable_final_entity_signature_count: float | None = None - shared_stable_final_entity_signature_count: int | None = None - flags: list[str] = Field(default_factory=list) - baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) - label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) - stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - - @model_validator(mode="after") - def fill_workload_family(self) -> "ScreenRow": - if self.workload_family is None: - self.workload_family = workload_family(self.workload_id) - return self - - -class ScreenSummary(BaseModel): - viable_count: int = 0 - review_count: int = 0 - reject_count: int = 0 - candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) - value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) - signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) - safety_verdict_counts: dict[str, int] = Field(default_factory=dict) - performance_verdict_counts: dict[str, int] = Field(default_factory=dict) - evidence_level_counts: dict[str, int] = Field(default_factory=dict) - - -class ScreenGroup(BaseModel): - group_key: str - candidate_strategy: str | None = None - candidate_replacement_strategy: str | None = None - candidate_config_ids: list[str] = Field(default_factory=list) - workload_ids: list[str] = Field(default_factory=list) - workload_families: list[str] = Field(default_factory=list) - row_count: int = 0 - min_baseline_case_count: int | None = None - min_candidate_case_count: int | None = None - viable_count: int = 0 - review_count: int = 0 - reject_count: int = 0 - has_conflicting_verdicts: bool = False - recommendation: str = "unknown" - value_protection_verdict_counts: dict[str, int] = Field(default_factory=dict) - signature_parity_verdict_counts: dict[str, int] = Field(default_factory=dict) - performance_verdict_counts: dict[str, int] = Field(default_factory=dict) - evidence_level_counts: dict[str, int] = Field(default_factory=dict) - split_verdict_candidate_verdict_counts: dict[str, int] = Field(default_factory=dict) - best_pipeline_elapsed_sec_delta_pct: float | None = None - best_observed_total_tokens_delta: float | None = None - best_observed_total_requests_delta: float | None = None - worst_pipeline_elapsed_sec_delta_pct: float | None = None - worst_observed_total_tokens_delta: float | None = None - worst_observed_total_requests_delta: float | None = None - min_shared_stable_final_entity_signature_count: int | None = None - sum_baseline_replacement_missing_final_entity_count: float | None = None - sum_candidate_replacement_missing_final_entity_count: float | None = None - sum_baseline_original_value_leak_count: float | None = None - sum_candidate_original_value_leak_count: float | None = None - sum_baseline_original_value_leak_record_count: float | None = None - sum_candidate_original_value_leak_record_count: float | None = None - sum_baseline_replacement_synthetic_original_collision_count: float | None = None - sum_candidate_replacement_synthetic_original_collision_count: float | None = None - sum_baseline_replacement_synthetic_original_collision_value_count: float | None = None - sum_candidate_replacement_synthetic_original_collision_value_count: float | None = None - sum_baseline_duplicate_synthetic_replacement_count: float | None = None - sum_candidate_duplicate_synthetic_replacement_count: float | None = None - baseline_defect_improvement_count: int = 0 - flag_counts: dict[str, int] = Field(default_factory=dict) - baseline_only_label_counts: dict[str, int] = Field(default_factory=dict) - label_mismatch_label_counts: dict[str, int] = Field(default_factory=dict) - stable_lost_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_missing_final_entity_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) - baseline_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - candidate_replacement_synthetic_original_collision_label_counts: dict[str, int] = Field(default_factory=dict) - - @model_validator(mode="after") - def fill_recommendation(self) -> "ScreenGroup": - if self.recommendation == "unknown": - self.recommendation = group_recommendation(self) - return self - - -class ScreenResult(BaseModel): - input_paths: list[str] - scanned_file_count: int - comparison_file_count: int - row_count: int - duplicate_row_count: int = 0 - summary: ScreenSummary = Field(default_factory=ScreenSummary) - rows: list[ScreenRow] = Field(default_factory=list) - groups: list[ScreenGroup] = Field(default_factory=list) - - -def screen_comparison_paths( - paths: list[Path], - *, - group_by: GroupBy = GroupBy.strategy, - config_aliases: dict[str, str] | None = None, - source_includes: list[str] | None = None, - source_excludes: list[str] | None = None, -) -> ScreenResult: - files = filter_source_paths( - list(iter_csv_files(paths)), - includes=source_includes or [], - excludes=source_excludes or [], - ) - rows: list[ScreenRow] = [] - comparison_file_count = 0 - for csv_file in files: - table = read_csv_or_empty(csv_file) - if table is None: - continue - if not is_comparison_table(table): - continue - comparison_file_count += 1 - rows.extend(build_rows_from_table(table, source_path=csv_file)) - deduped_rows = sorted(dedupe_rows(rows), key=screen_sort_key) - groups = sorted( - summarize_groups(deduped_rows, group_by=group_by, config_aliases=config_aliases or {}), - key=group_sort_key, - ) - return ScreenResult( - input_paths=[str(path) for path in paths], - scanned_file_count=len(files), - comparison_file_count=comparison_file_count, - row_count=len(deduped_rows), - duplicate_row_count=len(rows) - len(deduped_rows), - summary=summarize_rows(deduped_rows), - rows=deduped_rows, - groups=groups, - ) - - -def iter_csv_files(paths: list[Path]) -> list[Path]: - files: list[Path] = [] - for path in paths: - if not path.exists(): - raise ValueError(f"comparison path does not exist: {path}") - if path.is_file(): - if path.suffix == ".csv": - files.append(path) - continue - files.extend(sorted(path.rglob("*.csv"))) - return sorted(set(files)) - - -def filter_source_paths(paths: list[Path], *, includes: list[str], excludes: list[str]) -> list[Path]: - return [path for path in paths if source_path_matches(path, includes=includes, excludes=excludes)] - - -def source_path_matches(path: Path, *, includes: list[str], excludes: list[str]) -> bool: - text = str(path) - if includes and not any(fragment in text for fragment in includes): - return False - return not any(fragment in text for fragment in excludes) - - -def is_comparison_table(table: pd.DataFrame) -> bool: - if "source_path" in table.columns: - return False - return COMPARISON_COLUMNS.issubset(set(table.columns)) - - -def read_csv_or_empty(csv_file: Path) -> pd.DataFrame | None: - try: - return pd.read_csv(csv_file) - except pd.errors.EmptyDataError: - return None - - -def read_config_aliases(path: Path | None) -> dict[str, str]: - if path is None: - return {} - try: - payload = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError as exc: - raise ValueError(f"config aliases file is not valid JSON: {path}") from exc - if not isinstance(payload, dict): - raise ValueError("config aliases file must contain a JSON object mapping config IDs to aliases") - return {str(key): str(value) for key, value in payload.items()} - - -def build_rows_from_table(table: pd.DataFrame, *, source_path: Path) -> list[ScreenRow]: - return [build_row(row, source_path=source_path) for _, row in table.iterrows()] - - -def build_row(row: pd.Series, *, source_path: Path) -> ScreenRow: - return ScreenRow( - source_path=str(source_path), - workload_id=required_string(row, "workload_id"), - workload_family=workload_family(required_string(row, "workload_id")), - baseline_config_id=required_string(row, "baseline_config_id"), - candidate_config_id=required_string(row, "candidate_config_id"), - baseline_strategy=optional_string(row.get("baseline_strategy")), - candidate_strategy=optional_string(row.get("candidate_strategy")), - baseline_replacement_strategy=optional_string(row.get("baseline_replacement_strategy")), - candidate_replacement_strategy=optional_string(row.get("candidate_replacement_strategy")), - baseline_case_count=optional_int(row.get("baseline_case_count")), - candidate_case_count=optional_int(row.get("candidate_case_count")), - value_protection_verdict=optional_string(row.get("value_protection_verdict")), - signature_parity_verdict=optional_string(row.get("signature_parity_verdict")), - safety_verdict=required_string(row, "safety_verdict"), - performance_verdict=required_string(row, "performance_verdict"), - candidate_verdict=required_string(row, "candidate_verdict"), - evidence_level=comparison_evidence_level(row), - pipeline_elapsed_sec_delta_pct=optional_float(row.get("pipeline_elapsed_sec_delta_pct")), - observed_total_requests_delta=optional_float(row.get("observed_total_requests_delta")), - observed_total_tokens_delta=optional_float(row.get("observed_total_tokens_delta")), - final_entity_count_delta=optional_float(row.get("final_entity_count_delta")), - augmented_entity_count_delta=optional_float(row.get("augmented_entity_count_delta")), - augmented_new_final_value_count_delta=optional_float(row.get("augmented_new_final_value_count_delta")), - baseline_original_value_leak_count=optional_float(row.get("baseline_original_value_leak_count")), - candidate_original_value_leak_count=optional_float(row.get("candidate_original_value_leak_count")), - original_value_leak_count_delta=optional_float(row.get("original_value_leak_count_delta")), - baseline_original_value_leak_record_count=optional_float(row.get("baseline_original_value_leak_record_count")), - candidate_original_value_leak_record_count=optional_float( - row.get("candidate_original_value_leak_record_count") - ), - original_value_leak_record_count_delta=optional_float(row.get("original_value_leak_record_count_delta")), - baseline_replacement_missing_final_entity_count=optional_float( - row.get("baseline_replacement_missing_final_entity_count") - ), - candidate_replacement_missing_final_entity_count=optional_float( - row.get("candidate_replacement_missing_final_entity_count") - ), - replacement_missing_final_entity_count_delta=optional_float( - row.get("replacement_missing_final_entity_count_delta") - ), - baseline_replacement_synthetic_original_collision_count=optional_float( - row.get("baseline_replacement_synthetic_original_collision_count") - ), - candidate_replacement_synthetic_original_collision_count=optional_float( - row.get("candidate_replacement_synthetic_original_collision_count") - ), - replacement_synthetic_original_collision_count_delta=optional_float( - row.get("replacement_synthetic_original_collision_count_delta") - ), - baseline_replacement_synthetic_original_collision_value_count=optional_float( - row.get("baseline_replacement_synthetic_original_collision_value_count") - ), - candidate_replacement_synthetic_original_collision_value_count=optional_float( - row.get("candidate_replacement_synthetic_original_collision_value_count") - ), - replacement_synthetic_original_collision_value_count_delta=optional_float( - row.get("replacement_synthetic_original_collision_value_count_delta") - ), - baseline_duplicate_synthetic_replacement_count=optional_float( - row.get("baseline_duplicate_synthetic_replacement_count") - ), - candidate_duplicate_synthetic_replacement_count=optional_float( - row.get("candidate_duplicate_synthetic_replacement_count") - ), - duplicate_synthetic_replacement_count_delta=optional_float( - row.get("duplicate_synthetic_replacement_count_delta") - ), - baseline_only_final_entity_signature_count=optional_float( - row.get("baseline_only_final_entity_signature_count") - ), - candidate_only_final_entity_signature_count=optional_float( - row.get("candidate_only_final_entity_signature_count") - ), - shared_final_entity_signature_count=optional_float(row.get("shared_final_entity_signature_count")), - baseline_stable_final_entity_signature_count=optional_int( - row.get("baseline_stable_final_entity_signature_count") - ), - candidate_stable_final_entity_signature_count=optional_int( - row.get("candidate_stable_final_entity_signature_count") - ), - baseline_stable_candidate_unstable_final_entity_signature_count=optional_float( - row.get("baseline_stable_candidate_unstable_final_entity_signature_count") - ), - candidate_stable_baseline_unstable_final_entity_signature_count=optional_float( - row.get("candidate_stable_baseline_unstable_final_entity_signature_count") - ), - shared_stable_final_entity_signature_count=optional_int(row.get("shared_stable_final_entity_signature_count")), - flags=parse_flags(row.get("flags")), - baseline_only_label_counts=preferred_count_columns( - row, - preferred_prefix="baseline_only_candidate_uncovered_signature_label_counts", - fallback_prefix="baseline_only_final_entity_signature_label_counts", - preferred_count_column="baseline_only_candidate_uncovered_signature_count", - ), - label_mismatch_label_counts=count_columns( - row, - "baseline_only_candidate_label_mismatch_signature_label_counts", - ), - baseline_replacement_missing_final_entity_label_counts=count_columns( - row, - "baseline_replacement_missing_final_entity_label_counts", - ), - candidate_replacement_missing_final_entity_label_counts=count_columns( - row, - "candidate_replacement_missing_final_entity_label_counts", - ), - baseline_original_value_leak_label_counts=count_columns(row, "baseline_original_value_leak_label_counts"), - candidate_original_value_leak_label_counts=count_columns(row, "candidate_original_value_leak_label_counts"), - baseline_replacement_synthetic_original_collision_label_counts=count_columns( - row, - "baseline_replacement_synthetic_original_collision_label_counts", - ), - candidate_replacement_synthetic_original_collision_label_counts=count_columns( - row, - "candidate_replacement_synthetic_original_collision_label_counts", - ), - stable_lost_label_counts=preferred_count_columns( - row, - preferred_prefix="baseline_stable_candidate_uncovered_signature_label_counts", - fallback_prefix="baseline_stable_candidate_unstable_final_entity_signature_label_counts", - preferred_count_column="baseline_stable_candidate_uncovered_signature_count", - ), - ) - - -def comparison_evidence_level(row: pd.Series) -> str: - if optional_string(row.get("value_protection_verdict")) and optional_string(row.get("signature_parity_verdict")): - return "split_verdicts" - if ( - not is_missing(row.get("baseline_stable_final_entity_signature_count")) - or not is_missing(row.get("candidate_stable_final_entity_signature_count")) - or not is_missing(row.get("shared_stable_final_entity_signature_count")) - ): - return "stable_signatures" - if ( - not is_missing(row.get("baseline_only_final_entity_signature_count")) - or not is_missing(row.get("candidate_only_final_entity_signature_count")) - or not is_missing(row.get("shared_final_entity_signature_count")) - ): - return "signature_counts" - return "legacy" - - -def summarize_rows(rows: list[ScreenRow]) -> ScreenSummary: - candidate_counts = count_values(row.candidate_verdict for row in rows) - return ScreenSummary( - viable_count=candidate_counts.get("candidate_viable", 0), - review_count=candidate_counts.get("review", 0), - reject_count=candidate_counts.get("reject", 0), - candidate_verdict_counts=candidate_counts, - value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), - signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), - safety_verdict_counts=count_values(row.safety_verdict for row in rows), - performance_verdict_counts=count_values(row.performance_verdict for row in rows), - evidence_level_counts=count_values(row.evidence_level for row in rows), - ) - - -def summarize_groups( - rows: list[ScreenRow], - *, - group_by: GroupBy = GroupBy.strategy, - config_aliases: dict[str, str] | None = None, -) -> list[ScreenGroup]: - return [ - build_group(group_key, group_rows) - for group_key, group_rows in grouped_rows( - rows, - group_by=group_by, - config_aliases=config_aliases or {}, - ).items() - ] - - -def grouped_rows( - rows: list[ScreenRow], - *, - group_by: GroupBy, - config_aliases: dict[str, str], -) -> dict[str, list[ScreenRow]]: - groups: dict[str, list[ScreenRow]] = {} - for row in rows: - groups.setdefault(group_key_for_row(row, group_by=group_by, config_aliases=config_aliases), []).append(row) - return dict(sorted(groups.items())) - - -def group_key_for_row(row: ScreenRow, *, group_by: GroupBy, config_aliases: dict[str, str]) -> str: - base = group_base_for_row(row, config_aliases=config_aliases) - if group_by == GroupBy.strategy_workload_family: - return f"{base}|family:{row.workload_family}" - if group_by == GroupBy.strategy_workload: - return f"{base}|workload:{row.workload_id}" - return base - - -def group_base_for_row(row: ScreenRow, *, config_aliases: dict[str, str]) -> str: - replacement = _non_default_replacement_strategy(row) - if row.candidate_strategy and row.candidate_strategy != "default": - base = f"strategy:{row.candidate_strategy}" - return f"{base}|replacement:{replacement}" if replacement else base - if replacement: - return f"replacement:{replacement}" - if alias := config_aliases.get(row.candidate_config_id): - return f"alias:{alias}" - return f"config:{row.candidate_config_id}" - - -def _non_default_replacement_strategy(row: ScreenRow) -> str | None: - if row.candidate_replacement_strategy and row.candidate_replacement_strategy != "default": - return row.candidate_replacement_strategy - return None - - -def workload_family(workload_id: str) -> str: - parts = [part for part in workload_id.split("-") if part] - while parts and _is_workload_slice_suffix(parts[-1]): - parts.pop() - return "-".join(parts) if parts else workload_id - - -def _is_workload_slice_suffix(value: str) -> bool: - return ( - value.isdigit() - or value == "slice" - or (value.startswith("r") and value[1:].isdigit()) - or (value.startswith("offset") and value[len("offset") :].isdigit()) - ) - - -def build_group(group_key: str, rows: list[ScreenRow]) -> ScreenGroup: - verdict_counts = count_values(row.candidate_verdict for row in rows) - group = ScreenGroup( - group_key=group_key, - candidate_strategy=single_optional_value(row.candidate_strategy for row in rows), - candidate_replacement_strategy=single_optional_value(row.candidate_replacement_strategy for row in rows), - candidate_config_ids=sorted({row.candidate_config_id for row in rows}), - workload_ids=sorted({row.workload_id for row in rows}), - workload_families=sorted({row.workload_family or workload_family(row.workload_id) for row in rows}), - row_count=len(rows), - min_baseline_case_count=min_present_int(row.baseline_case_count for row in rows), - min_candidate_case_count=min_present_int(row.candidate_case_count for row in rows), - viable_count=verdict_counts.get("candidate_viable", 0), - review_count=verdict_counts.get("review", 0), - reject_count=verdict_counts.get("reject", 0), - has_conflicting_verdicts=len(verdict_counts) > 1, - value_protection_verdict_counts=count_present_values(row.value_protection_verdict for row in rows), - signature_parity_verdict_counts=count_present_values(row.signature_parity_verdict for row in rows), - performance_verdict_counts=count_values(row.performance_verdict for row in rows), - evidence_level_counts=count_values(row.evidence_level for row in rows), - split_verdict_candidate_verdict_counts=count_values( - row.candidate_verdict for row in rows if row.evidence_level == "split_verdicts" - ), - best_pipeline_elapsed_sec_delta_pct=min_present(row.pipeline_elapsed_sec_delta_pct for row in rows), - best_observed_total_tokens_delta=min_present(row.observed_total_tokens_delta for row in rows), - best_observed_total_requests_delta=min_present(row.observed_total_requests_delta for row in rows), - worst_pipeline_elapsed_sec_delta_pct=max_present(row.pipeline_elapsed_sec_delta_pct for row in rows), - worst_observed_total_tokens_delta=max_present(row.observed_total_tokens_delta for row in rows), - worst_observed_total_requests_delta=max_present(row.observed_total_requests_delta for row in rows), - min_shared_stable_final_entity_signature_count=min_present_int( - row.shared_stable_final_entity_signature_count for row in rows - ), - sum_baseline_replacement_missing_final_entity_count=sum_present( - row.baseline_replacement_missing_final_entity_count for row in rows - ), - sum_candidate_replacement_missing_final_entity_count=sum_present( - row.candidate_replacement_missing_final_entity_count for row in rows - ), - sum_baseline_original_value_leak_count=sum_present(row.baseline_original_value_leak_count for row in rows), - sum_candidate_original_value_leak_count=sum_present(row.candidate_original_value_leak_count for row in rows), - sum_baseline_original_value_leak_record_count=sum_present( - row.baseline_original_value_leak_record_count for row in rows - ), - sum_candidate_original_value_leak_record_count=sum_present( - row.candidate_original_value_leak_record_count for row in rows - ), - sum_baseline_replacement_synthetic_original_collision_count=sum_present( - row.baseline_replacement_synthetic_original_collision_count for row in rows - ), - sum_candidate_replacement_synthetic_original_collision_count=sum_present( - row.candidate_replacement_synthetic_original_collision_count for row in rows - ), - sum_baseline_replacement_synthetic_original_collision_value_count=sum_present( - row.baseline_replacement_synthetic_original_collision_value_count for row in rows - ), - sum_candidate_replacement_synthetic_original_collision_value_count=sum_present( - row.candidate_replacement_synthetic_original_collision_value_count for row in rows - ), - sum_baseline_duplicate_synthetic_replacement_count=sum_present( - row.baseline_duplicate_synthetic_replacement_count for row in rows - ), - sum_candidate_duplicate_synthetic_replacement_count=sum_present( - row.candidate_duplicate_synthetic_replacement_count for row in rows - ), - baseline_defect_improvement_count=sum(1 for row in rows if row_has_baseline_defect_improvement(row)), - flag_counts=sum_string_counts(row.flags for row in rows), - baseline_only_label_counts=sum_dict_counts(row.baseline_only_label_counts for row in rows), - label_mismatch_label_counts=sum_dict_counts(row.label_mismatch_label_counts for row in rows), - stable_lost_label_counts=sum_dict_counts(row.stable_lost_label_counts for row in rows), - baseline_replacement_missing_final_entity_label_counts=sum_dict_counts( - row.baseline_replacement_missing_final_entity_label_counts for row in rows - ), - candidate_replacement_missing_final_entity_label_counts=sum_dict_counts( - row.candidate_replacement_missing_final_entity_label_counts for row in rows - ), - baseline_original_value_leak_label_counts=sum_dict_counts( - row.baseline_original_value_leak_label_counts for row in rows - ), - candidate_original_value_leak_label_counts=sum_dict_counts( - row.candidate_original_value_leak_label_counts for row in rows - ), - baseline_replacement_synthetic_original_collision_label_counts=sum_dict_counts( - row.baseline_replacement_synthetic_original_collision_label_counts for row in rows - ), - candidate_replacement_synthetic_original_collision_label_counts=sum_dict_counts( - row.candidate_replacement_synthetic_original_collision_label_counts for row in rows - ), - ) - group.recommendation = group_recommendation(group) - return group - - -def group_recommendation(group: ScreenGroup) -> str: - if group.viable_count and group.reject_count: - return "conflicting_evidence" - if group_needs_split_verdict_rerun(group): - return "needs_split_verdict_rerun" - if group_needs_viable_split_verdict(group): - return "needs_viable_split_verdict" - if group.viable_count and group.review_count and group_has_baseline_defect_improvement_reviews(group): - return "promising_with_baseline_defect_improvements" - if group.viable_count and group.review_count: - return "promising_needs_review" - if group.viable_count: - if group.row_count == 1: - return "single_slice_viable" - return "candidate_family_viable" - if group.reject_count: - return "reject" - if group.review_count: - if is_baseline_defect_improvement_group(group): - return "candidate_covers_baseline_defects" - if is_replacement_replay_review_group(group): - return "replacement_replay_review" - if is_reliability_review_group(group): - return "reliability_review" - if is_label_policy_review_group(group): - return "label_policy_review" - if group.performance_verdict_counts.get("improved", 0) == group.review_count: - return "review_only" - if group.performance_verdict_counts.get("improved", 0) or group.performance_verdict_counts.get("mixed", 0): - return "review_mixed_performance" - return "no_performance_win" - return "unknown" - - -def group_needs_split_verdict_rerun(group: ScreenGroup) -> bool: - if not group.evidence_level_counts: - return False - if group.reject_count: - return False - split_verdict_count = group.evidence_level_counts.get("split_verdicts", 0) - if split_verdict_count == group.row_count: - return False - if group.viable_count and group.review_count: - return split_verdict_count == 0 - if group.review_count == group.row_count and split_verdict_count: - return True - return False - - -def group_needs_viable_split_verdict(group: ScreenGroup) -> bool: - if not (group.viable_count and group.review_count): - return False - if not group.evidence_level_counts.get("split_verdicts", 0): - return False - return not bool(group.split_verdict_candidate_verdict_counts.get("candidate_viable", 0)) - - -def is_replacement_replay_review_group(group: ScreenGroup) -> bool: - if not group.candidate_replacement_strategy or group.candidate_replacement_strategy == "default": - return False - if group.review_count != group.row_count: - return False - if group.performance_verdict_counts.get("improved", 0) != group.review_count: - return False - return bool(group.flag_counts.get("replacement_only_detection_instability", 0)) - - -_BASELINE_DEFECT_IMPROVEMENT_FLAGS = { - "candidate_covers_baseline_original_value_leak", - "candidate_covers_baseline_replacement_missing_final_entity", - "candidate_covers_baseline_replacement_synthetic_original_collision", -} - - -def row_has_baseline_defect_improvement(row: ScreenRow) -> bool: - return bool(set(row.flags) & _BASELINE_DEFECT_IMPROVEMENT_FLAGS) - - -def group_has_baseline_defect_improvement_reviews(group: ScreenGroup) -> bool: - if not group.review_count: - return False - return group.baseline_defect_improvement_count == group.review_count - - -def is_baseline_defect_improvement_group(group: ScreenGroup) -> bool: - if group.review_count != group.row_count: - return False - if group.performance_verdict_counts.get("improved", 0) != group.review_count: - return False - if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: - return False - if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: - return False - return group.baseline_defect_improvement_count == group.review_count - - -_RELIABILITY_REVIEW_FLAGS = {"failed_request_increase", "bridge_fallback_increase"} - - -def is_reliability_review_group(group: ScreenGroup) -> bool: - if group.review_count != group.row_count: - return False - if group.performance_verdict_counts.get("improved", 0) != group.review_count: - return False - return any(group.flag_counts.get(flag, 0) > 0 for flag in _RELIABILITY_REVIEW_FLAGS) - - -def is_label_policy_review_group(group: ScreenGroup) -> bool: - if group.review_count != group.row_count: - return False - if group.performance_verdict_counts.get("improved", 0) != group.review_count: - return False - if group.value_protection_verdict_counts.get("pass", 0) != group.review_count: - return False - if group.signature_parity_verdict_counts.get("review", 0) != group.review_count: - return False - return bool(group.label_mismatch_label_counts or group.flag_counts.get("covered_label_mismatch")) - - -def group_performance_summary(group: ScreenGroup) -> str: - if group.performance_verdict_counts: - return label_summary(group.performance_verdict_counts) - return "unknown" - - -def single_optional_value(values: object) -> str | None: - unique = sorted({str(value) for value in values if value is not None}) - return unique[0] if len(unique) == 1 else None - - -def min_present(values: object) -> float | None: - present = [float(value) for value in values if value is not None] - return min(present) if present else None - - -def min_present_int(values: object) -> int | None: - value = min_present(values) - return int(value) if value is not None else None - - -def max_present(values: object) -> float | None: - present = [float(value) for value in values if value is not None] - return max(present) if present else None - - -def sum_present(values: object) -> float | None: - present = [float(value) for value in values if value is not None] - return sum(present) if present else None - - -def sum_string_counts(values: object) -> dict[str, int]: - counts: dict[str, int] = {} - for items in values: - for item in items: - counts[str(item)] = counts.get(str(item), 0) + 1 - return dict(sorted(counts.items())) - - -def sum_dict_counts(values: object) -> dict[str, int]: - counts: dict[str, int] = {} - for mapping in values: - for key, value in mapping.items(): - counts[str(key)] = counts.get(str(key), 0) + int(value) - return dict(sorted(counts.items())) - - -def group_sort_key(group: ScreenGroup) -> tuple[int, int, float, float, str]: - conflict_rank = 1 if group.has_conflicting_verdicts else 0 - verdict_rank = 0 if group.viable_count else 1 if group.review_count else 2 - elapsed = group.best_pipeline_elapsed_sec_delta_pct - tokens = group.best_observed_total_tokens_delta - return ( - conflict_rank, - verdict_rank, - elapsed if elapsed is not None else float("inf"), - tokens if tokens is not None else float("inf"), - group.group_key, - ) - - -def dedupe_rows(rows: list[ScreenRow]) -> list[ScreenRow]: - deduped: dict[str, ScreenRow] = {} - for row in rows: - payload = row.model_dump() - payload.pop("source_path", None) - key = json.dumps(payload, ensure_ascii=True, sort_keys=True) - deduped.setdefault(key, row) - return list(deduped.values()) - - -def count_values(values: object) -> dict[str, int]: - counts: dict[str, int] = {} - for value in values: - key = str(value) - counts[key] = counts.get(key, 0) + 1 - return dict(sorted(counts.items())) - - -def count_present_values(values: object) -> dict[str, int]: - return count_values(value for value in values if value is not None) - - -def screen_sort_key(row: ScreenRow) -> tuple[int, float, float, str, str]: - verdict_rank = {"candidate_viable": 0, "review": 1, "reject": 2}.get(row.candidate_verdict, 3) - elapsed_delta = row.pipeline_elapsed_sec_delta_pct - token_delta = row.observed_total_tokens_delta - return ( - verdict_rank, - elapsed_delta if elapsed_delta is not None else float("inf"), - token_delta if token_delta is not None else float("inf"), - row.workload_id, - row.candidate_config_id, - ) - - -def required_string(row: pd.Series, column: str) -> str: - value = optional_string(row.get(column)) - if value is None: - raise ValueError(f"comparison row missing required value: {column}") - return value - - -def optional_string(value: object) -> str | None: - if is_missing(value): - return None - return str(value) - - -def optional_float(value: object) -> float | None: - if is_missing(value): - return None - return float(value) - - -def optional_int(value: object) -> int | None: - value = optional_float(value) - return int(value) if value is not None else None - - -def is_missing(value: object) -> bool: - return value is None or (isinstance(value, float) and pd.isna(value)) - - -def parse_flags(value: object) -> list[str]: - if is_missing(value): - return [] - if isinstance(value, list): - return [str(item) for item in value] - if not isinstance(value, str): - return [str(value)] - parsed = parse_nested(value) - if isinstance(parsed, list): - return [str(item) for item in parsed] - return [value] if value else [] - - -def parse_nested(value: str) -> object: - try: - return json.loads(value) - except json.JSONDecodeError: - try: - return ast.literal_eval(value) - except (SyntaxError, ValueError): - return value - - -def count_columns(row: pd.Series, prefix: str) -> dict[str, int]: - label_counts: dict[str, int] = {} - column_prefix = f"{prefix}." - for column, raw_value in row.items(): - if not str(column).startswith(column_prefix): - continue - value = optional_float(raw_value) - if value is not None and value > 0: - label_counts[str(column)[len(column_prefix) :]] = int(value) - return dict(sorted(label_counts.items())) - - -def preferred_count_columns( - row: pd.Series, - *, - preferred_prefix: str, - fallback_prefix: str, - preferred_count_column: str, -) -> dict[str, int]: - if has_comparison_field(row, preferred_prefix) or not is_missing(row.get(preferred_count_column)): - return count_columns(row, preferred_prefix) - return count_columns(row, fallback_prefix) - - -def has_comparison_field(row: pd.Series, prefix: str) -> bool: - column_prefix = f"{prefix}." - return any(str(column).startswith(column_prefix) for column in row.index) - - -def write_rows(rows: list[ScreenRow], output_path: Path, export_format: ExportFormat) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - table = pd.json_normalize([row.model_dump() for row in rows], sep=".") - table = normalize_table_cells(table) - if export_format == ExportFormat.parquet: - table.to_parquet(output_path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(output_path, index=False) - else: - table.to_json(output_path, orient="records", lines=True) - - -def write_groups(groups: list[ScreenGroup], output_path: Path, export_format: ExportFormat) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - table = pd.DataFrame([group.model_dump() for group in groups]) - table = normalize_table_cells(table) - if export_format == ExportFormat.parquet: - table.to_parquet(output_path, index=False) - elif export_format == ExportFormat.csv: - table.to_csv(output_path, index=False) - else: - table.to_json(output_path, orient="records", lines=True) - - -def normalize_table_cells(table: pd.DataFrame) -> pd.DataFrame: - normalized = table.copy() - for column in normalized.columns: - if normalized[column].map(is_nested_cell).any(): - normalized[column] = normalized[column].map(json_cell) - return normalized - - -def is_nested_cell(value: object) -> bool: - return isinstance(value, dict | list) - - -def json_cell(value: object) -> object: - if not is_nested_cell(value): - return value - return json.dumps(value, ensure_ascii=True, sort_keys=True) - - -def render_result(result: ScreenResult, *, json_output: bool, limit: int) -> str: - if json_output: - return result.model_dump_json(indent=2) - lines = [ - f"Screened {result.row_count} comparison row(s) from " - f"{result.comparison_file_count}/{result.scanned_file_count} CSV file(s): " - f"viable={result.summary.viable_count}, review={result.summary.review_count}, " - f"reject={result.summary.reject_count}, duplicates_skipped={result.duplicate_row_count}", - ] - for row in result.rows[:limit]: - lines.append(render_row(row)) - if len(result.rows) > limit: - lines.append(f"... {len(result.rows) - limit} more row(s)") - lines.append("Candidate groups:") - for group in result.groups[:limit]: - lines.append(render_group(group)) - if len(result.groups) > limit: - lines.append(f"... {len(result.groups) - limit} more group(s)") - return "\n".join(lines) - - -def render_row(row: ScreenRow) -> str: - return ( - f"- {row.workload_id}: {row.baseline_config_id}->{row.candidate_config_id} " - f"verdict={row.candidate_verdict} safety={row.safety_verdict} " - f"value_protection={row.value_protection_verdict or 'unknown'} " - f"signature_parity={row.signature_parity_verdict or 'unknown'} " - f"evidence={row.evidence_level} " - f"perf={row.performance_verdict} elapsed_delta={format_number(row.pipeline_elapsed_sec_delta_pct, '%')} " - f"replacement={row.baseline_replacement_strategy or 'unknown'}" - f"->{row.candidate_replacement_strategy or 'unknown'} " - f"tokens_delta={format_number(row.observed_total_tokens_delta)} " - f"cases={format_count(row.baseline_case_count)}/{format_count(row.candidate_case_count)} " - f"shared_stable={format_count(row.shared_stable_final_entity_signature_count)} " - f"aug_new_final_delta={format_number(row.augmented_new_final_value_count_delta)} " - f"baseline_missing_replacements={format_number(row.baseline_replacement_missing_final_entity_count)} " - f"candidate_missing_replacements={format_number(row.candidate_replacement_missing_final_entity_count)} " - f"baseline_original_value_leaks={format_number(row.baseline_original_value_leak_count)} " - f"candidate_original_value_leaks={format_number(row.candidate_original_value_leak_count)} " - f"baseline_replacement_collisions=" - f"{format_number(row.baseline_replacement_synthetic_original_collision_count)} " - f"candidate_replacement_collisions=" - f"{format_number(row.candidate_replacement_synthetic_original_collision_count)} " - f"baseline_duplicate_synthetics={format_number(row.baseline_duplicate_synthetic_replacement_count)} " - f"candidate_duplicate_synthetics={format_number(row.candidate_duplicate_synthetic_replacement_count)} " - f"lost={label_summary(row.baseline_only_label_counts)} " - f"label_mismatch={label_summary(row.label_mismatch_label_counts)} " - f"stable_lost={label_summary(row.stable_lost_label_counts)} " - f"baseline_missing_labels={label_summary(row.baseline_replacement_missing_final_entity_label_counts)} " - f"candidate_missing_labels={label_summary(row.candidate_replacement_missing_final_entity_label_counts)} " - f"baseline_leak_labels={label_summary(row.baseline_original_value_leak_label_counts)} " - f"leak_labels={label_summary(row.candidate_original_value_leak_label_counts)} " - f"baseline_collision_labels=" - f"{label_summary(row.baseline_replacement_synthetic_original_collision_label_counts)} " - f"collision_labels=" - f"{label_summary(row.candidate_replacement_synthetic_original_collision_label_counts)} " - f"flags={','.join(row.flags) if row.flags else 'none'}" - ) - - -def render_group(group: ScreenGroup) -> str: - return ( - f"- {group.group_key}: rows={group.row_count} viable={group.viable_count} " - f"review={group.review_count} reject={group.reject_count} " - f"conflict={str(group.has_conflicting_verdicts).lower()} recommendation={group.recommendation} " - f"replacement={group.candidate_replacement_strategy or 'unknown'} " - f"value_protection_counts={label_summary(group.value_protection_verdict_counts)} " - f"signature_parity_counts={label_summary(group.signature_parity_verdict_counts)} " - f"evidence_counts={label_summary(group.evidence_level_counts)} " - f"split_verdict_candidate_counts={label_summary(group.split_verdict_candidate_verdict_counts)} " - f"perf_counts={group_performance_summary(group)} " - f"best_elapsed_delta={format_number(group.best_pipeline_elapsed_sec_delta_pct, '%')} " - f"worst_elapsed_delta={format_number(group.worst_pipeline_elapsed_sec_delta_pct, '%')} " - f"best_tokens_delta={format_number(group.best_observed_total_tokens_delta)} " - f"worst_tokens_delta={format_number(group.worst_observed_total_tokens_delta)} " - f"best_requests_delta={format_number(group.best_observed_total_requests_delta)} " - f"worst_requests_delta={format_number(group.worst_observed_total_requests_delta)} " - f"min_cases={format_count(group.min_baseline_case_count)}/{format_count(group.min_candidate_case_count)} " - f"min_shared_stable={format_count(group.min_shared_stable_final_entity_signature_count)} " - f"baseline_defect_improvements={group.baseline_defect_improvement_count} " - f"baseline_missing_replacements=" - f"{format_number(group.sum_baseline_replacement_missing_final_entity_count)} " - f"candidate_missing_replacements=" - f"{format_number(group.sum_candidate_replacement_missing_final_entity_count)} " - f"baseline_original_value_leaks={format_number(group.sum_baseline_original_value_leak_count)} " - f"candidate_original_value_leaks={format_number(group.sum_candidate_original_value_leak_count)} " - f"baseline_replacement_collisions=" - f"{format_number(group.sum_baseline_replacement_synthetic_original_collision_count)} " - f"candidate_replacement_collisions=" - f"{format_number(group.sum_candidate_replacement_synthetic_original_collision_count)} " - f"baseline_duplicate_synthetics={format_number(group.sum_baseline_duplicate_synthetic_replacement_count)} " - f"candidate_duplicate_synthetics={format_number(group.sum_candidate_duplicate_synthetic_replacement_count)} " - f"lost={label_summary(group.baseline_only_label_counts)} " - f"label_mismatch={label_summary(group.label_mismatch_label_counts)} " - f"stable_lost={label_summary(group.stable_lost_label_counts)} " - f"baseline_missing_labels={label_summary(group.baseline_replacement_missing_final_entity_label_counts)} " - f"candidate_missing_labels={label_summary(group.candidate_replacement_missing_final_entity_label_counts)} " - f"baseline_leak_labels={label_summary(group.baseline_original_value_leak_label_counts)} " - f"leak_labels={label_summary(group.candidate_original_value_leak_label_counts)} " - f"baseline_collision_labels=" - f"{label_summary(group.baseline_replacement_synthetic_original_collision_label_counts)} " - f"collision_labels=" - f"{label_summary(group.candidate_replacement_synthetic_original_collision_label_counts)} " - f"flags={label_summary(group.flag_counts)}" - ) - - -def format_number(value: float | None, suffix: str = "") -> str: - if value is None: - return "unknown" - return f"{value:.1f}{suffix}" - - -def format_count(value: int | None) -> str: - return str(value) if value is not None else "unknown" - - -def label_summary(counts: dict[str, int]) -> str: - if not counts: - return "none" - return ",".join(f"{label}:{count}" for label, count in counts.items()) - - -@app.default -def main( - comparison_paths: list[Path], - *, - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - group_output: Annotated[Path | None, cyclopts.Parameter("--group-output")] = None, - group_by: Annotated[GroupBy, cyclopts.Parameter("--group-by")] = GroupBy.strategy, - config_aliases: Annotated[Path | None, cyclopts.Parameter("--config-aliases")] = None, - source_include: Annotated[list[str] | None, cyclopts.Parameter("--source-include")] = None, - source_exclude: Annotated[list[str] | None, cyclopts.Parameter("--source-exclude")] = None, - format: Annotated[ExportFormat, cyclopts.Parameter("--format")] = ExportFormat.csv, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - limit: Annotated[int, cyclopts.Parameter("--limit")] = 20, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - configure_logging(log_format) - try: - result = screen_comparison_paths( - comparison_paths, - group_by=group_by, - config_aliases=read_config_aliases(config_aliases), - source_includes=source_include or [], - source_excludes=source_exclude or [], - ) - except ValueError as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - if output is not None: - write_rows(result.rows, output, format) - if group_output is not None: - write_groups(result.groups, group_output, format) - sys.stdout.write(render_result(result, json_output=json_output, limit=limit) + "\n") - - -if __name__ == "__main__": - app() diff --git a/tools/measurement/staged_detection_probe.py b/tools/measurement/staged_detection_probe.py deleted file mode 100644 index 57295a54..00000000 --- a/tools/measurement/staged_detection_probe.py +++ /dev/null @@ -1,1375 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Run benchmark-only DD-free staged entity detection probes. - -Usage: - uv run python tools/measurement/staged_detection_probe.py docs/data/NVIDIA_synthetic_biographies.csv \ - --text-column biography --labels age,city,first_name,last_name,occupation \ - --output /tmp/staged-probe --overwrite -""" - -from __future__ import annotations - -import json -import logging -import os -import re -import shutil -import sys -import time -from collections import Counter -from dataclasses import dataclass, replace -from enum import StrEnum -from pathlib import Path -from typing import Annotated, Any, Protocol - -import cyclopts -import httpx -import pandas as pd -from analyze_detection_artifacts import DetectionArtifactRow, build_detection_artifact_row_from_entities -from dd_parser_compat import _load_embedded_json -from direct_detection_probe import ( - CaseStatus, - DirectCompletion, - DirectDetectionClient, - DirectDetectionRequest, - DirectGenerationRequest, - HttpxDirectDetectionClient, - PromptMode, - SignatureComparison, - build_direct_prompt, - compare_signature_sets, - parse_labels, -) -from measurement_tools.cli import LogFormat, configure_logging, log_bad_input -from pydantic import BaseModel, Field, ValidationError, model_validator - -from anonymizer.engine.constants import ( - COL_AUGMENTED_ENTITIES, - COL_DETECTED_ENTITIES, - COL_RAW_DETECTED, - COL_SEED_ENTITIES, - COL_SEED_ENTITIES_JSON, - COL_SEED_VALIDATION_CANDIDATES, - COL_TAG_NOTATION, - COL_TAGGED_TEXT, - COL_TEXT, - COL_VALIDATED_ENTITIES, - COL_VALIDATION_CANDIDATES, - COL_VALIDATION_DECISIONS, -) -from anonymizer.engine.detection.chunked_validation import ( - build_chunk_excerpt, - chunk_candidates, - merge_chunk_decisions, - order_candidates_by_position, -) -from anonymizer.engine.detection.custom_columns import ( - apply_validation_and_finalize, - apply_validation_to_seed_entities, - enrich_validation_decisions, - merge_and_build_candidates, - prepare_validation_inputs, -) -from anonymizer.engine.detection.postprocess import ( - VALIDATION_CONTEXT_WINDOW, - EntitySpan, - TagNotation, - apply_augmented_entities, - build_tagged_text, - get_tag_notation, - parse_raw_entities, -) -from anonymizer.engine.schemas import ( - EntitiesSchema, - EntitySchema, - RawValidationDecisionsSchema, - ValidatedDecisionsSchema, - ValidationCandidatesSchema, -) - -app = cyclopts.App(help=__doc__) -logger = logging.getLogger("measurement.staged_detection_probe") - -_NATIVE_ENDPOINT_ENV = "ANONYMIZER_BENCH_NATIVE_ENDPOINT" -_NATIVE_MODEL_ENV = "ANONYMIZER_BENCH_NATIVE_MODEL" -_GLINER_ENDPOINT_ENV = "ANONYMIZER_BENCH_GLINER_ENDPOINT" -_GLINER_MODEL_ENV = "ANONYMIZER_BENCH_GLINER_MODEL" -_UNCONFIGURED_ENDPOINT = "configured-native-endpoint" -_UNCONFIGURED_MODEL = "configured-native-model" -_UNCONFIGURED_GLINER_ENDPOINT = "configured-gliner-endpoint" -_UNCONFIGURED_GLINER_MODEL = "configured-gliner-model" -_DATE_OF_BIRTH_CONTEXT_RE = re.compile(r"\b(born|birth|date of birth|dob)\b", re.IGNORECASE) - - -class SeedSource(StrEnum): - direct_llm = "direct_llm" - gliner = "gliner" - - -class ValidationPromptMode(StrEnum): - full_text = "full_text" - chunked_excerpt = "chunked_excerpt" - - -class GlinerDetectionRequest(BaseModel): - endpoint: str - model: str - text: str - labels: list[str] = Field(min_length=1) - threshold: float = Field(default=0.3, ge=0.0, le=1.0) - max_tokens: int = Field(default=4096, gt=0) - timeout_sec: float = Field(default=120.0, gt=0) - api_key_env: str = "NVIDIA_API_KEY" - - -class GlinerSeedClient(Protocol): - def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: ... - - -class StagedExecutionConfig(BaseModel): - endpoint: str = _UNCONFIGURED_ENDPOINT - model: str = _UNCONFIGURED_MODEL - seed_source: SeedSource = SeedSource.direct_llm - gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT - gliner_model: str = _UNCONFIGURED_GLINER_MODEL - gliner_api_key_env: str = "NVIDIA_API_KEY" - gliner_threshold: float = Field(default=0.3, ge=0.0, le=1.0) - max_tokens: int = Field(default=4096, gt=0) - timeout_sec: float = Field(default=180.0, gt=0) - skip_augmentation: bool = False - validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text - validation_max_entities_per_call: int = Field(default=10, gt=0) - validation_excerpt_window_chars: int = Field(default=160, gt=0) - - -class HttpxGlinerSeedClient: - def detect(self, request: GlinerDetectionRequest) -> DirectCompletion: - api_key = os.environ.get(request.api_key_env) - if not api_key: - raise ValueError(f"{request.api_key_env} is not set") - response = httpx.post( - f"{request.endpoint.rstrip('/')}/chat/completions", - headers={"Authorization": f"Bearer {api_key}"}, - json=_gliner_payload(request), - timeout=request.timeout_sec, - ) - response.raise_for_status() - data = response.json() - return DirectCompletion( - content=_completion_content(data), - elapsed_sec=float(response.elapsed.total_seconds()), - usage=data.get("usage") or {}, - ) - - -def _gliner_payload(request: GlinerDetectionRequest) -> dict[str, Any]: - return { - "model": request.model, - "messages": [{"role": "user", "content": request.text}], - "temperature": 0, - "max_tokens": request.max_tokens, - "labels": request.labels, - "threshold": request.threshold, - "chunk_length": 384, - "overlap": 128, - "flat_ner": False, - } - - -def _completion_content(data: dict[str, Any]) -> str: - choice = (data.get("choices") or [{}])[0] - message = choice.get("message") if isinstance(choice, dict) else {} - if isinstance(message, dict): - return str(message.get("content") or "") - return str(choice.get("text") or "") if isinstance(choice, dict) else "" - - -class StagedDetectionRequest(BaseModel): - case_id: str - text: str - labels: list[str] = Field(min_length=1) - row_index: int = 0 - data_summary: str | None = None - - @model_validator(mode="after") - def normalize_labels(self) -> StagedDetectionRequest: - self.labels = list(dict.fromkeys(label.strip() for label in self.labels if label.strip())) - if not self.labels: - raise ValueError("labels must contain at least one non-empty label") - return self - - -class PhaseUsage(BaseModel): - seed: dict[str, Any] = Field(default_factory=dict) - validation: dict[str, Any] = Field(default_factory=dict) - augmentation: dict[str, Any] = Field(default_factory=dict) - - -class PhaseModelWork(BaseModel): - seed: bool = False - validation: bool = False - augmentation: bool = False - - -class PhaseSkipReasons(BaseModel): - seed: str | None = None - validation: str | None = None - augmentation: str | None = None - - -class PhaseModelRequests(BaseModel): - seed: int = 0 - validation: int = 0 - augmentation: int = 0 - - -class StagedDetectionCase(BaseModel): - case_id: str - row_index: int - seed_source: SeedSource = SeedSource.direct_llm - status: CaseStatus - elapsed_sec: float | None = None - model_elapsed_sec: float | None = None - phase_usage: PhaseUsage = Field(default_factory=PhaseUsage) - phase_model_work: PhaseModelWork = Field(default_factory=PhaseModelWork) - phase_skip_reasons: PhaseSkipReasons = Field(default_factory=PhaseSkipReasons) - phase_model_requests: PhaseModelRequests = Field(default_factory=PhaseModelRequests) - total_usage: dict[str, int] = Field(default_factory=dict) - model_phase_count: int = 0 - model_request_count: int = 0 - seed_suggestion_count: int = 0 - seed_entity_count: int = 0 - validation_candidate_count: int = 0 - validation_decision_count: int = 0 - augmented_suggestion_count: int = 0 - final_entity_count: int = 0 - final_entity_signature_count: int = 0 - final_entity_signature_hashes: list[str] = Field(default_factory=list) - final_label_counts: dict[str, int] = Field(default_factory=dict) - comparison: SignatureComparison | None = None - artifact: DetectionArtifactRow | None = None - error: str | None = None - - -class StagedDetectionRun(BaseModel): - input_path: str - text_column: str - endpoint: str - model: str - labels: list[str] - rows: list[StagedDetectionCase] = Field(default_factory=list) - - @property - def error_count(self) -> int: - return sum(1 for row in self.rows if row.status == CaseStatus.error) - - -@dataclass(frozen=True) -class StagedDetectionExecution: - case: StagedDetectionCase - row: dict[str, Any] - - -def run_staged_detection_case( - request: StagedDetectionRequest, - *, - client: DirectDetectionClient, - seed_client: GlinerSeedClient | None = None, - seed_source: SeedSource = SeedSource.direct_llm, - endpoint: str = _UNCONFIGURED_ENDPOINT, - model: str = _UNCONFIGURED_MODEL, - gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, - gliner_model: str = _UNCONFIGURED_GLINER_MODEL, - gliner_api_key_env: str = "NVIDIA_API_KEY", - gliner_threshold: float = 0.3, - max_tokens: int = 4096, - timeout_sec: float = 180.0, - skip_augmentation: bool = False, - validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, - validation_max_entities_per_call: int = 10, - validation_excerpt_window_chars: int = 160, -) -> StagedDetectionCase: - return execute_staged_detection_case( - request, - client=client, - seed_client=seed_client, - seed_source=seed_source, - endpoint=endpoint, - model=model, - gliner_endpoint=gliner_endpoint, - gliner_model=gliner_model, - gliner_api_key_env=gliner_api_key_env, - gliner_threshold=gliner_threshold, - max_tokens=max_tokens, - timeout_sec=timeout_sec, - skip_augmentation=skip_augmentation, - validation_prompt_mode=validation_prompt_mode, - validation_max_entities_per_call=validation_max_entities_per_call, - validation_excerpt_window_chars=validation_excerpt_window_chars, - ).case - - -def execute_staged_detection_case( - request: StagedDetectionRequest, - *, - client: DirectDetectionClient, - seed_client: GlinerSeedClient | None = None, - seed_source: SeedSource = SeedSource.direct_llm, - endpoint: str = _UNCONFIGURED_ENDPOINT, - model: str = _UNCONFIGURED_MODEL, - gliner_endpoint: str = _UNCONFIGURED_GLINER_ENDPOINT, - gliner_model: str = _UNCONFIGURED_GLINER_MODEL, - gliner_api_key_env: str = "NVIDIA_API_KEY", - gliner_threshold: float = 0.3, - max_tokens: int = 4096, - timeout_sec: float = 180.0, - skip_augmentation: bool = False, - validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, - validation_max_entities_per_call: int = 10, - validation_excerpt_window_chars: int = 160, -) -> StagedDetectionExecution: - config = _execution_config_from_params(locals()) - try: - return _run_staged_detection_execution( - request, - client=client, - seed_client=seed_client, - config=config, - ) - except Exception as exc: # noqa: BLE001 - benchmark probe records per-case failures - return StagedDetectionExecution(case=_errored_case(request, config.seed_source, exc), row={}) - - -def _execution_config_from_params(params: dict[str, Any]) -> StagedExecutionConfig: - return StagedExecutionConfig( - **{key: value for key, value in params.items() if key in StagedExecutionConfig.model_fields} - ) - - -def _errored_case(request: StagedDetectionRequest, seed_source: SeedSource, exc: Exception) -> StagedDetectionCase: - return StagedDetectionCase( - case_id=request.case_id, - row_index=request.row_index, - seed_source=seed_source, - status=CaseStatus.error, - error=f"{type(exc).__name__}: {exc}", - ) - - -def _run_staged_detection_execution( - request: StagedDetectionRequest, - *, - client: DirectDetectionClient, - seed_client: GlinerSeedClient | None, - config: StagedExecutionConfig, -) -> StagedDetectionExecution: - started = time.perf_counter() - row, seed_suggestion_count, seed_completion = _run_seed_phase( - request, - client=client, - seed_client=seed_client, - config=config, - ) - validation_completion = _run_validation_phase(row, request, client, config) - augmentation_completion = _run_augmentation_phase(row, request, client, config) - artifact = _finalize_row(row, request) - return StagedDetectionExecution( - case=_completed_case( - request, - seed_completion, - validation_completion, - augmentation_completion, - artifact, - row, - config=config, - seed_suggestion_count=seed_suggestion_count, - elapsed_sec=time.perf_counter() - started, - ), - row=row, - ) - - -def _run_seed_phase( - request: StagedDetectionRequest, - *, - client: DirectDetectionClient, - seed_client: GlinerSeedClient | None, - config: StagedExecutionConfig, -) -> tuple[dict[str, Any], int, DirectCompletion]: - if config.seed_source == SeedSource.gliner: - return _run_gliner_seed_phase(request, seed_client or HttpxGlinerSeedClient(), config) - return _run_direct_llm_seed_phase(request, client, config) - - -def _run_gliner_seed_phase( - request: StagedDetectionRequest, - detector: GlinerSeedClient, - config: StagedExecutionConfig, -) -> tuple[dict[str, Any], int, DirectCompletion]: - completion = detector.detect( - GlinerDetectionRequest( - endpoint=config.gliner_endpoint, - model=config.gliner_model, - text=request.text, - labels=request.labels, - threshold=config.gliner_threshold, - max_tokens=config.max_tokens, - timeout_sec=config.timeout_sec, - api_key_env=config.gliner_api_key_env, - ) - ) - seed_spans = parse_raw_entities(raw_response=completion.content, text=request.text) - return _seed_row_from_spans(request, seed_spans), _raw_detector_entity_count(completion.content), completion - - -def _run_direct_llm_seed_phase( - request: StagedDetectionRequest, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> tuple[dict[str, Any], int, DirectCompletion]: - completion = _complete( - client, - prompt=_seed_prompt(request), - config=config, - ) - row, seed_suggestions = _seed_row_from_llm(request, completion.content) - return row, len(seed_suggestions), completion - - -def _complete( - client: DirectDetectionClient, - *, - prompt: str, - config: StagedExecutionConfig, -) -> DirectCompletion: - return client.complete( - DirectGenerationRequest( - endpoint=config.endpoint, - model=config.model, - prompt=prompt, - max_tokens=config.max_tokens, - timeout_sec=config.timeout_sec, - ) - ) - - -def _seed_prompt(request: StagedDetectionRequest) -> str: - return build_direct_prompt(_direct_seed_request(request)) - - -def _direct_seed_request(request: StagedDetectionRequest) -> DirectDetectionRequest: - return DirectDetectionRequest( - case_id=request.case_id, - text=request.text, - labels=request.labels, - row_index=request.row_index, - prompt_mode=PromptMode.compact, - data_summary=request.data_summary, - ) - - -def _seed_row_from_llm(request: StagedDetectionRequest, content: str) -> tuple[dict[str, Any], list[dict[str, Any]]]: - seed_spans, suggestions = _direct_seed_spans(request, content) - return _seed_row_from_spans(request, seed_spans), suggestions - - -def _direct_seed_spans(request: StagedDetectionRequest, content: str) -> tuple[list[EntitySpan], list[dict[str, Any]]]: - suggestions = _extract_entity_suggestions(content) - seed_spans = _spans_from_suggestions(request.text, suggestions, labels=request.labels, source="direct_seed") - return seed_spans, suggestions - - -def _seed_row_from_spans(request: StagedDetectionRequest, seed_spans: list[EntitySpan]) -> dict[str, Any]: - allowed = set(request.labels) - seed_spans = [_normalize_seed_span(request.text, span, allowed_labels=allowed) for span in seed_spans] - seed_spans = [span for span in seed_spans if span.label in allowed] - row = {COL_TEXT: request.text, COL_TAG_NOTATION: get_tag_notation(text=request.text)} - row[COL_RAW_DETECTED] = "" - row[COL_SEED_ENTITIES] = EntitiesSchema(entities=[_span_to_entity_schema(span) for span in seed_spans]).model_dump( - mode="json" - ) - prepare_validation_inputs(row) - return row - - -def _normalize_seed_span(text: str, span: EntitySpan, *, allowed_labels: set[str]) -> EntitySpan: - if _should_promote_birth_context_date_seed(text, span, allowed_labels=allowed_labels): - return replace(span, entity_id=_seed_entity_id("date_of_birth", span), label="date_of_birth") - return span - - -def _should_promote_birth_context_date_seed( - text: str, - span: EntitySpan, - *, - allowed_labels: set[str], -) -> bool: - return span.label == "date" and "date_of_birth" in allowed_labels and _span_has_date_of_birth_context(text, span) - - -def _span_has_date_of_birth_context(text: str, span: EntitySpan) -> bool: - before_start = max(0, span.start_position - VALIDATION_CONTEXT_WINDOW) - after_end = min(len(text), span.end_position + VALIDATION_CONTEXT_WINDOW) - return _contains_date_of_birth_context( - text[before_start : span.start_position], span.value, text[span.end_position : after_end] - ) - - -def _seed_entity_id(label: str, span: EntitySpan) -> str: - return f"{label}_{span.start_position}_{span.end_position}" - - -def _run_validation_phase( - row: dict[str, Any], - request: StagedDetectionRequest, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> DirectCompletion: - candidates = ValidationCandidatesSchema.from_raw(row.get(COL_SEED_VALIDATION_CANDIDATES, {})) - if not candidates.candidates: - row[COL_VALIDATION_DECISIONS] = {"decisions": []} - row[COL_VALIDATED_ENTITIES] = {"decisions": []} - apply_validation_to_seed_entities(row) - return DirectCompletion(content='{"decisions": []}', elapsed_sec=0.0, usage={}) - if config.validation_prompt_mode == ValidationPromptMode.chunked_excerpt: - completion = _run_chunked_validation_phase(row, request, candidates, client, config) - else: - completion = _run_full_text_validation_phase(row, request, candidates, client, config) - row[COL_VALIDATION_DECISIONS] = _extract_validation_decisions( - completion.content, - candidates=candidates, - labels=request.labels, - ) - enrich_validation_decisions(row) - apply_validation_to_seed_entities(row) - return completion - - -def _run_full_text_validation_phase( - row: dict[str, Any], - request: StagedDetectionRequest, - candidates: ValidationCandidatesSchema, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> DirectCompletion: - completion = _complete( - client, - prompt=_validation_prompt(request, candidates), - config=config, - ) - return completion - - -def _run_chunked_validation_phase( - row: dict[str, Any], - request: StagedDetectionRequest, - candidates: ValidationCandidatesSchema, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> DirectCompletion: - completions: list[DirectCompletion] = [] - chunk_results: list[RawValidationDecisionsSchema] = [] - for chunk in _validation_chunks(row, candidates, config): - completion = _complete(client, prompt=_chunk_validation_prompt(request, chunk), config=config) - completions.append(completion) - chunk_results.append( - RawValidationDecisionsSchema.from_raw( - _extract_validation_decisions(completion.content, candidates=chunk[1], labels=request.labels) - ) - ) - decisions = merge_chunk_decisions(chunk_results, candidates) - return _combine_completions(completions, content=json.dumps(decisions, ensure_ascii=True, sort_keys=True)) - - -def _validation_chunks( - row: dict[str, Any], - candidates: ValidationCandidatesSchema, - config: StagedExecutionConfig, -) -> list[tuple[str, ValidationCandidatesSchema]]: - seed_spans = _seed_entity_spans(row) - ordered = order_candidates_by_position(candidates, seed_spans) - return [ - (_validation_excerpt(row, chunk, seed_spans, config), _chunk_candidate_schema(chunk)) - for chunk in chunk_candidates(ordered, config.validation_max_entities_per_call) - ] - - -def _validation_excerpt( - row: dict[str, Any], - chunk: list[tuple[Any, EntitySpan]], - seed_spans: list[EntitySpan], - config: StagedExecutionConfig, -) -> str: - notation = TagNotation(str(row.get(COL_TAG_NOTATION) or TagNotation.sentinel.value)) - return build_chunk_excerpt( - text=str(row.get(COL_TEXT, "")), - chunk_spans=[span for _candidate, span in chunk], - all_spans=seed_spans, - window_chars=config.validation_excerpt_window_chars, - notation=notation, - ) - - -def _chunk_candidate_schema(chunk: list[tuple[Any, EntitySpan]]) -> ValidationCandidatesSchema: - return ValidationCandidatesSchema(candidates=[candidate for candidate, _span in chunk]) - - -def _chunk_validation_prompt( - request: StagedDetectionRequest, - chunk: tuple[str, ValidationCandidatesSchema], -) -> str: - excerpt, candidates = chunk - return _validation_prompt(request, candidates, input_text=excerpt) - - -def _seed_entity_spans(row: dict[str, Any]) -> list[EntitySpan]: - return [ - EntitySpan( - entity_id=entity.id, - value=entity.value, - label=entity.label, - start_position=entity.start_position, - end_position=entity.end_position, - score=entity.score, - source=entity.source, - ) - for entity in EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities - ] - - -def _combine_completions(completions: list[DirectCompletion], *, content: str) -> DirectCompletion: - return DirectCompletion( - content=content, - elapsed_sec=sum(completion.elapsed_sec for completion in completions), - usage=_sum_completion_usage(completions), - ) - - -def _sum_completion_usage(completions: list[DirectCompletion]) -> dict[str, int]: - totals: Counter[str] = Counter() - for completion in completions: - for key, value in completion.usage.items(): - if isinstance(value, int): - totals[key] += value - return dict(sorted(totals.items())) - - -def _validation_prompt( - request: StagedDetectionRequest, - candidates: ValidationCandidatesSchema, - *, - input_text: str | None = None, -) -> str: - text_for_prompt = input_text if input_text is not None else request.text - label_guidance = _validation_label_guidance(request.labels) - return f"""Validate candidate privacy-sensitive entities. -Use only these decisions: keep, reclass, drop. -Use only these labels when reclassifying: {", ".join(request.labels)}. -Prefer keep when the candidate already has the right specific label. For example, reclass date_of_birth to date only when the surrounding context is not birth-related. -{label_guidance} -Return JSON only with this shape: -{{"decisions": [{{"id": "candidate id", "decision": "keep|reclass|drop", "proposed_label": "", "reason": "short reason"}}]}} - -Context text: ---- -{text_for_prompt} ---- - -Candidates: -{candidates.model_dump_json()} -""" - - -def _validation_label_guidance(labels: list[str]) -> str: - allowed = set(labels) - lines: list[str] = [] - if "degree" in allowed: - lines.append( - "Prefer degree for credential names and degree abbreviations such as Bachelor of Science, BSc, BA, MA, MSc, PhD, or JD; use education_level for broad levels such as undergraduate, graduate, or high school." - ) - if not lines: - return "" - return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) - - -def _run_augmentation_phase( - row: dict[str, Any], - request: StagedDetectionRequest, - client: DirectDetectionClient, - config: StagedExecutionConfig, -) -> DirectCompletion: - if _should_skip_augmentation(request, config): - row[COL_AUGMENTED_ENTITIES] = {"entities": []} - return DirectCompletion(content="", elapsed_sec=0.0, usage={}) - completion = _complete( - client, - prompt=_augmentation_prompt(request, row), - config=config, - ) - row[COL_AUGMENTED_ENTITIES] = { - "entities": _allowed_entity_suggestions(_extract_entity_suggestions(completion.content), labels=request.labels) - } - return completion - - -def _should_skip_augmentation(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: - return config.skip_augmentation - - -def _augmentation_prompt(request: StagedDetectionRequest, row: dict[str, Any]) -> str: - seed_entities = row.get(COL_SEED_ENTITIES_JSON, "[]") - label_guidance = _augmentation_label_guidance(request.labels) - return f"""Find additional sensitive entities not already covered by the validated seed entities. -Use only these labels: {", ".join(request.labels)}. -Return exact substrings only. Do not invent values. -{label_guidance} -Return JSON only with this shape: -{{"entities": [{{"value": "exact substring", "label": "one_allowed_label", "reason": "short reason"}}]}} - -Input text: ---- -{request.text} ---- - -Validated seed entities: -{seed_entities} -""" - - -def _augmentation_label_guidance(labels: list[str]) -> str: - allowed = set(labels) - lines: list[str] = [] - if "first_name" in allowed: - lines.append( - "For family/member lists, split personal names connected by 'and' into separate first_name or last_name entities. Do not label a list of people as organization_name." - ) - if "last_name" in allowed and ({"place_name", "company_name", "organization_name"} & allowed): - lines.append( - "If a surname appears inside a street, place, organization, company, or possessive business phrase, also return the surname substring as last_name when it is not already covered." - ) - if not lines: - return "" - return "Label boundary guidance:\n" + "\n".join(f"- {line}" for line in lines) - - -def _finalize_row(row: dict[str, Any], request: StagedDetectionRequest) -> DetectionArtifactRow: - merge_and_build_candidates(row) - apply_validation_and_finalize(row) - normalize_birth_context_final_entities(row, allowed_labels=request.labels) - seed_entities = EntitiesSchema.from_raw(row.get(COL_SEED_ENTITIES, {})).entities - final_entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities - augmented = _augmented_entity_schemas(row, request) - return build_detection_artifact_row_from_entities( - workflow_name="staged-direct-detection", - batch_file="staged-direct-detection", - row_index=request.row_index, - seed_entities=list(seed_entities), - seed_validation_candidate_count=_candidate_count(row.get(COL_SEED_VALIDATION_CANDIDATES)), - merged_validation_candidate_count=_candidate_count(row.get(COL_VALIDATION_CANDIDATES)), - augmented_entities=augmented, - final_entities=list(final_entities), - ) - - -def _completed_case( - request: StagedDetectionRequest, - seed: DirectCompletion, - validation: DirectCompletion, - augmentation: DirectCompletion, - artifact: DetectionArtifactRow, - row: dict[str, Any], - config: StagedExecutionConfig, - seed_suggestion_count: int, - elapsed_sec: float, -) -> StagedDetectionCase: - phase_usage = PhaseUsage(seed=seed.usage, validation=validation.usage, augmentation=augmentation.usage) - phase_model_work = _phase_model_work(request, artifact, config) - phase_model_requests = _phase_model_requests(artifact, phase_model_work, config) - model_elapsed_sec = seed.elapsed_sec + validation.elapsed_sec + augmentation.elapsed_sec - return StagedDetectionCase( - case_id=request.case_id, - row_index=request.row_index, - seed_source=config.seed_source, - status=CaseStatus.completed, - elapsed_sec=elapsed_sec, - model_elapsed_sec=model_elapsed_sec, - phase_usage=phase_usage, - phase_model_work=phase_model_work, - phase_skip_reasons=_phase_skip_reasons(request, artifact, config), - phase_model_requests=phase_model_requests, - total_usage=_sum_usage(phase_usage), - model_phase_count=_model_phase_count(phase_model_work), - model_request_count=_model_request_count(phase_model_requests), - seed_suggestion_count=seed_suggestion_count, - seed_entity_count=artifact.seed_entity_count, - validation_candidate_count=artifact.seed_validation_candidate_count, - validation_decision_count=_validated_decision_count(row.get(COL_VALIDATED_ENTITIES)), - augmented_suggestion_count=artifact.augmented_entity_count, - final_entity_count=artifact.final_entity_count, - final_entity_signature_count=artifact.final_entity_signature_count, - final_entity_signature_hashes=artifact.final_entity_signature_hashes, - final_label_counts=artifact.final_label_counts, - artifact=artifact, - ) - - -def _phase_model_work( - request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig -) -> PhaseModelWork: - return PhaseModelWork( - seed=_uses_seed_model(request, config), - validation=_uses_validation_model(request, artifact, config), - augmentation=not _should_skip_augmentation(request, config), - ) - - -def _uses_seed_model(request: StagedDetectionRequest, config: StagedExecutionConfig) -> bool: - return config.seed_source in {SeedSource.direct_llm, SeedSource.gliner} - - -def _uses_validation_model( - request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig -) -> bool: - return artifact.seed_validation_candidate_count > 0 - - -def _phase_skip_reasons( - request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig -) -> PhaseSkipReasons: - return PhaseSkipReasons( - seed=_seed_skip_reason(request, config), - validation=_validation_skip_reason(request, artifact, config), - augmentation=_augmentation_skip_reason(request, config), - ) - - -def _seed_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: - return None - - -def _validation_skip_reason( - request: StagedDetectionRequest, artifact: DetectionArtifactRow, config: StagedExecutionConfig -) -> str | None: - if artifact.seed_validation_candidate_count == 0: - return "no_seed_candidates" - return None - - -def _augmentation_skip_reason(request: StagedDetectionRequest, config: StagedExecutionConfig) -> str | None: - if config.skip_augmentation: - return "disabled" - return None - - -def _model_phase_count(phase_model_work: PhaseModelWork) -> int: - return sum((phase_model_work.seed, phase_model_work.validation, phase_model_work.augmentation)) - - -def _phase_model_requests( - artifact: DetectionArtifactRow, - phase_model_work: PhaseModelWork, - config: StagedExecutionConfig, -) -> PhaseModelRequests: - return PhaseModelRequests( - seed=1 if phase_model_work.seed else 0, - validation=_validation_model_request_count(artifact, phase_model_work, config), - augmentation=1 if phase_model_work.augmentation else 0, - ) - - -def _validation_model_request_count( - artifact: DetectionArtifactRow, - phase_model_work: PhaseModelWork, - config: StagedExecutionConfig, -) -> int: - if not phase_model_work.validation: - return 0 - if config.validation_prompt_mode == ValidationPromptMode.full_text: - return 1 - return _ceil_div(artifact.seed_validation_candidate_count, config.validation_max_entities_per_call) - - -def _ceil_div(numerator: int, denominator: int) -> int: - return (numerator + denominator - 1) // denominator - - -def _model_request_count(phase_model_requests: PhaseModelRequests) -> int: - return phase_model_requests.seed + phase_model_requests.validation + phase_model_requests.augmentation - - -def _extract_entity_suggestions(content: str) -> list[dict[str, str]]: - payload = _load_embedded_json(content) - raw_entities = payload.get("entities", []) if isinstance(payload, dict) else [] - return [ - {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} - for item in raw_entities - if isinstance(item, dict) and item.get("value") and item.get("label") - ] - - -def _allowed_entity_suggestions(suggestions: list[dict[str, str]], *, labels: list[str]) -> list[dict[str, str]]: - allowed = set(labels) - return [suggestion for suggestion in suggestions if suggestion["label"] in allowed] - - -def _extract_validation_decisions( - content: str, - *, - candidates: ValidationCandidatesSchema | None = None, - labels: list[str] | None = None, -) -> dict[str, Any]: - payload = _load_embedded_json(content) - decisions = payload.get("decisions", []) if isinstance(payload, dict) else [] - if not isinstance(decisions, list): - decisions = [] - if candidates is not None: - decisions = _preserve_specific_seed_labels(decisions, candidates) - if labels is not None: - decisions = _keep_invalid_reclass_labels(decisions, allowed_labels=set(labels)) - return {"decisions": decisions} - - -def _preserve_specific_seed_labels( - decisions: list[object], - candidates: ValidationCandidatesSchema, -) -> list[object]: - candidate_by_id = {candidate.id: candidate for candidate in candidates.candidates} - normalized: list[object] = [] - for decision in decisions: - if not isinstance(decision, dict): - normalized.append(decision) - continue - candidate = candidate_by_id.get(str(decision.get("id") or "")) - if candidate is None or not _should_preserve_specific_seed_label(decision, candidate): - normalized.append(decision) - continue - normalized.append( - { - **decision, - "decision": "keep", - "proposed_label": "", - "reason": _preserved_label_reason(decision.get("reason")), - } - ) - return normalized - - -def _should_preserve_specific_seed_label(decision: dict[str, object], candidate: Any) -> bool: - if str(decision.get("decision") or "") != "reclass": - return False - proposed_label = str(decision.get("proposed_label") or "") - if candidate.label == "date_of_birth" and proposed_label == "date": - return _has_date_of_birth_context(candidate) - return False - - -def _has_date_of_birth_context(candidate: Any) -> bool: - return _contains_date_of_birth_context(candidate.context_before, candidate.value, candidate.context_after) - - -def _keep_invalid_reclass_labels(decisions: list[object], *, allowed_labels: set[str]) -> list[object]: - normalized: list[object] = [] - for decision in decisions: - if not isinstance(decision, dict): - normalized.append(decision) - continue - if not _has_invalid_reclass_label(decision, allowed_labels=allowed_labels): - normalized.append(decision) - continue - normalized.append( - { - **decision, - "decision": "keep", - "proposed_label": "", - "reason": _invalid_reclass_label_reason(decision.get("reason"), decision.get("proposed_label")), - } - ) - return normalized - - -def _has_invalid_reclass_label(decision: dict[str, object], *, allowed_labels: set[str]) -> bool: - if str(decision.get("decision") or "") != "reclass": - return False - proposed_label = str(decision.get("proposed_label") or "").strip() - return proposed_label not in allowed_labels - - -def _invalid_reclass_label_reason(reason: object, proposed_label: object) -> str: - text = str(reason or "").strip() - suffix = f"ignored invalid reclass label {str(proposed_label or '').strip()!r}" - return f"{text}; {suffix}" if text else suffix - - -def normalize_birth_context_final_entities( - row: dict[str, Any], *, allowed_labels: list[str] | set[str] -) -> dict[str, Any]: - """Demote native date_of_birth spans without birth context back to date.""" - allowed = set(allowed_labels) - if not {"date", "date_of_birth"}.issubset(allowed): - return row - text = str(row.get(COL_TEXT, "")) - entities = EntitiesSchema.from_raw(row.get(COL_DETECTED_ENTITIES, {})).entities - normalized: list[EntitySchema] = [] - changed = False - for entity in entities: - if entity.label == "date_of_birth" and not _entity_has_date_of_birth_context(text, entity): - normalized.append(entity.model_copy(update={"id": _entity_id("date", entity), "label": "date"})) - changed = True - else: - normalized.append(entity) - if changed: - row[COL_DETECTED_ENTITIES] = EntitiesSchema(entities=normalized).model_dump(mode="json") - row[COL_TAGGED_TEXT] = build_tagged_text( - text=text, - entities=[_entity_schema_to_span(entity) for entity in normalized], - ) - return row - - -def _entity_has_date_of_birth_context(text: str, entity: EntitySchema) -> bool: - before_start = max(0, entity.start_position - VALIDATION_CONTEXT_WINDOW) - after_end = min(len(text), entity.end_position + VALIDATION_CONTEXT_WINDOW) - return _contains_date_of_birth_context( - text[before_start : entity.start_position], - entity.value, - text[entity.end_position : after_end], - ) - - -def _entity_id(label: str, entity: EntitySchema) -> str: - return f"{label}_{entity.start_position}_{entity.end_position}" - - -def _entity_schema_to_span(entity: EntitySchema) -> EntitySpan: - return EntitySpan( - entity_id=entity.id, - value=entity.value, - label=entity.label, - start_position=entity.start_position, - end_position=entity.end_position, - score=entity.score, - source=entity.source, - ) - - -def _contains_date_of_birth_context(context_before: object, value: object, context_after: object) -> bool: - context = f"{context_before} {value} {context_after}" - return bool(_DATE_OF_BIRTH_CONTEXT_RE.search(context)) - - -def _preserved_label_reason(reason: object) -> str: - text = str(reason or "").strip() - suffix = "preserved specific seed label from birth-related context" - return f"{text}; {suffix}" if text else suffix - - -def _raw_detector_entity_count(content: str) -> int: - payload = _load_embedded_json(content) - entities = payload.get("entities", []) if isinstance(payload, dict) else [] - return len(entities) if isinstance(entities, list) else 0 - - -def _spans_from_suggestions( - text: str, suggestions: list[dict[str, str]], *, labels: list[str], source: str -) -> list[EntitySpan]: - allowed = set(labels) - cleaned = [item for item in suggestions if item["value"] and item["label"] in allowed] - spans = apply_augmented_entities(text=text, entities=[], augmented_output={"entities": cleaned}) - return [ - EntitySpan( - entity_id=span.entity_id, - value=span.value, - label=span.label, - start_position=span.start_position, - end_position=span.end_position, - score=span.score, - source=source, - ) - for span in spans - ] - - -def _span_to_entity_schema(span: EntitySpan) -> EntitySchema: - return EntitySchema( - id=span.entity_id, - value=span.value, - label=span.label, - start_position=span.start_position, - end_position=span.end_position, - score=span.score, - source=span.source, - ) - - -def _augmented_entity_schemas(row: dict[str, Any], request: StagedDetectionRequest) -> list[EntitySchema]: - spans = _spans_from_suggestions( - request.text, - _extract_entities_from_payload(row.get(COL_AUGMENTED_ENTITIES)), - labels=request.labels, - source="augmenter", - ) - return [_span_to_entity_schema(span) for span in spans] - - -def _extract_entities_from_payload(payload: object) -> list[dict[str, str]]: - entities = payload.get("entities", []) if isinstance(payload, dict) else [] - return [ - {"value": str(item.get("value", "")).strip(), "label": str(item.get("label", "")).strip()} - for item in entities - if isinstance(item, dict) - ] - - -def _candidate_count(raw: object) -> int: - return len(ValidationCandidatesSchema.from_raw(raw).candidates) - - -def _validated_decision_count(raw: object) -> int: - return len(ValidatedDecisionsSchema.from_raw(raw).decisions) - - -def _sum_usage(usage: PhaseUsage) -> dict[str, int]: - totals: Counter[str] = Counter() - for phase in (usage.seed, usage.validation, usage.augmentation): - for key, value in phase.items(): - if isinstance(value, int): - totals[key] += value - return dict(sorted(totals.items())) - - -def apply_baseline_comparisons( - cases: list[StagedDetectionCase], - baseline_artifacts: Path, -) -> list[StagedDetectionCase]: - baseline = _read_baseline_artifacts(baseline_artifacts) - return [_case_with_comparison(case, baseline.get(case.row_index)) for case in cases] - - -def _case_with_comparison(case: StagedDetectionCase, baseline_row: dict[str, Any] | None) -> StagedDetectionCase: - if baseline_row is None or case.status != CaseStatus.completed or case.artifact is None: - return case - baseline_hashes = _baseline_signature_hashes(baseline_row) - if baseline_hashes is None: - return case - comparison = compare_signature_sets( - baseline_hashes=baseline_hashes, - baseline_labels=_signature_labels(baseline_row), - direct_hashes=set(case.artifact.final_entity_signature_hashes), - direct_labels=case.artifact.final_entity_signature_labels, - ) - return case.model_copy(update={"comparison": comparison}) - - -def _baseline_signature_hashes(row: dict[str, Any]) -> set[str] | None: - hashes = row.get("final_entity_signature_hashes") - if not isinstance(hashes, list): - return None - return {str(item) for item in hashes} - - -def _read_baseline_artifacts(path: Path) -> dict[int, dict[str, Any]]: - baseline: dict[int, dict[str, Any]] = {} - with path.open(encoding="utf-8") as source: - for line in source: - if not line.strip(): - continue - row = json.loads(line) - row_index = int(row.get("row_index", 0)) - if row_index in baseline: - raise ValueError(f"baseline artifacts has multiple rows for row_index={row_index}") - baseline[row_index] = row - return baseline - - -def _signature_labels(row: dict[str, Any]) -> dict[str, str]: - return { - key.removeprefix("final_entity_signature_labels."): str(value) - for key, value in row.items() - if key.startswith("final_entity_signature_labels.") and value is not None - } - - -def run_probe( - input_path: Path, - *, - text_column: str, - labels: list[str], - output: Path | None = None, - overwrite: bool = False, - endpoint: str | None = None, - model: str | None = None, - seed_source: SeedSource = SeedSource.direct_llm, - gliner_endpoint: str | None = None, - gliner_model: str | None = None, - gliner_api_key_env: str = "NVIDIA_API_KEY", - gliner_threshold: float = 0.3, - skip_augmentation: bool = False, - validation_prompt_mode: ValidationPromptMode = ValidationPromptMode.full_text, - validation_max_entities_per_call: int = 10, - validation_excerpt_window_chars: int = 160, - limit: int = 1, - offset: int = 0, - baseline_artifacts: Path | None = None, -) -> StagedDetectionRun: - endpoint = _required_runtime_value("endpoint", explicit=endpoint, env_var=_NATIVE_ENDPOINT_ENV) - model = _required_runtime_value("model", explicit=model, env_var=_NATIVE_MODEL_ENV) - if seed_source == SeedSource.gliner: - gliner_endpoint = _required_runtime_value( - "gliner_endpoint", explicit=gliner_endpoint, env_var=_GLINER_ENDPOINT_ENV - ) - gliner_model = _required_runtime_value("gliner_model", explicit=gliner_model, env_var=_GLINER_MODEL_ENV) - if not os.environ.get(gliner_api_key_env): - raise ValueError(f"{gliner_api_key_env} is not set") - else: - gliner_endpoint = gliner_endpoint or _UNCONFIGURED_GLINER_ENDPOINT - gliner_model = gliner_model or _UNCONFIGURED_GLINER_MODEL - requests = _load_requests(input_path, text_column=text_column, labels=labels, limit=limit, offset=offset) - config = _execution_config_from_params(locals()) - cases = _run_probe_cases(requests, config) - if baseline_artifacts is not None: - cases = apply_baseline_comparisons(cases, baseline_artifacts) - result = StagedDetectionRun( - input_path=str(input_path), text_column=text_column, endpoint=endpoint, model=model, labels=labels, rows=cases - ) - if output is not None: - write_outputs(result, output, overwrite=overwrite) - return result - - -def _required_runtime_value(name: str, *, explicit: str | None, env_var: str) -> str: - value = explicit or os.environ.get(env_var) - if not value: - raise ValueError(f"{name} is required; pass --{name.replace('_', '-')} or set {env_var}") - return value - - -def _run_probe_cases( - requests: list[StagedDetectionRequest], - config: StagedExecutionConfig, -) -> list[StagedDetectionCase]: - client = HttpxDirectDetectionClient() - seed_client = HttpxGlinerSeedClient() if config.seed_source == SeedSource.gliner else None - return [ - run_staged_detection_case( - request, - client=client, - seed_client=seed_client, - seed_source=config.seed_source, - endpoint=config.endpoint, - model=config.model, - gliner_endpoint=config.gliner_endpoint, - gliner_model=config.gliner_model, - gliner_api_key_env=config.gliner_api_key_env, - gliner_threshold=config.gliner_threshold, - skip_augmentation=config.skip_augmentation, - validation_prompt_mode=config.validation_prompt_mode, - validation_max_entities_per_call=config.validation_max_entities_per_call, - validation_excerpt_window_chars=config.validation_excerpt_window_chars, - ) - for request in requests - ] - - -def _load_requests( - input_path: Path, *, text_column: str, labels: list[str], limit: int, offset: int -) -> list[StagedDetectionRequest]: - dataframe = pd.read_csv(input_path) - if text_column not in dataframe.columns: - raise ValueError(f"text column {text_column!r} not found in {input_path}") - selected = dataframe.iloc[offset : offset + limit] - return [ - StagedDetectionRequest( - case_id=f"{input_path.stem}-row-{int(index)}", - text=str(row[text_column]), - labels=labels, - row_index=int(index), - ) - for index, row in selected.iterrows() - ] - - -def write_outputs(result: StagedDetectionRun, output_dir: Path, *, overwrite: bool) -> None: - if output_dir.exists(): - if not overwrite: - raise ValueError(f"output directory already exists: {output_dir}") - shutil.rmtree(output_dir) - output_dir.mkdir(parents=True) - _write_jsonl(output_dir / "staged-detection-cases.jsonl", [_case_payload(case) for case in result.rows]) - _write_jsonl(output_dir / "staged-detection-artifacts.jsonl", [_artifact_payload(case) for case in result.rows]) - (output_dir / "summary.json").write_text(result.model_dump_json(indent=2) + "\n", encoding="utf-8") - - -def _case_payload(case: StagedDetectionCase) -> dict[str, Any]: - payload = case.model_dump(exclude={"artifact"}) - payload["record_type"] = "staged_detection_case" - return payload - - -def _artifact_payload(case: StagedDetectionCase) -> dict[str, Any]: - payload = case.artifact.model_dump() if case.artifact is not None else {} - payload.update({"case_id": case.case_id, "row_index": case.row_index, "record_type": "staged_detection_artifact"}) - return payload - - -def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: - with path.open("w", encoding="utf-8") as target: - for row in rows: - target.write(json.dumps(row, ensure_ascii=True, sort_keys=True) + "\n") - - -def render_result(result: StagedDetectionRun, *, json_output: bool) -> str: - if json_output: - return result.model_dump_json(indent=2) - completed = len(result.rows) - result.error_count - return f"Ran {completed}/{len(result.rows)} staged detection case(s); errors={result.error_count}" - - -@app.default -def main( - input_path: Path, - *, - text_column: Annotated[str, cyclopts.Parameter("--text-column")], - labels: Annotated[str, cyclopts.Parameter("--labels")], - output: Annotated[Path | None, cyclopts.Parameter(("--output", "-o"))] = None, - overwrite: Annotated[bool, cyclopts.Parameter("--overwrite")] = False, - endpoint: Annotated[str | None, cyclopts.Parameter("--endpoint")] = None, - model: Annotated[str | None, cyclopts.Parameter("--model")] = None, - seed_source: Annotated[SeedSource, cyclopts.Parameter("--seed-source")] = SeedSource.direct_llm, - gliner_endpoint: Annotated[str | None, cyclopts.Parameter("--gliner-endpoint")] = None, - gliner_model: Annotated[str | None, cyclopts.Parameter("--gliner-model")] = None, - gliner_api_key_env: Annotated[str, cyclopts.Parameter("--gliner-api-key-env")] = "NVIDIA_API_KEY", - gliner_threshold: Annotated[float, cyclopts.Parameter("--gliner-threshold")] = 0.3, - skip_augmentation: Annotated[bool, cyclopts.Parameter("--skip-augmentation")] = False, - validation_prompt_mode: Annotated[ - ValidationPromptMode, cyclopts.Parameter("--validation-prompt-mode") - ] = ValidationPromptMode.full_text, - validation_max_entities_per_call: Annotated[int, cyclopts.Parameter("--validation-max-entities-per-call")] = 10, - validation_excerpt_window_chars: Annotated[int, cyclopts.Parameter("--validation-excerpt-window-chars")] = 160, - limit: Annotated[int, cyclopts.Parameter("--limit")] = 1, - offset: Annotated[int, cyclopts.Parameter("--offset")] = 0, - baseline_artifacts: Annotated[Path | None, cyclopts.Parameter("--baseline-artifacts")] = None, - json_output: Annotated[bool, cyclopts.Parameter("--json")] = False, - log_format: Annotated[LogFormat, cyclopts.Parameter("--log-format")] = LogFormat.plain, -) -> None: - params = locals() - json_output = bool(params.pop("json_output")) - result = _run_main_probe(params) - sys.stdout.write(render_result(result, json_output=json_output) + "\n") - if result.error_count: - raise SystemExit(1) - - -def _run_main_probe(params: dict[str, Any]) -> StagedDetectionRun: - configure_logging(params.pop("log_format")) - params["labels"] = parse_labels(params["labels"]) - try: - return run_probe(**params) - except (ValueError, ValidationError, httpx.HTTPError) as exc: - log_bad_input(logger, str(exc)) - raise SystemExit(125) from exc - - -if __name__ == "__main__": - app() From 7d58fd2b15084787272e253d23088a608d174ac4 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Wed, 10 Jun 2026 20:57:40 +0000 Subject: [PATCH 21/26] Trace custom DataDesigner model calls Signed-off-by: Aaron Gonzales --- docs/development/observability.md | 19 +- src/anonymizer/engine/ndd/adapter.py | 369 +++++++++++++++++++++++++-- src/anonymizer/measurement/sinks.py | 8 +- tests/test_measurement.py | 112 ++++++-- tools/measurement/README.md | 18 +- 5 files changed, 479 insertions(+), 47 deletions(-) diff --git a/docs/development/observability.md b/docs/development/observability.md index b69cff29..c1af461e 100644 --- a/docs/development/observability.md +++ b/docs/development/observability.md @@ -115,12 +115,19 @@ Message traces are separate from measurement records. They may contain raw input text, prompts, generated output, entity values, replacement values, secrets, and PII. Do not share them unless they have been reviewed or redacted. -Anonymizer requests these traces through DataDesigner native LLM column trace -side effects. That covers `LLMTextColumnConfig` and -`LLMStructuredColumnConfig`. It does not cover model calls made inside -`CustomColumnConfig` generator functions. When tracing is enabled, the -measurement stream records a `dd_trace_coverage` row so benchmark analysis can -see unsupported columns. +Anonymizer requests standard LLM-column traces through DataDesigner native LLM +column trace side effects. That covers `LLMTextColumnConfig` and +`LLMStructuredColumnConfig`. + +Model-backed `CustomColumnConfig` generator functions use a temporary +Anonymizer shim that instruments the per-run DataDesigner model registry and +returned model facades. This captures model calls that DataDesigner does not yet +expose through a public trace sink. Treat this as a brittle bridge over private +DataDesigner internals, not as a stable integration point. + +When tracing is enabled, the measurement stream records a `dd_trace_coverage` +row with native, private-facade, and unsupported column counts so benchmark +analysis can see which trace path covered each workflow. ## DataDesigner Task Traces diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 3ca6855e..96e1b4cd 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -5,12 +5,15 @@ import json import logging +import re import tempfile import time import uuid from collections.abc import Iterator, Mapping from contextlib import contextmanager +from contextvars import ContextVar from dataclasses import dataclass +from functools import wraps from pathlib import Path from threading import RLock from typing import TYPE_CHECKING, Any, cast @@ -36,6 +39,8 @@ RECORD_ID_COLUMN = "_anonymizer_record_id" _TRACEABLE_LLM_COLUMN_TYPES = (LLMTextColumnConfig, LLMStructuredColumnConfig) +_MODEL_TRACE_COLUMN: ContextVar[str | None] = ContextVar("anonymizer_dd_model_trace_column", default=None) +_MODEL_TRACE_PURPOSE: ContextVar[str | None] = ContextVar("anonymizer_dd_model_trace_purpose", default=None) @dataclass(frozen=True) @@ -64,6 +69,11 @@ class _NativeTraceColumn: model_provider_name: str | None +@dataclass(frozen=True) +class _PrivateFacadeTraceColumn: + column_name: str + + class NddAdapter: """Adapter for running NDD workflows with uniform I/O and record tracking.""" @@ -115,16 +125,23 @@ def run_workflow( ) started = time.perf_counter() collector = current_collector() - usage_probe = _DataDesignerUsageProbe(self._data_designer, enabled=collector is not None) - columns, native_trace_columns, unsupported_trace_columns = _configure_native_dd_message_traces( + columns, native_trace_columns, private_trace_columns, unsupported_trace_columns = _configure_dd_message_traces( columns=columns, model_configs=model_configs, collector=collector, ) + usage_probe = _DataDesignerUsageProbe( + self._data_designer, + enabled=collector is not None, + collector=collector, + workflow_name=workflow_name, + private_trace_columns=private_trace_columns, + ) _record_dd_trace_coverage( workflow_name=workflow_name, collector=collector, native_trace_columns=native_trace_columns, + private_trace_columns=private_trace_columns, unsupported_trace_columns=unsupported_trace_columns, ) @@ -328,11 +345,24 @@ def _as_alias_list(raw: Any) -> list[str]: class _DataDesignerUsageProbe: """Capture DataDesigner model usage from the per-run private ResourceProvider.""" - def __init__(self, data_designer: DataDesigner, *, enabled: bool) -> None: + def __init__( + self, + data_designer: DataDesigner, + *, + enabled: bool, + collector: Any | None = None, + workflow_name: str | None = None, + private_trace_columns: list[_PrivateFacadeTraceColumn] | None = None, + ) -> None: self._data_designer = data_designer self._enabled = enabled + self._collector = collector + self._workflow_name = workflow_name + self._private_trace_column_names = {column.column_name for column in private_trace_columns or []} self._original_create_resource_provider: Any | None = None self._resource_providers: list[Any] = [] + self._model_registry_patches: list[tuple[Any, Any]] = [] + self._facade_patches: dict[int, tuple[Any, dict[str, Any]]] = {} def __enter__(self) -> _DataDesignerUsageProbe: if not self._enabled: @@ -347,12 +377,14 @@ def __enter__(self) -> _DataDesignerUsageProbe: def wrapper(*args: Any, **kwargs: Any) -> Any: resource_provider = original(*args, **kwargs) self._resource_providers.append(resource_provider) + self._install_private_model_trace(resource_provider) return resource_provider setattr(self._data_designer, "_create_resource_provider", wrapper) return self def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + self._restore_private_trace_patches() if self._original_create_resource_provider is not None: setattr(self._data_designer, "_create_resource_provider", self._original_create_resource_provider) @@ -367,6 +399,157 @@ def model_usage(self) -> dict[str, Any] | None: usage[str(model_name)] = _model_usage_as_json(stats) return usage or None + def _private_trace_enabled(self) -> bool: + return bool( + self._collector is not None + and self._collector.dd_trace_enabled + and self._workflow_name + and self._private_trace_column_names + ) + + def _install_private_model_trace(self, resource_provider: Any) -> None: + if not self._private_trace_enabled(): + return + model_registry = getattr(resource_provider, "model_registry", None) + get_model = getattr(model_registry, "get_model", None) + if not callable(get_model): + return + + def wrapped_get_model(*args: Any, **kwargs: Any) -> Any: + facade = get_model(*args, **kwargs) + self._patch_model_facade(facade) + return facade + + # Temporary private DataDesigner shim: CustomColumnConfig receives + # ModelFacade objects directly and DD does not yet expose a public + # model-call event sink for those calls. + setattr(model_registry, "get_model", wrapped_get_model) + self._model_registry_patches.append((model_registry, get_model)) + + def _patch_model_facade(self, facade: Any) -> None: + facade_id = id(facade) + if facade_id in self._facade_patches: + return + + originals: dict[str, Any] = {} + for method_name in ("completion", "acompletion", "generate", "agenerate"): + method = getattr(facade, method_name, None) + if not callable(method): + continue + originals[method_name] = method + setattr(facade, method_name, self._wrap_facade_method(facade, method_name, method)) + + if originals: + self._facade_patches[facade_id] = (facade, originals) + + def _wrap_facade_method(self, facade: Any, method_name: str, method: Any) -> Any: + if method_name == "acompletion": + return self._wrap_async_completion(facade, method) + if method_name == "completion": + return self._wrap_completion(facade, method) + if method_name == "agenerate": + return self._wrap_async_generate(method) + return self._wrap_generate(method) + + def _wrap_generate(self, method: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: + token = _MODEL_TRACE_PURPOSE.set(_purpose_from_kwargs(kwargs)) + try: + return method(*args, **kwargs) + finally: + _MODEL_TRACE_PURPOSE.reset(token) + + return wrapper + + def _wrap_async_generate(self, method: Any) -> Any: + async def wrapper(*args: Any, **kwargs: Any) -> Any: + token = _MODEL_TRACE_PURPOSE.set(_purpose_from_kwargs(kwargs)) + try: + return await method(*args, **kwargs) + finally: + _MODEL_TRACE_PURPOSE.reset(token) + + return wrapper + + def _wrap_completion(self, facade: Any, method: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: + started = time.perf_counter() + error: Exception | None = None + response: Any = None + try: + response = method(*args, **kwargs) + return response + except Exception as exc: + error = exc + raise + finally: + self._record_private_completion_trace(facade, args, kwargs, started, response, error, is_async=False) + + return wrapper + + def _wrap_async_completion(self, facade: Any, method: Any) -> Any: + async def wrapper(*args: Any, **kwargs: Any) -> Any: + started = time.perf_counter() + error: Exception | None = None + response: Any = None + try: + response = await method(*args, **kwargs) + return response + except Exception as exc: + error = exc + raise + finally: + self._record_private_completion_trace(facade, args, kwargs, started, response, error, is_async=True) + + return wrapper + + def _record_private_completion_trace( + self, + facade: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + started: float, + response: Any, + error: Exception | None, + *, + is_async: bool, + ) -> None: + if not self._private_trace_enabled(): + return + column_name = _private_trace_column_name( + column_names=self._private_trace_column_names, + purpose=_purpose_from_kwargs(kwargs) or _MODEL_TRACE_PURPOSE.get(), + ) + if column_name is None: + return + collector = self._collector + if collector is None: + return + collector.record_dd_message_trace( + **_private_completion_trace_fields( + workflow_name=self._workflow_name, + column_name=column_name, + facade=facade, + args=args, + kwargs=kwargs, + response=response, + error=error, + elapsed_sec=time.perf_counter() - started, + is_async=is_async, + trace_mode=collector.dd_trace_mode, + ) + ) + + def _restore_private_trace_patches(self) -> None: + for facade, originals in reversed(list(self._facade_patches.values())): + for method_name, original in originals.items(): + setattr(facade, method_name, original) + self._facade_patches.clear() + + for model_registry, get_model in reversed(self._model_registry_patches): + setattr(model_registry, "get_model", get_model) + self._model_registry_patches.clear() + def _get_model_usage_snapshot(model_registry: object) -> Mapping[str, object] | None: alias_snapshot = _get_model_usage_snapshot_by_alias(model_registry) @@ -411,6 +594,120 @@ def _model_usage_as_json(stats: object) -> Any: return stats +def _purpose_from_kwargs(kwargs: Mapping[str, Any]) -> str | None: + purpose = kwargs.get("purpose") + return purpose if isinstance(purpose, str) and purpose else None + + +def _private_trace_column_name(*, column_names: set[str], purpose: str | None) -> str | None: + context_column = _MODEL_TRACE_COLUMN.get() + if context_column in column_names: + return context_column + + task_column = _runtime_correlation_task_column() + if task_column in column_names: + return task_column + + purpose_column = _column_name_from_purpose(purpose) + if purpose_column in column_names: + return purpose_column + + if len(column_names) == 1: + return next(iter(column_names)) + return None + + +def _runtime_correlation_task_column() -> str | None: + try: + from data_designer.engine.observability import runtime_correlation_provider + except Exception: + return None + + correlation = runtime_correlation_provider.current() + task_column = getattr(correlation, "task_column", None) + return task_column if isinstance(task_column, str) and task_column else None + + +def _column_name_from_purpose(purpose: str | None) -> str | None: + if not purpose: + return None + match = re.search(r"column '([^']+)'", purpose) + if match: + return match.group(1) + return None + + +def _model_provider_endpoint(facade: Any) -> str | None: + provider = getattr(facade, "model_provider", None) + endpoint = getattr(provider, "endpoint", None) + return endpoint if isinstance(endpoint, str) and endpoint else None + + +def _private_trace_messages(*, args: tuple[Any, ...], kwargs: Mapping[str, Any]) -> list[dict[str, Any]]: + messages = args[0] if args else kwargs.get("messages") + if isinstance(messages, list): + return [_trace_message(message) for message in messages] + return [] + + +def _private_completion_trace_fields( + *, + workflow_name: str | None, + column_name: str, + facade: Any, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + response: Any, + error: Exception | None, + elapsed_sec: float, + is_async: bool, + trace_mode: str, +) -> dict[str, Any]: + return { + "workflow_name": workflow_name, + "trace_source": "anonymizer_private_model_facade", + "column_name": column_name, + "trace_column_name": None, + "model_alias": getattr(facade, "model_alias", None), + "model_name": getattr(facade, "model_name", None), + "model_provider_name": getattr(facade, "model_provider_name", None), + "model_provider_endpoint": _model_provider_endpoint(facade), + "modality": "chat", + "is_async": is_async, + "status": "error" if error is not None else "completed", + "error_type": type(error).__name__ if error is not None else None, + "elapsed_sec": elapsed_sec, + "messages": _select_native_trace_messages(_private_trace_messages(args=args, kwargs=kwargs), mode=trace_mode), + "response": _model_trace_response(response), + "usage": _model_trace_usage(response), + } + + +def _model_trace_response(response: Any) -> dict[str, Any] | None: + message = getattr(response, "message", None) + if message is None: + return None + return { + "content": getattr(message, "content", None), + "reasoning_content": getattr(message, "reasoning_content", None), + "tool_calls": _trace_tool_calls(getattr(message, "tool_calls", [])), + } + + +def _model_trace_usage(response: Any) -> Any: + usage = getattr(response, "usage", None) + if usage is None: + return None + model_dump = getattr(usage, "model_dump", None) + if callable(model_dump): + return model_dump(mode="json") + if isinstance(usage, Mapping): + return dict(usage) + fields = ("input_tokens", "output_tokens", "total_tokens", "reasoning_tokens") + payload = {field: getattr(usage, field) for field in fields if getattr(usage, field, None) is not None} + return payload or None + + @contextmanager def _temporary_dd_task_trace(data_designer: DataDesigner, *, collector: Any | None) -> Iterator[None]: if collector is None or not collector.dd_task_trace_enabled: @@ -452,17 +749,33 @@ def _task_traces_from_result(result: Any) -> list[Any]: return [] -def _configure_native_dd_message_traces( +def _custom_column_with_trace_context(column: CustomColumnConfig) -> ColumnConfigT: + generator = column.generator_function + + @wraps(generator) + def traced_generator(*args: Any, **kwargs: Any) -> Any: + token = _MODEL_TRACE_COLUMN.set(column.name) + try: + return generator(*args, **kwargs) + finally: + _MODEL_TRACE_COLUMN.reset(token) + + traced_generator.custom_column_metadata = getattr(generator, "custom_column_metadata", {}) # type: ignore[attr-defined] + return cast(ColumnConfigT, column.model_copy(update={"generator_function": traced_generator})) + + +def _configure_dd_message_traces( *, columns: list[ColumnConfigT], model_configs: list[ModelConfig], collector: Any | None, -) -> tuple[list[ColumnConfigT], list[_NativeTraceColumn], list[ColumnConfigT]]: +) -> tuple[list[ColumnConfigT], list[_NativeTraceColumn], list[_PrivateFacadeTraceColumn], list[ColumnConfigT]]: if collector is None or not collector.dd_trace_enabled: - return columns, [], [] + return columns, [], [], [] model_configs_by_alias = {model_config.alias: model_config for model_config in model_configs} - traced_columns: list[_NativeTraceColumn] = [] + native_trace_columns: list[_NativeTraceColumn] = [] + private_trace_columns: list[_PrivateFacadeTraceColumn] = [] unsupported_columns: list[ColumnConfigT] = [] configured_columns: list[ColumnConfigT] = [] trace_type = _native_dd_trace_type() @@ -472,7 +785,7 @@ def _configure_native_dd_message_traces( configured_column = cast(ColumnConfigT, column.model_copy(update={"with_trace": trace_type})) configured_columns.append(configured_column) model_config = model_configs_by_alias.get(column.model_alias) - traced_columns.append( + native_trace_columns.append( _NativeTraceColumn( column_name=column.name, trace_column_name=f"{column.name}{TRACE_COLUMN_POSTFIX}", @@ -483,11 +796,14 @@ def _configure_native_dd_message_traces( ) continue + if _column_has_private_facade_model_calls(column): + configured_columns.append(_custom_column_with_trace_context(column)) + private_trace_columns.append(_PrivateFacadeTraceColumn(column_name=column.name)) + continue + configured_columns.append(column) - if _column_has_untraced_model_calls(column): - unsupported_columns.append(column) - return configured_columns, traced_columns, unsupported_columns + return configured_columns, native_trace_columns, private_trace_columns, unsupported_columns def _native_dd_trace_type() -> TraceType: @@ -497,7 +813,7 @@ def _native_dd_trace_type() -> TraceType: return TraceType.ALL_MESSAGES -def _column_has_untraced_model_calls(column: ColumnConfigT) -> bool: +def _column_has_private_facade_model_calls(column: ColumnConfigT) -> bool: return isinstance(column, CustomColumnConfig) and bool(_extract_workflow_model_aliases([column])) @@ -506,24 +822,47 @@ def _record_dd_trace_coverage( workflow_name: str, collector: Any, native_trace_columns: list[_NativeTraceColumn], + private_trace_columns: list[_PrivateFacadeTraceColumn], unsupported_trace_columns: list[ColumnConfigT], ) -> Any: if collector is None or not collector.dd_trace_enabled: return + traced_column_names = [column.column_name for column in native_trace_columns] + [ + column.column_name for column in private_trace_columns + ] collector.record( "dd_trace_coverage", workflow_name=workflow_name, - trace_backend="data_designer_column", + trace_backend=_dd_trace_backend(native_trace_columns, private_trace_columns), trace_mode=collector.dd_trace_mode, native_trace_type=_native_dd_trace_type().value, - traced_column_count=len(native_trace_columns), - traced_column_names=[column.column_name for column in native_trace_columns], + traced_column_count=len(traced_column_names), + traced_column_names=traced_column_names, + native_trace_column_count=len(native_trace_columns), + native_trace_column_names=[column.column_name for column in native_trace_columns], + private_trace_column_count=len(private_trace_columns), + private_trace_column_names=[column.column_name for column in private_trace_columns], + private_trace_backend="anonymizer_private_model_facade" if private_trace_columns else None, + private_trace_note=( + "temporary private DataDesigner model registry/facade instrumentation" if private_trace_columns else None + ), unsupported_column_count=len(unsupported_trace_columns), unsupported_column_names=[column.name for column in unsupported_trace_columns], unsupported_column_types=[_column_type_name(column) for column in unsupported_trace_columns], ) +def _dd_trace_backend( + native_trace_columns: list[_NativeTraceColumn], + private_trace_columns: list[_PrivateFacadeTraceColumn], +) -> str: + if native_trace_columns and private_trace_columns: + return "mixed" + if private_trace_columns: + return "anonymizer_private_model_facade" + return "data_designer_column" + + def _column_type_name(column: ColumnConfigT) -> str: column_type = getattr(column, "column_type", None) return str(column_type) if column_type is not None else type(column).__name__ diff --git a/src/anonymizer/measurement/sinks.py b/src/anonymizer/measurement/sinks.py index 0465ddbf..4ec0f0aa 100644 --- a/src/anonymizer/measurement/sinks.py +++ b/src/anonymizer/measurement/sinks.py @@ -5,6 +5,7 @@ import json from pathlib import Path +from threading import Lock from typing import Any, Literal, Protocol @@ -46,9 +47,12 @@ def __init__(self, path: str | Path) -> None: output_path = Path(path) output_path.parent.mkdir(parents=True, exist_ok=True) self._file = output_path.open("w", encoding="utf-8", buffering=1) + self._lock = Lock() def write_record(self, record: dict[str, Any]) -> None: - self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") + with self._lock: + self._file.write(json.dumps(record, ensure_ascii=True, sort_keys=True) + "\n") def close(self) -> None: - self._file.close() + with self._lock: + self._file.close() diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 4c16b783..faa419e4 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -823,25 +823,72 @@ def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespac assert "secret response" not in serialized_measurements -def test_ndd_adapter_records_custom_column_dd_trace_coverage_gap(tmp_path: Path) -> None: - input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) +class _TraceModelFacade: + model_alias = "alias" + model_name = "dummy-model" + model_provider_name = "provider" + model_provider = SimpleNamespace(endpoint="http://provider/v1") + + def generate(self, prompt: str, **_kwargs: Any) -> str: + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + return self.completion(messages).message.content + + def completion(self, _messages: list[dict[str, Any]]) -> SimpleNamespace: + return SimpleNamespace( + message=SimpleNamespace(content="custom response secret", reasoning_content="scratch", tool_calls=[]), + usage=SimpleNamespace(input_tokens=3, output_tokens=5, total_tokens=8), + ) + + +class _TraceModelRegistry: + def __init__(self) -> None: + self._models = {"alias": _TraceModelFacade()} + + def get_model(self, *, model_alias: str) -> _TraceModelFacade: + return self._models[model_alias] + + +class _CustomTraceDataDesigner: + def __init__(self, input_df: pd.DataFrame) -> None: + self.input_df = input_df + self.resource_provider = SimpleNamespace(model_registry=_TraceModelRegistry()) + def _create_resource_provider(self, *_args: Any, **_kwargs: Any) -> SimpleNamespace: + return self.resource_provider + + def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespace: + resource_provider = self._create_resource_provider() + model = resource_provider.model_registry.get_model(model_alias="alias") + output = self.input_df.iloc[:num_records].copy() + for column in cast(Any, config_builder).get_column_configs(): + for row_index, row in output.iterrows(): + generated = column.generator_function( + row.to_dict(), + generator_params=None, + models={"alias": model}, + ) + for key, value in generated.items(): + output.loc[row_index, key] = value + return SimpleNamespace(dataset=output) + + +def _custom_trace_column(name: str, *, prompt: str, value: str) -> CustomColumnConfig: @custom_column_generator(required_columns=["text"], model_aliases=["alias"]) - def custom_generator( + def generator( row: dict[str, Any], generator_params: Any, models: dict[str, Any], ) -> dict[str, str]: - _ = row, generator_params, models - return {"raw_detected": "[]"} + _ = row, generator_params + models["alias"].generate(prompt) + return {name: value} - class TraceDataDesigner: - def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: - output = input_df.iloc[:num_records].copy() - output["raw_detected"] = "[]" - return SimpleNamespace(dataset=output) + return CustomColumnConfig(name=name, generator_function=generator) - adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + +def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace(tmp_path: Path) -> None: + input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + adapter = NddAdapter(data_designer=cast(DataDesigner, _CustomTraceDataDesigner(input_df))) trace_path = tmp_path / "trace.jsonl" with configured_measurement_session( @@ -852,19 +899,50 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa adapter.run_workflow( input_df, model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], - columns=[CustomColumnConfig(name="raw_detected", generator_function=custom_generator)], + columns=[ + _custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]"), + _custom_trace_column("quality_check", prompt="quality prompt secret", value="ok"), + ], workflow_name="entity-detection", preview_num_records=1, ) - assert trace_path.read_text(encoding="utf-8") == "" + traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] + traces_by_column = {trace["column_name"]: trace for trace in traces} + assert set(traces_by_column) == {"raw_detected", "quality_check"} + + raw_trace = traces_by_column["raw_detected"] + assert raw_trace["record_type"] == "dd_message_trace" + assert raw_trace["trace_source"] == "anonymizer_private_model_facade" + assert raw_trace["workflow_name"] == "entity-detection" + assert raw_trace["model_alias"] == "alias" + assert raw_trace["model_name"] == "dummy-model" + assert raw_trace["model_provider_name"] == "provider" + assert raw_trace["model_provider_endpoint"] == "http://provider/v1" + assert raw_trace["status"] == "completed" + assert raw_trace["messages"] == [{"role": "user", "content": [{"type": "text", "text": "raw prompt secret"}]}] + assert raw_trace["response"]["content"] == "custom response secret" + assert raw_trace["response"]["reasoning_content"] == "scratch" + assert raw_trace["usage"] == {"input_tokens": 3, "output_tokens": 5, "total_tokens": 8} + + quality_trace = traces_by_column["quality_check"] + assert quality_trace["messages"] == [ + {"role": "user", "content": [{"type": "text", "text": "quality prompt secret"}]} + ] + measurements = [json.loads(line) for line in (tmp_path / "measurements.jsonl").read_text().splitlines()] coverage = [record for record in measurements if record["record_type"] == "dd_trace_coverage"] assert len(coverage) == 1 - assert coverage[0]["traced_column_count"] == 0 - assert coverage[0]["unsupported_column_count"] == 1 - assert coverage[0]["unsupported_column_names"] == ["raw_detected"] - assert coverage[0]["unsupported_column_types"] == ["custom"] + assert coverage[0]["trace_backend"] == "anonymizer_private_model_facade" + assert coverage[0]["traced_column_count"] == 2 + assert coverage[0]["private_trace_column_count"] == 2 + assert coverage[0]["private_trace_column_names"] == ["raw_detected", "quality_check"] + assert coverage[0]["unsupported_column_count"] == 0 + + serialized_measurements = json.dumps(measurements) + assert "raw prompt secret" not in serialized_measurements + assert "quality prompt secret" not in serialized_measurements + assert "custom response secret" not in serialized_measurements def test_ndd_adapter_writes_sanitized_dd_task_traces_and_restores_run_config(tmp_path: Path) -> None: diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 4ae539ba..08f3a4db 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -247,13 +247,17 @@ values, replacement values, secrets, and PII. Treat them as debug artifacts: keep them out of shared benchmark bundles unless they have been reviewed or redacted. -Anonymizer requests these traces through DataDesigner native LLM column trace -side effects. That covers `LLMTextColumnConfig` and -`LLMStructuredColumnConfig`, but not model calls made inside -`CustomColumnConfig` generator functions. Safe measurement output includes a -`dd_trace_coverage` record with unsupported custom columns so trace-enabled -runs can detect this gap. Full custom-column message tracing would need a -DataDesigner hook; it is a good candidate for an upstream DataDesigner PR. +Anonymizer requests standard LLM-column traces through DataDesigner native LLM +column trace side effects. That covers `LLMTextColumnConfig` and +`LLMStructuredColumnConfig`. Model-backed `CustomColumnConfig` generator +functions are traced through a temporary Anonymizer shim that instruments the +per-run DataDesigner model registry and returned model facades. This is a +brittle bridge over private DataDesigner internals until DataDesigner exposes a +public model-call trace sink. + +Safe measurement output includes a `dd_trace_coverage` record with native, +private-facade, and unsupported column counts so trace-enabled runs can detect +which path covered each workflow. ## DataDesigner Scheduler Task Traces From 3afb44cc54ede7a5da8e212e27f9e3e078f15bf3 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Thu, 11 Jun 2026 17:59:39 +0000 Subject: [PATCH 22/26] Improve benchmark observability tracing Signed-off-by: Aaron Gonzales --- docs/development/observability.md | 16 +- src/anonymizer/engine/ndd/adapter.py | 167 ++++++-- tests/test_measurement.py | 395 ++++++++++++++---- tests/tools/test_measurement_tools.py | 245 ++++++++--- tools/measurement/README.md | 32 +- .../examples/repo-data-smoke-models.yaml | 30 ++ .../examples/repo-data-smoke-providers.yaml | 8 + .../measurement/examples/repo-data-smoke.yaml | 2 + tools/measurement/run_benchmarks.py | 30 +- 9 files changed, 730 insertions(+), 195 deletions(-) create mode 100644 tools/measurement/examples/repo-data-smoke-models.yaml create mode 100644 tools/measurement/examples/repo-data-smoke-providers.yaml diff --git a/docs/development/observability.md b/docs/development/observability.md index c1af461e..be4f2ba5 100644 --- a/docs/development/observability.md +++ b/docs/development/observability.md @@ -141,9 +141,14 @@ measurement = MeasurementConfig( ``` Task traces capture DataDesigner scheduler timing metadata: workflow, column, -row group, row index, task type, status, queue wait time, execution time, total -time, and whether an error was present. They do not store raw DataDesigner error -strings because those strings can contain prompts, outputs, or source values. +row group, row index, task type, status, relative dispatch/slot-acquired/ +completion offsets, queue wait time, execution time, total time, and whether an +error was present. They do not store raw DataDesigner error strings because +those strings can contain prompts, outputs, or source values. + +Offsets are relative to the earliest positive `dispatched_at` timestamp in the +task-trace batch for that workflow. They make task overlap easier to inspect +without persisting host-specific wall-clock timestamps. ## Safety Rules @@ -161,8 +166,9 @@ When adding instrumentation: - Put timing around stable phase boundaries, not every helper call. - Record metadata at the boundary where the information is known. - Keep raw debug payloads in explicit sidecars, never in measurement records. -- Prefer `run_tags` for benchmark context such as suite ID, case ID, workload, - config, or experimental strategy. +- Prefer `run_tags` for external run context such as source refs, CI IDs, + topology labels, or experimental strategy. The benchmark runner owns + `suite_id`, `case_id`, `workload_id`, `config_id`, and `repetition`. - Keep benchmark-only strategy switches in `tools/measurement`, not product defaults. diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 96e1b4cd..96795c36 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -3,6 +3,7 @@ from __future__ import annotations +import importlib import json import logging import re @@ -16,7 +17,7 @@ from functools import wraps from pathlib import Path from threading import RLock -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeGuard, cast from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig from data_designer.config.column_types import ColumnConfigT @@ -74,6 +75,38 @@ class _PrivateFacadeTraceColumn: column_name: str +class _TaskTraceLike(Protocol): + column: Any + row_group: Any + row_index: Any + task_type: Any + status: Any + error: Any + dispatched_at: Any + slot_acquired_at: Any + completed_at: Any + + +_TaskTrace = Mapping[str, Any] | _TaskTraceLike + + +class _DDTaskTraceFields(TypedDict): + workflow_name: str + trace_source: Literal["data_designer_scheduler"] + column: Any + row_group: Any + row_index: Any + task_type: Any + status: Any + error_present: bool + dispatched_offset_sec: float | None + slot_acquired_offset_sec: float | None + completed_offset_sec: float | None + queue_wait_sec: float | None + execution_sec: float | None + total_sec: float | None + + class NddAdapter: """Adapter for running NDD workflows with uniform I/O and record tracking.""" @@ -154,8 +187,8 @@ def run_workflow( for column in columns: config_builder.add_column(column) + task_traces: list[_TaskTrace] = [] try: - task_traces: list[Any] = [] with self._run_lock, usage_probe, _temporary_dd_task_trace(self._data_designer, collector=collector): if preview_num_records is None: run_results = self._data_designer.create( @@ -175,17 +208,6 @@ def run_workflow( output_df = workflow_input_df.iloc[0:0].copy() else: output_df = preview_results.dataset - output_df = _record_and_strip_native_dd_message_traces( - output_df=output_df, - workflow_name=workflow_name, - collector=collector, - native_trace_columns=native_trace_columns, - ) - _record_dd_task_traces( - workflow_name=workflow_name, - collector=collector, - task_traces=task_traces, - ) except Exception as exc: logger.warning( "Workflow failed for %d input record(s) on model(s) %s: %s", @@ -198,6 +220,10 @@ def run_workflow( workflow_name, col_names, ) + try: + usage_probe.flush_private_trace_records() + except Exception: + logger.warning("Failed to write DataDesigner private message trace records after workflow failure") record_ndd_workflow( workflow_name=workflow_name, model_aliases=model_aliases, @@ -214,6 +240,19 @@ def run_workflow( ) raise AnonymizerWorkflowError(f"Workflow failed: {exc}") from exc + output_df = _record_and_strip_native_dd_message_traces( + output_df=output_df, + workflow_name=workflow_name, + collector=collector, + native_trace_columns=native_trace_columns, + ) + _record_dd_task_traces( + workflow_name=workflow_name, + collector=collector, + task_traces=task_traces, + ) + usage_probe.flush_private_trace_records() + logger.debug("NDD workflow '%s' returned %d records", workflow_name, len(output_df)) failed_records = self._detect_missing_records( workflow_name=workflow_name, @@ -363,6 +402,7 @@ def __init__( self._resource_providers: list[Any] = [] self._model_registry_patches: list[tuple[Any, Any]] = [] self._facade_patches: dict[int, tuple[Any, dict[str, Any]]] = {} + self._private_trace_records: list[dict[str, Any]] = [] def __enter__(self) -> _DataDesignerUsageProbe: if not self._enabled: @@ -399,6 +439,14 @@ def model_usage(self) -> dict[str, Any] | None: usage[str(model_name)] = _model_usage_as_json(stats) return usage or None + def flush_private_trace_records(self) -> None: + collector = self._collector + if collector is None: + self._private_trace_records.clear() + return + while self._private_trace_records: + collector.record_dd_message_trace(**self._private_trace_records.pop(0)) + def _private_trace_enabled(self) -> bool: return bool( self._collector is not None @@ -525,8 +573,8 @@ def _record_private_completion_trace( collector = self._collector if collector is None: return - collector.record_dd_message_trace( - **_private_completion_trace_fields( + self._private_trace_records.append( + _private_completion_trace_fields( workflow_name=self._workflow_name, column_name=column_name, facade=facade, @@ -619,11 +667,15 @@ def _private_trace_column_name(*, column_names: set[str], purpose: str | None) - def _runtime_correlation_task_column() -> str | None: try: - from data_designer.engine.observability import runtime_correlation_provider + observability = importlib.import_module("data_designer.engine.observability") except Exception: return None - correlation = runtime_correlation_provider.current() + runtime_correlation_provider = getattr(observability, "runtime_correlation_provider", None) + current = getattr(runtime_correlation_provider, "current", None) + if not callable(current): + return None + correlation = current() task_column = getattr(correlation, "task_column", None) return task_column if isinstance(task_column, str) and task_column else None @@ -737,14 +789,14 @@ def _run_config_with_async_trace(run_config: Any) -> Any: return run_config -def _task_traces_from_result(result: Any) -> list[Any]: +def _task_traces_from_result(result: Any) -> list[_TaskTrace]: raw_traces = getattr(result, "task_traces", None) if raw_traces is None: return [] if isinstance(raw_traces, list): - return raw_traces + return cast(list[_TaskTrace], raw_traces) try: - return list(raw_traces) + return cast(list[_TaskTrace], list(raw_traces)) except TypeError: return [] @@ -813,7 +865,7 @@ def _native_dd_trace_type() -> TraceType: return TraceType.ALL_MESSAGES -def _column_has_private_facade_model_calls(column: ColumnConfigT) -> bool: +def _column_has_private_facade_model_calls(column: ColumnConfigT) -> TypeGuard[CustomColumnConfig]: return isinstance(column, CustomColumnConfig) and bool(_extract_workflow_model_aliases([column])) @@ -950,40 +1002,63 @@ def _trace_tool_calls(tool_calls: Any) -> list[Any]: return [] -def _record_dd_task_traces(*, workflow_name: str, collector: Any | None, task_traces: list[Any]) -> None: +def _record_dd_task_traces(*, workflow_name: str, collector: Any | None, task_traces: list[_TaskTrace]) -> None: if collector is None or not collector.dd_task_trace_enabled: return + trace_origin = _task_trace_origin(task_traces) for task_trace in task_traces: - collector.record_dd_task_trace( - workflow_name=workflow_name, - trace_source="data_designer_scheduler", - column=_trace_attr(task_trace, "column"), - row_group=_trace_attr(task_trace, "row_group"), - row_index=_trace_attr(task_trace, "row_index"), - task_type=_trace_attr(task_trace, "task_type"), - status=_trace_attr(task_trace, "status"), - error_present=bool(_trace_attr(task_trace, "error")), - queue_wait_sec=_trace_duration( - _trace_attr(task_trace, "dispatched_at"), - _trace_attr(task_trace, "slot_acquired_at"), - ), - execution_sec=_trace_duration( - _trace_attr(task_trace, "slot_acquired_at"), - _trace_attr(task_trace, "completed_at"), - ), - total_sec=_trace_duration( - _trace_attr(task_trace, "dispatched_at"), - _trace_attr(task_trace, "completed_at"), - ), - ) + collector.record_dd_task_trace(**_dd_task_trace_fields(workflow_name, task_trace, trace_origin)) + + +def _dd_task_trace_fields( + workflow_name: str, + task_trace: _TaskTrace, + trace_origin: float | None, +) -> _DDTaskTraceFields: + dispatched_at = _trace_attr(task_trace, "dispatched_at") + slot_acquired_at = _trace_attr(task_trace, "slot_acquired_at") + completed_at = _trace_attr(task_trace, "completed_at") + return { + "workflow_name": workflow_name, + "trace_source": "data_designer_scheduler", + "column": _trace_attr(task_trace, "column"), + "row_group": _trace_attr(task_trace, "row_group"), + "row_index": _trace_attr(task_trace, "row_index"), + "task_type": _trace_attr(task_trace, "task_type"), + "status": _trace_attr(task_trace, "status"), + "error_present": bool(_trace_attr(task_trace, "error")), + "dispatched_offset_sec": _trace_offset(trace_origin, dispatched_at), + "slot_acquired_offset_sec": _trace_offset(trace_origin, slot_acquired_at), + "completed_offset_sec": _trace_offset(trace_origin, completed_at), + "queue_wait_sec": _trace_duration(dispatched_at, slot_acquired_at), + "execution_sec": _trace_duration(slot_acquired_at, completed_at), + "total_sec": _trace_duration(dispatched_at, completed_at), + } + + +def _task_trace_origin(task_traces: list[_TaskTrace]) -> float | None: + dispatch_times: list[float] = [] + for task_trace in task_traces: + dispatched_at = _trace_attr(task_trace, "dispatched_at") + if isinstance(dispatched_at, (int, float)) and dispatched_at > 0: + dispatch_times.append(float(dispatched_at)) + return min(dispatch_times) if dispatch_times else None -def _trace_attr(task_trace: Any, name: str) -> Any: +def _trace_attr(task_trace: _TaskTrace, name: str) -> Any: if isinstance(task_trace, Mapping): - return task_trace.get(name) + return cast(Mapping[str, Any], task_trace).get(name) return getattr(task_trace, name, None) +def _trace_offset(origin: float | None, timestamp: Any) -> float | None: + if origin is None or not isinstance(timestamp, (int, float)): + return None + if timestamp <= 0 or timestamp < origin: + return None + return float(timestamp - origin) + + def _trace_duration(start: Any, end: Any) -> float | None: if not isinstance(start, (int, float)) or not isinstance(end, (int, float)): return None diff --git a/tests/test_measurement.py b/tests/test_measurement.py index faa419e4..1fc6823a 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -14,6 +14,7 @@ import pandas as pd import pytest from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig +from data_designer.config.column_types import ColumnConfigT from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ModelConfig from data_designer.config.run_config import RunConfig @@ -40,7 +41,7 @@ COL_WEIGHTED_LEAKAGE_RATE, ) from anonymizer.engine.detection.detection_workflow import EntityDetectionResult, EntityDetectionWorkflow -from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter +from anonymizer.engine.ndd.adapter import RECORD_ID_COLUMN, NddAdapter, WorkflowRunResult from anonymizer.engine.replace.replace_runner import ReplacementResult, ReplacementWorkflow from anonymizer.engine.rewrite.rewrite_workflow import RewriteResult, RewriteWorkflow from anonymizer.interface.anonymizer import Anonymizer @@ -58,6 +59,45 @@ ) +class _FailingSink: + def __init__(self, message: str) -> None: + self.message = message + + def write_record(self, record: dict[str, Any]) -> None: + _ = record + raise OSError(self.message) + + def close(self) -> None: + pass + + +@pytest.fixture +def trace_input_df() -> pd.DataFrame: + return pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) + + +def _trace_model_configs() -> list[ModelConfig]: + return [ModelConfig(alias="alias", model="dummy-model", provider="provider")] + + +def _run_entity_detection_preview( + adapter: NddAdapter, + input_df: pd.DataFrame, + columns: list[ColumnConfigT], +) -> WorkflowRunResult: + return adapter.run_workflow( + input_df, + model_configs=_trace_model_configs(), + columns=columns, + workflow_name="entity-detection", + preview_num_records=1, + ) + + +def _raw_detected_text_column() -> LLMTextColumnConfig: + return LLMTextColumnConfig(name="raw_detected", prompt="{{ text }}", model_alias="alias") + + def test_ndd_adapter_records_workflow_measurement_without_raw_text() -> None: input_df = pd.DataFrame( { @@ -705,8 +745,8 @@ def __init__(self, name: str, *, fail: bool = False) -> None: self.name = name self.fail = fail - def write_record(self, _record: dict[str, Any]) -> None: - pass + def write_record(self, record: dict[str, Any]) -> None: + _ = record def close(self) -> None: close_events.append(self.name) @@ -787,13 +827,7 @@ def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespac output_path=tmp_path / "measurements.jsonl", dd_trace="last_message", dd_trace_path=trace_path ) ): - result = adapter.run_workflow( - input_df, - model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], - columns=[original_column], - workflow_name="entity-detection", - preview_num_records=1, - ) + result = _run_entity_detection_preview(adapter, input_df, [original_column]) assert original_column.with_trace == TraceType.NONE assert captured_columns[0].with_trace == TraceType.ALL_MESSAGES @@ -841,17 +875,17 @@ def completion(self, _messages: list[dict[str, Any]]) -> SimpleNamespace: class _TraceModelRegistry: - def __init__(self) -> None: - self._models = {"alias": _TraceModelFacade()} + def __init__(self, facade: Any | None = None) -> None: + self._models = {"alias": facade or _TraceModelFacade()} def get_model(self, *, model_alias: str) -> _TraceModelFacade: return self._models[model_alias] class _CustomTraceDataDesigner: - def __init__(self, input_df: pd.DataFrame) -> None: + def __init__(self, input_df: pd.DataFrame, *, facade: Any | None = None) -> None: self.input_df = input_df - self.resource_provider = SimpleNamespace(model_registry=_TraceModelRegistry()) + self.resource_provider = SimpleNamespace(model_registry=_TraceModelRegistry(facade)) def _create_resource_provider(self, *_args: Any, **_kwargs: Any) -> SimpleNamespace: return self.resource_provider @@ -872,6 +906,31 @@ def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespac return SimpleNamespace(dataset=output) +class _TaskTraceDataDesigner: + def __init__( + self, + input_df: pd.DataFrame, + *, + task_traces: list[Any] | None = None, + error: Exception | None = None, + ) -> None: + self.input_df = input_df + self.task_traces = task_traces or [] + self.error = error + self.run_config = RunConfig(async_trace=False) + self.async_trace_values: list[bool] = [] + + def set_run_config(self, run_config: RunConfig) -> None: + self.async_trace_values.append(run_config.async_trace) + self.run_config = run_config + + def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: + assert self.run_config.async_trace is True + if self.error is not None: + raise self.error + return SimpleNamespace(dataset=self.input_df.iloc[:num_records].copy(), task_traces=self.task_traces) + + def _custom_trace_column(name: str, *, prompt: str, value: str) -> CustomColumnConfig: @custom_column_generator(required_columns=["text"], model_aliases=["alias"]) def generator( @@ -886,9 +945,11 @@ def generator( return CustomColumnConfig(name=name, generator_function=generator) -def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace(tmp_path: Path) -> None: - input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) - adapter = NddAdapter(data_designer=cast(DataDesigner, _CustomTraceDataDesigner(input_df))) +def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + adapter = NddAdapter(data_designer=cast(DataDesigner, _CustomTraceDataDesigner(trace_input_df))) trace_path = tmp_path / "trace.jsonl" with configured_measurement_session( @@ -896,15 +957,13 @@ def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace(tmp_path output_path=tmp_path / "measurements.jsonl", dd_trace="all_messages", dd_trace_path=trace_path ) ): - adapter.run_workflow( - input_df, - model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], - columns=[ + _run_entity_detection_preview( + adapter, + trace_input_df, + [ _custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]"), _custom_trace_column("quality_check", prompt="quality prompt secret", value="ok"), ], - workflow_name="entity-detection", - preview_num_records=1, ) traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] @@ -945,46 +1004,114 @@ def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace(tmp_path assert "custom response secret" not in serialized_measurements -def test_ndd_adapter_writes_sanitized_dd_task_traces_and_restores_run_config(tmp_path: Path) -> None: - input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) +def test_ndd_adapter_private_model_facade_trace_write_error_is_not_wrapped( + trace_input_df: pd.DataFrame, +) -> None: + adapter = NddAdapter(data_designer=cast(DataDesigner, _CustomTraceDataDesigner(trace_input_df))) + collector = MeasurementCollector( + dd_trace_mode="all_messages", + dd_trace_sink=_FailingSink("private trace sidecar unavailable"), + fail_on_write_error=True, + ) - class TraceDataDesigner: - def __init__(self) -> None: - self.run_config = RunConfig(async_trace=False) - self.async_trace_values: list[bool] = [] + with measurement_session(collector), pytest.raises(OSError, match="private trace sidecar unavailable"): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]")], + ) - def set_run_config(self, run_config: RunConfig) -> None: - self.async_trace_values.append(run_config.async_trace) - self.run_config = run_config - def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: - assert self.run_config.async_trace is True - task_trace = SimpleNamespace( - column="raw_detected", - row_group=0, - row_index=7, - task_type="llm", - dispatched_at=10.0, - slot_acquired_at=10.25, - completed_at=12.0, - status="error", - error="raw secret token Alice", +def test_ndd_adapter_flushes_private_model_facade_error_trace_when_workflow_fails( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + class FailingTraceModelFacade(_TraceModelFacade): + def completion(self, _messages: list[dict[str, Any]]) -> SimpleNamespace: + raise RuntimeError("custom model call failed") + + trace_path = tmp_path / "trace.jsonl" + adapter = NddAdapter( + data_designer=cast(DataDesigner, _CustomTraceDataDesigner(trace_input_df, facade=FailingTraceModelFacade())) + ) + + with pytest.raises(AnonymizerWorkflowError, match="Workflow failed"): + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", + dd_trace="all_messages", + dd_trace_path=trace_path, + ) + ): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]")], ) - return SimpleNamespace(dataset=input_df.iloc[:num_records].copy(), task_traces=[task_trace]) - data_designer = TraceDataDesigner() + traces = [json.loads(line) for line in trace_path.read_text(encoding="utf-8").splitlines()] + assert len(traces) == 1 + trace = traces[0] + assert trace["record_type"] == "dd_message_trace" + assert trace["trace_source"] == "anonymizer_private_model_facade" + assert trace["workflow_name"] == "entity-detection" + assert trace["column_name"] == "raw_detected" + assert trace["status"] == "error" + assert trace["error_type"] == "RuntimeError" + assert trace["messages"] == [{"role": "user", "content": [{"type": "text", "text": "raw prompt secret"}]}] + assert "custom model call failed" not in trace_path.read_text(encoding="utf-8") + + +def test_ndd_adapter_private_model_facade_trace_write_error_does_not_mask_workflow_failure( + trace_input_df: pd.DataFrame, +) -> None: + class FailingTraceModelFacade(_TraceModelFacade): + def completion(self, _messages: list[dict[str, Any]]) -> SimpleNamespace: + raise RuntimeError("custom model call failed") + + adapter = NddAdapter( + data_designer=cast(DataDesigner, _CustomTraceDataDesigner(trace_input_df, facade=FailingTraceModelFacade())) + ) + collector = MeasurementCollector( + dd_trace_mode="all_messages", + dd_trace_sink=_FailingSink("private trace sidecar unavailable"), + fail_on_write_error=True, + ) + + with measurement_session(collector), pytest.raises(AnonymizerWorkflowError, match="Workflow failed"): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]")], + ) + + +def test_ndd_adapter_writes_sanitized_dd_task_traces_and_restores_run_config( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + task_trace = SimpleNamespace( + column="raw_detected", + row_group=0, + row_index=7, + task_type="llm", + dispatched_at=10.0, + slot_acquired_at=10.25, + completed_at=12.0, + status="error", + error="raw secret token Alice", + ) + data_designer = _TaskTraceDataDesigner(trace_input_df, task_traces=[task_trace]) adapter = NddAdapter(data_designer=cast(DataDesigner, data_designer)) task_trace_path = tmp_path / "task-trace.jsonl" with configured_measurement_session( MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_task_trace_path=task_trace_path) ): - adapter.run_workflow( - input_df, - model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], - columns=[LLMTextColumnConfig(name="raw_detected", prompt="{{ text }}", model_alias="alias")], - workflow_name="entity-detection", - preview_num_records=1, + _run_entity_detection_preview( + adapter, + trace_input_df, + [_raw_detected_text_column()], ) assert data_designer.async_trace_values == [True, False] @@ -998,6 +1125,9 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa assert task_trace["task_type"] == "llm" assert task_trace["status"] == "error" assert task_trace["error_present"] is True + assert task_trace["dispatched_offset_sec"] == pytest.approx(0.0) + assert task_trace["slot_acquired_offset_sec"] == pytest.approx(0.25) + assert task_trace["completed_offset_sec"] == pytest.approx(2.0) assert task_trace["queue_wait_sec"] == pytest.approx(0.25) assert task_trace["execution_sec"] == pytest.approx(1.75) assert task_trace["total_sec"] == pytest.approx(2.0) @@ -1005,36 +1135,159 @@ def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespa assert "raw secret token Alice" not in (tmp_path / "measurements.jsonl").read_text(encoding="utf-8") -def test_ndd_adapter_restores_run_config_when_task_traced_workflow_fails(tmp_path: Path) -> None: - input_df = pd.DataFrame({"text": ["Alice works at Acme"], RECORD_ID_COLUMN: ["record-a"]}) +def test_ndd_adapter_task_trace_handles_mapping_and_invalid_timestamps( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + task_traces = [ + { + "column": "missing_timestamps", + "row_group": 0, + "row_index": 1, + "task_type": "llm", + "status": "completed", + "error": None, + }, + { + "column": "nonpositive_timestamps", + "row_group": 0, + "row_index": 2, + "task_type": "llm", + "dispatched_at": 0.0, + "slot_acquired_at": -1.0, + "completed_at": -2.0, + "status": "completed", + "error": None, + }, + { + "column": "out_of_order_timestamps", + "row_group": 0, + "row_index": 3, + "task_type": "llm", + "dispatched_at": 20.0, + "slot_acquired_at": 19.0, + "completed_at": 18.0, + "status": "completed", + "error": None, + }, + ] + task_trace_path = tmp_path / "task-trace.jsonl" + adapter = NddAdapter( + data_designer=cast(DataDesigner, _TaskTraceDataDesigner(trace_input_df, task_traces=task_traces)) + ) - class TraceDataDesigner: - def __init__(self) -> None: - self.run_config = RunConfig(async_trace=False) - self.async_trace_values: list[bool] = [] + with configured_measurement_session( + MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_task_trace_path=task_trace_path) + ): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_raw_detected_text_column()], + ) + + traces = [json.loads(line) for line in task_trace_path.read_text(encoding="utf-8").splitlines()] + traces_by_column = {trace["column"]: trace for trace in traces} + + missing = traces_by_column["missing_timestamps"] + assert missing["dispatched_offset_sec"] is None + assert missing["slot_acquired_offset_sec"] is None + assert missing["completed_offset_sec"] is None + assert missing["queue_wait_sec"] is None + assert missing["execution_sec"] is None + assert missing["total_sec"] is None + + nonpositive = traces_by_column["nonpositive_timestamps"] + assert nonpositive["dispatched_offset_sec"] is None + assert nonpositive["slot_acquired_offset_sec"] is None + assert nonpositive["completed_offset_sec"] is None + assert nonpositive["queue_wait_sec"] is None + assert nonpositive["execution_sec"] is None + assert nonpositive["total_sec"] is None + + out_of_order = traces_by_column["out_of_order_timestamps"] + assert out_of_order["dispatched_offset_sec"] == pytest.approx(0.0) + assert out_of_order["slot_acquired_offset_sec"] is None + assert out_of_order["completed_offset_sec"] is None + assert out_of_order["queue_wait_sec"] is None + assert out_of_order["execution_sec"] is None + assert out_of_order["total_sec"] is None + + +def test_ndd_adapter_task_trace_write_error_is_not_wrapped_as_workflow_error( + trace_input_df: pd.DataFrame, +) -> None: + task_trace = SimpleNamespace( + column="raw_detected", + row_group=0, + row_index=7, + task_type="llm", + dispatched_at=10.0, + slot_acquired_at=10.25, + completed_at=12.0, + status="completed", + error=None, + ) + adapter = NddAdapter( + data_designer=cast(DataDesigner, _TaskTraceDataDesigner(trace_input_df, task_traces=[task_trace])) + ) + collector = MeasurementCollector( + dd_task_trace_sink=_FailingSink("task trace sidecar unavailable"), + fail_on_write_error=True, + ) + + with measurement_session(collector), pytest.raises(OSError, match="task trace sidecar unavailable"): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_raw_detected_text_column()], + ) - def set_run_config(self, run_config: RunConfig) -> None: - self.async_trace_values.append(run_config.async_trace) - self.run_config = run_config +def test_ndd_adapter_trace_write_error_is_not_wrapped_as_workflow_error( + trace_input_df: pd.DataFrame, +) -> None: + class TraceDataDesigner: def preview(self, _config_builder: object, *, num_records: int) -> SimpleNamespace: - assert self.run_config.async_trace is True - _ = num_records - raise RuntimeError("raw secret failure") + output = trace_input_df.iloc[:num_records].copy() + output["raw_detected"] = "[]" + output[f"raw_detected{TRACE_COLUMN_POSTFIX}"] = [ + [ + {"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}, + {"role": "assistant", "content": "secret response"}, + ] + ] + return SimpleNamespace(dataset=output) + + adapter = NddAdapter(data_designer=cast(DataDesigner, TraceDataDesigner())) + collector = MeasurementCollector( + dd_trace_mode="last_message", + dd_trace_sink=_FailingSink("trace sidecar unavailable"), + fail_on_write_error=True, + ) + + with measurement_session(collector), pytest.raises(OSError, match="trace sidecar unavailable"): + _run_entity_detection_preview( + adapter, + trace_input_df, + [_raw_detected_text_column()], + ) - data_designer = TraceDataDesigner() + +def test_ndd_adapter_restores_run_config_when_task_traced_workflow_fails( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + data_designer = _TaskTraceDataDesigner(trace_input_df, error=RuntimeError("raw secret failure")) adapter = NddAdapter(data_designer=cast(DataDesigner, data_designer)) with pytest.raises(AnonymizerWorkflowError, match="Workflow failed"): with configured_measurement_session( MeasurementConfig(output_path=tmp_path / "measurements.jsonl", dd_task_trace_path=tmp_path / "task.jsonl") ): - adapter.run_workflow( - input_df, - model_configs=[ModelConfig(alias="alias", model="dummy-model", provider="provider")], - columns=[LLMTextColumnConfig(name="raw_detected", prompt="{{ text }}", model_alias="alias")], - workflow_name="entity-detection", - preview_num_records=1, + _run_entity_detection_preview( + adapter, + trace_input_df, + [_raw_detected_text_column()], ) assert data_designer.async_trace_values == [True, False] diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 7566974f..07967b8f 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -46,6 +46,48 @@ def _minimal_case_contexts(tool: ModuleType, spec: Any, tmp_path: Path) -> dict[ } +def _minimal_benchmark_spec( + tool: ModuleType, + *, + suite_id: str = "suite", + configs: list[Any] | None = None, + case_retries: int = 0, + case_retry_backoff_sec: float = 0.0, + run_tags: dict[str, Any] | None = None, +) -> Any: + return tool.BenchmarkSpec( + suite_id=suite_id, + case_retries=case_retries, + case_retry_backoff_sec=case_retry_backoff_sec, + run_tags=run_tags or {}, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=configs or [tool.ConfigSpec(id="redact", replace="redact")], + ) + + +def _minimal_benchmark_case( + tool: ModuleType, + *, + suite_id: str = "suite", + workload_id: str = "input", + config_id: str = "redact", + repetition: int = 0, +) -> Any: + return tool.BenchmarkCase( + suite_id=suite_id, + workload_id=workload_id, + config_id=config_id, + repetition=repetition, + case_id=f"{workload_id}__{config_id}__r{repetition:03d}", + ) + + +def _write_text_input(tmp_path: Path, text: str = "Alice works at Acme") -> Path: + input_path = tmp_path / "input.csv" + pd.DataFrame({"text": [text]}).to_csv(input_path, index=False) + return input_path + + def _copy_biography_data(tmp_path: Path, filename: str = "input.csv") -> Path: source = REPO_ROOT / "docs" / "data" / "NVIDIA_synthetic_biographies.csv" destination = tmp_path / filename @@ -68,6 +110,29 @@ def test_benchmark_spec_rejects_duplicate_matrix_entries() -> None: ) +def test_benchmark_spec_rejects_reserved_run_tags() -> None: + tool = load_tool("measurement_benchmark_tool_reserved_tags", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + with pytest.raises(ValidationError, match="reserved benchmark tag"): + tool.BenchmarkSpec( + suite_id="tag-suite", + run_tags={"pipeline_id": "1234", "case_id": "manual"}, + workloads=[tool.WorkloadSpec(id="input", source="input.csv")], + configs=[tool.ConfigSpec(id="redact", replace="redact")], + ) + + +def test_benchmark_config_rejects_evaluate_on_rewrite() -> None: + tool = load_tool("measurement_benchmark_tool_evaluate_rewrite", REPO_ROOT / "tools/measurement/run_benchmarks.py") + + with pytest.raises(ValidationError, match="evaluate is only supported for replace configs"): + tool.ConfigSpec( + id="rewrite-evaluate", + rewrite=tool.RewriteSpec(), + evaluate=True, + ) + + def test_export_measurements_groups_records_by_type(tmp_path: Path) -> None: tool = load_tool("measurement_export_tool", REPO_ROOT / "tools/measurement/export_measurements.py") dataframe = pd.DataFrame( @@ -212,11 +277,7 @@ def test_run_suite_records_detection_artifact_analysis_path( tmp_path: Path, ) -> None: tool = load_tool("measurement_benchmark_tool_run_suite_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py") - spec = tool.BenchmarkSpec( - suite_id="artifact-suite", - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) + spec = _minimal_benchmark_spec(tool, suite_id="artifact-suite") output_dir = tmp_path / "output" output_dir.mkdir() artifact_path = output_dir / "artifacts" @@ -274,11 +335,7 @@ def test_run_suite_skips_detection_artifact_analysis_when_export_disabled( tool = load_tool( "measurement_benchmark_tool_run_suite_no_export_artifact", REPO_ROOT / "tools/measurement/run_benchmarks.py" ) - spec = tool.BenchmarkSpec( - suite_id="artifact-suite", - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) + spec = _minimal_benchmark_spec(tool, suite_id="artifact-suite") output_dir = tmp_path / "output" output_dir.mkdir() @@ -321,21 +378,9 @@ def test_benchmark_case_retries_transient_errors_and_records_attempts( ) -> None: tool = load_tool("measurement_benchmark_tool_case_retry_success", REPO_ROOT / "tools/measurement/run_benchmarks.py") attempts: list[Path] = [] - spec = tool.BenchmarkSpec( - suite_id="retry-suite", - case_retries=1, - case_retry_backoff_sec=0, - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) - case = tool.BenchmarkCase( - suite_id="retry-suite", - workload_id="input", - config_id="redact", - repetition=0, - case_id="input__redact__r000", - ) - pd.DataFrame({"text": ["Alice"]}).to_csv(tmp_path / "input.csv", index=False) + spec = _minimal_benchmark_spec(tool, suite_id="retry-suite", case_retries=1) + case = _minimal_benchmark_case(tool, suite_id="retry-suite") + _write_text_input(tmp_path, text="Alice") def fake_execute_case(*_args: Any, raw_path: Path, **_kwargs: Any) -> None: attempts.append(raw_path) @@ -366,20 +411,8 @@ def test_benchmark_case_records_persistent_retry_failures( tmp_path: Path, ) -> None: tool = load_tool("measurement_benchmark_tool_case_retry_failure", REPO_ROOT / "tools/measurement/run_benchmarks.py") - spec = tool.BenchmarkSpec( - suite_id="retry-suite", - case_retries=1, - case_retry_backoff_sec=0, - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) - case = tool.BenchmarkCase( - suite_id="retry-suite", - workload_id="input", - config_id="redact", - repetition=0, - case_id="input__redact__r000", - ) + spec = _minimal_benchmark_spec(tool, suite_id="retry-suite", case_retries=1) + case = _minimal_benchmark_case(tool, suite_id="retry-suite") attempts = 0 errors: list[str] = [] @@ -421,20 +454,8 @@ def test_benchmark_case_fail_fast_skips_retries( tool = load_tool( "measurement_benchmark_tool_case_retry_fail_fast", REPO_ROOT / "tools/measurement/run_benchmarks.py" ) - spec = tool.BenchmarkSpec( - suite_id="retry-suite", - case_retries=3, - case_retry_backoff_sec=0, - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) - case = tool.BenchmarkCase( - suite_id="retry-suite", - workload_id="input", - config_id="redact", - repetition=0, - case_id="input__redact__r000", - ) + spec = _minimal_benchmark_spec(tool, suite_id="retry-suite", case_retries=3) + case = _minimal_benchmark_case(tool, suite_id="retry-suite") attempts = 0 def fake_execute_case(*_args: Any, **_kwargs: Any) -> Any: @@ -694,10 +715,52 @@ def test_benchmark_preflight_rejects_bad_model_alias_references(tmp_path: Path) tool.preflight_suite(spec, spec_path=spec_path) +def test_benchmark_preflight_rejects_missing_evaluate_model_alias(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_tool_preflight_evaluate_models", REPO_ROOT / "tools/measurement/run_benchmarks.py" + ) + _copy_biography_data(tmp_path) + spec_path = tmp_path / "suite.yaml" + spec_path.write_text( + """ +suite_id: bad-evaluate-model-suite +model_configs: | + selected_models: + detection: + entity_detector: detector + entity_validator: [validator] + entity_augmenter: augmenter + evaluate: + detection_validity_judge: missing-evaluator + model_configs: + - alias: detector + model: test/detector + - alias: validator + model: test/validator + - alias: augmenter + model: test/augmenter +workloads: + - id: biography + source: input.csv + text_column: biography +configs: + - id: redact-evaluate + replace: redact + evaluate: true +""", + encoding="utf-8", + ) + spec = tool.load_spec(spec_path) + + with pytest.raises(ValueError, match="evaluate.detection_validity_judge='missing-evaluator'"): + tool.preflight_suite(spec, spec_path=spec_path) + + def test_benchmark_example_suites_are_portable() -> None: example_paths = sorted((REPO_ROOT / "tools/measurement/examples").glob("*.yaml")) assert example_paths + allowed_public_endpoints = {"https://integrate.api.nvidia.com/v1"} machine_specific_fragments = ( "/root/", "/Users/", @@ -728,7 +791,9 @@ def walk(value: Any) -> Iterator[tuple[str, Any]]: if key in path_fields: assert not Path(value).is_absolute(), f"{example_path} uses absolute path for {key}: {value}" if key in {"endpoint", "gliner_endpoint"}: - raise AssertionError(f"{example_path} should use endpoint_env for {key}, not a literal endpoint") + assert value in allowed_public_endpoints, ( + f"{example_path} should use an approved portable endpoint for {key}: {value}" + ) def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: @@ -814,19 +879,13 @@ def run(self, *, config: Any, data: Any) -> None: monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) - spec = tool.BenchmarkSpec( - suite_id="trace-suite", - workloads=[tool.WorkloadSpec(id="input", source="input.csv")], - configs=[tool.ConfigSpec(id="redact", replace="redact")], - ) - pd.DataFrame({"text": ["Alice works at Acme"]}).to_csv(tmp_path / "input.csv", index=False) - case = tool.BenchmarkCase( + spec = _minimal_benchmark_spec( + tool, suite_id="trace-suite", - workload_id="input", - config_id="redact", - repetition=0, - case_id="input__redact__r000", + run_tags={"commit_sha": "abc123", "pipeline_id": "456"}, ) + _write_text_input(tmp_path) + case = _minimal_benchmark_case(tool, suite_id="trace-suite") trace_path = tmp_path / "traces" / "input__redact__r000.jsonl" task_trace_path = tmp_path / "task-traces" / "input__redact__r000.jsonl" @@ -849,3 +908,59 @@ def run(self, *, config: Any, data: Any) -> None: assert captured[0].dd_task_trace_path == task_trace_path assert captured[0].streaming is True assert captured[0].keep_records is False + assert captured[0].run_tags == { + "suite_id": "trace-suite", + "workload_id": "input", + "config_id": "redact", + "repetition": 0, + "case_id": "input__redact__r000", + "commit_sha": "abc123", + "pipeline_id": "456", + } + + +def test_benchmark_case_can_run_optional_evaluation( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + tool = load_tool("measurement_benchmark_tool_evaluate", REPO_ROOT / "tools/measurement/run_benchmarks.py") + calls: list[Any] = [] + run_result = object() + + @contextmanager + def fake_measurement_session(_config: Any) -> Iterator[None]: + yield None + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> object: + calls.append(("run", config.replace, data.text_column)) + return run_result + + def evaluate(self, result: object) -> object: + calls.append(("evaluate", result)) + return result + + monkeypatch.setattr(tool, "configured_measurement_session", fake_measurement_session) + + spec = _minimal_benchmark_spec( + tool, + suite_id="evaluate-suite", + configs=[tool.ConfigSpec(id="redact", replace="redact", evaluate=True)], + ) + _write_text_input(tmp_path) + case = _minimal_benchmark_case(tool, suite_id="evaluate-suite") + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=tmp_path / "raw" / "input__redact__r000.jsonl", + trace_path=None, + task_trace_path=None, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.none, + ) + + assert calls == [("run", tool.Redact(), "text"), ("evaluate", run_result)] diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 08f3a4db..7681217b 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -176,6 +176,10 @@ Example: suite_id: biography-smoke model_configs: ./model-configs.yaml model_providers: ./providers.yaml +run_tags: + anonymizer_ref: main + commit_sha: abc123 + pipeline_id: "456" case_retries: 1 case_retry_backoff_sec: 10 workloads: @@ -192,6 +196,7 @@ workloads: configs: - id: redact-default replace: redact + evaluate: true - id: hash-agent-labels detect: entity_labels: [person, email, api_key, password] @@ -222,6 +227,22 @@ The runner refuses to write into a non-empty output directory unless `--overwrite` is set. By default it also exports Parquet tables into `tables/`; pass `--no-export` when only raw measurement JSONL is needed. +Set `model_configs` and `model_providers` explicitly in checked-in or CI suites. +Relying on Anonymizer defaults makes a run depend on the caller's installed +defaults and provider environment. In provider YAML, put environment variable +names such as `NVIDIA_API_KEY` in `api_key`; do not commit raw keys. The bundled +`repo-data-smoke.yaml` follows this pattern with adjacent model/provider files. + +Use `run_tags` for stable suite-level metadata copied into every measurement +record, such as source refs, commit SHAs, CI pipeline IDs, topology labels, or +benchmark-suite revisions. The runner reserves `suite_id`, `workload_id`, +`config_id`, `repetition`, and `case_id` for its own case identity tags. + +Set `evaluate: true` on a replace config when the benchmark should run +`Anonymizer.evaluate()` after `run()` and capture the LLM-as-judge work in the +same case. This is intentionally replace-only for now; rewrite runs already +perform their internal evaluation/repair loop during `run()`. + Before starting a real run, the benchmark runner performs cheap preflight checks: suite/config parsing, local dataset existence, CSV/Parquet text-column metadata, provider YAML shape, and active model-alias references. `--dry-run` @@ -268,9 +289,14 @@ another directory. Task trace records are separate from raw message traces. They include scheduler metadata such as workflow name, column, row group, row index, task type, status, -queue wait time, execution time, total time, and whether an error was present. -They intentionally do not store raw DataDesigner error strings because those -can contain prompts, outputs, or source values. +relative dispatch/slot-acquired/completion offsets, queue wait time, execution +time, total time, and whether an error was present. They intentionally do not +store raw DataDesigner error strings because those can contain prompts, outputs, +or source values. + +Offsets are relative to the earliest positive `dispatched_at` timestamp in each +DataDesigner workflow trace batch written into the case sidecar. They are meant +for timeline analysis without storing host-specific wall-clock timestamps. ```bash uv run python tools/measurement/run_benchmarks.py \ diff --git a/tools/measurement/examples/repo-data-smoke-models.yaml b/tools/measurement/examples/repo-data-smoke-models.yaml new file mode 100644 index 00000000..5559587a --- /dev/null +++ b/tools/measurement/examples/repo-data-smoke-models.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +model_configs: + - alias: gliner-pii-detector + model: nvidia/gliner-pii + provider: nvidia + inference_parameters: + max_parallel_requests: 1 + timeout: 120 + + - alias: gpt-oss-120b + model: openai/gpt-oss-120b + provider: nvidia + inference_parameters: + max_parallel_requests: 16 + max_tokens: 16384 + temperature: 0.3 + top_p: 0.95 + timeout: 300 + + - alias: nemotron-30b-thinking + model: nvidia/nemotron-3-nano-30b-a3b + provider: nvidia + inference_parameters: + max_parallel_requests: 16 + max_tokens: 8192 + temperature: 0.4 + top_p: 1.0 + timeout: 300 diff --git a/tools/measurement/examples/repo-data-smoke-providers.yaml b/tools/measurement/examples/repo-data-smoke-providers.yaml new file mode 100644 index 00000000..8799886f --- /dev/null +++ b/tools/measurement/examples/repo-data-smoke-providers.yaml @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +providers: + - name: nvidia + endpoint: https://integrate.api.nvidia.com/v1 + provider_type: openai + api_key: NVIDIA_API_KEY diff --git a/tools/measurement/examples/repo-data-smoke.yaml b/tools/measurement/examples/repo-data-smoke.yaml index 5009f054..7a381091 100644 --- a/tools/measurement/examples/repo-data-smoke.yaml +++ b/tools/measurement/examples/repo-data-smoke.yaml @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 suite_id: repo-data-smoke +model_configs: ./repo-data-smoke-models.yaml +model_providers: ./repo-data-smoke-providers.yaml workloads: - id: biographies source: ../../../docs/data/NVIDIA_synthetic_biographies.csv diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 2533dc09..83c8db68 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -111,6 +111,7 @@ class ConfigSpec(BaseModel): detect: dict[str, Any] = Field(default_factory=dict) replace: str | ReplaceSpec | None = None rewrite: RewriteSpec | None = None + evaluate: bool = False emit_telemetry: bool = False @model_validator(mode="after") @@ -119,6 +120,8 @@ def validate_mode(self) -> "ConfigSpec": raise ValueError("config must define replace or rewrite") if self.replace is not None and self.rewrite is not None: raise ValueError("config cannot define both replace and rewrite") + if self.evaluate and self.rewrite is not None: + raise ValueError("evaluate is only supported for replace configs") return self @@ -130,6 +133,9 @@ class MatrixEntry(BaseModel): repetitions: int = Field(default=1, ge=1) +RESERVED_RUN_TAG_KEYS = frozenset({"suite_id", "workload_id", "config_id", "repetition", "case_id"}) + + def _duplicates(values: list[str]) -> list[str]: seen: set[str] = set() duplicates: set[str] = set() @@ -147,6 +153,7 @@ class BenchmarkSpec(BaseModel): model_configs: str | None = None model_providers: str | None = None artifact_path: str | None = None + run_tags: dict[str, Any] = Field(default_factory=dict) case_retries: int = Field(default=0, ge=0) case_retry_backoff_sec: float = Field(default=0.0, ge=0.0) workloads: list[WorkloadSpec] = Field(min_length=1) @@ -162,6 +169,7 @@ def validate_ids(self) -> "BenchmarkSpec": if duplicate_configs := _duplicates(config_ids): raise ValueError(f"duplicate config id(s): {', '.join(duplicate_configs)}") self._validate_matrix_references(set(workload_ids), set(config_ids)) + self._validate_run_tags() return self def _validate_matrix_references(self, workload_ids: set[str], config_ids: set[str]) -> None: @@ -178,6 +186,12 @@ def _validate_matrix_references(self, workload_ids: set[str], config_ids: set[st formatted = ", ".join(f"{workload}/{config}" for workload, config in duplicate_entries) raise ValueError(f"duplicate matrix workload/config entry(s): {formatted}; use repetitions for repeats") + def _validate_run_tags(self) -> None: + reserved_tags = sorted(set(self.run_tags) & RESERVED_RUN_TAG_KEYS) + if reserved_tags: + formatted = ", ".join(reserved_tags) + raise ValueError(f"run_tags cannot define reserved benchmark tag(s): {formatted}") + class BenchmarkCase(BaseModel): suite_id: str @@ -336,6 +350,7 @@ def _preflight_config_errors(spec: BenchmarkSpec, *, parsed_models: Any | None) check_substitute=isinstance(anonymizer_config.replace, Substitute) or anonymizer_config.rewrite is not None, check_rewrite=anonymizer_config.rewrite is not None, + check_evaluate=config.evaluate, ) except ValueError as exc: errors.append(f"config '{config.id}' model aliases invalid: {exc}") @@ -352,14 +367,16 @@ def _preflight_model_providers(spec: BenchmarkSpec, *, base_dir: Path) -> None: raw = _resolve_config_source(spec.model_providers, base_dir) if raw is None: return - config_source: str | Path = raw - if isinstance(raw, str) and "\n" not in raw: + if "\n" in raw: + config_dict = yaml.safe_load(raw) + else: candidate = Path(raw.strip()).expanduser() if candidate.suffix in (".yaml", ".yml"): if not candidate.is_file(): raise FileNotFoundError(f"Providers config file not found: {candidate}") - config_source = candidate - config_dict = yaml.safe_load(raw) if isinstance(raw, str) and "\n" in raw else load_config_file(config_source) + config_dict = load_config_file(candidate) + else: + config_dict = yaml.safe_load(raw) raw_providers = config_dict.get("providers") if isinstance(config_dict, dict) else None if not isinstance(raw_providers, list): raise ValueError("model_providers YAML must contain a top-level 'providers' list.") @@ -796,10 +813,12 @@ def _execute_case( fail_on_write_error=True, ) with configured_measurement_session(measurement): - anonymizer.run( + result = anonymizer.run( config=anonymizer_config, data=input_data, ) + if config.evaluate: + anonymizer.evaluate(result) def build_input( @@ -1063,6 +1082,7 @@ def render_result(result: BenchmarkResult, *, json_output: bool) -> str: def _run_tags(case: BenchmarkCase, spec: BenchmarkSpec) -> dict[str, Any]: return { + **spec.run_tags, "suite_id": spec.suite_id, "workload_id": case.workload_id, "config_id": case.config_id, From 1d88389445d093120e68f0899fc5e697b2305350 Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Thu, 11 Jun 2026 18:47:40 +0000 Subject: [PATCH 23/26] Track unsupported DD trace columns Signed-off-by: Aaron Gonzales --- src/anonymizer/engine/ndd/adapter.py | 1 + tests/test_measurement.py | 44 ++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 96795c36..8c812346 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -853,6 +853,7 @@ def _configure_dd_message_traces( private_trace_columns.append(_PrivateFacadeTraceColumn(column_name=column.name)) continue + unsupported_columns.append(column) configured_columns.append(column) return configured_columns, native_trace_columns, private_trace_columns, unsupported_columns diff --git a/tests/test_measurement.py b/tests/test_measurement.py index 1fc6823a..f36aa4e9 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -945,6 +945,19 @@ def generator( return CustomColumnConfig(name=name, generator_function=generator) +def _local_custom_column(name: str, *, value: str) -> CustomColumnConfig: + @custom_column_generator(required_columns=["text"]) + def generator( + row: dict[str, Any], + generator_params: Any, + models: dict[str, Any], + ) -> dict[str, str]: + _ = row, generator_params, models + return {name: value} + + return CustomColumnConfig(name=name, generator_function=generator) + + def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace( tmp_path: Path, trace_input_df: pd.DataFrame, @@ -1004,6 +1017,37 @@ def test_ndd_adapter_writes_custom_column_private_model_facade_dd_trace( assert "custom response secret" not in serialized_measurements +def test_ndd_adapter_reports_untraced_custom_columns_in_dd_trace_coverage( + tmp_path: Path, + trace_input_df: pd.DataFrame, +) -> None: + adapter = NddAdapter(data_designer=cast(DataDesigner, _CustomTraceDataDesigner(trace_input_df))) + trace_path = tmp_path / "trace.jsonl" + + with configured_measurement_session( + MeasurementConfig( + output_path=tmp_path / "measurements.jsonl", + dd_trace="all_messages", + dd_trace_path=trace_path, + ) + ): + _run_entity_detection_preview( + adapter, + trace_input_df, + [ + _custom_trace_column("raw_detected", prompt="raw prompt secret", value="[]"), + _local_custom_column("local_note", value="ok"), + ], + ) + + measurements = [json.loads(line) for line in (tmp_path / "measurements.jsonl").read_text().splitlines()] + coverage = [record for record in measurements if record["record_type"] == "dd_trace_coverage"] + assert len(coverage) == 1 + assert coverage[0]["traced_column_names"] == ["raw_detected"] + assert coverage[0]["unsupported_column_count"] == 1 + assert coverage[0]["unsupported_column_names"] == ["local_note"] + + def test_ndd_adapter_private_model_facade_trace_write_error_is_not_wrapped( trace_input_df: pd.DataFrame, ) -> None: From 00c729393c31528676191cae47c3360eefead16b Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Thu, 11 Jun 2026 19:16:04 +0000 Subject: [PATCH 24/26] Harden DD trace benchmark plumbing Signed-off-by: Aaron Gonzales --- .github/workflows/benchmark-ci.yml | 4 +- src/anonymizer/engine/ndd/adapter.py | 320 +++++++++--------- tests/test_measurement.py | 13 +- tests/tools/test_measurement_tools.py | 10 + tools/measurement/README.md | 14 +- .../run-repo-data-smoke-with-dd-traces.sh | 2 +- 6 files changed, 196 insertions(+), 167 deletions(-) diff --git a/.github/workflows/benchmark-ci.yml b/.github/workflows/benchmark-ci.yml index a67ffebf..687690ca 100644 --- a/.github/workflows/benchmark-ci.yml +++ b/.github/workflows/benchmark-ci.yml @@ -24,8 +24,8 @@ on: type: choice options: - "none" - - "last-message" - - "all-messages" + - "last_message" + - "all_messages" default: "none" dd_task_trace: description: "Capture sanitized DataDesigner scheduler task traces" diff --git a/src/anonymizer/engine/ndd/adapter.py b/src/anonymizer/engine/ndd/adapter.py index 8c812346..a7d9015b 100644 --- a/src/anonymizer/engine/ndd/adapter.py +++ b/src/anonymizer/engine/ndd/adapter.py @@ -75,6 +75,163 @@ class _PrivateFacadeTraceColumn: column_name: str +@dataclass(frozen=True) +class _DDMessageTracePlan: + columns: list[ColumnConfigT] + native_columns: list[_NativeTraceColumn] + private_columns: list[_PrivateFacadeTraceColumn] + unsupported_columns: list[ColumnConfigT] + + @classmethod + def from_columns( + cls, + *, + columns: list[ColumnConfigT], + model_configs: list[ModelConfig], + collector: Any | None, + ) -> _DDMessageTracePlan: + if collector is None or not collector.dd_trace_enabled: + return cls(columns=columns, native_columns=[], private_columns=[], unsupported_columns=[]) + + model_configs_by_alias = {model_config.alias: model_config for model_config in model_configs} + native_columns: list[_NativeTraceColumn] = [] + private_columns: list[_PrivateFacadeTraceColumn] = [] + unsupported_columns: list[ColumnConfigT] = [] + configured_columns: list[ColumnConfigT] = [] + + for column in columns: + if isinstance(column, _TRACEABLE_LLM_COLUMN_TYPES): + configured_columns.append( + cast(ColumnConfigT, column.model_copy(update={"with_trace": cls.trace_type()})) + ) + model_config = model_configs_by_alias.get(column.model_alias) + native_columns.append( + _NativeTraceColumn( + column_name=column.name, + trace_column_name=f"{column.name}{TRACE_COLUMN_POSTFIX}", + model_alias=column.model_alias, + model_name=getattr(model_config, "model", None), + model_provider_name=getattr(model_config, "provider", None), + ) + ) + continue + + if _column_has_private_facade_model_calls(column): + configured_columns.append(_custom_column_with_trace_context(column)) + private_columns.append(_PrivateFacadeTraceColumn(column_name=column.name)) + continue + + unsupported_columns.append(column) + configured_columns.append(column) + + return cls( + columns=configured_columns, + native_columns=native_columns, + private_columns=private_columns, + unsupported_columns=unsupported_columns, + ) + + @staticmethod + def trace_type() -> TraceType: + # Preserve Anonymizer's existing dd_trace=last_message semantics: the trace + # sink records the final prompt message and response separately, while DD's + # native LAST_MESSAGE side effect only keeps the final assistant message. + return TraceType.ALL_MESSAGES + + def record_coverage(self, *, workflow_name: str, collector: Any | None) -> None: + if collector is None or not collector.dd_trace_enabled: + return + + traced_column_names = [column.column_name for column in self.native_columns] + [ + column.column_name for column in self.private_columns + ] + collector.record( + "dd_trace_coverage", + workflow_name=workflow_name, + trace_backend=self.backend, + trace_mode=collector.dd_trace_mode, + native_trace_type=self.trace_type().value, + traced_column_count=len(traced_column_names), + traced_column_names=traced_column_names, + native_trace_column_count=len(self.native_columns), + native_trace_column_names=[column.column_name for column in self.native_columns], + private_trace_column_count=len(self.private_columns), + private_trace_column_names=[column.column_name for column in self.private_columns], + private_trace_backend="anonymizer_private_model_facade" if self.private_columns else None, + private_trace_note=( + "temporary private DataDesigner model registry/facade instrumentation" if self.private_columns else None + ), + unsupported_column_count=len(self.unsupported_columns), + unsupported_column_names=[column.name for column in self.unsupported_columns], + unsupported_column_types=[_column_type_name(column) for column in self.unsupported_columns], + ) + + @property + def backend(self) -> str: + if self.native_columns and self.private_columns: + return "mixed" + if self.private_columns: + return "anonymizer_private_model_facade" + return "data_designer_column" + + def record_and_strip_native_traces( + self, + *, + output_df: pd.DataFrame, + workflow_name: str, + collector: Any | None, + ) -> pd.DataFrame: + if not self.native_columns: + return output_df + + trace_column_names = [column.trace_column_name for column in self.native_columns] + if collector is not None and collector.dd_trace_enabled: + for _, row in output_df.iterrows(): + for trace_column in self.native_columns: + if trace_column.trace_column_name not in output_df.columns: + continue + self._record_native_trace( + trace_column=trace_column, + trace_value=row.get(trace_column.trace_column_name), + workflow_name=workflow_name, + collector=collector, + ) + + existing_trace_columns = [column_name for column_name in trace_column_names if column_name in output_df.columns] + if not existing_trace_columns: + return output_df + return output_df.drop(columns=existing_trace_columns) + + @staticmethod + def _record_native_trace( + *, + trace_column: _NativeTraceColumn, + trace_value: Any, + workflow_name: str, + collector: Any, + ) -> None: + trace_messages = _native_trace_messages(trace_value) + if not trace_messages: + return + collector.record_dd_message_trace( + workflow_name=workflow_name, + trace_source="data_designer_column", + column_name=trace_column.column_name, + trace_column_name=trace_column.trace_column_name, + model_alias=trace_column.model_alias, + model_name=trace_column.model_name, + model_provider_name=trace_column.model_provider_name, + modality="chat", + is_async=None, + status="completed", + error_type=None, + elapsed_sec=None, + messages=_select_native_trace_messages(trace_messages, mode=collector.dd_trace_mode), + response=_native_trace_response(trace_messages), + usage=None, + ) + + class _TaskTraceLike(Protocol): column: Any row_group: Any @@ -158,25 +315,20 @@ def run_workflow( ) started = time.perf_counter() collector = current_collector() - columns, native_trace_columns, private_trace_columns, unsupported_trace_columns = _configure_dd_message_traces( + trace_plan = _DDMessageTracePlan.from_columns( columns=columns, model_configs=model_configs, collector=collector, ) + columns = trace_plan.columns usage_probe = _DataDesignerUsageProbe( self._data_designer, enabled=collector is not None, collector=collector, workflow_name=workflow_name, - private_trace_columns=private_trace_columns, - ) - _record_dd_trace_coverage( - workflow_name=workflow_name, - collector=collector, - native_trace_columns=native_trace_columns, - private_trace_columns=private_trace_columns, - unsupported_trace_columns=unsupported_trace_columns, + private_trace_columns=trace_plan.private_columns, ) + trace_plan.record_coverage(workflow_name=workflow_name, collector=collector) with tempfile.TemporaryDirectory(prefix=f"anonymizer_{workflow_name}_") as tmp_dir: seed_path = str(Path(tmp_dir) / "seed.parquet") @@ -240,11 +392,10 @@ def run_workflow( ) raise AnonymizerWorkflowError(f"Workflow failed: {exc}") from exc - output_df = _record_and_strip_native_dd_message_traces( + output_df = trace_plan.record_and_strip_native_traces( output_df=output_df, workflow_name=workflow_name, collector=collector, - native_trace_columns=native_trace_columns, ) _record_dd_task_traces( workflow_name=workflow_name, @@ -816,158 +967,23 @@ def traced_generator(*args: Any, **kwargs: Any) -> Any: return cast(ColumnConfigT, column.model_copy(update={"generator_function": traced_generator})) -def _configure_dd_message_traces( - *, - columns: list[ColumnConfigT], - model_configs: list[ModelConfig], - collector: Any | None, -) -> tuple[list[ColumnConfigT], list[_NativeTraceColumn], list[_PrivateFacadeTraceColumn], list[ColumnConfigT]]: - if collector is None or not collector.dd_trace_enabled: - return columns, [], [], [] - - model_configs_by_alias = {model_config.alias: model_config for model_config in model_configs} - native_trace_columns: list[_NativeTraceColumn] = [] - private_trace_columns: list[_PrivateFacadeTraceColumn] = [] - unsupported_columns: list[ColumnConfigT] = [] - configured_columns: list[ColumnConfigT] = [] - trace_type = _native_dd_trace_type() - - for column in columns: - if isinstance(column, _TRACEABLE_LLM_COLUMN_TYPES): - configured_column = cast(ColumnConfigT, column.model_copy(update={"with_trace": trace_type})) - configured_columns.append(configured_column) - model_config = model_configs_by_alias.get(column.model_alias) - native_trace_columns.append( - _NativeTraceColumn( - column_name=column.name, - trace_column_name=f"{column.name}{TRACE_COLUMN_POSTFIX}", - model_alias=column.model_alias, - model_name=getattr(model_config, "model", None), - model_provider_name=getattr(model_config, "provider", None), - ) - ) - continue - - if _column_has_private_facade_model_calls(column): - configured_columns.append(_custom_column_with_trace_context(column)) - private_trace_columns.append(_PrivateFacadeTraceColumn(column_name=column.name)) - continue - - unsupported_columns.append(column) - configured_columns.append(column) - - return configured_columns, native_trace_columns, private_trace_columns, unsupported_columns - - -def _native_dd_trace_type() -> TraceType: - # Preserve Anonymizer's existing dd_trace=last_message semantics: the trace - # sink records the final prompt message and response separately, while DD's - # native LAST_MESSAGE side effect only keeps the final assistant message. - return TraceType.ALL_MESSAGES - - def _column_has_private_facade_model_calls(column: ColumnConfigT) -> TypeGuard[CustomColumnConfig]: return isinstance(column, CustomColumnConfig) and bool(_extract_workflow_model_aliases([column])) -def _record_dd_trace_coverage( - *, - workflow_name: str, - collector: Any, - native_trace_columns: list[_NativeTraceColumn], - private_trace_columns: list[_PrivateFacadeTraceColumn], - unsupported_trace_columns: list[ColumnConfigT], -) -> Any: - if collector is None or not collector.dd_trace_enabled: - return - traced_column_names = [column.column_name for column in native_trace_columns] + [ - column.column_name for column in private_trace_columns - ] - collector.record( - "dd_trace_coverage", - workflow_name=workflow_name, - trace_backend=_dd_trace_backend(native_trace_columns, private_trace_columns), - trace_mode=collector.dd_trace_mode, - native_trace_type=_native_dd_trace_type().value, - traced_column_count=len(traced_column_names), - traced_column_names=traced_column_names, - native_trace_column_count=len(native_trace_columns), - native_trace_column_names=[column.column_name for column in native_trace_columns], - private_trace_column_count=len(private_trace_columns), - private_trace_column_names=[column.column_name for column in private_trace_columns], - private_trace_backend="anonymizer_private_model_facade" if private_trace_columns else None, - private_trace_note=( - "temporary private DataDesigner model registry/facade instrumentation" if private_trace_columns else None - ), - unsupported_column_count=len(unsupported_trace_columns), - unsupported_column_names=[column.name for column in unsupported_trace_columns], - unsupported_column_types=[_column_type_name(column) for column in unsupported_trace_columns], - ) - - -def _dd_trace_backend( - native_trace_columns: list[_NativeTraceColumn], - private_trace_columns: list[_PrivateFacadeTraceColumn], -) -> str: - if native_trace_columns and private_trace_columns: - return "mixed" - if private_trace_columns: - return "anonymizer_private_model_facade" - return "data_designer_column" - - def _column_type_name(column: ColumnConfigT) -> str: column_type = getattr(column, "column_type", None) return str(column_type) if column_type is not None else type(column).__name__ -def _record_and_strip_native_dd_message_traces( - *, - output_df: pd.DataFrame, - workflow_name: str, - collector: Any, - native_trace_columns: list[_NativeTraceColumn], -) -> pd.DataFrame: - if not native_trace_columns: - return output_df - - trace_column_names = [column.trace_column_name for column in native_trace_columns] - if collector is not None and collector.dd_trace_enabled: - for _, row in output_df.iterrows(): - for trace_column in native_trace_columns: - if trace_column.trace_column_name not in output_df.columns: - continue - trace_messages = _native_trace_messages(row.get(trace_column.trace_column_name)) - if not trace_messages: - continue - collector.record_dd_message_trace( - workflow_name=workflow_name, - trace_source="data_designer_column", - column_name=trace_column.column_name, - trace_column_name=trace_column.trace_column_name, - model_alias=trace_column.model_alias, - model_name=trace_column.model_name, - model_provider_name=trace_column.model_provider_name, - modality="chat", - is_async=None, - status="completed", - error_type=None, - elapsed_sec=None, - messages=_select_native_trace_messages(trace_messages, mode=collector.dd_trace_mode), - response=_native_trace_response(trace_messages), - usage=None, - ) - - existing_trace_columns = [column_name for column_name in trace_column_names if column_name in output_df.columns] - if not existing_trace_columns: - return output_df - return output_df.drop(columns=existing_trace_columns) - - def _native_trace_messages(value: Any) -> list[dict[str, Any]]: - if not isinstance(value, list): + if value is None or isinstance(value, (str, bytes, Mapping)): + return [] + try: + messages = list(value) + except TypeError: return [] - return [_trace_message(message) for message in value] + return [_trace_message(message) for message in messages] def _select_native_trace_messages(messages: list[dict[str, Any]], *, mode: str) -> list[dict[str, Any]]: diff --git a/tests/test_measurement.py b/tests/test_measurement.py index f36aa4e9..5be9e361 100644 --- a/tests/test_measurement.py +++ b/tests/test_measurement.py @@ -811,11 +811,14 @@ def preview(self, config_builder: object, *, num_records: int) -> SimpleNamespac output = input_df.iloc[:num_records].copy() output["raw_detected"] = "[]" output[f"raw_detected{TRACE_COLUMN_POSTFIX}"] = [ - [ - {"role": "system", "content": [{"type": "text", "text": "system secret"}]}, - {"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}, - {"role": "assistant", "content": "secret response", "reasoning_content": "scratch"}, - ] + np.array( + [ + {"role": "system", "content": [{"type": "text", "text": "system secret"}]}, + {"role": "user", "content": [{"type": "text", "text": "prompt secret"}]}, + {"role": "assistant", "content": "secret response", "reasoning_content": "scratch"}, + ], + dtype=object, + ) ] return SimpleNamespace(dataset=output) diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 07967b8f..0dc1993e 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -796,6 +796,16 @@ def walk(value: Any) -> Iterator[tuple[str, Any]]: ) +def test_benchmark_ci_dd_trace_options_match_runner_enum() -> None: + tool = load_tool("measurement_benchmark_tool_ci", REPO_ROOT / "tools/measurement/run_benchmarks.py") + workflow = yaml.safe_load((REPO_ROOT / ".github/workflows/benchmark-ci.yml").read_text(encoding="utf-8")) + on_section = workflow.get("on", workflow.get(True)) + + options = on_section["workflow_dispatch"]["inputs"]["dd_trace"]["options"] + + assert options == [mode.value for mode in tool.DDTraceMode] + + def test_benchmark_preflight_rejects_bad_provider_config(tmp_path: Path) -> None: tool = load_tool( "measurement_benchmark_tool_preflight_providers", REPO_ROOT / "tools/measurement/run_benchmarks.py" diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 7681217b..2c424f30 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -115,7 +115,7 @@ uv run python tools/measurement/run_benchmarks.py suite.yaml --output benchmark- uv run python tools/measurement/run_benchmarks.py suite.yaml --dry-run --json uv run python tools/measurement/run_benchmarks.py suite.yaml \ --output benchmark-runs/suite \ - --dd-trace last-message + --dd-trace last_message uv run python tools/measurement/run_benchmarks.py suite.yaml \ --output benchmark-runs/suite \ --dd-task-trace @@ -129,10 +129,10 @@ bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh The script writes to `/tmp/anonymizer-repo-data-smoke-dd-traces` by default. Pass a different output directory as the first argument, or set -`DD_TRACE_MODE=all-messages` when full chat history is needed: +`DD_TRACE_MODE=all_messages` when full chat history is needed: ```bash -DD_TRACE_MODE=all-messages \ +DD_TRACE_MODE=all_messages \ bash tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh \ /tmp/anonymizer-repo-data-smoke-dd-traces-full ``` @@ -257,11 +257,11 @@ remains fail-fast and bypasses retries. ## DataDesigner traces -For debugging DataDesigner calls, pass `--dd-trace last-message` or -`--dd-trace all-messages`. Trace records are written separately from sanitized +For debugging DataDesigner calls, pass `--dd-trace last_message` or +`--dd-trace all_messages`. Trace records are written separately from sanitized measurements, under `traces/{case_id}.jsonl` by default. Use `--trace-dir` to -choose another directory. `last-message` stores only the final prompt message -for each DataDesigner model call; `all-messages` stores the full message list. +choose another directory. `last_message` stores only the final prompt message +for each DataDesigner model call; `all_messages` stores the full message list. DataDesigner traces may contain raw input text, prompts, model outputs, entity values, replacement values, secrets, and PII. Treat them as debug artifacts: diff --git a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh index 9000f03f..5bc5612a 100644 --- a/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh +++ b/tools/measurement/examples/run-repo-data-smoke-with-dd-traces.sh @@ -5,7 +5,7 @@ set -euo pipefail output_dir="${1:-/tmp/anonymizer-repo-data-smoke-dd-traces}" -trace_mode="${DD_TRACE_MODE:-last-message}" +trace_mode="${DD_TRACE_MODE:-last_message}" script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" repo_root="$(cd "${script_dir}/../../.." && pwd)" suite_file="${BENCHMARK_SUITE:-${script_dir}/repo-data-smoke.yaml}" From 2a49d1889a48bd488946c7047a8b5f79872134de Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Fri, 12 Jun 2026 00:25:52 +0000 Subject: [PATCH 25/26] Add sanitized evaluation measurement records Signed-off-by: Aaron Gonzales --- src/anonymizer/measurement/__init__.py | 3 +- src/anonymizer/measurement/records/row.py | 127 +++++++++++++++++- tests/tools/test_measurement_tools.py | 151 +++++++++++++++++++++- tools/measurement/README.md | 22 +++- tools/measurement/run_benchmarks.py | 10 +- 5 files changed, 306 insertions(+), 7 deletions(-) diff --git a/src/anonymizer/measurement/__init__.py b/src/anonymizer/measurement/__init__.py index 944ba35e..aee9499a 100644 --- a/src/anonymizer/measurement/__init__.py +++ b/src/anonymizer/measurement/__init__.py @@ -19,7 +19,7 @@ record_stage, stage_timer, ) -from anonymizer.measurement.records.row import record_record_metrics +from anonymizer.measurement.records.row import record_evaluation_metrics, record_record_metrics from anonymizer.measurement.session import configured_measurement_session, current_collector, measurement_session __all__ = [ @@ -35,6 +35,7 @@ "measurement_session", "record_model_workflow", "record_ndd_workflow", + "record_evaluation_metrics", "record_record_metrics", "record_run_metadata", "record_stage", diff --git a/src/anonymizer/measurement/records/row.py b/src/anonymizer/measurement/records/row.py index e8b64106..087e938f 100644 --- a/src/anonymizer/measurement/records/row.py +++ b/src/anonymizer/measurement/records/row.py @@ -4,10 +4,22 @@ from __future__ import annotations from collections import Counter -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, TypedDict -from anonymizer.engine.constants import COL_FINAL_ENTITIES +from anonymizer.engine.constants import ( + COL_ATTRIBUTE_FIDELITY_INVALID_ENTITIES, + COL_ATTRIBUTE_FIDELITY_VALID, + COL_DETECTION_INVALID_ENTITIES, + COL_DETECTION_VALID, + COL_FINAL_ENTITIES, + COL_RELATIONAL_CONSISTENCY_INVALID_RELATIONS, + COL_RELATIONAL_CONSISTENCY_VALID, + COL_TYPE_FIDELITY_INVALID_REPLACEMENTS, + COL_TYPE_FIDELITY_VALID, +) from anonymizer.measurement._coerce import ( + _coerce_bool, _coerce_int, _count_items, _count_text_tokens, @@ -34,6 +46,69 @@ from anonymizer.measurement.collector import MeasurementCollector +_EvaluationBoolField = Literal[ + "detection_valid", + "type_fidelity_valid", + "relational_consistency_valid", + "attribute_fidelity_valid", +] +_EvaluationCountField = Literal[ + "detection_invalid_entity_count", + "type_fidelity_invalid_replacement_count", + "relational_consistency_invalid_relation_count", + "attribute_fidelity_invalid_entity_count", +] + + +class _EvaluationRecordFields(TypedDict, total=False): + detection_valid: bool | None + type_fidelity_valid: bool | None + relational_consistency_valid: bool | None + attribute_fidelity_valid: bool | None + detection_invalid_entity_count: int + type_fidelity_invalid_replacement_count: int + relational_consistency_invalid_relation_count: int + attribute_fidelity_invalid_entity_count: int + + +@dataclass(frozen=True) +class _EvaluationBoolMetric: + source_column: str + output_field: _EvaluationBoolField + + +@dataclass(frozen=True) +class _EvaluationCountMetric: + source_column: str + output_field: _EvaluationCountField + primary_key: str + + +_EVALUATION_BOOL_METRICS = ( + _EvaluationBoolMetric(COL_DETECTION_VALID, "detection_valid"), + _EvaluationBoolMetric(COL_TYPE_FIDELITY_VALID, "type_fidelity_valid"), + _EvaluationBoolMetric(COL_RELATIONAL_CONSISTENCY_VALID, "relational_consistency_valid"), + _EvaluationBoolMetric(COL_ATTRIBUTE_FIDELITY_VALID, "attribute_fidelity_valid"), +) + +_EVALUATION_COUNT_METRICS = ( + _EvaluationCountMetric(COL_DETECTION_INVALID_ENTITIES, "detection_invalid_entity_count", "invalid_entities"), + _EvaluationCountMetric( + COL_TYPE_FIDELITY_INVALID_REPLACEMENTS, + "type_fidelity_invalid_replacement_count", + "invalid_replacements", + ), + _EvaluationCountMetric( + COL_RELATIONAL_CONSISTENCY_INVALID_RELATIONS, + "relational_consistency_invalid_relation_count", + "invalid_relations", + ), + _EvaluationCountMetric( + COL_ATTRIBUTE_FIDELITY_INVALID_ENTITIES, "attribute_fidelity_invalid_entity_count", "entities" + ), +) + + def record_record_metrics( dataframe: pd.DataFrame, *, @@ -76,6 +151,54 @@ def record_record_metrics( ) +def record_evaluation_metrics( + dataframe: pd.DataFrame, + *, + mode: str, + strategy: str, + text_column: str, +) -> None: + """Record sanitized per-row LLM-as-judge verdict metrics from an evaluated trace dataframe.""" + collector = current_collector() + if collector is None or not collector.record_level: + return + + columns = set(dataframe.columns) + if not _has_evaluation_metrics(columns): + return + + for row_index, row in dataframe.iterrows(): + collector.record( + "evaluation_record", + **_base_record_fields( + collector=collector, + row_index=row_index, + row=row, + text_column=text_column, + mode=mode, + strategy=strategy, + ), + **_evaluation_record_fields(row, columns=columns), + ) + + +def _has_evaluation_metrics(columns: set[str]) -> bool: + return any(metric.source_column in columns for metric in _EVALUATION_BOOL_METRICS) or any( + metric.source_column in columns for metric in _EVALUATION_COUNT_METRICS + ) + + +def _evaluation_record_fields(row: pd.Series, *, columns: set[str]) -> _EvaluationRecordFields: + fields: _EvaluationRecordFields = {} + for metric in _EVALUATION_BOOL_METRICS: + if metric.source_column in columns: + fields[metric.output_field] = _coerce_bool(row.get(metric.source_column)) + for metric in _EVALUATION_COUNT_METRICS: + if metric.source_column in columns: + fields[metric.output_field] = _count_items(row.get(metric.source_column), primary_key=metric.primary_key) + return fields + + def _base_record_fields( *, collector: MeasurementCollector, diff --git a/tests/tools/test_measurement_tools.py b/tests/tools/test_measurement_tools.py index 0dc1993e..3aa4a009 100644 --- a/tests/tools/test_measurement_tools.py +++ b/tests/tools/test_measurement_tools.py @@ -4,6 +4,7 @@ from __future__ import annotations import importlib.util +import json import sys from collections.abc import Iterator from contextlib import contextmanager @@ -934,8 +935,16 @@ def test_benchmark_case_can_run_optional_evaluation( tmp_path: Path, ) -> None: tool = load_tool("measurement_benchmark_tool_evaluate", REPO_ROOT / "tools/measurement/run_benchmarks.py") + from anonymizer.interface.results import AnonymizerResult + calls: list[Any] = [] - run_result = object() + run_result = AnonymizerResult( + dataframe=pd.DataFrame({"text": ["Alice works at Acme"]}), + trace_dataframe=pd.DataFrame({"text": ["Alice works at Acme"]}), + resolved_text_column="text", + failed_records=[], + replace_method=None, + ) @contextmanager def fake_measurement_session(_config: Any) -> Iterator[None]: @@ -974,3 +983,143 @@ def evaluate(self, result: object) -> object: ) assert calls == [("run", tool.Redact(), "text"), ("evaluate", run_result)] + + +def test_benchmark_optional_evaluation_records_sanitized_judge_metrics(tmp_path: Path) -> None: + tool = load_tool("measurement_benchmark_tool_evaluate_metrics", REPO_ROOT / "tools/measurement/run_benchmarks.py") + from anonymizer.interface.results import AnonymizerResult + + dangerous_values = [ + "alice@example.com", + "bob@example.com", + "sk-secret-123", + "replacement-output-secret", + "nested-malformed-secret", + "raw judge prompt", + "raw judge response", + ] + run_result = AnonymizerResult( + dataframe=pd.DataFrame({"text": ["Alice has sk-secret-123"]}), + trace_dataframe=pd.DataFrame({"text": ["Alice has sk-secret-123"]}), + resolved_text_column="text", + failed_records=[], + replace_method=None, + ) + evaluated_public_columns = { + "text": ["Alice has sk-secret-123"], + "text_replaced": ["Avery has replacement-output-secret"], + "final_entities": [[{"value": "alice@example.com", "label": "email"}]], + "detection_valid": [False], + "detection_invalid_entities": [{"invalid_entities": [{"value": "alice@example.com", "label": "email"}]}], + "type_fidelity_valid": [False], + "type_fidelity_invalid_replacements": [ + {"invalid_replacements": [{"original": "alice@example.com", "synthetic": "bob@example.com"}]} + ], + "relational_consistency_valid": [False], + "relational_consistency_invalid_relations": [{"invalid_relations": [{"reasoning": "raw judge response"}]}], + "attribute_fidelity_valid": [False], + "attribute_fidelity_invalid_entities": ['[{"entity": "nested-malformed-secret"}'], + } + evaluated_result = AnonymizerResult( + dataframe=pd.DataFrame(evaluated_public_columns), + trace_dataframe=pd.DataFrame( + { + **evaluated_public_columns, + "_detection_judge": [ + { + "prompt": "raw judge prompt", + "response": "raw judge response", + "invalid_entities": [{"value": "alice@example.com"}], + } + ], + "_type_fidelity_judge": [ + {"invalid_replacements": [{"original": "alice@example.com", "synthetic": "bob@example.com"}]} + ], + } + ), + resolved_text_column="text", + failed_records=[], + replace_method=None, + ) + + class FakeAnonymizer: + def run(self, *, config: Any, data: Any) -> AnonymizerResult: + return run_result + + def evaluate(self, result: AnonymizerResult) -> AnonymizerResult: + assert result is run_result + return evaluated_result + + spec = _minimal_benchmark_spec( + tool, + suite_id="evaluate-suite", + configs=[ + tool.ConfigSpec( + id="substitute", + replace=tool.ReplaceSpec(strategy=tool.ReplaceKind.substitute), + evaluate=True, + ) + ], + ) + _write_text_input(tmp_path, "Alice has sk-secret-123") + case = _minimal_benchmark_case(tool, suite_id="evaluate-suite", config_id="substitute") + measurement_path = tmp_path / "raw" / "input__substitute__r000.jsonl" + + tool._execute_case( + FakeAnonymizer(), + spec.workloads[0], + spec.configs[0], + raw_path=measurement_path, + trace_path=None, + task_trace_path=None, + case=case, + spec=spec, + base_dir=tmp_path, + dd_trace=tool.DDTraceMode.none, + ) + + serialized = measurement_path.read_text(encoding="utf-8") + rows = [json.loads(line) for line in serialized.splitlines()] + evaluation_rows = [row for row in rows if row["record_type"] == "evaluation_record"] + + assert len(evaluation_rows) == 1 + assert { + "record_type": "evaluation_record", + "mode": "replace", + "strategy": "Substitute", + "row_index": 0, + "detection_valid": False, + "detection_invalid_entity_count": 1, + "type_fidelity_valid": False, + "type_fidelity_invalid_replacement_count": 1, + "relational_consistency_valid": False, + "relational_consistency_invalid_relation_count": 1, + "attribute_fidelity_valid": False, + "attribute_fidelity_invalid_entity_count": 0, + }.items() <= evaluation_rows[0].items() + forbidden_fields = { + "text", + "text_replaced", + "text_with_spans", + "final_entities", + "detection_invalid_entities", + "type_fidelity_invalid_replacements", + "relational_consistency_invalid_relations", + "attribute_fidelity_invalid_entities", + "_detection_judge", + "_type_fidelity_judge", + "_relational_consistency_judge", + "_attribute_fidelity_judge", + } + assert forbidden_fields.isdisjoint(evaluation_rows[0]) + for raw_value in dangerous_values: + assert raw_value not in serialized + + table_dir = tmp_path / "tables" + tool.export_measurement_tables(measurement_path, table_dir) + exported = pd.read_parquet(table_dir / "evaluation_record.parquet") + exported_text = str(exported.to_json(orient="records")) + + assert forbidden_fields.isdisjoint(exported.columns) + for raw_value in dangerous_values: + assert raw_value not in exported_text diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 2c424f30..6b69cc4b 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -28,6 +28,7 @@ plus `manifest.json`: - `run.parquet` - `stage.parquet` - `record.parquet` +- `evaluation_record.parquet` when replace judge evaluation is enabled - `ndd_workflow.parquet` when DataDesigner adapter records are present - `model_workflow.parquet` when direct model workflow records are present @@ -65,7 +66,8 @@ raw = pd.read_json("benchmark-runs/suite/measurements.jsonl", lines=True) The measurement system has three layers: - Instrumentation in Anonymizer emits JSONL records for runs, stages, - DataDesigner workflows, direct model workflows, and per-record safety metrics. + DataDesigner workflows, direct model workflows, per-record safety metrics, + and optional sanitized replace-judge evaluation metrics. - Benchmark runners create repeatable workloads and write those JSONL records plus optional sidecars such as detection artifacts and DataDesigner traces. - Analysis tools convert raw run artifacts into case, group, and model tables. @@ -243,6 +245,12 @@ Set `evaluate: true` on a replace config when the benchmark should run same case. This is intentionally replace-only for now; rewrite runs already perform their internal evaluation/repair loop during `run()`. +When evaluation is enabled, the safe measurement log includes +`evaluation_record` rows with judge verdict booleans and invalid-item counts. +It does not persist the evaluated result dataframe or trace dataframe. Those +dataframes can contain original text, entity values, replacement values, raw +judge outputs, prompts, and model responses. + Before starting a real run, the benchmark runner performs cheap preflight checks: suite/config parsing, local dataset existence, CSV/Parquet text-column metadata, provider YAML shape, and active model-alias references. `--dry-run` @@ -409,3 +417,15 @@ Safety and replacement: - `replacement_synthetic_original_collision_count`: final entity occurrences whose original value was reused as a synthetic replacement value elsewhere in the same record. + +Replace judge evaluation: + +- `detection_valid`, `type_fidelity_valid`, + `relational_consistency_valid`, and `attribute_fidelity_valid`: per-record + judge verdicts when `evaluate: true` is enabled. +- `detection_invalid_entity_count`, + `type_fidelity_invalid_replacement_count`, + `relational_consistency_invalid_relation_count`, and + `attribute_fidelity_invalid_entity_count`: counts of invalid judge findings. + These fields count structures returned by the judges but do not include raw + values, replacement strings, or judge reasoning text. diff --git a/tools/measurement/run_benchmarks.py b/tools/measurement/run_benchmarks.py index 83c8db68..5e67ca6c 100755 --- a/tools/measurement/run_benchmarks.py +++ b/tools/measurement/run_benchmarks.py @@ -45,7 +45,7 @@ from anonymizer.engine.io.constants import SUPPORTED_IO_FORMATS from anonymizer.engine.ndd.model_loader import parse_model_configs, validate_model_alias_references from anonymizer.interface.anonymizer import Anonymizer -from anonymizer.measurement import MeasurementConfig, configured_measurement_session +from anonymizer.measurement import MeasurementConfig, configured_measurement_session, record_evaluation_metrics app = cyclopts.App(help=__doc__) logger = logging.getLogger("measurement.benchmark") @@ -818,7 +818,13 @@ def _execute_case( data=input_data, ) if config.evaluate: - anonymizer.evaluate(result) + evaluated = anonymizer.evaluate(result) + record_evaluation_metrics( + evaluated.trace_dataframe, + mode="replace", + strategy=type(anonymizer_config.replace).__name__, + text_column=evaluated.resolved_text_column, + ) def build_input( From 6f843afa744ab96069d30767ba06b9f39d7cc2bc Mon Sep 17 00:00:00 2001 From: Aaron Gonzales Date: Fri, 12 Jun 2026 06:29:21 +0000 Subject: [PATCH 26/26] Add evaluation rollups to benchmark analysis Signed-off-by: Aaron Gonzales --- tests/tools/test_benchmark_output_analysis.py | 98 ++++++++++++++++ tools/measurement/README.md | 7 ++ tools/measurement/analyze_benchmark_output.py | 105 ++++++++++++++++++ 3 files changed, 210 insertions(+) diff --git a/tests/tools/test_benchmark_output_analysis.py b/tests/tools/test_benchmark_output_analysis.py index dbb8afc4..0a88236e 100644 --- a/tests/tools/test_benchmark_output_analysis.py +++ b/tests/tools/test_benchmark_output_analysis.py @@ -187,6 +187,104 @@ def test_analyze_benchmark_output_counts_generic_model_workflow_records(tmp_path assert result.model_usage_groups[0].sum_observed_total_tokens == 42 +def test_analyze_benchmark_output_rolls_up_evaluation_records(tmp_path: Path) -> None: + tool = load_tool( + "measurement_benchmark_output_analysis_evaluation_rollups", + REPO_ROOT / "tools/measurement/analyze_benchmark_output.py", + ) + benchmark_dir = tmp_path / "benchmark" + benchmark_dir.mkdir() + _write_jsonl( + benchmark_dir / "measurements.jsonl", + [ + { + "record_type": "evaluation_record", + "run_id": "bio__substitute__r000", + "detection_valid": True, + "detection_invalid_entity_count": 0, + "type_fidelity_valid": True, + "type_fidelity_invalid_replacement_count": 0, + "relational_consistency_valid": False, + "relational_consistency_invalid_relation_count": 2, + "attribute_fidelity_valid": True, + "attribute_fidelity_invalid_entity_count": 0, + "run_tags": { + "workload_id": "bio", + "config_id": "substitute", + "case_id": "bio__substitute__r000", + }, + }, + { + "record_type": "evaluation_record", + "run_id": "bio__substitute__r000", + "detection_valid": False, + "detection_invalid_entity_count": 3, + "type_fidelity_valid": True, + "type_fidelity_invalid_replacement_count": 0, + "relational_consistency_valid": True, + "relational_consistency_invalid_relation_count": 0, + "attribute_fidelity_valid": None, + "attribute_fidelity_invalid_entity_count": 0, + "run_tags": { + "workload_id": "bio", + "config_id": "substitute", + "case_id": "bio__substitute__r000", + }, + }, + { + "record_type": "evaluation_record", + "run_id": "bio__substitute__r001", + "detection_valid": True, + "detection_invalid_entity_count": 1, + "type_fidelity_valid": False, + "type_fidelity_invalid_replacement_count": 4, + "relational_consistency_valid": True, + "relational_consistency_invalid_relation_count": 0, + "attribute_fidelity_valid": False, + "attribute_fidelity_invalid_entity_count": 5, + "run_tags": { + "workload_id": "bio", + "config_id": "substitute", + "case_id": "bio__substitute__r001", + }, + }, + ], + ) + + result = tool.analyze_benchmark_output(benchmark_dir) + + cases = {row.case_id: row for row in result.cases} + first_case = cases["bio__substitute__r000"] + assert first_case.detection_judged_record_count == 2 + assert first_case.detection_valid_record_count == 1 + assert first_case.detection_valid_rate == pytest.approx(0.5) + assert first_case.detection_invalid_entity_count == 3 + assert first_case.relational_consistency_judged_record_count == 2 + assert first_case.relational_consistency_valid_rate == pytest.approx(0.5) + assert first_case.attribute_fidelity_judged_record_count == 1 + assert first_case.attribute_fidelity_valid_rate == pytest.approx(1.0) + + second_case = cases["bio__substitute__r001"] + assert second_case.type_fidelity_judged_record_count == 1 + assert second_case.type_fidelity_valid_record_count == 0 + assert second_case.type_fidelity_valid_rate == pytest.approx(0.0) + assert second_case.type_fidelity_invalid_replacement_count == 4 + + group = result.groups[0] + assert group.sum_detection_judged_record_count == 3 + assert group.sum_detection_valid_record_count == 2 + assert group.micro_detection_valid_rate == pytest.approx(2 / 3) + assert group.sum_detection_invalid_entity_count == 4 + assert group.sum_type_fidelity_judged_record_count == 3 + assert group.sum_type_fidelity_valid_record_count == 2 + assert group.micro_type_fidelity_valid_rate == pytest.approx(2 / 3) + assert group.sum_type_fidelity_invalid_replacement_count == 4 + assert group.sum_attribute_fidelity_judged_record_count == 2 + assert group.sum_attribute_fidelity_valid_record_count == 1 + assert group.micro_attribute_fidelity_valid_rate == pytest.approx(0.5) + assert group.sum_attribute_fidelity_invalid_entity_count == 5 + + def test_analyze_benchmark_output_accepts_detection_artifact_override(tmp_path: Path) -> None: tool = load_tool( "measurement_benchmark_output_analysis_artifact_override", diff --git a/tools/measurement/README.md b/tools/measurement/README.md index 6b69cc4b..f98f5360 100644 --- a/tools/measurement/README.md +++ b/tools/measurement/README.md @@ -429,3 +429,10 @@ Replace judge evaluation: `attribute_fidelity_invalid_entity_count`: counts of invalid judge findings. These fields count structures returned by the judges but do not include raw values, replacement strings, or judge reasoning text. +- `case_analysis` also includes per-case rollups for each judge family: + `{family}_judged_record_count`, `{family}_valid_record_count`, + `{family}_valid_rate`, and the corresponding invalid-count field. +- `group_analysis` includes grouped micro-rate rollups: + `sum_{family}_judged_record_count`, `sum_{family}_valid_record_count`, + `micro_{family}_valid_rate`, and `sum_{invalid_count_field}`. These rates are + computed from summed counts, not medians of case-level rates. diff --git a/tools/measurement/analyze_benchmark_output.py b/tools/measurement/analyze_benchmark_output.py index 5e1a482e..434cef6f 100644 --- a/tools/measurement/analyze_benchmark_output.py +++ b/tools/measurement/analyze_benchmark_output.py @@ -16,6 +16,7 @@ import logging import math import sys +from dataclasses import dataclass from pathlib import Path from typing import Annotated, Any, cast @@ -45,6 +46,25 @@ } +@dataclass(frozen=True) +class _EvaluationRollup: + prefix: str + valid_column: str + invalid_count_column: str + + +_EVALUATION_ROLLUPS = ( + _EvaluationRollup("detection", "detection_valid", "detection_invalid_entity_count"), + _EvaluationRollup("type_fidelity", "type_fidelity_valid", "type_fidelity_invalid_replacement_count"), + _EvaluationRollup( + "relational_consistency", + "relational_consistency_valid", + "relational_consistency_invalid_relation_count", + ), + _EvaluationRollup("attribute_fidelity", "attribute_fidelity_valid", "attribute_fidelity_invalid_entity_count"), +) + + class CaseAnalysisRow(BaseModel): suite_id: str | None = None workload_id: str | None = None @@ -125,6 +145,22 @@ class CaseAnalysisRow(BaseModel): original_value_leak_count: float | None = None original_value_leak_record_count: int = 0 original_value_leak_label_counts: dict[str, int] = Field(default_factory=dict) + detection_judged_record_count: int = 0 + detection_valid_record_count: int = 0 + detection_valid_rate: float | None = None + detection_invalid_entity_count: int = 0 + type_fidelity_judged_record_count: int = 0 + type_fidelity_valid_record_count: int = 0 + type_fidelity_valid_rate: float | None = None + type_fidelity_invalid_replacement_count: int = 0 + relational_consistency_judged_record_count: int = 0 + relational_consistency_valid_record_count: int = 0 + relational_consistency_valid_rate: float | None = None + relational_consistency_invalid_relation_count: int = 0 + attribute_fidelity_judged_record_count: int = 0 + attribute_fidelity_valid_record_count: int = 0 + attribute_fidelity_valid_rate: float | None = None + attribute_fidelity_invalid_entity_count: int = 0 validation_max_entities_per_call: int | None = None detection_artifact_rows: int = 0 seed_entity_count: float | None = None @@ -217,6 +253,22 @@ class GroupAnalysisRow(BaseModel): sum_original_value_leak_count: float | None = None leaking_case_count: int = 0 median_original_value_leak_count: float | None = None + sum_detection_judged_record_count: int = 0 + sum_detection_valid_record_count: int = 0 + micro_detection_valid_rate: float | None = None + sum_detection_invalid_entity_count: int = 0 + sum_type_fidelity_judged_record_count: int = 0 + sum_type_fidelity_valid_record_count: int = 0 + micro_type_fidelity_valid_rate: float | None = None + sum_type_fidelity_invalid_replacement_count: int = 0 + sum_relational_consistency_judged_record_count: int = 0 + sum_relational_consistency_valid_record_count: int = 0 + micro_relational_consistency_valid_rate: float | None = None + sum_relational_consistency_invalid_relation_count: int = 0 + sum_attribute_fidelity_judged_record_count: int = 0 + sum_attribute_fidelity_valid_record_count: int = 0 + micro_attribute_fidelity_valid_rate: float | None = None + sum_attribute_fidelity_invalid_entity_count: int = 0 median_seed_entity_count: float | None = None median_seed_validation_candidate_count: float | None = None median_estimated_seed_validation_chunk_count: float | None = None @@ -404,6 +456,7 @@ def _build_case_row( artifact_rows = _rows_for_case(artifacts, case_id) trace_rows = _rows_for_case(traces, case_id) record_rows = _records_of_type(measurement_rows, "record") + evaluation_rows = _records_of_type(measurement_rows, "evaluation_record") ndd_rows = _records_of_type(measurement_rows, "ndd_workflow") model_rows = _model_workflow_rows(measurement_rows) stage_rows = _records_of_type(measurement_rows, "stage") @@ -493,6 +546,7 @@ def _build_case_row( original_value_leak_count=_sum_or_none(record_rows, "original_value_leak_count"), original_value_leak_record_count=_positive_count(record_rows, "original_value_leak_count"), original_value_leak_label_counts=_sum_prefixed_ints(record_rows, "original_value_leak_label_counts."), + **_case_evaluation_metrics(evaluation_rows), validation_max_entities_per_call=validation_max_entities_per_call, **_case_artifact_metrics( artifact_rows, @@ -690,6 +744,43 @@ def _error_status_count(rows: pd.DataFrame) -> int: return int(statuses.isin({"error", "failed"}).sum()) +def _case_evaluation_metrics(evaluation_rows: pd.DataFrame) -> dict[str, int | float | None]: + metrics: dict[str, int | float | None] = {} + for rollup in _EVALUATION_ROLLUPS: + judged_count, valid_count = _evaluation_judged_and_valid_counts(evaluation_rows, rollup.valid_column) + metrics[f"{rollup.prefix}_judged_record_count"] = judged_count + metrics[f"{rollup.prefix}_valid_record_count"] = valid_count + metrics[f"{rollup.prefix}_valid_rate"] = _safe_ratio(valid_count, judged_count) + metrics[rollup.invalid_count_column] = _sum_int_or_zero(evaluation_rows, rollup.invalid_count_column) + return metrics + + +def _evaluation_judged_and_valid_counts(evaluation_rows: pd.DataFrame, valid_column: str) -> tuple[int, int]: + if valid_column not in evaluation_rows.columns: + return 0, 0 + verdicts = [_optional_bool(value) for value in evaluation_rows[valid_column].tolist()] + judged_count = sum(verdict is not None for verdict in verdicts) + valid_count = sum(verdict is True for verdict in verdicts) + return judged_count, valid_count + + +def _optional_bool(value: object) -> bool | None: + if value is None or pd.isna(value): + return None + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes"}: + return True + if normalized in {"false", "0", "no"}: + return False + return None + if isinstance(value, int | float): + return bool(value) + return None + + def _case_artifact_metrics( artifact_rows: pd.DataFrame, *, @@ -1187,6 +1278,7 @@ def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysi relaxed_recall = _safe_ratio(relaxed_gt_found, ground_truth_entity_count) label_compatible_precision = _safe_ratio(label_compatible_detected_tp, final_entity_count) label_compatible_recall = _safe_ratio(label_compatible_gt_found, ground_truth_entity_count) + evaluation_metrics = _group_evaluation_metrics(group) return GroupAnalysisRow( workload_id=_none_if_nan(workload_id), workload_category=_none_if_nan(workload_category), @@ -1287,6 +1379,7 @@ def _build_group_row(keys: tuple[Any, ...], group: pd.DataFrame) -> GroupAnalysi sum_original_value_leak_count=_sum_or_none(group, "original_value_leak_count"), leaking_case_count=_positive_count(group, "original_value_leak_count"), median_original_value_leak_count=_median_or_none(group, "original_value_leak_count"), + **evaluation_metrics, median_seed_entity_count=_median_or_none(group, "seed_entity_count"), median_seed_validation_candidate_count=_median_or_none(group, "seed_validation_candidate_count"), median_estimated_seed_validation_chunk_count=_median_or_none(group, "estimated_seed_validation_chunk_count"), @@ -1317,6 +1410,18 @@ def _sum_bool_or_zero(dataframe: pd.DataFrame, column: str) -> int: return int(dataframe[column].fillna(False).astype(bool).sum()) +def _group_evaluation_metrics(group: pd.DataFrame) -> dict[str, int | float | None]: + metrics: dict[str, int | float | None] = {} + for rollup in _EVALUATION_ROLLUPS: + judged_count = _sum_int_or_zero(group, f"{rollup.prefix}_judged_record_count") + valid_count = _sum_int_or_zero(group, f"{rollup.prefix}_valid_record_count") + metrics[f"sum_{rollup.prefix}_judged_record_count"] = judged_count + metrics[f"sum_{rollup.prefix}_valid_record_count"] = valid_count + metrics[f"micro_{rollup.prefix}_valid_rate"] = _safe_ratio(valid_count, judged_count) + metrics[f"sum_{rollup.invalid_count_column}"] = _sum_int_or_zero(group, rollup.invalid_count_column) + return metrics + + def _sum_int_or_none(dataframe: pd.DataFrame, column: str) -> int | None: value = _sum_or_none(dataframe, column) return int(value) if value is not None else None