Skip to content

feat(api): add unified RejectionSamplingConfig for async training#1088

Open
guozhihao-224 wants to merge 2 commits intoinclusionAI:mainfrom
guozhihao-224:feat/reject_sample
Open

feat(api): add unified RejectionSamplingConfig for async training#1088
guozhihao-224 wants to merge 2 commits intoinclusionAI:mainfrom
guozhihao-224:feat/reject_sample

Conversation

@guozhihao-224
Copy link
Copy Markdown
Contributor

Description

Replace behave_imp_weight_cap/behave_imp_weight_mode with unified RejectionSamplingConfig supporting multiple metrics (ratio, kl_k1, kl_k2, kl_k3), levels
(token/sequence), and actions (mask/clamp).

This provides a unified and principled approach for filtering stale samples in async RL training, addressing staleness issues caused by version gaps between behavior
policy and current policy.

Key changes:

  • Add RejectionSamplingConfig dataclass with comprehensive validation
  • Implement apply_rejection_sampling supporting both 1D packed and 2D padded formats
  • Update ppo_actor_loss_fn and grpo_loss_fn to use new config
  • Migrate 40 example configs to new rejection_sampling field
  • Add comprehensive unit tests (40 test cases)

Recommended config for async training:

actor:
  rejection_sampling:
    level: sequence
    agg: mean
    metric: kl_k2
    upper: 1.0

Related Issue

Fixes #1052

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the ../CONTRIBUTING.md
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

Users need to migrate from old config:

  • behave_imp_weight_cap=5.0, mode=token_mask → rejection_sampling: {metric: ratio, upper: 5.0}
  • behave_imp_weight_mode=disabled → rejection_sampling: null

Additional Context

Test Results:

  • ✅ 40/40 rejection sampling unit tests passed
  • ✅ All related functional tests passed
  • ✅ Pre-commit hooks passed (ruff format/lint)

Files changed:

  • areal/api/cli_args.py: Add RejectionSamplingConfig dataclass
  • areal/utils/functional/functional.py: Implement apply_rejection_sampling
  • areal/trainer/ppo/actor.py: Update loss functions
  • docs/en/cli_reference.md, docs/zh/cli_reference.md: Update CLI docs
  • examples/**/*.yaml: Migrate 40 configs to new format
  • tests/test_rejection_sampling.py: Add 40 test cases

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the sample filtering mechanism within asynchronous Reinforcement Learning (RL) training. It transitions from a less flexible importance weight capping system to a new, unified RejectionSamplingConfig. This change provides a more principled and versatile approach to manage stale samples, which is crucial for maintaining training stability and performance in decoupled RL setups. The new configuration allows for fine-grained control over how samples are filtered or clamped based on policy divergence, supporting various metrics and granularities.

Highlights

  • Unified Rejection Sampling Configuration: Introduced a new RejectionSamplingConfig dataclass to provide a unified and principled approach for filtering stale samples in asynchronous RL training.
  • Deprecation of Legacy Fields: Replaced the older behave_imp_weight_cap and behave_imp_weight_mode with the comprehensive RejectionSamplingConfig.
  • Enhanced Filtering Capabilities: Implemented apply_rejection_sampling supporting multiple metrics (ratio, KL divergence), levels (token/sequence), and actions (mask/clamp) for policy divergence.
  • Loss Function Integration: Updated ppo_actor_loss_fn and grpo_loss_fn to utilize the new RejectionSamplingConfig for improved sample handling.
  • Configuration Migration: Migrated 40 example configurations to adopt the new rejection_sampling field, ensuring consistency across the codebase.
  • Comprehensive Testing: Added extensive unit tests (40 test cases) for the new rejection sampling functionality and its configuration validation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the sample filtering mechanism by introducing a new RejectionSamplingConfig dataclass, replacing the older behave_imp_weight_cap and behave_imp_weight_mode parameters. This new configuration unifies token-level and sequence-level filtering with options for masking or clamping importance weights based on various divergence metrics (ratio, KL estimators) and aggregation methods. The core logic in areal/utils/functional/functional.py is updated with a new apply_rejection_sampling function, and PPOActorConfig is modified to use this new structure. Additionally, migration logic for legacy configurations, logging, documentation, and example YAML files are updated. Review feedback suggests moving the import warnings statement to the top of areal/api/cli_args.py for PEP 8 compliance, refactoring validation logic in RejectionSamplingConfig.__post_init__ for better maintainability, and optimizing apply_rejection_sampling by avoiding redundant computations of loss_mask sums for improved efficiency.

Comment thread areal/api/cli_args.py

def __post_init__(self):
"""Validate configuration."""
import warnings
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per PEP 8, imports should generally be at the top of the file. Moving import warnings to the top of areal/api/cli_args.py would be more conventional. Local imports are sometimes used to avoid circular dependencies, but that does not seem to be the case here for the warnings module.

Comment thread areal/api/cli_args.py
Comment on lines +1152 to +1171
_VALID_LEVELS = ("token", "sequence")
_VALID_ACTIONS = ("mask", "clamp")
_VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3")
_VALID_AGGS = ("sum", "mean", "max")

# Validate enum-like fields.
if self.level not in _VALID_LEVELS:
raise ValueError(
f"level must be one of {_VALID_LEVELS}, got '{self.level}'"
)
if self.action not in _VALID_ACTIONS:
raise ValueError(
f"action must be one of {_VALID_ACTIONS}, got '{self.action}'"
)
if self.metric not in _VALID_METRICS:
raise ValueError(
f"metric must be one of {_VALID_METRICS}, got '{self.metric}'"
)
if self.agg not in _VALID_AGGS:
raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation constants are defined inside __post_init__ and the checks are repetitive. For better maintainability, you can refactor the validation logic into a loop. Also, consider moving the _VALID_* constants to be class-level attributes to follow Python conventions for constants.

Suggested change
_VALID_LEVELS = ("token", "sequence")
_VALID_ACTIONS = ("mask", "clamp")
_VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3")
_VALID_AGGS = ("sum", "mean", "max")
# Validate enum-like fields.
if self.level not in _VALID_LEVELS:
raise ValueError(
f"level must be one of {_VALID_LEVELS}, got '{self.level}'"
)
if self.action not in _VALID_ACTIONS:
raise ValueError(
f"action must be one of {_VALID_ACTIONS}, got '{self.action}'"
)
if self.metric not in _VALID_METRICS:
raise ValueError(
f"metric must be one of {_VALID_METRICS}, got '{self.metric}'"
)
if self.agg not in _VALID_AGGS:
raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'")
_VALID_LEVELS = ("token", "sequence")
_VALID_ACTIONS = ("mask", "clamp")
_VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3")
_VALID_AGGS = ("sum", "mean", "max")
# Validate enum-like fields.
validations = {
"level": _VALID_LEVELS,
"action": _VALID_ACTIONS,
"metric": _VALID_METRICS,
"agg": _VALID_AGGS,
}
for field_name, valid_values in validations.items():
value = getattr(self, field_name)
if value not in valid_values:
raise ValueError(
f"{field_name} must be one of {valid_values}, got '{value}'"
)

Comment on lines +295 to +298
raw_valid = torch.zeros(
batch_size, device=loss_mask.device, dtype=torch.int32
).scatter_add_(0, sequence_idx, loss_mask.int())
all_masked = raw_valid == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve efficiency, you can avoid re-computing the sum of loss_mask per sequence. The scatter_add_ operation is performed here to get raw_valid, but it was also done earlier (lines 268-271) to compute valid_count_per_seq.

Consider computing the raw counts once before the aggregation logic, and then reuse it to derive both the clamped counts for mean aggregation and for the all_masked check in max aggregation.

seq_metric = metric_for_max.max(dim=-1, keepdim=True)[0]
# All-masked sequences stay -inf; treat them as in-bounds (no valid
# tokens to filter, and their loss_mask is already all-zero).
all_masked = loss_mask.sum(dim=-1, keepdim=True) == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve efficiency, you can avoid re-computing loss_mask.sum(dim=-1, keepdim=True). This sum is calculated here and also earlier on line 325 to compute valid_count. You could compute the sum once, store it in a variable, and reuse it in both places.

@guozhihao-224
Copy link
Copy Markdown
Contributor Author

cc @garrett4wade @rchardx

Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review comment on loss denominator scaling.

Comment thread areal/utils/functional/functional.py
Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion on sequence-level ratio aggregation semantics.

Comment thread areal/utils/functional/functional.py Outdated
Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion on sequence-level behave_imp_weight semantics.

