Add CISPO and SAPO loss type support for Triton GRPO loss kernel#1074
Open
yukiu00 wants to merge 7 commits intolinkedin:mainfrom
Open
Add CISPO and SAPO loss type support for Triton GRPO loss kernel#1074yukiu00 wants to merge 7 commits intolinkedin:mainfrom
yukiu00 wants to merge 7 commits intolinkedin:mainfrom
Conversation
SAPO (Soft Adaptive Policy Optimization) replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful learning signals. Changes: - Add sapo_loss_fn helper function using sigmoid-based soft gating - Add sapo_temperature_pos and sapo_temperature_neg parameters - Support different temperatures for positive/negative advantages - Use GRPO-style normalization (per-sequence) for SAPO - Block SAPO in Triton path (chunked loss path only) - Add comprehensive tests for SAPO loss type Reference: https://huggingface.co/papers/2511.20347
- Add loss type constants (_LOSS_TYPE_GRPO, _LOSS_TYPE_CISPO, _LOSS_TYPE_SAPO) with constexpr branching for compile-time optimization - Implement CISPO loss in forward/backward kernels: - Upper-bound only clipping (no lower bound) - Detached coefficient (gradient only flows through logp) - Loss formula: -coef_2 * advantage * logp - Implement SAPO loss in forward/backward kernels: - Sigmoid-based soft gating instead of hard clipping - Different temperatures for positive/negative advantages - Loss formula: -sigmoid(τ*(ρ-1)) * 4/τ * advantage - Update GrpoLossFunction with loss_type and sapo_temperature parameters - Remove error blocking for CISPO/SAPO in transformers/grpo_loss.py - Add CISPO/SAPO reduction logic (CISPO uses DAPO norm, SAPO uses GRPO norm) - Add comprehensive tests with reference PyTorch implementations This is a follow-up to PR linkedin#1054 (CISPO) and PR linkedin#1073 (SAPO) which added chunked_loss support. Now the Triton ops path also supports these loss types.
- Add validation for unknown loss_type in GrpoLossFunction.forward - Add validation for SAPO temperature parameters (must be positive) - Improve SAPO gradient computation comments with detailed derivation - Remove unused atol/rtol parameters from test_cispo_loss and test_sapo_loss
3 tasks
Collaborator
|
Please resolve conflict, thanks. |
Contributor
Author
|
I've resolved the conflicts! Could you review this PR? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add CISPO and SAPO loss type support to the Triton
ops/grpo_loss.pykernel.This is a follow-up to:
Background
PR #1054 and #1073 added CISPO and SAPO support to the
chunked_losspath, but theops(Triton kernel) path was marked as a follow-up. This PR implements that follow-up.Changes
src/liger_kernel/ops/grpo_loss.py_LOSS_TYPE_GRPO,_LOSS_TYPE_CISPO,_LOSS_TYPE_SAPO) withtl.constexprfor compile-time branching-coef_2 * advantage * logp-sigmoid(τ*(ρ-1)) * 4/τ * advantageGrpoLossFunctionwithloss_type,sapo_temperature_pos,sapo_temperature_negparameterssrc/liger_kernel/transformers/grpo_loss.pysapo_temperature_posandsapo_temperature_negparameters_reduce_grpo_lossto handle CISPO (DAPO normalization) and SAPO (GRPO normalization)test/transformers/test_grpo_loss.pytorch_cispo_loss,torch_sapo_loss)test_cispo_lossandtest_sapo_losstest functionsTesting Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence