feat(api): add unified RejectionSamplingConfig for async training#1088
feat(api): add unified RejectionSamplingConfig for async training#1088guozhihao-224 wants to merge 2 commits intoinclusionAI:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the sample filtering mechanism within asynchronous Reinforcement Learning (RL) training. It transitions from a less flexible importance weight capping system to a new, unified Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the sample filtering mechanism by introducing a new RejectionSamplingConfig dataclass, replacing the older behave_imp_weight_cap and behave_imp_weight_mode parameters. This new configuration unifies token-level and sequence-level filtering with options for masking or clamping importance weights based on various divergence metrics (ratio, KL estimators) and aggregation methods. The core logic in areal/utils/functional/functional.py is updated with a new apply_rejection_sampling function, and PPOActorConfig is modified to use this new structure. Additionally, migration logic for legacy configurations, logging, documentation, and example YAML files are updated. Review feedback suggests moving the import warnings statement to the top of areal/api/cli_args.py for PEP 8 compliance, refactoring validation logic in RejectionSamplingConfig.__post_init__ for better maintainability, and optimizing apply_rejection_sampling by avoiding redundant computations of loss_mask sums for improved efficiency.
|
|
||
| def __post_init__(self): | ||
| """Validate configuration.""" | ||
| import warnings |
There was a problem hiding this comment.
| _VALID_LEVELS = ("token", "sequence") | ||
| _VALID_ACTIONS = ("mask", "clamp") | ||
| _VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3") | ||
| _VALID_AGGS = ("sum", "mean", "max") | ||
|
|
||
| # Validate enum-like fields. | ||
| if self.level not in _VALID_LEVELS: | ||
| raise ValueError( | ||
| f"level must be one of {_VALID_LEVELS}, got '{self.level}'" | ||
| ) | ||
| if self.action not in _VALID_ACTIONS: | ||
| raise ValueError( | ||
| f"action must be one of {_VALID_ACTIONS}, got '{self.action}'" | ||
| ) | ||
| if self.metric not in _VALID_METRICS: | ||
| raise ValueError( | ||
| f"metric must be one of {_VALID_METRICS}, got '{self.metric}'" | ||
| ) | ||
| if self.agg not in _VALID_AGGS: | ||
| raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'") |
There was a problem hiding this comment.
The validation constants are defined inside __post_init__ and the checks are repetitive. For better maintainability, you can refactor the validation logic into a loop. Also, consider moving the _VALID_* constants to be class-level attributes to follow Python conventions for constants.
| _VALID_LEVELS = ("token", "sequence") | |
| _VALID_ACTIONS = ("mask", "clamp") | |
| _VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3") | |
| _VALID_AGGS = ("sum", "mean", "max") | |
| # Validate enum-like fields. | |
| if self.level not in _VALID_LEVELS: | |
| raise ValueError( | |
| f"level must be one of {_VALID_LEVELS}, got '{self.level}'" | |
| ) | |
| if self.action not in _VALID_ACTIONS: | |
| raise ValueError( | |
| f"action must be one of {_VALID_ACTIONS}, got '{self.action}'" | |
| ) | |
| if self.metric not in _VALID_METRICS: | |
| raise ValueError( | |
| f"metric must be one of {_VALID_METRICS}, got '{self.metric}'" | |
| ) | |
| if self.agg not in _VALID_AGGS: | |
| raise ValueError(f"agg must be one of {_VALID_AGGS}, got '{self.agg}'") | |
| _VALID_LEVELS = ("token", "sequence") | |
| _VALID_ACTIONS = ("mask", "clamp") | |
| _VALID_METRICS = ("ratio", "kl_k1", "kl_k2", "kl_k3") | |
| _VALID_AGGS = ("sum", "mean", "max") | |
| # Validate enum-like fields. | |
| validations = { | |
| "level": _VALID_LEVELS, | |
| "action": _VALID_ACTIONS, | |
| "metric": _VALID_METRICS, | |
| "agg": _VALID_AGGS, | |
| } | |
| for field_name, valid_values in validations.items(): | |
| value = getattr(self, field_name) | |
| if value not in valid_values: | |
| raise ValueError( | |
| f"{field_name} must be one of {valid_values}, got '{value}'" | |
| ) |
| raw_valid = torch.zeros( | ||
| batch_size, device=loss_mask.device, dtype=torch.int32 | ||
| ).scatter_add_(0, sequence_idx, loss_mask.int()) | ||
| all_masked = raw_valid == 0 |
There was a problem hiding this comment.
To improve efficiency, you can avoid re-computing the sum of loss_mask per sequence. The scatter_add_ operation is performed here to get raw_valid, but it was also done earlier (lines 268-271) to compute valid_count_per_seq.
Consider computing the raw counts once before the aggregation logic, and then reuse it to derive both the clamped counts for mean aggregation and for the all_masked check in max aggregation.
| seq_metric = metric_for_max.max(dim=-1, keepdim=True)[0] | ||
| # All-masked sequences stay -inf; treat them as in-bounds (no valid | ||
| # tokens to filter, and their loss_mask is already all-zero). | ||
| all_masked = loss_mask.sum(dim=-1, keepdim=True) == 0 |
rchardx
left a comment
There was a problem hiding this comment.
Review comment on loss denominator scaling.
rchardx
left a comment
There was a problem hiding this comment.
Discussion on sequence-level ratio aggregation semantics.
rchardx
left a comment
There was a problem hiding this comment.
Discussion on sequence-level behave_imp_weight semantics.
rchardx
left a comment
There was a problem hiding this comment.
Suggestion on default behavior change for decoupled PPO users.
25fc9b2 to
3f2eb34
Compare
|
@rchardx Thanks for the thorough review. All four issues have been addressed in the latest force-push (squashed into a single commit):
Ready for re-review. |
Please solve the conflicts first. |
Replace behave_imp_weight_cap/behave_imp_weight_mode with unified RejectionSamplingConfig supporting multiple metrics (ratio, kl_k1, kl_k2, kl_k3), levels (token/sequence), and actions (mask/clamp). Key changes: - Add RejectionSamplingConfig dataclass with comprehensive validation - Implement apply_rejection_sampling for 1D packed and 2D padded formats - Fix loss denominator scaling bug in mask mode (save count before filtering) - Use geometric mean for sequence-level ratio aggregation (matching GSPO) - Broadcast sequence-level geometric mean as uniform behave_imp_weight - Warn when use_decoupled_loss=True but rejection_sampling is None - Update ppo_actor_loss_fn and grpo_loss_fn to use new config - Migrate 40 example configs to new rejection_sampling field - Add 43 unit tests covering all modes, metrics, and edge cases Refs: inclusionAI#1052
3f2eb34 to
31fb334
Compare
@rchardx Code conflict has been resolved |
Description
Replace
behave_imp_weight_cap/behave_imp_weight_modewith unifiedRejectionSamplingConfigsupporting multiple metrics (ratio, kl_k1, kl_k2, kl_k3), levels(token/sequence), and actions (mask/clamp).
This provides a unified and principled approach for filtering stale samples in async RL training, addressing staleness issues caused by version gaps between behavior
policy and current policy.
Key changes:
RejectionSamplingConfigdataclass with comprehensive validationapply_rejection_samplingsupporting both 1D packed and 2D padded formatsppo_actor_loss_fnandgrpo_loss_fnto use new configrejection_samplingfieldRecommended config for async training:
Related Issue
Fixes #1052
Type of Change
Checklist
Breaking Change Details (if applicable):
Users need to migrate from old config:
Additional Context
Test Results:
Files changed: