Skip to content

Add CISPO and SAPO loss type support for Triton GRPO loss kernel#1074

Open
yukiu00 wants to merge 7 commits intolinkedin:mainfrom
yukiu00:feat/ops-grpo-cispo-sapo
Open

Add CISPO and SAPO loss type support for Triton GRPO loss kernel#1074
yukiu00 wants to merge 7 commits intolinkedin:mainfrom
yukiu00:feat/ops-grpo-cispo-sapo

Conversation

@yukiu00
Copy link
Contributor

@yukiu00 yukiu00 commented Feb 5, 2026

Summary

Add CISPO and SAPO loss type support to the Triton ops/grpo_loss.py kernel.

This is a follow-up to:

Note: This PR depends on #1073 (SAPO PR) being merged first, as it builds on top of that branch.

Background

PR #1054 and #1073 added CISPO and SAPO support to the chunked_loss path, but the ops (Triton kernel) path was marked as a follow-up. This PR implements that follow-up.

Changes

src/liger_kernel/ops/grpo_loss.py

  • Add loss type constants (_LOSS_TYPE_GRPO, _LOSS_TYPE_CISPO, _LOSS_TYPE_SAPO) with tl.constexpr for compile-time branching
  • Implement CISPO in forward/backward kernels:
    • Upper-bound only clipping (no lower bound)
    • Detached coefficient (gradient only flows through logp)
    • Loss formula: -coef_2 * advantage * logp
  • Implement SAPO in forward/backward kernels:
    • Sigmoid-based soft gating instead of hard clipping
    • Different temperatures for positive/negative advantages
    • Loss formula: -sigmoid(τ*(ρ-1)) * 4/τ * advantage
  • Update GrpoLossFunction with loss_type, sapo_temperature_pos, sapo_temperature_neg parameters

src/liger_kernel/transformers/grpo_loss.py

  • Remove error blocking for CISPO and SAPO loss types
  • Add sapo_temperature_pos and sapo_temperature_neg parameters
  • Update _reduce_grpo_loss to handle CISPO (DAPO normalization) and SAPO (GRPO normalization)

test/transformers/test_grpo_loss.py

  • Add reference PyTorch implementations (torch_cispo_loss, torch_sapo_loss)
  • Add test_cispo_loss and test_sapo_loss test functions

Testing Done

  • Hardware Type: NVIDIA GPU
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

SAPO (Soft Adaptive Policy Optimization) replaces hard clipping with a
smooth, temperature-controlled gate that adaptively attenuates off-policy
updates while preserving useful learning signals.

Changes:
- Add sapo_loss_fn helper function using sigmoid-based soft gating
- Add sapo_temperature_pos and sapo_temperature_neg parameters
- Support different temperatures for positive/negative advantages
- Use GRPO-style normalization (per-sequence) for SAPO
- Block SAPO in Triton path (chunked loss path only)
- Add comprehensive tests for SAPO loss type

Reference: https://huggingface.co/papers/2511.20347
- Add loss type constants (_LOSS_TYPE_GRPO, _LOSS_TYPE_CISPO, _LOSS_TYPE_SAPO)
  with constexpr branching for compile-time optimization
- Implement CISPO loss in forward/backward kernels:
  - Upper-bound only clipping (no lower bound)
  - Detached coefficient (gradient only flows through logp)
  - Loss formula: -coef_2 * advantage * logp
- Implement SAPO loss in forward/backward kernels:
  - Sigmoid-based soft gating instead of hard clipping
  - Different temperatures for positive/negative advantages
  - Loss formula: -sigmoid(τ*(ρ-1)) * 4/τ * advantage
- Update GrpoLossFunction with loss_type and sapo_temperature parameters
- Remove error blocking for CISPO/SAPO in transformers/grpo_loss.py
- Add CISPO/SAPO reduction logic (CISPO uses DAPO norm, SAPO uses GRPO norm)
- Add comprehensive tests with reference PyTorch implementations

This is a follow-up to PR linkedin#1054 (CISPO) and PR linkedin#1073 (SAPO) which added
chunked_loss support. Now the Triton ops path also supports these loss types.
- Add validation for unknown loss_type in GrpoLossFunction.forward
- Add validation for SAPO temperature parameters (must be positive)
- Improve SAPO gradient computation comments with detailed derivation
- Remove unused atol/rtol parameters from test_cispo_loss and test_sapo_loss
yukiu00 and others added 2 commits February 5, 2026 23:21
Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Feb 7, 2026

Please resolve conflict, thanks.

@yukiu00
Copy link
Contributor Author

yukiu00 commented Feb 7, 2026

@Tcc0403

I've resolved the conflicts! Could you review this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants