diff --git a/src/anonymizer/engine/detection/detection_workflow.py b/src/anonymizer/engine/detection/detection_workflow.py index d59246b7..1c762844 100644 --- a/src/anonymizer/engine/detection/detection_workflow.py +++ b/src/anonymizer/engine/detection/detection_workflow.py @@ -117,6 +117,7 @@ def detect_and_validate_entities( gliner_detection_threshold=gliner_detection_threshold, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, entity_labels=entity_labels, data_summary=data_summary, ) @@ -138,6 +139,7 @@ def _build_detection_spec( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, data_summary: str | None = None, ) -> tuple[list[ModelConfig], list[ColumnConfigT]]: @@ -245,6 +247,7 @@ def build_detection_config( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, data_summary: str | None = None, ) -> DataDesignerConfigBuilder: @@ -259,6 +262,7 @@ def build_detection_config( gliner_detection_threshold=gliner_detection_threshold, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, entity_labels=entity_labels, data_summary=data_summary, ) @@ -278,6 +282,7 @@ def build_detection_builder_for_seed( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, data_summary: str | None = None, job_index: int = 0, @@ -298,6 +303,7 @@ def build_detection_builder_for_seed( gliner_detection_threshold=gliner_detection_threshold, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, entity_labels=entity_labels, data_summary=data_summary, ) @@ -362,6 +368,7 @@ def run( gliner_detection_threshold: float, validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL, validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS, + validation_single_chunk_full_text: bool = True, entity_labels: list[str] | None = None, privacy_goal: PrivacyGoal | None = None, data_summary: str | None = None, @@ -391,6 +398,7 @@ def run( gliner_detection_threshold=gliner_detection_threshold, validation_max_entities_per_call=validation_max_entities_per_call, validation_excerpt_window_chars=validation_excerpt_window_chars, + validation_single_chunk_full_text=validation_single_chunk_full_text, entity_labels=entity_labels, data_summary=data_summary, preview_num_records=preview_num_records, diff --git a/tests/engine/test_detection_workflow.py b/tests/engine/test_detection_workflow.py index 3f45f1b1..27dfbe8c 100644 --- a/tests/engine/test_detection_workflow.py +++ b/tests/engine/test_detection_workflow.py @@ -543,6 +543,92 @@ def test_validator_pool_kwargs_thread_through_to_generator_params( assert params.excerpt_window_chars == 42 +def test_validation_single_chunk_full_text_threads_to_generator_params( + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + adapter = Mock() + adapter.run_workflow.return_value = WorkflowRunResult( + dataframe=pd.DataFrame( + { + COL_TEXT: ["Alice"], + COL_DETECTED_ENTITIES: [{"entities": [{"value": "Alice", "label": "first_name"}]}], + } + ), + failed_records=[], + ) + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.run( + pd.DataFrame({COL_TEXT: ["Alice"]}), + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + validation_single_chunk_full_text=False, + tag_latent_entities=False, + ) + columns = adapter.run_workflow.call_args.kwargs["columns"] + params = _find_column(columns, COL_VALIDATION_DECISIONS).generator_params + assert params.single_chunk_full_text is False + + +def test_build_detection_config_threads_validation_single_chunk_full_text( + tmp_path, + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + adapter = Mock() + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.build_detection_config( + pd.DataFrame({COL_TEXT: ["Alice"]}), + seed_path=tmp_path / "seed.parquet", + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + validation_single_chunk_full_text=False, + ) + columns = adapter.build_config.call_args.kwargs["columns"] + params = _find_column(columns, COL_VALIDATION_DECISIONS).generator_params + assert params.single_chunk_full_text is False + + +def test_build_detection_builder_for_seed_threads_validation_single_chunk_full_text( + tmp_path, + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + adapter = Mock() + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.build_detection_builder_for_seed( + seed_path=tmp_path / "seed.parquet", + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + validation_single_chunk_full_text=False, + ) + columns = adapter.build_config_for_seed.call_args.kwargs["columns"] + params = _find_column(columns, COL_VALIDATION_DECISIONS).generator_params + assert params.single_chunk_full_text is False + + +def test_build_detection_config_uses_default_validation_single_chunk_full_text( + tmp_path, + stub_detector_model_configs: list[ModelConfig], + stub_detection_model_selection: DetectionModelSelection, +) -> None: + adapter = Mock() + workflow = EntityDetectionWorkflow(adapter=adapter) + workflow.build_detection_config( + pd.DataFrame({COL_TEXT: ["Alice"]}), + seed_path=tmp_path / "seed.parquet", + model_configs=stub_detector_model_configs, + selected_models=stub_detection_model_selection, + gliner_detection_threshold=0.5, + ) + columns = adapter.build_config.call_args.kwargs["columns"] + params = _find_column(columns, COL_VALIDATION_DECISIONS).generator_params + assert params.single_chunk_full_text is True + + def test_pool_size_greater_than_one_emits_warning( stub_detector_model_configs: list[ModelConfig], stub_detection_model_selection: DetectionModelSelection,