Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 229 additions & 53 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,171 @@ def __post_init__(self):
)


@dataclass
class RejectionSamplingConfig:
"""Unified configuration for sample filtering based on policy divergence.

Filters tokens/sequences where the divergence between proximal policy
and behavior policy exceeds a threshold, via two action modes:
- 'mask': zero out loss_mask (rejection, exclude from gradient)
- 'clamp': clamp importance weight to bounds (truncation, bounded gradient)

Supports direct ratio bounds and KL divergence estimators (K1/K2/K3),
at both token-level and sequence-level granularity.

Replaces the removed ``behave_imp_weight_cap`` and ``behave_imp_weight_mode``.

Attributes:
level: Filtering granularity ('token' or 'sequence'). When ``level='sequence'``
and ``metric='ratio'``, both the filtering decision and the correction
weight (behave_imp_weight) use the sequence-level geometric mean,
matching the old ``sequence_mask``/``sequence_truncate`` semantics.
action: Action mode ('mask' or 'clamp').
metric: Divergence metric ('ratio', 'kl_k1', 'kl_k2', 'kl_k3').
agg: Aggregation method for sequence-level ('sum', 'mean', 'max').
For 'ratio' metric, aggregation is performed in log space (geometric
mean/sum) to avoid the "length trap" and match GSPO semantics.
For KL metrics, aggregation is arithmetic.
upper: Upper bound for filtering.
lower: Lower bound for filtering (optional).
"""

level: str = field(
default="token",
metadata={
"help": "Filtering granularity. "
"'token': per-token filtering (each token judged independently). "
"'sequence': per-sequence filtering (all tokens in a sequence share the same fate). "
"When metric='ratio', both the filtering decision and the correction weight "
"(behave_imp_weight) operate at sequence level using the geometric mean.",
"choices": ["token", "sequence"],
},
)
action: str = field(
default="mask",
metadata={
"help": "Action to take when metric exceeds threshold. "
"'mask': zero out loss_mask for filtered tokens/sequences (rejection, "
"completely excludes from gradient computation). "
"'clamp': clamp importance weight to [lower, upper] bounds (truncation, "
"tokens still participate in gradient but with bounded weight).",
"choices": ["mask", "clamp"],
},
)
metric: str = field(
default="ratio",
metadata={
"help": "Divergence metric for filtering. "
"'ratio': direct importance ratio π_proximal/π_behave. "
"'kl_k1': KL estimator k1 = log(r), forward KL unbiased estimator (can be negative). "
"'kl_k2': KL estimator k2 = 0.5 * (log r)^2, non-negative quadratic approximation. "
"'kl_k3': KL estimator k3 = r - log(r) - 1, non-negative exact forward KL estimator.",
"choices": ["ratio", "kl_k1", "kl_k2", "kl_k3"],
},
)
agg: str = field(
default="mean",
metadata={
"help": "Aggregation method for sequence-level filtering. "
"Only used when level='sequence'. "
"For 'ratio' metric, aggregation is in log space: "
"'sum' = exp(sum(log(r_i))), 'mean' = exp(mean(log(r_i))) = geometric mean "
"(length-invariant, consistent with GSPO). "
"For KL metrics, aggregation is arithmetic: "
"'sum' = sum(kl_i), 'mean' = mean(kl_i). "
"'max': max of per-token metric values (most conservative).",
"choices": ["sum", "mean", "max"],
},
)
upper: float = field(
default=5.0,
metadata={
"help": "Upper bound for filtering. "
"Tokens/sequences with metric > upper are filtered out (loss_mask zeroed). "
"For 'ratio' metric: must be > 1.0, typical values are 2.0 or 5.0. "
"For 'kl_k2'/'kl_k3' metrics: typical values are 0.5-2.0."
},
)
lower: float | None = field(
default=None,
metadata={
"help": "Lower bound for filtering (optional). "
"None means no lower bound. "
"For 'ratio' metric: typical value is 0.5 (filter out tokens where policy "
"probability dropped significantly). Must be > 0. "
"For 'kl_k1' metric: can be used to filter negative KL estimates."
},
)

def __post_init__(self):
"""Validate configuration."""
import warnings

_VALID_LEVELS = ("token", "sequence")
_VALID_ACTIONS = ("mask", "clamp")
_VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3")
_VALID_AGGS = ("sum", "mean", "max")

# Validate enum-like fields.
if self.level not in _VALID_LEVELS:
raise ValueError(
f"level must be one of {_VALID_LEVELS}, got '{self.level}'"
)
if self.action not in _VALID_ACTIONS:
raise ValueError(
f"action must be one of {_VALID_ACTIONS}, got '{self.action}'"
)
if self.metric not in _VALID_METRICS:
raise ValueError(
f"metric must be one of {_VALID_METRICS}, got '{self.metric}'"
)
if self.agg not in _VALID_AGGS:
raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'")

# Validate lower <= upper when both are set.
if self.lower is not None and self.lower > self.upper:
raise ValueError(
f"lower ({self.lower}) cannot be greater than upper ({self.upper})"
)

# For ratio metric, upper must be > 1.0 (otherwise all non-identical policy tokens are filtered).
if self.metric == "ratio":
if self.upper <= 1.0:
raise ValueError(
f"upper must be > 1.0 for 'ratio' metric (otherwise all non-identical "
f"policy tokens will be filtered), got {self.upper}"
)
if self.lower is not None and self.lower <= 0:
raise ValueError(
f"lower must be positive for 'ratio' metric, got {self.lower}"
)
# For KL metrics, upper must be > 0.
# Note: kl_k1 is excluded because it is a forward KL unbiased estimator that
# can produce negative values, so requiring upper > 0 would be too restrictive.
if self.metric in ("kl_k2", "kl_k3") and self.upper <= 0:
raise ValueError(
f"upper must be positive for '{self.metric}' metric, got {self.upper}"
)
# Clamp action only supports ratio metric (direct importance weight truncation).
if self.action == "clamp" and self.metric != "ratio":
raise ValueError(
f"action='clamp' only supports metric='ratio' (direct importance weight "
f"truncation). Got metric='{self.metric}'. "
f"Use action='mask' for KL-based filtering."
)
# Clamp action defaults lower to 0.0 (consistent with old truncate behavior).
if self.action == "clamp" and self.lower is None:
self.lower = 0.0
# Validate sequence-level aggregation.
if self.level == "token" and self.agg != "mean":
warnings.warn(
f"agg='{self.agg}' is ignored when level='token'. "
"Aggregation is only used for sequence-level filtering.",
UserWarning,
stacklevel=2,
)


@dataclass
class PPOActorConfig(TrainEngineConfig):
"""Configuration for PPO actor model, a subclass of a TrainEngine."""
Expand Down Expand Up @@ -1294,32 +1459,12 @@ class PPOActorConfig(TrainEngineConfig):
"help": "Use the decoupled loss. Implicitly enables recompute_logprob."
},
)
behave_imp_weight_cap: float | None = field(
default=5.0,
metadata={
"help": "Filter out tokens/sequences where behave_imp_weight exceeds this cap when computing loss. "
"Only effective when use_decoupled_loss=True (decoupled/async training). "
"Must be > 1.0 when mode is not 'disabled'. "
"Mode controlled by behave_imp_weight_mode (mask/truncate/disabled)."
},
)
behave_imp_weight_mode: str = field(
default="token_mask",
rejection_sampling: RejectionSamplingConfig | None = field(
default=None,
metadata={
"help": "Mode for importance weight filtering. "
"Only effective when use_decoupled_loss=True (decoupled/async training). "
"'token_truncate': clamp token ratio to [0, cap]. "
"'token_mask': set token ratio to 0 where ratio > cap. "
"'sequence_truncate': clamp sequence ratio to [0, cap]. "
"'sequence_mask': set sequence ratio to 0 where ratio > cap. "
"'disabled': disable importance weight correction.",
"choices": [
"token_truncate",
"token_mask",
"sequence_truncate",
"sequence_mask",
"disabled",
],
"help": "Rejection sampling configuration for filtering stale samples. "
"None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). "
"Only effective when use_decoupled_loss=True."
},
)
importance_sampling_level: str = field(
Expand Down Expand Up @@ -1372,34 +1517,28 @@ def should_compute_prox_logp(self) -> bool:

def __post_init__(self):
"""Validate PPO actor configuration."""
# Validate MIS/TIS configuration
if self.behave_imp_weight_mode == "disabled":
if self.behave_imp_weight_cap is not None:
raise ValueError(
f"behave_imp_weight_cap must be None when behave_imp_weight_mode is 'disabled', "
f"got {self.behave_imp_weight_cap}."
)
else:
if (
self.behave_imp_weight_cap is not None
and self.behave_imp_weight_cap <= 1.0
):
raise ValueError(
f"behave_imp_weight_cap must be > 1.0 when behave_imp_weight_mode is not 'disabled', "
f"got {self.behave_imp_weight_cap}."
)

# Warn if behave_imp_weight settings are configured but use_decoupled_loss is False
if not self.use_decoupled_loss:
if (
self.behave_imp_weight_cap is not None
or self.behave_imp_weight_mode != "disabled"
):
logger.warning(
"behave_imp_weight_cap and behave_imp_weight_mode are configured but "
"use_decoupled_loss=False. These settings will be ignored. "
"Set use_decoupled_loss=True to enable decoupled loss with importance weight correction."
)
# Warn if rejection_sampling is configured but use_decoupled_loss is False
if not self.use_decoupled_loss and self.rejection_sampling is not None:
logger.warning(
"rejection_sampling is configured but use_decoupled_loss=False. "
"Filtering will be ignored. Set use_decoupled_loss=True to enable."
)
# Warn if decoupled loss is enabled but no rejection sampling configured.
# The old default (behave_imp_weight_cap=5.0, mode=token_mask) enabled
# filtering implicitly; the new default (rejection_sampling=None) disables
# it. This warning helps users who relied on the old defaults.
if self.use_decoupled_loss and self.rejection_sampling is None:
logger.warning(
"use_decoupled_loss=True with rejection_sampling=None: "
"staleness filtering is disabled. If you previously relied on "
"the default behave_imp_weight_cap=5.0 with token_mask mode, "
"restore equivalent behavior with:\n"
" rejection_sampling:\n"
" level: token\n"
" action: mask\n"
" metric: ratio\n"
" upper: 5.0"
)

# Validate SAPO configuration
if self.use_sapo_loss:
Expand Down Expand Up @@ -2562,7 +2701,44 @@ def parse_cli_args(argv: list[str]):
return cfg, config_file


_LEGACY_REJECTION_SAMPLING_KEYS = {
"behave_imp_weight_cap",
"behave_imp_weight_mode",
}

_LEGACY_MIGRATION_MESSAGE = (
"Config keys 'behave_imp_weight_cap' and 'behave_imp_weight_mode' have been "
"removed. Use 'rejection_sampling' sub-config instead.\n"
"Migration mapping:\n"
" behave_imp_weight_mode='disabled' -> rejection_sampling: null\n"
" behave_imp_weight_mode='token_mask', behave_imp_weight_cap=X\n"
" -> rejection_sampling: {level: token, action: mask, metric: ratio, upper: X}\n"
" behave_imp_weight_mode='token_truncate', behave_imp_weight_cap=X\n"
" -> rejection_sampling: {level: token, action: clamp, metric: ratio, upper: X}\n"
)


def _migrate_legacy_rejection_sampling(cfg: DictConfig) -> DictConfig:
"""Intercept removed behave_imp_weight_* keys and raise actionable error."""
# Walk top-level and known nested actor/teacher configs for legacy keys.
sections_to_check = ["actor", "teacher"]
for section in sections_to_check:
if not OmegaConf.is_missing(cfg, section) and section in cfg:
sub = cfg[section]
if sub is None or not isinstance(sub, DictConfig):
continue
found = _LEGACY_REJECTION_SAMPLING_KEYS.intersection(sub.keys())
if found:
raise ValueError(
f"Found removed config key(s) {found} under '{section}'. "
+ _LEGACY_MIGRATION_MESSAGE
)
return cfg


def to_structured_cfg(cfg, config_cls):
# Intercept legacy config keys before merge to give actionable error.
_migrate_legacy_rejection_sampling(cfg)
# Merge with the default configuration.
# The yaml and commandline can omit some default values defined in python dataclasses.
default_cfg = OmegaConf.structured(config_cls)
Expand Down
17 changes: 13 additions & 4 deletions areal/experimental/inference_service/controller/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from dataclasses import dataclass, field

from areal.api.cli_args import OpenAIProxyConfig


@dataclass
class GatewayControllerConfig:
Expand All @@ -20,6 +18,7 @@ class GatewayControllerConfig:
# -- Model / tokenizer -------------------------------------------------
tokenizer_path: str = ""
model_path: str = ""
model: str = "default"

# -- Routing -----------------------------------------------------------
routing_strategy: str = "round_robin"
Expand Down Expand Up @@ -53,5 +52,15 @@ class GatewayControllerConfig:
pause_grace_period: float = 0.5
n_gpus_per_node: int | None = None # GPUs per physical node; None = single-node

# -- OpenAI proxy configuration (for agent-like workflows) ---------------
openai: OpenAIProxyConfig = field(default_factory=lambda: OpenAIProxyConfig())
# -- Admin / workflow --------------------------------------------------
admin_api_key: str | None = None
turn_discount: float = 1.0
export_style: str = "individual"
tool_call_parser: str = "qwen"
reasoning_parser: str = "qwen3"
engine_max_tokens: int | None = None
chat_template_type: str = "hf"

# -- External model API ------------------------------------------------
api_url: str | None = None
provider_api_key: str | None = None
Loading
Loading