Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/anonymizer/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Distributed-executor entrypoint for running the detection workflow on a SLURM cluster.

For running detection at scale, an external DataDesigner runtime (e.g. a SLURM
orchestrator) provisions the model servers, partitions the dataset across workers, and
runs the workflow. Such runtimes usually ship a *serialized* config to each worker and
rebuild it with ``from_config`` — but the detection workflow can't go through that path:
it uses ``CustomColumnConfig`` columns whose ``generator_function`` is a live Python
callable (DataDesigner custom columns are "library only"), which do not survive JSON
serialization.

This module is the alternative: a factory the runtime imports and calls **in-process on
each worker** to get the live ``DataDesignerConfigBuilder`` (callables intact). The custom
columns reference their LLM by *alias* and receive model facades injected by the
DataDesigner runtime, so the runtime's provider wiring (alias → provisioned server) still
routes their calls correctly. The seed parquet is read from the path the runtime provides
(not rewritten — workers may share it), and ``num_jobs > 1`` selects this worker's ordered
partition.

The runtime calls:
build_detection_builder(seed_path=..., job_index=..., num_jobs=..., spec={...})
where ``spec`` is the JSON-serializable detection spec produced by the submitting side.
Requires ``nemo-anonymizer`` installed in the worker environment.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from data_designer.config.config_builder import DataDesignerConfigBuilder

# Placeholder provider endpoint; the distributed runtime overrides providers at run time
# (the workflow is only *built* here, never run against this endpoint).
_PLACEHOLDER_ENDPOINT = "http://overridden-by-runtime:8000/v1"


def build_detection_builder(
*,
seed_path: str,
spec: dict[str, Any],
job_index: int = 0,
num_jobs: int = 1,
) -> DataDesignerConfigBuilder:
"""Return the live detection ``DataDesignerConfigBuilder`` for one distributed worker.

Args:
seed_path: Path to the seed parquet the runtime placed on this worker (read, not
written). Record ids are assumed already attached by the submitting side.
spec: JSON-serializable detection spec with keys:
``model_configs_yaml`` (str): the Anonymizer model_configs YAML (selected_models
+ model_configs aliases) — the alias ``model`` ids must match the served
model names the runtime provisions, so its provider wiring can map them.
``provider_names`` (list[str]): provider names referenced by the YAML; placeholder
``ModelProvider``s are created for them (the runtime supplies the real ones).
``detect`` (dict): ``gliner_threshold`` (float) and optional ``entity_labels``
(list[str] | None).
``data_summary`` (str | None): optional dataset description for prompts.
job_index: index of this worker's ordered partition of the seed.
num_jobs: total number of partitions the seed is split across.
"""
from anonymizer import Anonymizer, AnonymizerConfig, ModelProvider, Redact # noqa: PLC0415
from anonymizer.config.anonymizer_config import Detect # noqa: PLC0415

providers = [
ModelProvider(name=name, endpoint=_PLACEHOLDER_ENDPOINT, provider_type="openai", api_key="EMPTY")
for name in spec["provider_names"]
]
anonymizer = Anonymizer(model_configs=spec["model_configs_yaml"], model_providers=providers)

if "detect" not in spec:
raise KeyError("spec must include required 'detect' section")
detect = spec["detect"]
detect_kwargs: dict[str, Any] = {"gliner_threshold": detect["gliner_threshold"]}
if detect.get("entity_labels") is not None:
detect_kwargs["entity_labels"] = detect["entity_labels"]
config = AnonymizerConfig(detect=Detect(**detect_kwargs), replace=Redact())

return anonymizer.export_detection_builder_for_seed(
config=config,
seed_path=seed_path,
job_index=job_index,
num_jobs=num_jobs,
data_summary=spec.get("data_summary"),
)
206 changes: 156 additions & 50 deletions src/anonymizer/engine/detection/detection_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import logging
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from data_designer.config.column_configs import CustomColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig
from data_designer.config.column_types import ColumnConfigT
from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.models import ModelConfig

from anonymizer.config.anonymizer_config import Detect as AnonymizerDetectConfig
Expand Down Expand Up @@ -106,6 +109,42 @@ def detect_and_validate_entities(
have missed, and produces final standoff entity spans with overlap
resolution.
"""
workflow_model_configs, columns = self._build_detection_spec(
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
validation_max_entities_per_call=validation_max_entities_per_call,
validation_excerpt_window_chars=validation_excerpt_window_chars,
entity_labels=entity_labels,
data_summary=data_summary,
)
detection_result = self._adapter.run_workflow(
dataframe,
model_configs=workflow_model_configs,
columns=columns,
workflow_name="entity-detection",
preview_num_records=preview_num_records,
)
detected_df = detection_result.dataframe.copy()
return EntityDetectionResult(dataframe=detected_df, failed_records=detection_result.failed_records)

def _build_detection_spec(
self,
*,
model_configs: list[ModelConfig],
selected_models: DetectionModelSelection,
gliner_detection_threshold: float,
validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL,
validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS,
entity_labels: list[str] | None = None,
data_summary: str | None = None,
) -> tuple[list[ModelConfig], list[ColumnConfigT]]:
"""Build the (model_configs, columns) for the core detection workflow.

Shared by :meth:`detect_and_validate_entities` (which executes it in-process)
and :meth:`build_detection_config` (which exports it for an external runtime),
so both paths run exactly the same workflow.
"""
labels = _resolve_detection_labels(entity_labels)
workflow_model_configs = self._inject_detector_params(
model_configs=model_configs,
Expand Down Expand Up @@ -146,59 +185,126 @@ def detect_and_validate_entities(
prompt_template=_get_validation_prompt(data_summary=data_summary, labels=labels),
)

detection_result = self._adapter.run_workflow(
columns: list[ColumnConfigT] = [
LLMTextColumnConfig(
name=COL_RAW_DETECTED,
prompt=_jinja(COL_TEXT),
model_alias=detection_alias,
),
CustomColumnConfig(
name=COL_SEED_ENTITIES,
generator_function=parse_detected_entities,
),
CustomColumnConfig(
name=COL_SEED_VALIDATION_CANDIDATES,
generator_function=prepare_validation_inputs,
),
CustomColumnConfig(
name=COL_VALIDATION_DECISIONS,
generator_function=validator_generator,
generator_params=validator_params,
drop=True,
),
CustomColumnConfig(
name=COL_VALIDATED_ENTITIES,
generator_function=enrich_validation_decisions,
),
CustomColumnConfig(
name=COL_SEED_ENTITIES_JSON,
generator_function=apply_validation_to_seed_entities,
),
LLMStructuredColumnConfig(
name=COL_AUGMENTED_ENTITIES,
prompt=_get_augment_prompt(
data_summary=data_summary, labels=labels, strict_labels=entity_labels is not None
),
model_alias=augmenter_alias,
output_format=AugmentedEntitiesSchema,
),
CustomColumnConfig(
name=COL_MERGED_ENTITIES,
generator_function=merge_and_build_candidates,
),
CustomColumnConfig(
name=COL_DETECTED_ENTITIES,
generator_function=apply_validation_and_finalize,
),
]
return workflow_model_configs, columns

def build_detection_config(
self,
dataframe: pd.DataFrame,
*,
seed_path: str | Path,
model_configs: list[ModelConfig],
selected_models: DetectionModelSelection,
gliner_detection_threshold: float,
validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL,
validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS,
entity_labels: list[str] | None = None,
data_summary: str | None = None,
) -> DataDesignerConfigBuilder:
"""Build (without executing) the core detection workflow as a DataDesigner
config, for an external at-scale executor to run. Produces the same columns
as :meth:`detect_and_validate_entities` (culminating in final entities); the
external runtime supplies the model providers and the seed dataset.
"""
workflow_model_configs, columns = self._build_detection_spec(
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
validation_max_entities_per_call=validation_max_entities_per_call,
validation_excerpt_window_chars=validation_excerpt_window_chars,
entity_labels=entity_labels,
data_summary=data_summary,
)
return self._adapter.build_config(
dataframe,
model_configs=workflow_model_configs,
columns=[
LLMTextColumnConfig(
name=COL_RAW_DETECTED,
prompt=_jinja(COL_TEXT),
model_alias=detection_alias,
),
CustomColumnConfig(
name=COL_SEED_ENTITIES,
generator_function=parse_detected_entities,
),
CustomColumnConfig(
name=COL_SEED_VALIDATION_CANDIDATES,
generator_function=prepare_validation_inputs,
),
CustomColumnConfig(
name=COL_VALIDATION_DECISIONS,
generator_function=validator_generator,
generator_params=validator_params,
drop=True,
),
CustomColumnConfig(
name=COL_VALIDATED_ENTITIES,
generator_function=enrich_validation_decisions,
),
CustomColumnConfig(
name=COL_SEED_ENTITIES_JSON,
generator_function=apply_validation_to_seed_entities,
),
LLMStructuredColumnConfig(
name=COL_AUGMENTED_ENTITIES,
prompt=_get_augment_prompt(
data_summary=data_summary, labels=labels, strict_labels=entity_labels is not None
),
model_alias=augmenter_alias,
output_format=AugmentedEntitiesSchema,
),
CustomColumnConfig(
name=COL_MERGED_ENTITIES,
generator_function=merge_and_build_candidates,
),
CustomColumnConfig(
name=COL_DETECTED_ENTITIES,
generator_function=apply_validation_and_finalize,
),
],
workflow_name="entity-detection",
preview_num_records=preview_num_records,
columns=columns,
seed_path=seed_path,
)

def build_detection_builder_for_seed(
self,
*,
seed_path: str | Path,
model_configs: list[ModelConfig],
selected_models: DetectionModelSelection,
gliner_detection_threshold: float,
validation_max_entities_per_call: int = _DEFAULT_VALIDATION_MAX_ENTITIES_PER_CALL,
validation_excerpt_window_chars: int = _DEFAULT_VALIDATION_EXCERPT_WINDOW_CHARS,
entity_labels: list[str] | None = None,
data_summary: str | None = None,
job_index: int = 0,
num_jobs: int = 1,
) -> DataDesignerConfigBuilder:
"""Build the detection workflow reading an EXISTING seed parquet (no write).

Same columns as :meth:`build_detection_config`, but the seed dataset points at an
already-written ``seed_path`` and optionally selects this worker's ordered
partition (``job_index`` of ``num_jobs``). For a distributed executor (e.g. a SLURM
orchestrator) that builds this workflow *in-process on the worker* — the
custom-column callables stay live (they can't survive JSON serialization) and the
model aliases are resolved by the runtime's providers.
"""
workflow_model_configs, columns = self._build_detection_spec(
model_configs=model_configs,
selected_models=selected_models,
gliner_detection_threshold=gliner_detection_threshold,
validation_max_entities_per_call=validation_max_entities_per_call,
validation_excerpt_window_chars=validation_excerpt_window_chars,
entity_labels=entity_labels,
data_summary=data_summary,
)
return self._adapter.build_config_for_seed(
model_configs=workflow_model_configs,
columns=columns,
seed_path=seed_path,
job_index=job_index,
num_jobs=num_jobs,
)
detected_df = detection_result.dataframe.copy()
return EntityDetectionResult(dataframe=detected_df, failed_records=detection_result.failed_records)

def identify_latent_entities(
self,
Expand Down
71 changes: 71 additions & 0 deletions src/anonymizer/engine/ndd/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,77 @@ def run_workflow(
)
return WorkflowRunResult(dataframe=output_df, failed_records=failed_records)

def build_config(
self,
df: pd.DataFrame,
*,
model_configs: list[ModelConfig],
columns: list[ColumnConfigT],
seed_path: str | Path,
) -> DataDesignerConfigBuilder:
"""Assemble (but do NOT execute) the DataDesigner config for a workflow.

Writes the record-id-tagged input to ``seed_path`` as the seed dataset and
returns the assembled ``DataDesignerConfigBuilder`` for an *external* executor
(e.g. an at-scale SLURM orchestrator) to run. This mirrors the config assembly
in :meth:`run_workflow` without the ``DataDesigner.create()/.preview()`` call,
so callers can hand the same workflow to a different DataDesigner runtime.

Args:
df: Input DataFrame.
model_configs: NDD model aliases available to the workflow.
columns: NDD column configs to add to the workflow.
seed_path: Destination parquet path for the seed dataset (persisted; the
caller owns its lifetime, unlike ``run_workflow``'s tempdir).

Returns:
The assembled ``DataDesignerConfigBuilder`` (seed dataset + columns added).
"""
workflow_input_df = self._attach_record_ids(df=df)
seed_source = LocalFileSeedSource.from_dataframe(workflow_input_df, str(seed_path))
config_builder = DataDesignerConfigBuilder(model_configs=model_configs)
config_builder.with_seed_dataset(seed_source, sampling_strategy=SamplingStrategy.ORDERED)
for column in columns:
config_builder.add_column(column)
return config_builder

def build_config_for_seed(
self,
*,
model_configs: list[ModelConfig],
columns: list[ColumnConfigT],
seed_path: str | Path,
job_index: int = 0,
num_jobs: int = 1,
) -> DataDesignerConfigBuilder:
"""Assemble the workflow config reading an EXISTING seed parquet (no write).

Like :meth:`build_config` but the seed dataset points at an already-written
``seed_path`` (record ids assumed already attached) instead of materializing a
DataFrame. Use this on a distributed worker that received the seed from an
orchestrator and must NOT rewrite the shared file. ``num_jobs > 1`` selects this
worker's ordered partition (``job_index`` of ``num_jobs``), matching how the
orchestrator shards the seed.
"""
from data_designer.config.seed import PartitionBlock # noqa: PLC0415

if num_jobs < 1:
raise ValueError(f"num_jobs must be >= 1, got {num_jobs}")
if not (0 <= job_index < num_jobs):
raise ValueError(f"job_index must be in [0, num_jobs), got job_index={job_index}, num_jobs={num_jobs}")

config_builder = DataDesignerConfigBuilder(model_configs=model_configs)
seed_source = LocalFileSeedSource(path=str(seed_path))
selection = PartitionBlock(index=job_index, num_partitions=num_jobs) if num_jobs > 1 else None
Comment thread
mvansegbroeck marked this conversation as resolved.
config_builder.with_seed_dataset(
seed_source,
sampling_strategy=SamplingStrategy.ORDERED,
selection_strategy=selection,
)
for column in columns:
config_builder.add_column(column)
return config_builder

def _attach_record_ids(self, df: pd.DataFrame) -> pd.DataFrame:
if RECORD_ID_COLUMN in df.columns:
return df.copy()
Expand Down
Loading
Loading