Comment thread areal/utils/functional/functional.py
Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion on default behavior change for decoupled PPO users.

Comment thread areal/api/cli_args.py
@guozhihao-224
Copy link
Copy Markdown
Contributor Author

@rchardx Thanks for the thorough review. All four issues have been addressed in the latest force-push (squashed into a single commit):

  1. Bug: mask mode changes loss denominator (functional.py:456)
    Moved loss_mask_count = loss_mask.count_nonzero() or 1 before rejection sampling, so the denominator stays N_original regardless of filtering. This keeps gradient scaling consistent with loss_weight_fn in actor.py.

  2. Arithmetic mean vs geometric mean for ratio aggregation (functional.py:267)
    Added _use_log_agg branch: when metric="ratio", aggregation now happens in log space (geometric mean), matching _compute_sequence_level_ratio_and_advantages (GSPO path). KL metrics remain arithmetic. Docstrings and help text updated accordingly.

  3. Per-token vs uniform behave_imp_weight at sequence level (functional.py:248)
    When level="sequence" and metric="ratio", behave_imp_weight is now the sequence-level geometric mean broadcast uniformly to all tokens — restoring the old sequence_mask/sequence_truncate semantics. For KL metrics, weight stays per-token (it's always the importance ratio, not the KL value). Added 3 new test cases validating both behaviors.

  4. Warn when decoupled loss enabled but no rejection sampling (cli_args.py:1379)
    Added symmetric logger.warning() when use_decoupled_loss=True and rejection_sampling is None, with migration config suggestion.

Ready for re-review.

@rchardx
Copy link
Copy Markdown
Collaborator

rchardx commented Apr 14, 2026

@rchardx Thanks for the thorough review. All four issues have been addressed in the latest force-push (squashed into a single commit):

  1. Bug: mask mode changes loss denominator (functional.py:456)
    Moved loss_mask_count = loss_mask.count_nonzero() or 1 before rejection sampling, so the denominator stays N_original regardless of filtering. This keeps gradient scaling consistent with loss_weight_fn in actor.py.
  2. Arithmetic mean vs geometric mean for ratio aggregation (functional.py:267)
    Added _use_log_agg branch: when metric="ratio", aggregation now happens in log space (geometric mean), matching _compute_sequence_level_ratio_and_advantages (GSPO path). KL metrics remain arithmetic. Docstrings and help text updated accordingly.
  3. Per-token vs uniform behave_imp_weight at sequence level (functional.py:248)
    When level="sequence" and metric="ratio", behave_imp_weight is now the sequence-level geometric mean broadcast uniformly to all tokens — restoring the old sequence_mask/sequence_truncate semantics. For KL metrics, weight stays per-token (it's always the importance ratio, not the KL value). Added 3 new test cases validating both behaviors.
  4. Warn when decoupled loss enabled but no rejection sampling (cli_args.py:1379)
    Added symmetric logger.warning() when use_decoupled_loss=True and rejection_sampling is None, with migration config suggestion.

Ready for re-review.

Please solve the conflicts first.

Replace behave_imp_weight_cap/behave_imp_weight_mode with unified
RejectionSamplingConfig supporting multiple metrics (ratio, kl_k1,
kl_k2, kl_k3), levels (token/sequence), and actions (mask/clamp).

Key changes:
- Add RejectionSamplingConfig dataclass with comprehensive validation
- Implement apply_rejection_sampling for 1D packed and 2D padded formats
- Fix loss denominator scaling bug in mask mode (save count before filtering)
- Use geometric mean for sequence-level ratio aggregation (matching GSPO)
- Broadcast sequence-level geometric mean as uniform behave_imp_weight
- Warn when use_decoupled_loss=True but rejection_sampling is None
- Update ppo_actor_loss_fn and grpo_loss_fn to use new config
- Migrate 40 example configs to new rejection_sampling field
- Add 43 unit tests covering all modes, metrics, and edge cases

Refs: inclusionAI#1052
@guozhihao-224
Copy link
Copy Markdown
Contributor Author

Ready for re-review.

Please solve the conflicts first.

@rchardx Code conflict has been resolved

@guozhihao-224 guozhihao-224 requested a review from rchardx April 15, 2026 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Sequence/token level rejection sampling on async training

2 participants